create model before input prompt

This commit is contained in:
Lizonghang
2024-09-20 22:13:15 +04:00
parent ea234735a5
commit 7e3624538c

View File

@@ -107,6 +107,9 @@ def main(my_rank, args, dist=None):
except KeyError:
raise KeyError(f"Unsupported model type: {args.model_type}")
# load model
model = model_class.from_pretrained(args.model_path, comm, rank=my_rank, args=args)
# the master node initializes tokenizer and encodes user prompt
tokenizer, streamer = None, None
input_ids = ""
@@ -125,9 +128,6 @@ def main(my_rank, args, dist=None):
return_tensors="pt"
).to(args.device)
# load model
model = model_class.from_pretrained(args.model_path, comm, rank=my_rank, args=args)
# generate output with streaming output
model.generate(
input_ids=input_ids,