Files
exo/example_user.py

76 lines
2.5 KiB
Python

# In this example, a user is running a home cluster with 3 shards.
# They are prompting the cluster to generate a response to a question.
# The cluster is given the question, and the user is given the response.
from inference.mlx.sharded_utils import get_model_path, load_tokenizer
from inference.shard import Shard
from networking.peer_handle import PeerHandle
from networking.grpc.grpc_peer_handle import GRPCPeerHandle
from typing import List
import asyncio
import argparse
path_or_hf_repo = "mlx-community/Meta-Llama-3-8B-Instruct-4bit"
model_path = get_model_path(path_or_hf_repo)
tokenizer_config = {}
tokenizer = load_tokenizer(model_path, tokenizer_config)
peers: List[PeerHandle] = [
GRPCPeerHandle(
"node1",
"localhost:8080",
),
GRPCPeerHandle(
"node2",
"localhost:8081",
)
]
shards: List[Shard] = [
# Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=15, n_layers=32),
# Shard(model_id=path_or_hf_repo, start_layer=16, end_layer=31, n_layers=32),
Shard(model_id=path_or_hf_repo, start_layer=0, end_layer=30, n_layers=32),
Shard(model_id=path_or_hf_repo, start_layer=31, end_layer=31, n_layers=32),
]
async def run_prompt(prompt: str):
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
if (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for peer, shard in zip(peers, shards):
await peer.connect()
await peer.reset_shard(shard)
tokens = []
last_output = prompt
for _ in range(20):
for peer, shard in zip(peers, shards):
if isinstance(last_output, str):
last_output = await peer.send_prompt(shard, last_output)
print("prompt output:", last_output)
else:
last_output = await peer.send_tensor(shard, last_output)
print("tensor output:", last_output)
if not last_output:
break
tokens.append(last_output.item())
print(tokenizer.decode(tokens))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run prompt")
parser.add_argument("--prompt", type=str, help="The prompt to run")
args = parser.parse_args()
asyncio.run(run_prompt(args.prompt))