diff --git a/experiments/module-summarization/main.py b/experiments/module-summarization/main.py index 260c26b..9e8f4a3 100644 --- a/experiments/module-summarization/main.py +++ b/experiments/module-summarization/main.py @@ -19,7 +19,7 @@ from itertools import cycle class LLMGenerator: def __init__(self, model_name, device, **model_args): # Create a vllm LLM instance - engine_args = EngineArgs(model=model_name, device=device, **model_args) + engine_args = EngineArgs(model=model_name, gpu_memory_utilization=0.8, device=device, **model_args) self.model = LLM(**vars(engine_args)) self.model_name = model_name self.device = device @@ -58,7 +58,7 @@ class LLMGenerator: class LLMScorer: def __init__(self, model_name, device, **model_args): # Create a vllm LLM instance - engine_args = EngineArgs(model=model_name, device=device, **model_args) + engine_args = EngineArgs(model=model_name, gpu_memory_utilization=0.8, device=device, **model_args) self.model = LLM(**vars(engine_args)) self.model_name = model_name self.device = device @@ -1315,4 +1315,4 @@ def run_documentation_task( if __name__ == "__main__": - fire.Fire(run_documentation_task) \ No newline at end of file + fire.Fire(run_documentation_task)