Files
lagent/examples/internlm2_agent_cli_demo.py
liujiangning30 7b71988d09 max_tokens to max_new_tokens (#149)
* Fix: max_new_tokens to max_tokens

* change `max_tokens` to `max_new_tokens` in API models

* max_tokens to max_new_tokens

* inject parameter 'max_new_tokens' for examples

---------

Co-authored-by: wangzy <wangziyi@pjlab.org.cn>
2024-02-06 12:08:53 +08:00

100 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from argparse import ArgumentParser
from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter
from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol
from lagent.llms import HFTransformer
from lagent.llms.meta_template import INTERNLM2_META as META
from lagent.schema import AgentStatusCode
def parse_args():
parser = ArgumentParser(description='chatbot')
parser.add_argument(
'--path',
type=str,
default='internlm/internlm2-chat-20b',
help='The path to the model')
args = parser.parse_args()
return args
def main():
args = parse_args()
# Initialize the HFTransformer-based Language Model (llm)
model = HFTransformer(
path=args.path,
meta_template=META,
max_new_tokens=1024,
top_p=0.8,
top_k=None,
temperature=0.1,
repetition_penalty=1.0,
stop_words=['<|im_end|>'])
plugin_executor = ActionExecutor(actions=[ArxivSearch()]) # noqa: F841
interpreter_executor = ActionExecutor(actions=[IPythonInterpreter()])
chatbot = Internlm2Agent(
llm=model,
plugin_executor=None,
interpreter_executor=interpreter_executor,
protocol=Internlm2Protocol(
meta_prompt=META_CN,
interpreter_prompt=INTERPRETER_CN,
plugin_prompt=PLUGIN_CN,
tool=dict(
begin='{start_token}{name}\n',
start_token='<|action_start|>',
name_map=dict(
plugin='<|plugin|>', interpreter='<|interpreter|>'),
belong='assistant',
end='<|action_end|>\n',
),
),
)
def input_prompt():
print('\ndouble enter to end input >>> ', end='', flush=True)
sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel))
history = []
while True:
try:
prompt = input_prompt()
except UnicodeDecodeError:
print('UnicodeDecodeError')
continue
if prompt == 'exit':
exit(0)
history.append(dict(role='user', content=prompt))
print('\nInternLm2', end='')
current_length = 0
last_status = None
for agent_return in chatbot.stream_chat(history):
status = agent_return.state
if status not in [
AgentStatusCode.STREAM_ING, AgentStatusCode.CODING,
AgentStatusCode.PLUGIN_START
]:
continue
if status != last_status:
current_length = 0
print('')
if isinstance(agent_return.response, dict):
action = f"\n\n {agent_return.response['name']}: \n\n"
action_input = agent_return.response['parameters']
if agent_return.response['name'] == 'IPythonInterpreter':
action_input = action_input['command']
response = action + action_input
else:
response = agent_return.response
print(response[current_length:], end='', flush=True)
current_length = len(response)
last_status = status
print('')
history.extend(agent_return.inner_steps)
if __name__ == '__main__':
main()