Add support for native function ApiProviders and Assertions (#93)

This commit is contained in:
Ian Webster
2023-07-27 07:34:34 -07:00
committed by GitHub
parent bef4b436c6
commit 5aff43afcf
8 changed files with 134 additions and 23 deletions

View File

@@ -3,7 +3,15 @@ import promptfoo from '../../dist/src/index.js';
(async () => {
const results = await promptfoo.evaluate({
prompts: ['Rephrase this in French: {{body}}', 'Rephrase this like a pirate: {{body}}'],
providers: ['openai:gpt-3.5-turbo'],
providers: [
'openai:gpt-3.5-turbo',
(prompt) => {
// Call LLM here...
return {
output: '<LLM output>',
};
},
],
tests: [
{
vars: {
@@ -14,6 +22,19 @@ import promptfoo from '../../dist/src/index.js';
vars: {
body: "I'm hungry",
},
assert: [
{
type: 'javascript',
value: (output) => {
const pass = output.includes("J'ai faim");
return {
pass,
score: pass ? 1.0 : 0.0,
reason: pass ? 'Output contained substring' : 'Output did not contain substring',
};
},
},
],
},
],
});

View File

@@ -240,6 +240,9 @@ export async function runAssertion(
if (baseType === 'javascript') {
try {
if (typeof assertion.value === 'function') {
return assertion.value(output, test, assertion);
}
const customFunction = new Function('output', 'context', `return ${assertion.value}`);
const result = customFunction(output, context) as any;
if (typeof result === 'boolean') {
@@ -378,7 +381,10 @@ ${assertion.value}`,
}
if (baseType === 'rouge-n') {
invariant(assertion.value, '"rouge" assertion type must a value (string or string array)');
invariant(
typeof assertion.value === 'string' || Array.isArray(assertion.value),
'"rouge" assertion type must be a value (string or string array)',
);
return handleRougeScore(baseType, assertion, assertion.value, output, inverse);
}

View File

@@ -1,7 +1,5 @@
import path from 'path';
import { ApiProvider, ProviderConfig, ProviderId, RawProviderConfig } from './types';
import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from './providers/openai';
import { AnthropicCompletionProvider } from './providers/anthropic';
import { ReplicateProvider } from './providers/replicate';
@@ -12,17 +10,37 @@ import {
AzureOpenAiCompletionProvider,
} from './providers/azureopenai';
import type {
ApiProvider,
ProviderConfig,
ProviderFunction,
ProviderId,
RawProviderConfig,
} from './types';
export async function loadApiProviders(
providerPaths: ProviderId | ProviderId[] | RawProviderConfig[],
providerPaths: ProviderId | ProviderId[] | RawProviderConfig[] | ProviderFunction,
basePath?: string,
): Promise<ApiProvider[]> {
if (typeof providerPaths === 'string') {
return [await loadApiProvider(providerPaths, undefined, basePath)];
} else if (typeof providerPaths === 'function') {
return [
{
id: () => 'custom-function',
callApi: providerPaths,
},
];
} else if (Array.isArray(providerPaths)) {
return Promise.all(
providerPaths.map((provider) => {
providerPaths.map((provider, idx) => {
if (typeof provider === 'string') {
return loadApiProvider(provider, undefined, basePath);
} else if (typeof provider === 'function') {
return {
id: () => `custom-function-${idx}`,
callApi: provider,
};
} else {
const id = Object.keys(provider)[0];
const context = { ...provider[id], id };

View File

@@ -169,7 +169,10 @@ export interface Assertion {
type: AssertionType;
// The expected value, if applicable
value?: string | string[];
value?:
| string
| string[]
| ((output: string, testCase: AtomicTestCase, assertion: Assertion) => Promise<GradingResult>);
// The threshold value, only applicable for similarity (cosine distance)
threshold?: number;
@@ -228,6 +231,8 @@ export interface TestSuite {
export type ProviderId = string;
export type ProviderFunction = (prompt: string) => Promise<ProviderResponse>;
export type RawProviderConfig = Record<ProviderId, Omit<ProviderConfig, 'id'>>;
// TestSuiteConfig = Test Suite, but before everything is parsed and resolved. Providers are just strings, prompts are filepaths, tests can be filepath or inline.
@@ -236,7 +241,7 @@ export interface TestSuiteConfig {
description?: string;
// One or more LLM APIs to use, for example: openai:gpt-3.5-turbo, openai:gpt-4, localai:chat:vicuna
providers: ProviderId | ProviderId[] | RawProviderConfig[];
providers: ProviderId | ProviderId[] | RawProviderConfig[] | ProviderFunction;
// One or more prompt files to load
prompts: string | string[];

View File

@@ -4,6 +4,7 @@ import * as os from 'os';
import $RefParser from '@apidevtools/json-schema-ref-parser';
import fetch from 'node-fetch';
import invariant from 'tiny-invariant';
import yaml from 'js-yaml';
import nunjucks from 'nunjucks';
import { globSync } from 'glob';
@@ -44,6 +45,15 @@ export function readProviderPromptMap(
allPrompts.push(prompt.display);
}
invariant(
typeof config.providers !== 'string',
'In order to use a provider-prompt map, config.providers should be an array of objects, not a string',
);
invariant(
typeof config.providers !== 'function',
'In order to use a provider-prompt map, config.providers should be an array of objects, not a function',
);
for (const provider of config.providers) {
if (typeof provider === 'object') {
const rawProvider = provider as RawProviderConfig;

View File

@@ -58,16 +58,26 @@ describe('runAssertion', () => {
type: 'contains-json',
};
const functionAssertion: Assertion = {
const javascriptStringAssertion: Assertion = {
type: 'javascript',
value: 'output === "Expected output"',
};
const functionAssertionWithNumber: Assertion = {
const javascriptStringAssertionWithNumber: Assertion = {
type: 'javascript',
value: 'output.length * 10',
};
const javascriptFunctionAssertion: Assertion = {
type: 'javascript',
value: async (output: string) => ({ pass: true, score: 0.5, reason: 'Assertion passed' }),
};
const javascriptFunctionFailAssertion: Assertion = {
type: 'javascript',
value: async (output: string) => ({ pass: false, score: 0.5, reason: 'Assertion failed' }),
};
it('should pass when the equality assertion passes', async () => {
const output = 'Expected output';
@@ -133,11 +143,11 @@ describe('runAssertion', () => {
expect(result.reason).toContain('Expected output to contain valid JSON');
});
it('should pass when the function assertion passes', async () => {
it('should pass when the javascript assertion passes', async () => {
const output = 'Expected output';
const result: GradingResult = await runAssertion(
functionAssertion,
javascriptStringAssertion,
{} as AtomicTestCase,
output,
);
@@ -145,11 +155,11 @@ describe('runAssertion', () => {
expect(result.reason).toBe('Assertion passed');
});
it('should pass a score through when the function returns a number', async () => {
it('should pass a score through when the javascript returns a number', async () => {
const output = 'Expected output';
const result: GradingResult = await runAssertion(
functionAssertionWithNumber,
javascriptStringAssertionWithNumber,
{} as AtomicTestCase,
output,
);
@@ -158,11 +168,11 @@ describe('runAssertion', () => {
expect(result.reason).toBe('Assertion passed');
});
it('should fail when the function assertion fails', async () => {
it('should fail when the javascript assertion fails', async () => {
const output = 'Different output';
const result: GradingResult = await runAssertion(
functionAssertion,
javascriptStringAssertion,
{} as AtomicTestCase,
output,
);
@@ -173,12 +183,12 @@ describe('runAssertion', () => {
it('should pass when the function assertion passes - with vars', async () => {
const output = 'Expected output';
const functionAssertionWithVars: Assertion = {
const javascriptStringAssertionWithVars: Assertion = {
type: 'javascript',
value: 'output === "Expected output" && context.vars.foo === "bar"',
};
const result: GradingResult = await runAssertion(
functionAssertionWithVars,
javascriptStringAssertionWithVars,
{ vars: { foo: 'bar' } } as AtomicTestCase,
output,
);
@@ -186,15 +196,15 @@ describe('runAssertion', () => {
expect(result.reason).toBe('Assertion passed');
});
it('should fail when the function does not match vars', async () => {
it('should fail when the javascript does not match vars', async () => {
const output = 'Expected output';
const functionAssertionWithVars: Assertion = {
const javascriptStringAssertionWithVars: Assertion = {
type: 'javascript',
value: 'output === "Expected output" && context.vars.foo === "something else"',
};
const result: GradingResult = await runAssertion(
functionAssertionWithVars,
javascriptStringAssertionWithVars,
{ vars: { foo: 'bar' } } as AtomicTestCase,
output,
);
@@ -204,6 +214,32 @@ describe('runAssertion', () => {
);
});
it('should pass when the function returns pass', async () => {
const output = 'Expected output';
const result: GradingResult = await runAssertion(
javascriptFunctionAssertion,
{} as AtomicTestCase,
output,
);
expect(result.pass).toBeTruthy();
expect(result.score).toBe(0.5);
expect(result.reason).toBe('Assertion passed');
});
it('should fail when the function returns fail', async () => {
const output = 'Expected output';
const result: GradingResult = await runAssertion(
javascriptFunctionFailAssertion,
{} as AtomicTestCase,
output,
);
expect(result.pass).toBeFalsy();
expect(result.score).toBe(0.5);
expect(result.reason).toBe('Assertion failed');
});
const notContainsAssertion: Assertion = {
type: 'not-contains',
value: 'Unexpected output',

View File

@@ -5,7 +5,7 @@ import { AnthropicCompletionProvider } from '../src/providers/anthropic';
import { disableCache, enableCache } from '../src/cache.js';
import { loadApiProvider, loadApiProviders } from '../src/providers.js';
import type { RawProviderConfig } from '../src/types';
import type { RawProviderConfig, ProviderFunction } from '../src/types';
import {
AzureOpenAiChatCompletionProvider,
AzureOpenAiCompletionProvider,
@@ -201,6 +201,21 @@ describe('providers', () => {
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
});
test('loadApiProviders with ProviderFunction', async () => {
const providerFunction: ProviderFunction = async (prompt: string) => {
return {
output: `Output for ${prompt}`,
tokenUsage: { total: 10, prompt: 5, completion: 5 },
};
};
const providers = await loadApiProviders(providerFunction);
expect(providers).toHaveLength(1);
expect(providers[0].id()).toBe('custom-function');
const response = await providers[0].callApi('Test prompt');
expect(response.output).toBe('Output for Test prompt');
expect(response.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
});
test('loadApiProviders with RawProviderConfig[]', async () => {
const rawProviderConfigs: RawProviderConfig[] = [
{

View File

@@ -496,6 +496,6 @@ describe('readTests', () => {
const result = await readTests(testsPaths);
expect(fs.readFileSync).toHaveBeenCalledTimes(2);
expect(result).toEqual([Object.assign({}, test1Content[0], {vars: vars1Content})]);
expect(result).toEqual([Object.assign({}, test1Content[0], { vars: vars1Content })]);
});
});