mirror of
https://github.com/lm-sys/RouteLLM.git
synced 2024-07-11 08:05:43 +03:00
Fix progress bar
This commit is contained in:
@@ -9,8 +9,6 @@ from tqdm import tqdm
|
||||
from routellm.controller import Controller
|
||||
from routellm.routers.routers import ROUTER_CLS
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
||||
@@ -4,8 +4,8 @@ from types import SimpleNamespace
|
||||
from typing import Any, Optional
|
||||
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
from litellm import acompletion, completion
|
||||
from tqdm import tqdm
|
||||
|
||||
from routellm.routers.routers import ROUTER_CLS
|
||||
|
||||
@@ -54,10 +54,16 @@ class Controller:
|
||||
self.api_base = api_base
|
||||
self.api_key = api_key
|
||||
self.model_counts = defaultdict(lambda: defaultdict(int))
|
||||
self.progress_bar = progress_bar
|
||||
|
||||
if config is None:
|
||||
config = GPT_4_AUGMENTED_CONFIG
|
||||
|
||||
router_pbar = tqdm.tqdm(routers) if progress_bar else None
|
||||
router_pbar = None
|
||||
if progress_bar:
|
||||
router_pbar = tqdm(routers)
|
||||
tqdm.pandas()
|
||||
|
||||
for router in routers:
|
||||
if router_pbar is not None:
|
||||
router_pbar.set_description(f"Loading {router}")
|
||||
@@ -116,8 +122,10 @@ class Controller:
|
||||
):
|
||||
self._validate_router_threshold(router, 0)
|
||||
router_instance = self.routers[router]
|
||||
if router_instance.NO_PARALLEL:
|
||||
if router_instance.NO_PARALLEL and self.progress_bar:
|
||||
return prompts.progress_apply(router_instance.calculate_strong_win_rate)
|
||||
elif router_instance.NO_PARALLEL:
|
||||
return prompts.apply(router_instance.calculate_strong_win_rate)
|
||||
else:
|
||||
return prompts.parallel_apply(router_instance.calculate_strong_win_rate)
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from routellm.routers.routers import Router
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
pd.options.mode.copy_on_write = True
|
||||
tqdm.pandas()
|
||||
|
||||
|
||||
class Benchmark(abc.ABC):
|
||||
|
||||
Reference in New Issue
Block a user