undo diff

This commit is contained in:
Rory Clear
2024-11-27 16:42:19 +00:00
parent 1d1fa8c608
commit 3c81845ab7

View File

@@ -51,6 +51,7 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
weights = load(str(model_path), shard)
weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
weights = fix_bf16(weights)
with Context(BEAM=0):
# replace weights in model
load_state_dict(model, weights, strict=False, consume=False) # consume=True