48 lines
1.6 KiB
Python
48 lines
1.6 KiB
Python
"""
|
|
This is just a utility that I use to extract the projector for quantized models.
|
|
It is NOT necessary at all to train, or run inference/serve demos.
|
|
Use this script ONLY if you fully understand its implications.
|
|
"""
|
|
|
|
|
|
import os
|
|
import argparse
|
|
import torch
|
|
import json
|
|
from collections import defaultdict
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Extract MMProjector weights')
|
|
parser.add_argument('--model-path', type=str, help='model folder')
|
|
parser.add_argument('--output', type=str, help='output file')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
|
|
keys_to_match = ['mm_projector']
|
|
ckpt_to_key = defaultdict(list)
|
|
try:
|
|
model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json')))
|
|
for k, v in model_indices['weight_map'].items():
|
|
if any(key_match in k for key_match in keys_to_match):
|
|
ckpt_to_key[v].append(k)
|
|
except FileNotFoundError:
|
|
# Smaller models or model checkpoints saved by DeepSpeed.
|
|
v = 'pytorch_model.bin'
|
|
for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys():
|
|
if any(key_match in k for key_match in keys_to_match):
|
|
ckpt_to_key[v].append(k)
|
|
|
|
loaded_weights = {}
|
|
|
|
for ckpt_name, weight_keys in ckpt_to_key.items():
|
|
ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu')
|
|
for k in weight_keys:
|
|
loaded_weights[k] = ckpt[k]
|
|
|
|
torch.save(loaded_weights, args.output)
|