Files
TPI-LLM/examples/run_multiprocess.py
2024-09-17 18:04:07 +04:00

55 lines
2.5 KiB
Python

import argparse
import torch.multiprocessing as mp
from run_llama import main
def init_process(rank, fn, args):
dist = None
if args.torch_dist:
import os
import torch.distributed as dist
os.environ["MASTER_ADDR"] = args.master_ip
os.environ["MASTER_PORT"] = str(args.master_port)
dist.init_process_group("gloo", "env://", rank=rank, world_size=args.world_size)
fn(rank, args, dist=dist)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# necessary arguments
parser.add_argument("--model_type", default=None, type=str, required=True)
parser.add_argument("--model_path", default=None, type=str, required=True)
parser.add_argument("--world_size", default=None, type=int, required=True)
parser.add_argument("--master_ip", type=str, default="127.0.0.1")
parser.add_argument("--master_port", type=int, default=29500, help="Communication port.")
# for weight file synchronization
parser.add_argument("--file_port", type=int, default=29600, help="File server port.")
parser.add_argument("--force_download", action="store_true", help="Force download sliced model files.")
# for llm inference
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--length", type=int, default=20)
parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
parser.add_argument("--use_gpu", action="store_true", help="Whether to use gpu, default to use cpu.")
parser.add_argument("--torch_dist", action="store_true", help="Whether to use torch distributed.")
parser.add_argument("--split_bin", action="store_true", help="Whether to split the model file.")
parser.add_argument("--save_dir", type=str, default="split", help="Directory to save split models.")
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--k", type=int, default=0)
parser.add_argument("--p", type=float, default=0.9)
# for memory schedule
parser.add_argument("--disable_memory_schedule", action="store_true")
parser.add_argument("--memory_window", type=int, default=2,
help="Memory window size, should be at least 2.")
args = parser.parse_args()
processes = []
mp.set_start_method("spawn")
for rank in range(args.world_size):
p = mp.Process(target=init_process, args=(rank, main, args))
p.start()
processes.append(p)
for p in processes:
p.join()