mirror of
https://github.com/promptfoo/promptfoo.git
synced 2023-08-15 01:10:51 +03:00
Add support for native function ApiProviders and Assertions (#93)
This commit is contained in:
@@ -3,7 +3,15 @@ import promptfoo from '../../dist/src/index.js';
|
||||
(async () => {
|
||||
const results = await promptfoo.evaluate({
|
||||
prompts: ['Rephrase this in French: {{body}}', 'Rephrase this like a pirate: {{body}}'],
|
||||
providers: ['openai:gpt-3.5-turbo'],
|
||||
providers: [
|
||||
'openai:gpt-3.5-turbo',
|
||||
(prompt) => {
|
||||
// Call LLM here...
|
||||
return {
|
||||
output: '<LLM output>',
|
||||
};
|
||||
},
|
||||
],
|
||||
tests: [
|
||||
{
|
||||
vars: {
|
||||
@@ -14,6 +22,19 @@ import promptfoo from '../../dist/src/index.js';
|
||||
vars: {
|
||||
body: "I'm hungry",
|
||||
},
|
||||
assert: [
|
||||
{
|
||||
type: 'javascript',
|
||||
value: (output) => {
|
||||
const pass = output.includes("J'ai faim");
|
||||
return {
|
||||
pass,
|
||||
score: pass ? 1.0 : 0.0,
|
||||
reason: pass ? 'Output contained substring' : 'Output did not contain substring',
|
||||
};
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
@@ -240,6 +240,9 @@ export async function runAssertion(
|
||||
|
||||
if (baseType === 'javascript') {
|
||||
try {
|
||||
if (typeof assertion.value === 'function') {
|
||||
return assertion.value(output, test, assertion);
|
||||
}
|
||||
const customFunction = new Function('output', 'context', `return ${assertion.value}`);
|
||||
const result = customFunction(output, context) as any;
|
||||
if (typeof result === 'boolean') {
|
||||
@@ -378,7 +381,10 @@ ${assertion.value}`,
|
||||
}
|
||||
|
||||
if (baseType === 'rouge-n') {
|
||||
invariant(assertion.value, '"rouge" assertion type must a value (string or string array)');
|
||||
invariant(
|
||||
typeof assertion.value === 'string' || Array.isArray(assertion.value),
|
||||
'"rouge" assertion type must be a value (string or string array)',
|
||||
);
|
||||
return handleRougeScore(baseType, assertion, assertion.value, output, inverse);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import path from 'path';
|
||||
|
||||
import { ApiProvider, ProviderConfig, ProviderId, RawProviderConfig } from './types';
|
||||
|
||||
import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from './providers/openai';
|
||||
import { AnthropicCompletionProvider } from './providers/anthropic';
|
||||
import { ReplicateProvider } from './providers/replicate';
|
||||
@@ -12,17 +10,37 @@ import {
|
||||
AzureOpenAiCompletionProvider,
|
||||
} from './providers/azureopenai';
|
||||
|
||||
import type {
|
||||
ApiProvider,
|
||||
ProviderConfig,
|
||||
ProviderFunction,
|
||||
ProviderId,
|
||||
RawProviderConfig,
|
||||
} from './types';
|
||||
|
||||
export async function loadApiProviders(
|
||||
providerPaths: ProviderId | ProviderId[] | RawProviderConfig[],
|
||||
providerPaths: ProviderId | ProviderId[] | RawProviderConfig[] | ProviderFunction,
|
||||
basePath?: string,
|
||||
): Promise<ApiProvider[]> {
|
||||
if (typeof providerPaths === 'string') {
|
||||
return [await loadApiProvider(providerPaths, undefined, basePath)];
|
||||
} else if (typeof providerPaths === 'function') {
|
||||
return [
|
||||
{
|
||||
id: () => 'custom-function',
|
||||
callApi: providerPaths,
|
||||
},
|
||||
];
|
||||
} else if (Array.isArray(providerPaths)) {
|
||||
return Promise.all(
|
||||
providerPaths.map((provider) => {
|
||||
providerPaths.map((provider, idx) => {
|
||||
if (typeof provider === 'string') {
|
||||
return loadApiProvider(provider, undefined, basePath);
|
||||
} else if (typeof provider === 'function') {
|
||||
return {
|
||||
id: () => `custom-function-${idx}`,
|
||||
callApi: provider,
|
||||
};
|
||||
} else {
|
||||
const id = Object.keys(provider)[0];
|
||||
const context = { ...provider[id], id };
|
||||
|
||||
@@ -169,7 +169,10 @@ export interface Assertion {
|
||||
type: AssertionType;
|
||||
|
||||
// The expected value, if applicable
|
||||
value?: string | string[];
|
||||
value?:
|
||||
| string
|
||||
| string[]
|
||||
| ((output: string, testCase: AtomicTestCase, assertion: Assertion) => Promise<GradingResult>);
|
||||
|
||||
// The threshold value, only applicable for similarity (cosine distance)
|
||||
threshold?: number;
|
||||
@@ -228,6 +231,8 @@ export interface TestSuite {
|
||||
|
||||
export type ProviderId = string;
|
||||
|
||||
export type ProviderFunction = (prompt: string) => Promise<ProviderResponse>;
|
||||
|
||||
export type RawProviderConfig = Record<ProviderId, Omit<ProviderConfig, 'id'>>;
|
||||
|
||||
// TestSuiteConfig = Test Suite, but before everything is parsed and resolved. Providers are just strings, prompts are filepaths, tests can be filepath or inline.
|
||||
@@ -236,7 +241,7 @@ export interface TestSuiteConfig {
|
||||
description?: string;
|
||||
|
||||
// One or more LLM APIs to use, for example: openai:gpt-3.5-turbo, openai:gpt-4, localai:chat:vicuna
|
||||
providers: ProviderId | ProviderId[] | RawProviderConfig[];
|
||||
providers: ProviderId | ProviderId[] | RawProviderConfig[] | ProviderFunction;
|
||||
|
||||
// One or more prompt files to load
|
||||
prompts: string | string[];
|
||||
|
||||
10
src/util.ts
10
src/util.ts
@@ -4,6 +4,7 @@ import * as os from 'os';
|
||||
|
||||
import $RefParser from '@apidevtools/json-schema-ref-parser';
|
||||
import fetch from 'node-fetch';
|
||||
import invariant from 'tiny-invariant';
|
||||
import yaml from 'js-yaml';
|
||||
import nunjucks from 'nunjucks';
|
||||
import { globSync } from 'glob';
|
||||
@@ -44,6 +45,15 @@ export function readProviderPromptMap(
|
||||
allPrompts.push(prompt.display);
|
||||
}
|
||||
|
||||
invariant(
|
||||
typeof config.providers !== 'string',
|
||||
'In order to use a provider-prompt map, config.providers should be an array of objects, not a string',
|
||||
);
|
||||
invariant(
|
||||
typeof config.providers !== 'function',
|
||||
'In order to use a provider-prompt map, config.providers should be an array of objects, not a function',
|
||||
);
|
||||
|
||||
for (const provider of config.providers) {
|
||||
if (typeof provider === 'object') {
|
||||
const rawProvider = provider as RawProviderConfig;
|
||||
|
||||
@@ -58,16 +58,26 @@ describe('runAssertion', () => {
|
||||
type: 'contains-json',
|
||||
};
|
||||
|
||||
const functionAssertion: Assertion = {
|
||||
const javascriptStringAssertion: Assertion = {
|
||||
type: 'javascript',
|
||||
value: 'output === "Expected output"',
|
||||
};
|
||||
|
||||
const functionAssertionWithNumber: Assertion = {
|
||||
const javascriptStringAssertionWithNumber: Assertion = {
|
||||
type: 'javascript',
|
||||
value: 'output.length * 10',
|
||||
};
|
||||
|
||||
const javascriptFunctionAssertion: Assertion = {
|
||||
type: 'javascript',
|
||||
value: async (output: string) => ({ pass: true, score: 0.5, reason: 'Assertion passed' }),
|
||||
};
|
||||
|
||||
const javascriptFunctionFailAssertion: Assertion = {
|
||||
type: 'javascript',
|
||||
value: async (output: string) => ({ pass: false, score: 0.5, reason: 'Assertion failed' }),
|
||||
};
|
||||
|
||||
it('should pass when the equality assertion passes', async () => {
|
||||
const output = 'Expected output';
|
||||
|
||||
@@ -133,11 +143,11 @@ describe('runAssertion', () => {
|
||||
expect(result.reason).toContain('Expected output to contain valid JSON');
|
||||
});
|
||||
|
||||
it('should pass when the function assertion passes', async () => {
|
||||
it('should pass when the javascript assertion passes', async () => {
|
||||
const output = 'Expected output';
|
||||
|
||||
const result: GradingResult = await runAssertion(
|
||||
functionAssertion,
|
||||
javascriptStringAssertion,
|
||||
{} as AtomicTestCase,
|
||||
output,
|
||||
);
|
||||
@@ -145,11 +155,11 @@ describe('runAssertion', () => {
|
||||
expect(result.reason).toBe('Assertion passed');
|
||||
});
|
||||
|
||||
it('should pass a score through when the function returns a number', async () => {
|
||||
it('should pass a score through when the javascript returns a number', async () => {
|
||||
const output = 'Expected output';
|
||||
|
||||
const result: GradingResult = await runAssertion(
|
||||
functionAssertionWithNumber,
|
||||
javascriptStringAssertionWithNumber,
|
||||
{} as AtomicTestCase,
|
||||
output,
|
||||
);
|
||||
@@ -158,11 +168,11 @@ describe('runAssertion', () => {
|
||||
expect(result.reason).toBe('Assertion passed');
|
||||
});
|
||||
|
||||
it('should fail when the function assertion fails', async () => {
|
||||
it('should fail when the javascript assertion fails', async () => {
|
||||
const output = 'Different output';
|
||||
|
||||
const result: GradingResult = await runAssertion(
|
||||
functionAssertion,
|
||||
javascriptStringAssertion,
|
||||
{} as AtomicTestCase,
|
||||
output,
|
||||
);
|
||||
@@ -173,12 +183,12 @@ describe('runAssertion', () => {
|
||||
it('should pass when the function assertion passes - with vars', async () => {
|
||||
const output = 'Expected output';
|
||||
|
||||
const functionAssertionWithVars: Assertion = {
|
||||
const javascriptStringAssertionWithVars: Assertion = {
|
||||
type: 'javascript',
|
||||
value: 'output === "Expected output" && context.vars.foo === "bar"',
|
||||
};
|
||||
const result: GradingResult = await runAssertion(
|
||||
functionAssertionWithVars,
|
||||
javascriptStringAssertionWithVars,
|
||||
{ vars: { foo: 'bar' } } as AtomicTestCase,
|
||||
output,
|
||||
);
|
||||
@@ -186,15 +196,15 @@ describe('runAssertion', () => {
|
||||
expect(result.reason).toBe('Assertion passed');
|
||||
});
|
||||
|
||||
it('should fail when the function does not match vars', async () => {
|
||||
it('should fail when the javascript does not match vars', async () => {
|
||||
const output = 'Expected output';
|
||||
|
||||
const functionAssertionWithVars: Assertion = {
|
||||
const javascriptStringAssertionWithVars: Assertion = {
|
||||
type: 'javascript',
|
||||
value: 'output === "Expected output" && context.vars.foo === "something else"',
|
||||
};
|
||||
const result: GradingResult = await runAssertion(
|
||||
functionAssertionWithVars,
|
||||
javascriptStringAssertionWithVars,
|
||||
{ vars: { foo: 'bar' } } as AtomicTestCase,
|
||||
output,
|
||||
);
|
||||
@@ -204,6 +214,32 @@ describe('runAssertion', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass when the function returns pass', async () => {
|
||||
const output = 'Expected output';
|
||||
|
||||
const result: GradingResult = await runAssertion(
|
||||
javascriptFunctionAssertion,
|
||||
{} as AtomicTestCase,
|
||||
output,
|
||||
);
|
||||
expect(result.pass).toBeTruthy();
|
||||
expect(result.score).toBe(0.5);
|
||||
expect(result.reason).toBe('Assertion passed');
|
||||
});
|
||||
|
||||
it('should fail when the function returns fail', async () => {
|
||||
const output = 'Expected output';
|
||||
|
||||
const result: GradingResult = await runAssertion(
|
||||
javascriptFunctionFailAssertion,
|
||||
{} as AtomicTestCase,
|
||||
output,
|
||||
);
|
||||
expect(result.pass).toBeFalsy();
|
||||
expect(result.score).toBe(0.5);
|
||||
expect(result.reason).toBe('Assertion failed');
|
||||
});
|
||||
|
||||
const notContainsAssertion: Assertion = {
|
||||
type: 'not-contains',
|
||||
value: 'Unexpected output',
|
||||
|
||||
@@ -5,7 +5,7 @@ import { AnthropicCompletionProvider } from '../src/providers/anthropic';
|
||||
|
||||
import { disableCache, enableCache } from '../src/cache.js';
|
||||
import { loadApiProvider, loadApiProviders } from '../src/providers.js';
|
||||
import type { RawProviderConfig } from '../src/types';
|
||||
import type { RawProviderConfig, ProviderFunction } from '../src/types';
|
||||
import {
|
||||
AzureOpenAiChatCompletionProvider,
|
||||
AzureOpenAiCompletionProvider,
|
||||
@@ -201,6 +201,21 @@ describe('providers', () => {
|
||||
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
|
||||
});
|
||||
|
||||
test('loadApiProviders with ProviderFunction', async () => {
|
||||
const providerFunction: ProviderFunction = async (prompt: string) => {
|
||||
return {
|
||||
output: `Output for ${prompt}`,
|
||||
tokenUsage: { total: 10, prompt: 5, completion: 5 },
|
||||
};
|
||||
};
|
||||
const providers = await loadApiProviders(providerFunction);
|
||||
expect(providers).toHaveLength(1);
|
||||
expect(providers[0].id()).toBe('custom-function');
|
||||
const response = await providers[0].callApi('Test prompt');
|
||||
expect(response.output).toBe('Output for Test prompt');
|
||||
expect(response.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
|
||||
});
|
||||
|
||||
test('loadApiProviders with RawProviderConfig[]', async () => {
|
||||
const rawProviderConfigs: RawProviderConfig[] = [
|
||||
{
|
||||
|
||||
@@ -496,6 +496,6 @@ describe('readTests', () => {
|
||||
const result = await readTests(testsPaths);
|
||||
|
||||
expect(fs.readFileSync).toHaveBeenCalledTimes(2);
|
||||
expect(result).toEqual([Object.assign({}, test1Content[0], {vars: vars1Content})]);
|
||||
expect(result).toEqual([Object.assign({}, test1Content[0], { vars: vars1Content })]);
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user