mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
undo diff
This commit is contained in:
@@ -51,6 +51,7 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
|
||||
weights = load(str(model_path), shard)
|
||||
weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
|
||||
weights = fix_bf16(weights)
|
||||
|
||||
with Context(BEAM=0):
|
||||
# replace weights in model
|
||||
load_state_dict(model, weights, strict=False, consume=False) # consume=True
|
||||
|
||||
Reference in New Issue
Block a user