mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
Formatting
This commit is contained in:
@@ -34,6 +34,7 @@ import shutil
|
||||
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
|
||||
from exo.apputil import create_animation_mp4
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
|
||||
self.role = role
|
||||
@@ -47,7 +48,6 @@ class Message:
|
||||
return data
|
||||
|
||||
|
||||
|
||||
class ChatCompletionRequest:
|
||||
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
|
||||
self.model = model
|
||||
@@ -138,11 +138,7 @@ def remap_messages(messages: List[Message]) -> List[Message]:
|
||||
|
||||
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
|
||||
messages = remap_messages(_messages)
|
||||
chat_template_args = {
|
||||
"conversation": [m.to_dict() for m in messages],
|
||||
"tokenize": False,
|
||||
"add_generation_prompt": True
|
||||
}
|
||||
chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
|
||||
if tools: chat_template_args["tools"] = tools
|
||||
|
||||
prompt = tokenizer.apply_chat_template(**chat_template_args)
|
||||
@@ -171,8 +167,17 @@ class PromptSession:
|
||||
self.timestamp = timestamp
|
||||
self.prompt = prompt
|
||||
|
||||
|
||||
class ChatGPTAPI:
|
||||
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
node: Node,
|
||||
inference_engine_classname: str,
|
||||
response_timeout: int = 90,
|
||||
on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None,
|
||||
default_model: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None
|
||||
):
|
||||
self.node = node
|
||||
self.inference_engine_classname = inference_engine_classname
|
||||
self.response_timeout = response_timeout
|
||||
@@ -208,7 +213,6 @@ class ChatGPTAPI:
|
||||
cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
|
||||
|
||||
|
||||
if "__compiled__" not in globals():
|
||||
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
||||
self.app.router.add_get("/", self.handle_root)
|
||||
@@ -219,7 +223,7 @@ class ChatGPTAPI:
|
||||
self.app.middlewares.append(self.log_request)
|
||||
|
||||
async def handle_quit(self, request):
|
||||
if DEBUG>=1: print("Received quit signal")
|
||||
if DEBUG >= 1: print("Received quit signal")
|
||||
response = web.json_response({"detail": "Quit signal received"}, status=200)
|
||||
await response.prepare(request)
|
||||
await response.write_eof()
|
||||
@@ -249,61 +253,48 @@ class ChatGPTAPI:
|
||||
|
||||
async def handle_model_support(self, request):
|
||||
try:
|
||||
response = web.StreamResponse(
|
||||
status=200,
|
||||
reason='OK',
|
||||
headers={
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive',
|
||||
}
|
||||
)
|
||||
await response.prepare(request)
|
||||
response = web.StreamResponse(status=200, reason='OK', headers={
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive',
|
||||
})
|
||||
await response.prepare(request)
|
||||
|
||||
async def process_model(model_name, pretty):
|
||||
if model_name in model_cards:
|
||||
model_info = model_cards[model_name]
|
||||
async def process_model(model_name, pretty):
|
||||
if model_name in model_cards:
|
||||
model_info = model_cards[model_name]
|
||||
|
||||
if self.inference_engine_classname in model_info.get("repo", {}):
|
||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||
if shard:
|
||||
downloader = HFShardDownloader(quick_check=True)
|
||||
downloader.current_shard = shard
|
||||
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
||||
status = await downloader.get_shard_download_status()
|
||||
if self.inference_engine_classname in model_info.get("repo", {}):
|
||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||
if shard:
|
||||
downloader = HFShardDownloader(quick_check=True)
|
||||
downloader.current_shard = shard
|
||||
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
||||
status = await downloader.get_shard_download_status()
|
||||
|
||||
download_percentage = status.get("overall") if status else None
|
||||
total_size = status.get("total_size") if status else None
|
||||
total_downloaded = status.get("total_downloaded") if status else False
|
||||
download_percentage = status.get("overall") if status else None
|
||||
total_size = status.get("total_size") if status else None
|
||||
total_downloaded = status.get("total_downloaded") if status else False
|
||||
|
||||
model_data = {
|
||||
model_name: {
|
||||
"name": pretty,
|
||||
"downloaded": download_percentage == 100 if download_percentage is not None else False,
|
||||
"download_percentage": download_percentage,
|
||||
"total_size": total_size,
|
||||
"total_downloaded": total_downloaded
|
||||
}
|
||||
}
|
||||
model_data = {
|
||||
model_name: {
|
||||
"name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size,
|
||||
"total_downloaded": total_downloaded
|
||||
}
|
||||
}
|
||||
|
||||
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
||||
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
||||
|
||||
# Process all models in parallel
|
||||
await asyncio.gather(*[
|
||||
process_model(model_name, pretty)
|
||||
for model_name, pretty in pretty_name.items()
|
||||
])
|
||||
# Process all models in parallel
|
||||
await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()])
|
||||
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
return response
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in handle_model_support: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return web.json_response(
|
||||
{"detail": f"Server error: {str(e)}"},
|
||||
status=500
|
||||
)
|
||||
print(f"Error in handle_model_support: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
|
||||
|
||||
async def handle_get_models(self, request):
|
||||
models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
|
||||
@@ -472,7 +463,6 @@ class ChatGPTAPI:
|
||||
deregistered_callback = self.node.on_token.deregister(callback_id)
|
||||
if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
|
||||
|
||||
|
||||
async def handle_post_image_generations(self, request):
|
||||
data = await request.json()
|
||||
|
||||
@@ -485,7 +475,7 @@ class ChatGPTAPI:
|
||||
shard = build_base_shard(model, self.inference_engine_classname)
|
||||
if DEBUG >= 2: print(f"shard: {shard}")
|
||||
if not shard:
|
||||
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
|
||||
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
callback_id = f"chatgpt-api-wait-response-{request_id}"
|
||||
@@ -497,77 +487,73 @@ class ChatGPTAPI:
|
||||
img = None
|
||||
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
|
||||
|
||||
|
||||
response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
|
||||
response = web.StreamResponse(status=200, reason='OK', headers={
|
||||
'Content-Type': 'application/octet-stream',
|
||||
"Cache-Control": "no-cache",
|
||||
})
|
||||
await response.prepare(request)
|
||||
|
||||
def get_progress_bar(current_step, total_steps, bar_length=50):
|
||||
# Calculate the percentage of completion
|
||||
percent = float(current_step) / total_steps
|
||||
percent = float(current_step)/total_steps
|
||||
# Calculate the number of hashes to display
|
||||
arrow = '-' * int(round(percent * bar_length) - 1) + '>'
|
||||
spaces = ' ' * (bar_length - len(arrow))
|
||||
|
||||
arrow = '-'*int(round(percent*bar_length) - 1) + '>'
|
||||
spaces = ' '*(bar_length - len(arrow))
|
||||
|
||||
# Create the progress bar string
|
||||
progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
|
||||
return progress_bar
|
||||
|
||||
async def stream_image(_request_id: str, result, is_finished: bool):
|
||||
if isinstance(result, list):
|
||||
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
|
||||
if isinstance(result, list):
|
||||
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
|
||||
|
||||
elif isinstance(result, np.ndarray):
|
||||
im = Image.fromarray(np.array(result))
|
||||
images_folder = get_exo_images_dir()
|
||||
# Save the image to a file
|
||||
image_filename = f"{_request_id}.png"
|
||||
image_path = images_folder / image_filename
|
||||
im.save(image_path)
|
||||
image_url = request.app.router['static_images'].url_for(filename=image_filename)
|
||||
base_url = f"{request.scheme}://{request.host}"
|
||||
# Construct the full URL correctly
|
||||
full_image_url = base_url + str(image_url)
|
||||
|
||||
await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
|
||||
if is_finished:
|
||||
await response.write_eof()
|
||||
|
||||
elif isinstance(result, np.ndarray):
|
||||
im = Image.fromarray(np.array(result))
|
||||
images_folder = get_exo_images_dir()
|
||||
# Save the image to a file
|
||||
image_filename = f"{_request_id}.png"
|
||||
image_path = images_folder/image_filename
|
||||
im.save(image_path)
|
||||
image_url = request.app.router['static_images'].url_for(filename=image_filename)
|
||||
base_url = f"{request.scheme}://{request.host}"
|
||||
# Construct the full URL correctly
|
||||
full_image_url = base_url + str(image_url)
|
||||
|
||||
await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
|
||||
if is_finished:
|
||||
await response.write_eof()
|
||||
|
||||
stream_task = None
|
||||
|
||||
def on_result(_request_id: str, result, is_finished: bool):
|
||||
nonlocal stream_task
|
||||
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
|
||||
return _request_id == request_id and is_finished
|
||||
nonlocal stream_task
|
||||
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
|
||||
return _request_id == request_id and is_finished
|
||||
|
||||
await callback.wait(on_result, timeout=self.response_timeout*10)
|
||||
|
||||
|
||||
if stream_task:
|
||||
# Wait for the stream task to complete before returning
|
||||
await stream_task
|
||||
# Wait for the stream task to complete before returning
|
||||
await stream_task
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
||||
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
||||
|
||||
async def handle_delete_model(self, request):
|
||||
try:
|
||||
model_name = request.match_info.get('model_name')
|
||||
if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
|
||||
|
||||
if not model_name or model_name not in model_cards:
|
||||
return web.json_response(
|
||||
{"detail": f"Invalid model name: {model_name}"},
|
||||
status=400
|
||||
)
|
||||
return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400)
|
||||
|
||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||
if not shard:
|
||||
return web.json_response(
|
||||
{"detail": "Could not build shard for model"},
|
||||
status=400
|
||||
)
|
||||
return web.json_response({"detail": "Could not build shard for model"}, status=400)
|
||||
|
||||
repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
||||
if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
|
||||
@@ -582,38 +568,28 @@ class ChatGPTAPI:
|
||||
if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
|
||||
try:
|
||||
shutil.rmtree(cache_dir)
|
||||
return web.json_response({
|
||||
"status": "success",
|
||||
"message": f"Model {model_name} deleted successfully",
|
||||
"path": str(cache_dir)
|
||||
})
|
||||
return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)})
|
||||
except Exception as e:
|
||||
return web.json_response({
|
||||
"detail": f"Failed to delete model files: {str(e)}"
|
||||
}, status=500)
|
||||
return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500)
|
||||
else:
|
||||
return web.json_response({
|
||||
"detail": f"Model files not found at {cache_dir}"
|
||||
}, status=404)
|
||||
return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in handle_delete_model: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return web.json_response({
|
||||
"detail": f"Server error: {str(e)}"
|
||||
}, status=500)
|
||||
print(f"Error in handle_delete_model: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
|
||||
|
||||
async def handle_get_initial_models(self, request):
|
||||
model_data = {}
|
||||
for model_name, pretty in pretty_name.items():
|
||||
model_data[model_name] = {
|
||||
"name": pretty,
|
||||
"downloaded": None, # Initially unknown
|
||||
"download_percentage": None, # Change from 0 to null
|
||||
"total_size": None,
|
||||
"total_downloaded": None,
|
||||
"loading": True # Add loading state
|
||||
}
|
||||
model_data[model_name] = {
|
||||
"name": pretty,
|
||||
"downloaded": None, # Initially unknown
|
||||
"download_percentage": None, # Change from 0 to null
|
||||
"total_size": None,
|
||||
"total_downloaded": None,
|
||||
"loading": True # Add loading state
|
||||
}
|
||||
return web.json_response(model_data)
|
||||
|
||||
async def handle_create_animation(self, request):
|
||||
@@ -639,17 +615,9 @@ class ChatGPTAPI:
|
||||
if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
|
||||
|
||||
# Create the animation
|
||||
create_animation_mp4(
|
||||
replacement_image_path,
|
||||
output_path,
|
||||
device_name,
|
||||
prompt_text
|
||||
)
|
||||
create_animation_mp4(replacement_image_path, output_path, device_name, prompt_text)
|
||||
|
||||
return web.json_response({
|
||||
"status": "success",
|
||||
"output_path": output_path
|
||||
})
|
||||
return web.json_response({"status": "success", "output_path": output_path})
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
@@ -665,10 +633,7 @@ class ChatGPTAPI:
|
||||
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
||||
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
|
||||
|
||||
return web.json_response({
|
||||
"status": "success",
|
||||
"message": f"Download started for model: {model_name}"
|
||||
})
|
||||
return web.json_response({"status": "success", "message": f"Download started for model: {model_name}"})
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
@@ -682,10 +647,7 @@ class ChatGPTAPI:
|
||||
return web.json_response({})
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response(
|
||||
{"detail": f"Error getting topology: {str(e)}"},
|
||||
status=500
|
||||
)
|
||||
return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
|
||||
|
||||
async def run(self, host: str = "0.0.0.0", port: int = 52415):
|
||||
runner = web.AppRunner(self.app)
|
||||
@@ -696,15 +658,14 @@ class ChatGPTAPI:
|
||||
def base64_decode(self, base64_string):
|
||||
#decode and reshape image
|
||||
if base64_string.startswith('data:image'):
|
||||
base64_string = base64_string.split(',')[1]
|
||||
base64_string = base64_string.split(',')[1]
|
||||
image_data = base64.b64decode(base64_string)
|
||||
img = Image.open(BytesIO(image_data))
|
||||
W, H = (dim - dim % 64 for dim in (img.width, img.height))
|
||||
W, H = (dim - dim%64 for dim in (img.width, img.height))
|
||||
if W != img.width or H != img.height:
|
||||
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
|
||||
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
|
||||
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
|
||||
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
|
||||
img = mx.array(np.array(img))
|
||||
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
|
||||
img = (img[:, :, :3].astype(mx.float32)/255)*2 - 1
|
||||
img = img[None]
|
||||
return img
|
||||
|
||||
|
||||
@@ -245,15 +245,13 @@ def get_all_ip_addresses_and_interfaces():
|
||||
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
|
||||
return [("localhost", "lo")]
|
||||
|
||||
|
||||
async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
|
||||
try:
|
||||
# Use the shared subprocess_pool
|
||||
output = await asyncio.get_running_loop().run_in_executor(subprocess_pool, lambda: subprocess.run(
|
||||
['system_profiler', 'SPNetworkDataType', '-json'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
close_fds=True
|
||||
).stdout)
|
||||
output = await asyncio.get_running_loop().run_in_executor(
|
||||
subprocess_pool, lambda: subprocess.run(['system_profiler', 'SPNetworkDataType', '-json'], capture_output=True, text=True, close_fds=True).stdout
|
||||
)
|
||||
|
||||
data = json.loads(output)
|
||||
|
||||
@@ -279,6 +277,7 @@ async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
|
||||
# On macOS, try to get interface type using networksetup
|
||||
if psutil.MACOS:
|
||||
@@ -286,8 +285,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
|
||||
if macos_type is not None: return macos_type
|
||||
|
||||
# Local container/virtual interfaces
|
||||
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
|
||||
'bridge' in ifname):
|
||||
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or 'bridge' in ifname):
|
||||
return (7, "Container Virtual")
|
||||
|
||||
# Loopback interface
|
||||
@@ -313,6 +311,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
|
||||
# Other physical interfaces
|
||||
return (2, "Other")
|
||||
|
||||
|
||||
async def shutdown(signal, loop, server):
|
||||
"""Gracefully shutdown the server and close the asyncio loop."""
|
||||
print(f"Received exit signal {signal.name}...")
|
||||
@@ -332,16 +331,16 @@ def is_frozen():
|
||||
|
||||
|
||||
def get_exo_home() -> Path:
|
||||
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
|
||||
else: docs_folder = Path.home() / "Documents"
|
||||
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents"
|
||||
else: docs_folder = Path.home()/"Documents"
|
||||
if not docs_folder.exists(): docs_folder.mkdir(exist_ok=True)
|
||||
exo_folder = docs_folder / "Exo"
|
||||
exo_folder = docs_folder/"Exo"
|
||||
if not exo_folder.exists(): exo_folder.mkdir(exist_ok=True)
|
||||
return exo_folder
|
||||
|
||||
|
||||
def get_exo_images_dir() -> Path:
|
||||
exo_home = get_exo_home()
|
||||
images_dir = exo_home / "Images"
|
||||
images_dir = exo_home/"Images"
|
||||
if not images_dir.exists(): images_dir.mkdir(exist_ok=True)
|
||||
return images_dir
|
||||
|
||||
@@ -14,7 +14,7 @@ class InferenceEngine(ABC):
|
||||
@abstractmethod
|
||||
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def sample(self, x: np.ndarray) -> np.ndarray:
|
||||
pass
|
||||
@@ -33,13 +33,13 @@ class InferenceEngine(ABC):
|
||||
|
||||
async def save_checkpoint(self, shard: Shard, path: str):
|
||||
pass
|
||||
|
||||
|
||||
async def save_session(self, key, value):
|
||||
self.session[key] = value
|
||||
|
||||
|
||||
async def clear_session(self):
|
||||
self.session.empty()
|
||||
|
||||
|
||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
||||
tokens = await self.encode(shard, prompt)
|
||||
if shard.model_id != 'stable-diffusion-2-1-base':
|
||||
@@ -50,12 +50,14 @@ class InferenceEngine(ABC):
|
||||
|
||||
return output_data, inference_state
|
||||
|
||||
|
||||
inference_engine_classes = {
|
||||
"mlx": "MLXDynamicShardInferenceEngine",
|
||||
"tinygrad": "TinygradDynamicShardInferenceEngine",
|
||||
"dummy": "DummyInferenceEngine",
|
||||
}
|
||||
|
||||
|
||||
def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
|
||||
if DEBUG >= 2:
|
||||
print(f"get_inference_engine called with: {inference_engine_name}")
|
||||
|
||||
@@ -19,6 +19,7 @@ if platform.system().lower() == "darwin" and platform.machine().lower() == "arm6
|
||||
else:
|
||||
import numpy as mx
|
||||
|
||||
|
||||
class GRPCPeerHandle(PeerHandle):
|
||||
def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
|
||||
self._id = _id
|
||||
@@ -42,11 +43,9 @@ class GRPCPeerHandle(PeerHandle):
|
||||
|
||||
async def connect(self):
|
||||
if self.channel is None:
|
||||
self.channel = grpc.aio.insecure_channel(self.address, options=[
|
||||
("grpc.max_metadata_size", 32*1024*1024),
|
||||
('grpc.max_receive_message_length', 32*1024*1024),
|
||||
('grpc.max_send_message_length', 32*1024*1024)
|
||||
])
|
||||
self.channel = grpc.aio.insecure_channel(
|
||||
self.address, options=[("grpc.max_metadata_size", 32*1024*1024), ('grpc.max_receive_message_length', 32*1024*1024), ('grpc.max_send_message_length', 32*1024*1024)]
|
||||
)
|
||||
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
|
||||
await self.channel.channel_ready()
|
||||
|
||||
@@ -114,7 +113,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return None
|
||||
|
||||
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
||||
|
||||
|
||||
async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
request = node_service_pb2.ExampleRequest(
|
||||
shard=node_service_pb2.Shard(
|
||||
@@ -136,7 +135,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return loss, grads
|
||||
else:
|
||||
return loss
|
||||
|
||||
|
||||
async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
request = node_service_pb2.TensorRequest(
|
||||
shard=node_service_pb2.Shard(
|
||||
@@ -171,10 +170,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
topology = Topology()
|
||||
for node_id, capabilities in response.nodes.items():
|
||||
device_capabilities = DeviceCapabilities(
|
||||
model=capabilities.model,
|
||||
chip=capabilities.chip,
|
||||
memory=capabilities.memory,
|
||||
flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
|
||||
model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
|
||||
)
|
||||
topology.update_node(node_id, device_capabilities)
|
||||
for node_id, peer_connections in response.peer_graph.items():
|
||||
@@ -198,28 +194,20 @@ class GRPCPeerHandle(PeerHandle):
|
||||
proto_inference_state = node_service_pb2.InferenceState()
|
||||
other_data = {}
|
||||
for k, v in inference_state.items():
|
||||
if isinstance(v, mx.array):
|
||||
np_array = np.array(v)
|
||||
tensor_data = node_service_pb2.Tensor(
|
||||
tensor_data=np_array.tobytes(),
|
||||
shape=list(np_array.shape),
|
||||
dtype=str(np_array.dtype)
|
||||
)
|
||||
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
|
||||
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
|
||||
tensor_list = node_service_pb2.TensorList()
|
||||
for tensor in v:
|
||||
np_array = np.array(tensor)
|
||||
tensor_data = node_service_pb2.Tensor(
|
||||
tensor_data=np_array.tobytes(),
|
||||
shape=list(np_array.shape),
|
||||
dtype=str(np_array.dtype)
|
||||
)
|
||||
tensor_list.tensors.append(tensor_data)
|
||||
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
|
||||
else:
|
||||
# For non-tensor data, we'll still use JSON
|
||||
other_data[k] = v
|
||||
if isinstance(v, mx.array):
|
||||
np_array = np.array(v)
|
||||
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
|
||||
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
|
||||
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
|
||||
tensor_list = node_service_pb2.TensorList()
|
||||
for tensor in v:
|
||||
np_array = np.array(tensor)
|
||||
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
|
||||
tensor_list.tensors.append(tensor_data)
|
||||
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
|
||||
else:
|
||||
# For non-tensor data, we'll still use JSON
|
||||
other_data[k] = v
|
||||
if other_data:
|
||||
proto_inference_state.other_data_json = json.dumps(other_data)
|
||||
return proto_inference_state
|
||||
|
||||
@@ -80,7 +80,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
|
||||
tensor_data = result.tobytes() if result is not None else None
|
||||
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
|
||||
|
||||
|
||||
async def SendExample(self, request, context):
|
||||
shard = Shard(
|
||||
model_id=request.shard.model_id,
|
||||
@@ -102,7 +102,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
else:
|
||||
loss = await self.node.process_example(shard, example, target, length, train, request_id)
|
||||
return node_service_pb2.Loss(loss=loss, grads=None)
|
||||
|
||||
|
||||
async def CollectTopology(self, request, context):
|
||||
max_depth = request.max_depth
|
||||
visited = set(request.visited)
|
||||
@@ -118,12 +118,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
for node_id, cap in topology.nodes.items()
|
||||
}
|
||||
peer_graph = {
|
||||
node_id: node_service_pb2.PeerConnections(
|
||||
connections=[
|
||||
node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
|
||||
for conn in connections
|
||||
]
|
||||
)
|
||||
node_id: node_service_pb2.PeerConnections(connections=[node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) for conn in connections])
|
||||
for node_id, connections in topology.peer_graph.items()
|
||||
}
|
||||
if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
|
||||
@@ -137,7 +132,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
|
||||
result = list(result)
|
||||
if len(img.tensor_data) > 0:
|
||||
result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
|
||||
result = np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
|
||||
self.node.on_token.trigger_all(request_id, result, is_finished)
|
||||
return node_service_pb2.Empty()
|
||||
|
||||
@@ -151,21 +146,18 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
async def HealthCheck(self, request, context):
|
||||
return node_service_pb2.HealthCheckResponse(is_healthy=True)
|
||||
|
||||
def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict:
|
||||
def deserialize_inference_state(self, inference_state_proto: node_service_pb2.InferenceState) -> dict:
|
||||
inference_state = {}
|
||||
|
||||
|
||||
for k, tensor_data in inference_state_proto.tensor_data.items():
|
||||
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
|
||||
inference_state[k] = mx.array(np_array)
|
||||
|
||||
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
|
||||
inference_state[k] = mx.array(np_array)
|
||||
|
||||
for k, tensor_list in inference_state_proto.tensor_list_data.items():
|
||||
inference_state[k] = [
|
||||
mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape))
|
||||
for tensor in tensor_list.tensors
|
||||
]
|
||||
|
||||
inference_state[k] = [mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) for tensor in tensor_list.tensors]
|
||||
|
||||
if inference_state_proto.other_data_json:
|
||||
other_data = json.loads(inference_state_proto.other_data_json)
|
||||
inference_state.update(other_data)
|
||||
|
||||
other_data = json.loads(inference_state_proto.other_data_json)
|
||||
inference_state.update(other_data)
|
||||
|
||||
return inference_state
|
||||
|
||||
@@ -195,7 +195,7 @@ def linux_device_capabilities() -> DeviceCapabilities:
|
||||
gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
|
||||
if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
|
||||
|
||||
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return DeviceCapabilities(
|
||||
@@ -207,22 +207,22 @@ def linux_device_capabilities() -> DeviceCapabilities:
|
||||
elif Device.DEFAULT == "AMD":
|
||||
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
|
||||
from pyrsmi import rocml
|
||||
|
||||
|
||||
rocml.smi_initialize()
|
||||
gpu_name = rocml.smi_get_device_name(0).upper()
|
||||
gpu_memory_info = rocml.smi_get_device_memory_total(0)
|
||||
|
||||
|
||||
if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
|
||||
|
||||
|
||||
rocml.smi_shutdown()
|
||||
|
||||
|
||||
return DeviceCapabilities(
|
||||
model="Linux Box ({gpu_name})",
|
||||
chip={gpu_name},
|
||||
memory=gpu_memory_info.total // 2**20,
|
||||
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
return DeviceCapabilities(
|
||||
model=f"Linux Box (Device: {Device.DEFAULT})",
|
||||
@@ -234,30 +234,31 @@ def linux_device_capabilities() -> DeviceCapabilities:
|
||||
|
||||
def windows_device_capabilities() -> DeviceCapabilities:
|
||||
import psutil
|
||||
|
||||
def get_gpu_info():
|
||||
import win32com.client # install pywin32
|
||||
import win32com.client # install pywin32
|
||||
|
||||
wmiObj = win32com.client.GetObject("winmgmts:\\\\.\\root\\cimv2")
|
||||
gpus = wmiObj.ExecQuery("SELECT * FROM Win32_VideoController")
|
||||
|
||||
gpu_info = []
|
||||
for gpu in gpus:
|
||||
info = {
|
||||
"Name": gpu.Name,
|
||||
"AdapterRAM": gpu.AdapterRAM, # Bug in this property, returns -ve for VRAM > 4GB (uint32 overflow)
|
||||
"DriverVersion": gpu.DriverVersion,
|
||||
"VideoProcessor": gpu.VideoProcessor
|
||||
}
|
||||
gpu_info.append(info)
|
||||
|
||||
info = {
|
||||
"Name": gpu.Name,
|
||||
"AdapterRAM": gpu.AdapterRAM, # Bug in this property, returns -ve for VRAM > 4GB (uint32 overflow)
|
||||
"DriverVersion": gpu.DriverVersion,
|
||||
"VideoProcessor": gpu.VideoProcessor
|
||||
}
|
||||
gpu_info.append(info)
|
||||
|
||||
return gpu_info
|
||||
|
||||
|
||||
gpus_info = get_gpu_info()
|
||||
gpu_names = [gpu['Name'] for gpu in gpus_info]
|
||||
|
||||
contains_nvidia = any('nvidia' in gpu_name.lower()for gpu_name in gpu_names)
|
||||
|
||||
contains_nvidia = any('nvidia' in gpu_name.lower() for gpu_name in gpu_names)
|
||||
contains_amd = any('amd' in gpu_name.lower() for gpu_name in gpu_names)
|
||||
|
||||
|
||||
if contains_nvidia:
|
||||
import pynvml
|
||||
|
||||
@@ -266,7 +267,7 @@ def windows_device_capabilities() -> DeviceCapabilities:
|
||||
gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper()
|
||||
gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name
|
||||
gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
|
||||
|
||||
if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
|
||||
|
||||
return DeviceCapabilities(
|
||||
@@ -278,15 +279,15 @@ def windows_device_capabilities() -> DeviceCapabilities:
|
||||
elif contains_amd:
|
||||
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
|
||||
from pyrsmi import rocml
|
||||
|
||||
|
||||
rocml.smi_initialize()
|
||||
gpu_name = rocml.smi_get_device_name(0).upper()
|
||||
gpu_memory_info = rocml.smi_get_device_memory_total(0)
|
||||
|
||||
|
||||
if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
|
||||
|
||||
|
||||
rocml.smi_shutdown()
|
||||
|
||||
|
||||
return DeviceCapabilities(
|
||||
model="Windows Box ({gpu_name})",
|
||||
chip={gpu_name},
|
||||
@@ -299,4 +300,4 @@ def windows_device_capabilities() -> DeviceCapabilities:
|
||||
chip=f"Unknown Chip (Device(s): {gpu_names})",
|
||||
memory=psutil.virtual_memory().total // 2**20,
|
||||
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user