Add custom provider example

This commit is contained in:
Ian Webster
2023-05-03 16:23:41 -07:00
parent 0cd577bc00
commit 4f6b6e2a1a
11 changed files with 68 additions and 49 deletions

View File

@@ -306,6 +306,8 @@ Below is an example of a custom API provider that returns a predefined output an
```javascript
// customApiProvider.js
import fetch from 'node-fetch';
class CustomApiProvider {
id() {
return 'my-custom-api';
@@ -328,7 +330,7 @@ class CustomApiProvider {
}
}
module.exports.default = CustomApiProvider;
export default CustomApiProvider;
```
To use the custom API provider with `promptfoo`, pass the path to the module as the `provider` option in the CLI invocation:

View File

@@ -0,0 +1,5 @@
Run:
```
promptfoo --prompt prompts.txt --vars vars.csv --provider openai:chat --output output.json --provider customProvider.js
```

View File

@@ -0,0 +1,37 @@
import fetch from 'node-fetch';
class CustomApiProvider {
id() {
return 'my-custom-api';
}
async callApi(prompt) {
const body = {
model: 'text-davinci-002',
prompt,
max_tokens: 1024,
temperature: 0,
};
const response = await fetch('https://api.openai.com/v1/completions', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${process.env.OPENAI_API_KEY}`,
},
body: JSON.stringify(body),
});
const data = await response.json();
const ret = {
output: data.choices[0].text,
tokenUsage: {
total: data.usage.total_tokens,
prompt: data.usage.prompt_tokens,
completion: data.usage.completion_tokens,
},
};
return ret;
}
}
export default CustomApiProvider;

View File

@@ -0,0 +1,3 @@
Rephrase this in French: {{body}}
---
Rephrase this like a pirate: {{body}}

View File

@@ -0,0 +1,3 @@
body
Hello world
I'm hungry
1 body
2 Hello world
3 I'm hungry

View File

@@ -106,4 +106,4 @@
"I'm hungry"
]
]
}
}

View File

@@ -185,6 +185,7 @@ export async function evaluate(options: EvaluateOptions): Promise<EvaluateSummar
progressbar.stop();
}
// TODO(ian): Display errors in table UI.
if (isTest) {
table.push(
...combinedOutputs.map((output, index) => [

View File

@@ -55,7 +55,7 @@ program
vars = readVars(cmdObj.vars);
}
const providers = cmdObj.provider.map((p) => loadApiProvider(p));
const providers = await Promise.all(cmdObj.provider.map(async (p) => await loadApiProvider(p)));
const options: EvaluateOptions = {
prompts: readPrompts(cmdObj.prompt),
vars,

View File

@@ -1,4 +1,5 @@
import fetch from 'node-fetch';
import path from 'node:path';
import { ApiProvider, ProviderResponse } from './types.js';
import logger from './logger.js';
@@ -44,11 +45,7 @@ export class OpenAiCompletionProvider extends OpenAiGenericProvider {
constructor(modelName: string, apiKey?: string) {
if (!OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.includes(modelName)) {
throw new Error(
`Unknown OpenAI completion model name: ${modelName}. Use one of the following: ${OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.join(
', ',
)}`,
);
logger.warn(`Using unknown OpenAI completion model: ${modelName}`);
}
super(modelName, apiKey);
}
@@ -95,11 +92,7 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
constructor(modelName: string, apiKey?: string) {
if (!OpenAiChatCompletionProvider.OPENAI_CHAT_MODELS.includes(modelName)) {
throw new Error(
`Unknown OpenAI completion model name: ${modelName}. Use one of the following: ${OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.join(
', ',
)}`,
);
logger.warn(`Using unknown OpenAI chat model: ${modelName}`);
}
super(modelName, apiKey);
}
@@ -142,7 +135,7 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
}
}
export function loadApiProvider(providerPath: string): ApiProvider {
export async function loadApiProvider(providerPath: string): Promise<ApiProvider> {
if (providerPath?.startsWith('openai:')) {
// Load OpenAI module
const options = providerPath.split(':');
@@ -165,6 +158,6 @@ export function loadApiProvider(providerPath: string): ApiProvider {
}
// Load custom module
const CustomApiProvider = require(providerPath).default;
const CustomApiProvider = (await import(path.join(process.cwd(), providerPath))).default;
return new CustomApiProvider();
}

View File

@@ -11,4 +11,4 @@ class CustomApiProvider {
}
}
module.exports.default = CustomApiProvider;
export default CustomApiProvider;

View File

@@ -55,48 +55,23 @@ describe('providers', () => {
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
});
test('loadApiProvider with openai:chat', () => {
const provider = loadApiProvider('openai:chat');
test('loadApiProvider with openai:chat', async () => {
const provider = await loadApiProvider('openai:chat');
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
});
test('loadApiProvider with openai:completion', () => {
const provider = loadApiProvider('openai:completion');
test('loadApiProvider with openai:completion', async () => {
const provider = await loadApiProvider('openai:completion');
expect(provider).toBeInstanceOf(OpenAiCompletionProvider);
});
test('loadApiProvider with openai:chat:modelName', () => {
const provider = loadApiProvider('openai:chat:gpt-3.5-turbo');
test('loadApiProvider with openai:chat:modelName', async () => {
const provider = await loadApiProvider('openai:chat:gpt-3.5-turbo');
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
});
test('loadApiProvider with openai:completion:modelName', () => {
const provider = loadApiProvider('openai:completion:text-davinci-003');
test('loadApiProvider with openai:completion:modelName', async () => {
const provider = await loadApiProvider('openai:completion:text-davinci-003');
expect(provider).toBeInstanceOf(OpenAiCompletionProvider);
});
test('loadApiProvider with custom module', () => {
// Set up the custom module mock
const customModulePath = path.resolve(__dirname, '__mocks__', 'tempCustomModule.js');
jest.doMock(customModulePath);
const CustomApiProvider = require(customModulePath).default;
const provider = loadApiProvider(customModulePath);
expect(provider).toBeInstanceOf(CustomApiProvider);
// Clean up the mock
jest.dontMock(customModulePath);
});
test('loadApiProvider with invalid openai model', () => {
expect(() => {
loadApiProvider('openai:invalid');
}).toThrowError();
});
test('loadApiProvider with unknown openai model and type', () => {
expect(() => {
loadApiProvider('openai:unknown:unknown');
}).toThrowError();
});
});