mirror of
https://github.com/promptfoo/promptfoo.git
synced 2023-08-15 01:10:51 +03:00
Add custom provider example
This commit is contained in:
@@ -306,6 +306,8 @@ Below is an example of a custom API provider that returns a predefined output an
|
||||
|
||||
```javascript
|
||||
// customApiProvider.js
|
||||
import fetch from 'node-fetch';
|
||||
|
||||
class CustomApiProvider {
|
||||
id() {
|
||||
return 'my-custom-api';
|
||||
@@ -328,7 +330,7 @@ class CustomApiProvider {
|
||||
}
|
||||
}
|
||||
|
||||
module.exports.default = CustomApiProvider;
|
||||
export default CustomApiProvider;
|
||||
```
|
||||
|
||||
To use the custom API provider with `promptfoo`, pass the path to the module as the `provider` option in the CLI invocation:
|
||||
|
||||
5
examples/simple-cli-custom-provider/README.md
Normal file
5
examples/simple-cli-custom-provider/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
Run:
|
||||
|
||||
```
|
||||
promptfoo --prompt prompts.txt --vars vars.csv --provider openai:chat --output output.json --provider customProvider.js
|
||||
```
|
||||
37
examples/simple-cli-custom-provider/customProvider.js
Normal file
37
examples/simple-cli-custom-provider/customProvider.js
Normal file
@@ -0,0 +1,37 @@
|
||||
import fetch from 'node-fetch';
|
||||
|
||||
class CustomApiProvider {
|
||||
id() {
|
||||
return 'my-custom-api';
|
||||
}
|
||||
|
||||
async callApi(prompt) {
|
||||
const body = {
|
||||
model: 'text-davinci-002',
|
||||
prompt,
|
||||
max_tokens: 1024,
|
||||
temperature: 0,
|
||||
};
|
||||
const response = await fetch('https://api.openai.com/v1/completions', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${process.env.OPENAI_API_KEY}`,
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
const ret = {
|
||||
output: data.choices[0].text,
|
||||
tokenUsage: {
|
||||
total: data.usage.total_tokens,
|
||||
prompt: data.usage.prompt_tokens,
|
||||
completion: data.usage.completion_tokens,
|
||||
},
|
||||
};
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
export default CustomApiProvider;
|
||||
3
examples/simple-cli-custom-provider/prompts.txt
Normal file
3
examples/simple-cli-custom-provider/prompts.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
Rephrase this in French: {{body}}
|
||||
---
|
||||
Rephrase this like a pirate: {{body}}
|
||||
3
examples/simple-cli-custom-provider/vars.csv
Normal file
3
examples/simple-cli-custom-provider/vars.csv
Normal file
@@ -0,0 +1,3 @@
|
||||
body
|
||||
Hello world
|
||||
I'm hungry
|
||||
|
@@ -106,4 +106,4 @@
|
||||
"I'm hungry"
|
||||
]
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,6 +185,7 @@ export async function evaluate(options: EvaluateOptions): Promise<EvaluateSummar
|
||||
progressbar.stop();
|
||||
}
|
||||
|
||||
// TODO(ian): Display errors in table UI.
|
||||
if (isTest) {
|
||||
table.push(
|
||||
...combinedOutputs.map((output, index) => [
|
||||
|
||||
@@ -55,7 +55,7 @@ program
|
||||
vars = readVars(cmdObj.vars);
|
||||
}
|
||||
|
||||
const providers = cmdObj.provider.map((p) => loadApiProvider(p));
|
||||
const providers = await Promise.all(cmdObj.provider.map(async (p) => await loadApiProvider(p)));
|
||||
const options: EvaluateOptions = {
|
||||
prompts: readPrompts(cmdObj.prompt),
|
||||
vars,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import fetch from 'node-fetch';
|
||||
import path from 'node:path';
|
||||
|
||||
import { ApiProvider, ProviderResponse } from './types.js';
|
||||
import logger from './logger.js';
|
||||
@@ -44,11 +45,7 @@ export class OpenAiCompletionProvider extends OpenAiGenericProvider {
|
||||
|
||||
constructor(modelName: string, apiKey?: string) {
|
||||
if (!OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.includes(modelName)) {
|
||||
throw new Error(
|
||||
`Unknown OpenAI completion model name: ${modelName}. Use one of the following: ${OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.join(
|
||||
', ',
|
||||
)}`,
|
||||
);
|
||||
logger.warn(`Using unknown OpenAI completion model: ${modelName}`);
|
||||
}
|
||||
super(modelName, apiKey);
|
||||
}
|
||||
@@ -95,11 +92,7 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
|
||||
|
||||
constructor(modelName: string, apiKey?: string) {
|
||||
if (!OpenAiChatCompletionProvider.OPENAI_CHAT_MODELS.includes(modelName)) {
|
||||
throw new Error(
|
||||
`Unknown OpenAI completion model name: ${modelName}. Use one of the following: ${OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.join(
|
||||
', ',
|
||||
)}`,
|
||||
);
|
||||
logger.warn(`Using unknown OpenAI chat model: ${modelName}`);
|
||||
}
|
||||
super(modelName, apiKey);
|
||||
}
|
||||
@@ -142,7 +135,7 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
|
||||
}
|
||||
}
|
||||
|
||||
export function loadApiProvider(providerPath: string): ApiProvider {
|
||||
export async function loadApiProvider(providerPath: string): Promise<ApiProvider> {
|
||||
if (providerPath?.startsWith('openai:')) {
|
||||
// Load OpenAI module
|
||||
const options = providerPath.split(':');
|
||||
@@ -165,6 +158,6 @@ export function loadApiProvider(providerPath: string): ApiProvider {
|
||||
}
|
||||
|
||||
// Load custom module
|
||||
const CustomApiProvider = require(providerPath).default;
|
||||
const CustomApiProvider = (await import(path.join(process.cwd(), providerPath))).default;
|
||||
return new CustomApiProvider();
|
||||
}
|
||||
|
||||
@@ -11,4 +11,4 @@ class CustomApiProvider {
|
||||
}
|
||||
}
|
||||
|
||||
module.exports.default = CustomApiProvider;
|
||||
export default CustomApiProvider;
|
||||
@@ -55,48 +55,23 @@ describe('providers', () => {
|
||||
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
|
||||
});
|
||||
|
||||
test('loadApiProvider with openai:chat', () => {
|
||||
const provider = loadApiProvider('openai:chat');
|
||||
test('loadApiProvider with openai:chat', async () => {
|
||||
const provider = await loadApiProvider('openai:chat');
|
||||
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
|
||||
});
|
||||
|
||||
test('loadApiProvider with openai:completion', () => {
|
||||
const provider = loadApiProvider('openai:completion');
|
||||
test('loadApiProvider with openai:completion', async () => {
|
||||
const provider = await loadApiProvider('openai:completion');
|
||||
expect(provider).toBeInstanceOf(OpenAiCompletionProvider);
|
||||
});
|
||||
|
||||
test('loadApiProvider with openai:chat:modelName', () => {
|
||||
const provider = loadApiProvider('openai:chat:gpt-3.5-turbo');
|
||||
test('loadApiProvider with openai:chat:modelName', async () => {
|
||||
const provider = await loadApiProvider('openai:chat:gpt-3.5-turbo');
|
||||
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
|
||||
});
|
||||
|
||||
test('loadApiProvider with openai:completion:modelName', () => {
|
||||
const provider = loadApiProvider('openai:completion:text-davinci-003');
|
||||
test('loadApiProvider with openai:completion:modelName', async () => {
|
||||
const provider = await loadApiProvider('openai:completion:text-davinci-003');
|
||||
expect(provider).toBeInstanceOf(OpenAiCompletionProvider);
|
||||
});
|
||||
|
||||
test('loadApiProvider with custom module', () => {
|
||||
// Set up the custom module mock
|
||||
const customModulePath = path.resolve(__dirname, '__mocks__', 'tempCustomModule.js');
|
||||
jest.doMock(customModulePath);
|
||||
|
||||
const CustomApiProvider = require(customModulePath).default;
|
||||
const provider = loadApiProvider(customModulePath);
|
||||
expect(provider).toBeInstanceOf(CustomApiProvider);
|
||||
|
||||
// Clean up the mock
|
||||
jest.dontMock(customModulePath);
|
||||
});
|
||||
|
||||
test('loadApiProvider with invalid openai model', () => {
|
||||
expect(() => {
|
||||
loadApiProvider('openai:invalid');
|
||||
}).toThrowError();
|
||||
});
|
||||
|
||||
test('loadApiProvider with unknown openai model and type', () => {
|
||||
expect(() => {
|
||||
loadApiProvider('openai:unknown:unknown');
|
||||
}).toThrowError();
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user