[Fix] Fix generate kwargs mismatch for PEFT model & Set import * for agents (#44)

* set `import *` for agents

* adapt to the peft model

* fix pre-commit
This commit is contained in:
Zhihao Lin
2023-10-13 17:11:14 +08:00
committed by GitHub
parent 276154c212
commit d908d1b47a
4 changed files with 8 additions and 9 deletions

View File

@@ -1,6 +1,4 @@
from .autogpt import AutoGPT
from .base_agent import BaseAgent
from .react import ReAct
from .rewoo import ReWOO
__all__ = ['BaseAgent', 'ReAct', 'AutoGPT', 'ReWOO']
from .autogpt import * # noqa: F401, F403
from .base_agent import * # noqa: F401, F403
from .react import * # noqa: F401, F403
from .rewoo import * # noqa: F401, F403

View File

@@ -76,7 +76,8 @@ class LMTemplateParser:
"""Intermidate prompt template parser, specifically for language models.
Args:
meta_template (list of dict, optional): The meta template for the model.
meta_template (list of dict, optional): The meta template for the
model.
"""
def __init__(self, meta_template: Optional[List[Dict]] = None):

View File

@@ -100,7 +100,7 @@ class HFTransformer(BaseModel):
max_length=self.max_seq_len - max_out_len)['input_ids']
input_ids = torch.tensor(input_ids, device=self.model.device)
outputs = self.model.generate(
input_ids, max_new_tokens=max_out_len, **kwargs)
input_ids=input_ids, max_new_tokens=max_out_len, **kwargs)
if not self.extract_pred_after_decode:
outputs = outputs[:, input_ids.shape[1]:]

View File

@@ -1,4 +1,4 @@
lmdeploy
streamlit
torch
transformers
streamlit