mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
clean up DEBUG=2 logs, a few fixes for token
This commit is contained in:
@@ -441,7 +441,7 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
|
||||
shard_specific_patterns.add(sorted_file_names[-1])
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
|
||||
if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
|
||||
return list(default_patterns | shard_specific_patterns)
|
||||
|
||||
async def get_file_download_percentage(
|
||||
|
||||
@@ -159,13 +159,14 @@ class HFShardDownloader(ShardDownloader):
|
||||
print(f"Download calculation for {self.current_repo_id}:")
|
||||
print(f"Total bytes: {total_bytes}")
|
||||
print(f"Downloaded bytes: {downloaded_bytes}")
|
||||
if DEBUG >= 3:
|
||||
for file in relevant_files:
|
||||
print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
|
||||
|
||||
return status
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
if DEBUG >= 3:
|
||||
print(f"Error getting shard download status: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
10
exo/main.py
10
exo/main.py
@@ -187,10 +187,10 @@ api = ChatGPTAPI(
|
||||
system_prompt=args.system_prompt
|
||||
)
|
||||
buffered_token_output = {}
|
||||
def update_topology_viz(req_id, token, __):
|
||||
def update_topology_viz(req_id, tokens, __):
|
||||
if not topology_viz: return
|
||||
if req_id in buffered_token_output: buffered_token_output[req_id].append(token)
|
||||
else: buffered_token_output[req_id] = [token]
|
||||
if req_id in buffered_token_output: buffered_token_output[req_id].extend(tokens)
|
||||
else: buffered_token_output[req_id] = tokens
|
||||
|
||||
if inference_engine.shard.model_id != 'stable-diffusion-2-1-base':
|
||||
topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(buffered_token_output[req_id]))
|
||||
@@ -243,8 +243,8 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
|
||||
await node.process_prompt(shard, prompt, request_id=request_id)
|
||||
|
||||
tokens = []
|
||||
def on_token(_request_id, _token, _is_finished):
|
||||
tokens.append(_token)
|
||||
def on_token(_request_id, _tokens, _is_finished):
|
||||
tokens.extend(_tokens)
|
||||
return _request_id == request_id and _is_finished
|
||||
await callback.wait(on_token, timeout=300)
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class Node:
|
||||
self.max_generate_tokens = max_generate_tokens
|
||||
self.topology_viz = topology_viz
|
||||
self.default_sample_temperature = default_sample_temperature
|
||||
self._on_token = AsyncCallbackSystem[str, Tuple[str, int, bool]]()
|
||||
self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
|
||||
self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
|
||||
self._on_opaque_status.register("node_status").on_next(self.on_node_status)
|
||||
self.node_download_progress: Dict[str, RepoProgressEvent] = {}
|
||||
@@ -130,9 +130,8 @@ class Node:
|
||||
self.buffered_token_output[request_id][0].append(token.item())
|
||||
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
||||
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
||||
asyncio.create_task(self.broadcast_result(request_id, [self.buffered_token_output[request_id][0][-1]], is_finished))
|
||||
forward = token.reshape(1, -1)
|
||||
intermediate_result = self.buffered_token_output[request_id][0][-1]
|
||||
intermediate_result = [self.buffered_token_output[request_id][0][-1]]
|
||||
else:
|
||||
forward = result
|
||||
else:
|
||||
@@ -575,16 +574,16 @@ class Node:
|
||||
return self.topology
|
||||
|
||||
@property
|
||||
def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, int, bool]]:
|
||||
def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
|
||||
return self._on_token
|
||||
|
||||
@property
|
||||
def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
|
||||
return self._on_opaque_status
|
||||
|
||||
def trigger_on_token_callbacks(self, request_id: str, token: int, is_finished: bool) -> None:
|
||||
if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {token=} {is_finished=}")
|
||||
self.on_token.trigger_all(request_id, token, is_finished)
|
||||
def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
|
||||
if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {tokens=} {is_finished=}")
|
||||
self.on_token.trigger_all(request_id, tokens, is_finished)
|
||||
|
||||
async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
|
||||
async def send_result_to_peer(peer):
|
||||
|
||||
Reference in New Issue
Block a user