clean up DEBUG=2 logs, a few fixes for token

This commit is contained in:
Alex Cheema
2025-01-22 22:27:02 +00:00
parent 9954ce8e4d
commit 55d1846f5e
4 changed files with 14 additions and 14 deletions

View File

@@ -441,7 +441,7 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
shard_specific_patterns.add(sorted_file_names[-1])
else:
shard_specific_patterns = set(["*.safetensors"])
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)
async def get_file_download_percentage(

View File

@@ -159,13 +159,14 @@ class HFShardDownloader(ShardDownloader):
print(f"Download calculation for {self.current_repo_id}:")
print(f"Total bytes: {total_bytes}")
print(f"Downloaded bytes: {downloaded_bytes}")
if DEBUG >= 3:
for file in relevant_files:
print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
return status
except Exception as e:
if DEBUG >= 2:
if DEBUG >= 3:
print(f"Error getting shard download status: {e}")
traceback.print_exc()
return None

View File

@@ -187,10 +187,10 @@ api = ChatGPTAPI(
system_prompt=args.system_prompt
)
buffered_token_output = {}
def update_topology_viz(req_id, token, __):
def update_topology_viz(req_id, tokens, __):
if not topology_viz: return
if req_id in buffered_token_output: buffered_token_output[req_id].append(token)
else: buffered_token_output[req_id] = [token]
if req_id in buffered_token_output: buffered_token_output[req_id].extend(tokens)
else: buffered_token_output[req_id] = tokens
if inference_engine.shard.model_id != 'stable-diffusion-2-1-base':
topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(buffered_token_output[req_id]))
@@ -243,8 +243,8 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
await node.process_prompt(shard, prompt, request_id=request_id)
tokens = []
def on_token(_request_id, _token, _is_finished):
tokens.append(_token)
def on_token(_request_id, _tokens, _is_finished):
tokens.extend(_tokens)
return _request_id == request_id and _is_finished
await callback.wait(on_token, timeout=300)

View File

@@ -47,7 +47,7 @@ class Node:
self.max_generate_tokens = max_generate_tokens
self.topology_viz = topology_viz
self.default_sample_temperature = default_sample_temperature
self._on_token = AsyncCallbackSystem[str, Tuple[str, int, bool]]()
self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
self._on_opaque_status.register("node_status").on_next(self.on_node_status)
self.node_download_progress: Dict[str, RepoProgressEvent] = {}
@@ -130,9 +130,8 @@ class Node:
self.buffered_token_output[request_id][0].append(token.item())
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
asyncio.create_task(self.broadcast_result(request_id, [self.buffered_token_output[request_id][0][-1]], is_finished))
forward = token.reshape(1, -1)
intermediate_result = self.buffered_token_output[request_id][0][-1]
intermediate_result = [self.buffered_token_output[request_id][0][-1]]
else:
forward = result
else:
@@ -575,16 +574,16 @@ class Node:
return self.topology
@property
def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, int, bool]]:
def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
return self._on_token
@property
def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
return self._on_opaque_status
def trigger_on_token_callbacks(self, request_id: str, token: int, is_finished: bool) -> None:
if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {token=} {is_finished=}")
self.on_token.trigger_all(request_id, token, is_finished)
def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {tokens=} {is_finished=}")
self.on_token.trigger_all(request_id, tokens, is_finished)
async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
async def send_result_to_peer(peer):