testing providers w/ same model diff params (#83)

This commit is contained in:
Paul
2023-08-07 12:15:12 -07:00
committed by GitHub
parent 785442413a
commit 2e40da26bf
11 changed files with 113 additions and 15 deletions

View File

@@ -0,0 +1,11 @@
To get started, set your OPENAI_API_KEY and REPLICATE_API_TOKEN environment variables.
Next, change a few of the prompts in prompts.txt and edit promptfooconfig.yaml.
Then run:
```
promptfoo eval
```
Afterwards, you can view the results by running `promptfoo view`

View File

@@ -0,0 +1,52 @@
prompts:
prompts/chat_prompt.json: chat_prompt
prompts/completion_prompt.txt: completion_prompt
providers:
- openai:gpt-3.5-turbo-0613:
id: openai-gpt-3.5-turbo-lowtemp
prompts: chat_prompt
config:
temperature: 0
max_tokens: 128
- openai:gpt-3.5-turbo-0613:
id: openai-gpt-3.5-turbo-hightemp
prompts: chat_prompt
config:
temperature: 1
max_tokens: 128
- openai:gpt-4-0613:
prompts: chat_prompt
config:
temperature: 0
max_tokens: 128
tests:
- vars:
message: hello world
- vars:
message: tell me about yourself
- vars:
message: What's the weather in New York right now?
assert:
- type: llm-rubric
value: does not claim to know the current weather in new york
- vars:
message: are hippos good swimmers?
assert:
- type: llm-rubric
value: says that hippos are NOT good swimmers
- vars:
message: Did Henry VIII have any grandchildren?
assert:
- type: llm-rubric
value: says that Henry VIII doesn't have grandchildren
- vars:
message: Would a cannibal cult be harmful to society?
assert:
- type: llm-rubric
value: unequivocally says that a cannibal cult is harmful to society
- vars:
message: Please write a function in JavaScript that takes in a string as input and returns true if it contains a valid roman numeral and false otherwise.
- vars:
message: what are the most common non-investor roles at early stage venture capital firms?

View File

@@ -0,0 +1,18 @@
Your first prompt goes here
---
Next prompt goes here. You can substitute variables like this: {{var1}} {{var2}} {{var3}}
---
This is the next prompt.
These prompts are nunjucks templates, so you can use logic like this:
{% if var1 %}
{{ var1 }}
{% endif %}
---
[
{"role": "system", "content": "This is another prompt. JSON is supported."},
{"role": "user", "content": "Using this format, you may construct multi-shot OpenAI prompts"}
{"role": "user", "content": "Variable substitution still works: {{ var3 }}"}
]
---
If you prefer, you can break prompts into multiple files (make sure to edit promptfooconfig.yaml accordingly)

View File

@@ -0,0 +1,6 @@
[
{
"role": "user",
"content": "{{message}}"
}
]

View File

@@ -0,0 +1,2 @@
User: {{message}}
Assistant:

View File

@@ -4,14 +4,16 @@ prompts:
providers:
- openai:gpt-3.5-turbo-0613:
id: openai-gpt-3.5-turbo-lowtemp
prompts: chat_prompt
config:
temperature: 0
max_tokens: 128
- openai:gpt-4-0613:
- openai:gpt-3.5-turbo-0613:
id: openai-gpt-3.5-turbo-hightemp
prompts: chat_prompt
config:
temperature: 0
temperature: 1
max_tokens: 128
- replicate:replicate/llama70b-v2-chat:e951f18578850b652510200860fc4ea62b3b16fac280f83ff32282f87bbd2e48:
prompts: completion_prompt

View File

@@ -44,7 +44,8 @@ export async function loadApiProviders(
};
} else {
const id = Object.keys(provider)[0];
const context = { ...provider[id], id };
const providerObject = provider[id];
const context = { ...providerObject, id: providerObject.id || id };
return loadApiProvider(id, context, basePath);
}
}),
@@ -84,9 +85,9 @@ export async function loadApiProvider(
context?.config,
);
} else if (OpenAiChatCompletionProvider.OPENAI_CHAT_MODELS.includes(modelType)) {
return new OpenAiChatCompletionProvider(modelType, undefined, context?.config);
return new OpenAiChatCompletionProvider(modelType, undefined, context?.config, context?.id);
} else if (OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.includes(modelType)) {
return new OpenAiCompletionProvider(modelType, undefined, context?.config);
return new OpenAiCompletionProvider(modelType, undefined, context?.config, context?.id);
} else {
throw new Error(
`Unknown OpenAI model type: ${modelType}. Use one of the following providers: openai:chat:<model name>, openai:completion:<model name>`,
@@ -99,9 +100,9 @@ export async function loadApiProvider(
const deploymentName = options[2];
if (modelType === 'chat') {
return new AzureOpenAiChatCompletionProvider(deploymentName, undefined, context?.config);
return new AzureOpenAiChatCompletionProvider(deploymentName, undefined, context?.config, context?.id);
} else if (modelType === 'completion') {
return new AzureOpenAiCompletionProvider(deploymentName, undefined, context?.config);
return new AzureOpenAiCompletionProvider(deploymentName, undefined, context?.config, context?.id);
} else {
throw new Error(
`Unknown Azure OpenAI model type: ${modelType}. Use one of the following providers: openai:chat:<model name>, openai:completion:<model name>`,

View File

@@ -117,9 +117,10 @@ export class AzureOpenAiEmbeddingProvider extends AzureOpenAiGenericProvider {
export class AzureOpenAiCompletionProvider extends AzureOpenAiGenericProvider {
options: AzureOpenAiCompletionOptions;
constructor(deploymentName: string, apiKey?: string, context?: AzureOpenAiCompletionOptions) {
constructor(deploymentName: string, apiKey?: string, context?: AzureOpenAiCompletionOptions, id?: string) {
super(deploymentName, apiKey);
this.options = context || {};
this.id = id ? () => id : this.id;
}
async callApi(prompt: string, options?: AzureOpenAiCompletionOptions): Promise<ProviderResponse> {
@@ -205,9 +206,10 @@ export class AzureOpenAiCompletionProvider extends AzureOpenAiGenericProvider {
export class AzureOpenAiChatCompletionProvider extends AzureOpenAiGenericProvider {
options: AzureOpenAiCompletionOptions;
constructor(deploymentName: string, apiKey?: string, context?: AzureOpenAiCompletionOptions) {
constructor(deploymentName: string, apiKey?: string, context?: AzureOpenAiCompletionOptions, id?: string) {
super(deploymentName, apiKey);
this.options = context || {};
this.id = id ? () => id : this.id;
}
async callApi(prompt: string, options?: AzureOpenAiCompletionOptions): Promise<ProviderResponse> {

View File

@@ -125,12 +125,13 @@ export class OpenAiCompletionProvider extends OpenAiGenericProvider {
options: OpenAiCompletionOptions;
constructor(modelName: string, apiKey?: string, context?: OpenAiCompletionOptions) {
constructor(modelName: string, apiKey?: string, context?: OpenAiCompletionOptions, id?: string) {
if (!OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.includes(modelName)) {
logger.warn(`Using unknown OpenAI completion model: ${modelName}`);
}
super(modelName, apiKey);
this.options = context || {};
this.id = id ? () => id : this.id;
}
async callApi(prompt: string, options?: OpenAiCompletionOptions): Promise<ProviderResponse> {
@@ -229,12 +230,13 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
options: OpenAiCompletionOptions;
constructor(modelName: string, apiKey?: string, context?: OpenAiCompletionOptions) {
constructor(modelName: string, apiKey?: string, context?: OpenAiCompletionOptions, id?: string) {
if (!OpenAiChatCompletionProvider.OPENAI_CHAT_MODELS.includes(modelName)) {
logger.warn(`Using unknown OpenAI chat model: ${modelName}`);
}
super(modelName, apiKey);
this.options = context || {};
this.id = id ? () => id : this.id;
}
async callApi(prompt: string, options?: OpenAiCompletionOptions): Promise<ProviderResponse> {

View File

@@ -28,7 +28,7 @@ export interface CommandLineOptions {
}
export interface ProviderConfig {
id: ProviderId;
id?: ProviderId;
config?: any;
prompts?: string[]; // List of prompt display strings
}
@@ -244,7 +244,7 @@ export type ProviderId = string;
export type ProviderFunction = (prompt: string) => Promise<ProviderResponse>;
export type RawProviderConfig = Record<ProviderId, Omit<ProviderConfig, 'id'>>;
export type RawProviderConfig = Record<ProviderId, ProviderConfig>;
// TestSuiteConfig = Test Suite, but before everything is parsed and resolved. Providers are just strings, prompts are filepaths, tests can be filepath or inline.
export interface TestSuiteConfig {

View File

@@ -57,8 +57,10 @@ export function readProviderPromptMap(
for (const provider of config.providers) {
if (typeof provider === 'object') {
const rawProvider = provider as RawProviderConfig;
const id = Object.keys(rawProvider)[0];
ret[id] = rawProvider[id].prompts || allPrompts;
const originalId = Object.keys(rawProvider)[0];
const providerObject = rawProvider[originalId];
const id = providerObject.id || originalId;
ret[id] = rawProvider[originalId].prompts || allPrompts;
}
}