mirror of
https://github.com/Lizonghang/TPI-LLM.git
synced 2024-10-04 22:25:47 +03:00
create model before input prompt
This commit is contained in:
@@ -107,6 +107,9 @@ def main(my_rank, args, dist=None):
|
||||
except KeyError:
|
||||
raise KeyError(f"Unsupported model type: {args.model_type}")
|
||||
|
||||
# load model
|
||||
model = model_class.from_pretrained(args.model_path, comm, rank=my_rank, args=args)
|
||||
|
||||
# the master node initializes tokenizer and encodes user prompt
|
||||
tokenizer, streamer = None, None
|
||||
input_ids = ""
|
||||
@@ -125,9 +128,6 @@ def main(my_rank, args, dist=None):
|
||||
return_tensors="pt"
|
||||
).to(args.device)
|
||||
|
||||
# load model
|
||||
model = model_class.from_pretrained(args.model_path, comm, rank=my_rank, args=args)
|
||||
|
||||
# generate output with streaming output
|
||||
model.generate(
|
||||
input_ids=input_ids,
|
||||
|
||||
Reference in New Issue
Block a user