claude-3.5-sonnet-long, refs #11

This commit is contained in:
Simon Willison
2024-08-30 12:02:57 -07:00
parent 3f0bd4d803
commit 15f31a0717

View File

@@ -10,7 +10,17 @@ def register_models(register):
register(ClaudeMessages("claude-3-opus-20240229"), aliases=("claude-3-opus",))
register(ClaudeMessages("claude-3-sonnet-20240229"), aliases=("claude-3-sonnet",))
register(ClaudeMessages("claude-3-haiku-20240307"), aliases=("claude-3-haiku",))
register(ClaudeMessages("claude-3-5-sonnet-20240620"), aliases=("claude-3.5-sonnet",))
register(
ClaudeMessages("claude-3-5-sonnet-20240620"), aliases=("claude-3.5-sonnet",)
)
register(
ClaudeMessages(
"claude-3-5-sonnet-20240620-long",
claude_model_id="claude-3-5-sonnet-20240620",
extra_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"},
),
aliases=("claude-3.5-sonnet-long",),
)
class ClaudeOptions(llm.Options):
@@ -81,8 +91,10 @@ class ClaudeMessages(llm.Model):
class Options(ClaudeOptions): ...
def __init__(self, model_id):
def __init__(self, model_id, claude_model_id=None, extra_headers=None):
self.model_id = model_id
self.claude_model_id = claude_model_id or model_id
self.extra_headers = extra_headers
def build_messages(self, prompt, conversation) -> List[dict]:
messages = []
@@ -104,7 +116,7 @@ class ClaudeMessages(llm.Model):
client = Anthropic(api_key=self.get_key())
kwargs = {
"model": self.model_id,
"model": self.claude_model_id,
"messages": self.build_messages(prompt, conversation),
"max_tokens": prompt.options.max_tokens,
}
@@ -122,7 +134,9 @@ class ClaudeMessages(llm.Model):
if prompt.system:
kwargs["system"] = prompt.system
usage = None
if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers
if stream:
with client.messages.stream(**kwargs) as stream:
for text in stream.text_stream: