Files
llm-promptfoo/test/providers.test.ts
2023-08-08 16:30:55 -07:00

289 lines
11 KiB
TypeScript

import fetch from 'node-fetch';
import { fetchWithCache } from '../src/cache';
import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from '../src/providers/openai';
import { AnthropicCompletionProvider } from '../src/providers/anthropic';
import { LlamaProvider } from '../src/providers/llama';
import { disableCache, enableCache } from '../src/cache.js';
import { loadApiProvider, loadApiProviders } from '../src/providers.js';
import type { RawProviderConfig, ProviderFunction } from '../src/types';
import {
AzureOpenAiChatCompletionProvider,
AzureOpenAiCompletionProvider,
} from '../src/providers/azureopenai';
import { OllamaProvider } from '../src/providers/ollama';
jest.mock('node-fetch', () => jest.fn());
jest.mock('../src/esm.js');
describe('providers', () => {
afterEach(() => {
jest.clearAllMocks();
});
test('OpenAiCompletionProvider callApi', async () => {
const mockResponse = {
json: jest.fn().mockResolvedValue({
choices: [{ text: 'Test output' }],
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
}),
};
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
const provider = new OpenAiCompletionProvider('text-davinci-003', 'test-api-key');
const result = await provider.callApi('Test prompt');
expect(fetch).toHaveBeenCalledTimes(1);
expect(result.output).toBe('Test output');
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
});
test('OpenAiChatCompletionProvider callApi', async () => {
const mockResponse = {
json: jest.fn().mockResolvedValue({
choices: [{ message: { content: 'Test output' } }],
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
}),
};
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
const provider = new OpenAiChatCompletionProvider('gpt-3.5-turbo', 'test-api-key');
const result = await provider.callApi(
JSON.stringify([{ role: 'user', content: 'Test prompt' }]),
);
expect(fetch).toHaveBeenCalledTimes(1);
expect(result.output).toBe('Test output');
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
});
test('OpenAiChatCompletionProvider callApi with cache disabled', async () => {
disableCache();
const mockResponse = {
json: jest.fn().mockResolvedValue({
choices: [{ message: { content: 'Test output' } }],
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
}),
};
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
const provider = new OpenAiChatCompletionProvider('gpt-3.5-turbo', 'test-api-key');
const result = await provider.callApi(
JSON.stringify([{ role: 'user', content: 'Test prompt' }]),
);
expect(fetch).toHaveBeenCalledTimes(1);
expect(result.output).toBe('Test output');
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
enableCache();
});
test('AzureOpenAiCompletionProvider callApi', async () => {
const mockResponse = {
json: jest.fn().mockResolvedValue({
choices: [{ text: 'Test output' }],
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
}),
};
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
const provider = new AzureOpenAiCompletionProvider('text-davinci-003', 'test-api-key');
const result = await provider.callApi('Test prompt');
expect(fetch).toHaveBeenCalledTimes(1);
expect(result.output).toBe('Test output');
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
});
test('AzureOpenAiChatCompletionProvider callApi', async () => {
const mockResponse = {
json: jest.fn().mockResolvedValue({
choices: [{ message: { content: 'Test output' } }],
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
}),
};
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
const provider = new AzureOpenAiChatCompletionProvider('gpt-3.5-turbo', 'test-api-key');
const result = await provider.callApi(
JSON.stringify([{ role: 'user', content: 'Test prompt' }]),
);
expect(fetch).toHaveBeenCalledTimes(1);
expect(result.output).toBe('Test output');
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
});
test('AzureOpenAiChatCompletionProvider callApi with cache disabled', async () => {
disableCache();
const mockResponse = {
json: jest.fn().mockResolvedValue({
choices: [{ message: { content: 'Test output' } }],
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
}),
};
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
const provider = new AzureOpenAiChatCompletionProvider('gpt-3.5-turbo', 'test-api-key');
const result = await provider.callApi(
JSON.stringify([{ role: 'user', content: 'Test prompt' }]),
);
expect(fetch).toHaveBeenCalledTimes(1);
expect(result.output).toBe('Test output');
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
enableCache();
});
test('AnthropicCompletionProvider callApi', async () => {
const provider = new AnthropicCompletionProvider('claude-1', 'test-api-key');
provider.anthropic.completions.create = jest.fn().mockResolvedValue({
completion: 'Test output',
});
const result = await provider.callApi('Test prompt');
expect(provider.anthropic.completions.create).toHaveBeenCalledTimes(1);
expect(result.output).toBe('Test output');
expect(result.tokenUsage).toEqual({});
});
test('LlamaProvider callApi', async () => {
const mockResponse = {
json: jest.fn().mockResolvedValue({
content: 'Test output',
}),
};
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
const provider = new LlamaProvider('llama.cpp');
const result = await provider.callApi('Test prompt');
expect(fetch).toHaveBeenCalledTimes(1);
expect(result.output).toBe('Test output');
});
test('OllamaProvider callApi', async () => {
const mockResponse = {
text: jest.fn()
.mockResolvedValue(`{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.898068Z","response":"Gre","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.929199Z","response":"at","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.959989Z","response":" question","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.992117Z","response":"!","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.023658Z","response":" The","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.0551Z","response":" sky","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.086103Z","response":" appears","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.117166Z","response":" blue","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:41.695299Z","done":true,"context":[1,29871,1,13,9314],"total_duration":10411943458,"load_duration":458333,"sample_count":217,"sample_duration":154566000,"prompt_eval_count":11,"prompt_eval_duration":3334582000,"eval_count":216,"eval_duration":6905134000}`),
};
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
const provider = new OllamaProvider('llama');
const result = await provider.callApi('Test prompt');
expect(fetch).toHaveBeenCalledTimes(1);
expect(result.output).toBe('Great question! The sky appears blue');
});
test('loadApiProvider with openai:chat', async () => {
const provider = await loadApiProvider('openai:chat');
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
});
test('loadApiProvider with openai:completion', async () => {
const provider = await loadApiProvider('openai:completion');
expect(provider).toBeInstanceOf(OpenAiCompletionProvider);
});
test('loadApiProvider with openai:chat:modelName', async () => {
const provider = await loadApiProvider('openai:chat:gpt-3.5-turbo');
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
});
test('loadApiProvider with openai:completion:modelName', async () => {
const provider = await loadApiProvider('openai:completion:text-davinci-003');
expect(provider).toBeInstanceOf(OpenAiCompletionProvider);
});
test('loadApiProvider with azureopenai:completion:modelName', async () => {
const provider = await loadApiProvider('azureopenai:completion:text-davinci-003');
expect(provider).toBeInstanceOf(AzureOpenAiCompletionProvider);
});
test('loadApiProvider with azureopenai:chat:modelName', async () => {
const provider = await loadApiProvider('azureopenai:chat:gpt-3.5-turbo');
expect(provider).toBeInstanceOf(AzureOpenAiChatCompletionProvider);
});
test('loadApiProvider with anthropic:completion', async () => {
const provider = await loadApiProvider('anthropic:completion');
expect(provider).toBeInstanceOf(AnthropicCompletionProvider);
});
test('loadApiProvider with anthropic:completion:modelName', async () => {
const provider = await loadApiProvider('anthropic:completion:claude-1');
expect(provider).toBeInstanceOf(AnthropicCompletionProvider);
});
test('loadApiProvider with llama:modelName', async () => {
const provider = await loadApiProvider('llama');
expect(provider).toBeInstanceOf(LlamaProvider);
});
test('loadApiProvider with RawProviderConfig', async () => {
const rawProviderConfig = {
'openai:chat': {
id: 'test',
config: { foo: 'bar' },
},
};
const provider = await loadApiProvider('openai:chat', rawProviderConfig['openai:chat']);
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
});
test('loadApiProviders with ProviderFunction', async () => {
const providerFunction: ProviderFunction = async (prompt: string) => {
return {
output: `Output for ${prompt}`,
tokenUsage: { total: 10, prompt: 5, completion: 5 },
};
};
const providers = await loadApiProviders(providerFunction);
expect(providers).toHaveLength(1);
expect(providers[0].id()).toBe('custom-function');
const response = await providers[0].callApi('Test prompt');
expect(response.output).toBe('Output for Test prompt');
expect(response.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
});
test('loadApiProviders with RawProviderConfig[]', async () => {
const rawProviderConfigs: RawProviderConfig[] = [
{
'openai:chat:abc123': {
config: { foo: 'bar' },
},
},
{
'openai:completion:def456': {
config: { foo: 'bar' },
},
},
{
'anthropic:completion:ghi789': {
config: { foo: 'bar' },
},
},
];
const providers = await loadApiProviders(rawProviderConfigs);
expect(providers).toHaveLength(3);
expect(providers[0]).toBeInstanceOf(OpenAiChatCompletionProvider);
expect(providers[1]).toBeInstanceOf(OpenAiCompletionProvider);
expect(providers[2]).toBeInstanceOf(AnthropicCompletionProvider);
});
});