Fix progress bar

This commit is contained in:
Isaac Ong
2024-07-07 12:25:53 -07:00
parent ac422b9226
commit 1da650ac2b
3 changed files with 11 additions and 6 deletions

View File

@@ -9,8 +9,6 @@ from tqdm import tqdm
from routellm.controller import Controller
from routellm.routers.routers import ROUTER_CLS
tqdm.pandas()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(

View File

@@ -4,8 +4,8 @@ from types import SimpleNamespace
from typing import Any, Optional
import pandas as pd
import tqdm
from litellm import acompletion, completion
from tqdm import tqdm
from routellm.routers.routers import ROUTER_CLS
@@ -54,10 +54,16 @@ class Controller:
self.api_base = api_base
self.api_key = api_key
self.model_counts = defaultdict(lambda: defaultdict(int))
self.progress_bar = progress_bar
if config is None:
config = GPT_4_AUGMENTED_CONFIG
router_pbar = tqdm.tqdm(routers) if progress_bar else None
router_pbar = None
if progress_bar:
router_pbar = tqdm(routers)
tqdm.pandas()
for router in routers:
if router_pbar is not None:
router_pbar.set_description(f"Loading {router}")
@@ -116,8 +122,10 @@ class Controller:
):
self._validate_router_threshold(router, 0)
router_instance = self.routers[router]
if router_instance.NO_PARALLEL:
if router_instance.NO_PARALLEL and self.progress_bar:
return prompts.progress_apply(router_instance.calculate_strong_win_rate)
elif router_instance.NO_PARALLEL:
return prompts.apply(router_instance.calculate_strong_win_rate)
else:
return prompts.parallel_apply(router_instance.calculate_strong_win_rate)

View File

@@ -12,7 +12,6 @@ from routellm.routers.routers import Router
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
pd.options.mode.copy_on_write = True
tqdm.pandas()
class Benchmark(abc.ABC):