mirror of
https://github.com/promptfoo/promptfoo.git
synced 2023-08-15 01:10:51 +03:00
289 lines
11 KiB
TypeScript
289 lines
11 KiB
TypeScript
import fetch from 'node-fetch';
|
|
import { fetchWithCache } from '../src/cache';
|
|
|
|
import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from '../src/providers/openai';
|
|
import { AnthropicCompletionProvider } from '../src/providers/anthropic';
|
|
import { LlamaProvider } from '../src/providers/llama';
|
|
|
|
import { disableCache, enableCache } from '../src/cache.js';
|
|
import { loadApiProvider, loadApiProviders } from '../src/providers.js';
|
|
import type { RawProviderConfig, ProviderFunction } from '../src/types';
|
|
import {
|
|
AzureOpenAiChatCompletionProvider,
|
|
AzureOpenAiCompletionProvider,
|
|
} from '../src/providers/azureopenai';
|
|
import { OllamaProvider } from '../src/providers/ollama';
|
|
|
|
jest.mock('node-fetch', () => jest.fn());
|
|
|
|
jest.mock('../src/esm.js');
|
|
|
|
describe('providers', () => {
|
|
afterEach(() => {
|
|
jest.clearAllMocks();
|
|
});
|
|
|
|
test('OpenAiCompletionProvider callApi', async () => {
|
|
const mockResponse = {
|
|
json: jest.fn().mockResolvedValue({
|
|
choices: [{ text: 'Test output' }],
|
|
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
|
|
}),
|
|
};
|
|
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
|
|
|
|
const provider = new OpenAiCompletionProvider('text-davinci-003', 'test-api-key');
|
|
const result = await provider.callApi('Test prompt');
|
|
|
|
expect(fetch).toHaveBeenCalledTimes(1);
|
|
expect(result.output).toBe('Test output');
|
|
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
|
|
});
|
|
|
|
test('OpenAiChatCompletionProvider callApi', async () => {
|
|
const mockResponse = {
|
|
json: jest.fn().mockResolvedValue({
|
|
choices: [{ message: { content: 'Test output' } }],
|
|
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
|
|
}),
|
|
};
|
|
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
|
|
|
|
const provider = new OpenAiChatCompletionProvider('gpt-3.5-turbo', 'test-api-key');
|
|
const result = await provider.callApi(
|
|
JSON.stringify([{ role: 'user', content: 'Test prompt' }]),
|
|
);
|
|
|
|
expect(fetch).toHaveBeenCalledTimes(1);
|
|
expect(result.output).toBe('Test output');
|
|
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
|
|
});
|
|
|
|
test('OpenAiChatCompletionProvider callApi with cache disabled', async () => {
|
|
disableCache();
|
|
|
|
const mockResponse = {
|
|
json: jest.fn().mockResolvedValue({
|
|
choices: [{ message: { content: 'Test output' } }],
|
|
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
|
|
}),
|
|
};
|
|
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
|
|
|
|
const provider = new OpenAiChatCompletionProvider('gpt-3.5-turbo', 'test-api-key');
|
|
const result = await provider.callApi(
|
|
JSON.stringify([{ role: 'user', content: 'Test prompt' }]),
|
|
);
|
|
|
|
expect(fetch).toHaveBeenCalledTimes(1);
|
|
expect(result.output).toBe('Test output');
|
|
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
|
|
|
|
enableCache();
|
|
});
|
|
|
|
test('AzureOpenAiCompletionProvider callApi', async () => {
|
|
const mockResponse = {
|
|
json: jest.fn().mockResolvedValue({
|
|
choices: [{ text: 'Test output' }],
|
|
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
|
|
}),
|
|
};
|
|
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
|
|
|
|
const provider = new AzureOpenAiCompletionProvider('text-davinci-003', 'test-api-key');
|
|
const result = await provider.callApi('Test prompt');
|
|
|
|
expect(fetch).toHaveBeenCalledTimes(1);
|
|
expect(result.output).toBe('Test output');
|
|
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
|
|
});
|
|
|
|
test('AzureOpenAiChatCompletionProvider callApi', async () => {
|
|
const mockResponse = {
|
|
json: jest.fn().mockResolvedValue({
|
|
choices: [{ message: { content: 'Test output' } }],
|
|
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
|
|
}),
|
|
};
|
|
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
|
|
|
|
const provider = new AzureOpenAiChatCompletionProvider('gpt-3.5-turbo', 'test-api-key');
|
|
const result = await provider.callApi(
|
|
JSON.stringify([{ role: 'user', content: 'Test prompt' }]),
|
|
);
|
|
|
|
expect(fetch).toHaveBeenCalledTimes(1);
|
|
expect(result.output).toBe('Test output');
|
|
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
|
|
});
|
|
|
|
test('AzureOpenAiChatCompletionProvider callApi with cache disabled', async () => {
|
|
disableCache();
|
|
|
|
const mockResponse = {
|
|
json: jest.fn().mockResolvedValue({
|
|
choices: [{ message: { content: 'Test output' } }],
|
|
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
|
|
}),
|
|
};
|
|
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
|
|
|
|
const provider = new AzureOpenAiChatCompletionProvider('gpt-3.5-turbo', 'test-api-key');
|
|
const result = await provider.callApi(
|
|
JSON.stringify([{ role: 'user', content: 'Test prompt' }]),
|
|
);
|
|
|
|
expect(fetch).toHaveBeenCalledTimes(1);
|
|
expect(result.output).toBe('Test output');
|
|
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
|
|
|
|
enableCache();
|
|
});
|
|
|
|
test('AnthropicCompletionProvider callApi', async () => {
|
|
const provider = new AnthropicCompletionProvider('claude-1', 'test-api-key');
|
|
provider.anthropic.completions.create = jest.fn().mockResolvedValue({
|
|
completion: 'Test output',
|
|
});
|
|
const result = await provider.callApi('Test prompt');
|
|
|
|
expect(provider.anthropic.completions.create).toHaveBeenCalledTimes(1);
|
|
expect(result.output).toBe('Test output');
|
|
expect(result.tokenUsage).toEqual({});
|
|
});
|
|
|
|
test('LlamaProvider callApi', async () => {
|
|
const mockResponse = {
|
|
json: jest.fn().mockResolvedValue({
|
|
content: 'Test output',
|
|
}),
|
|
};
|
|
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
|
|
|
|
const provider = new LlamaProvider('llama.cpp');
|
|
const result = await provider.callApi('Test prompt');
|
|
|
|
expect(fetch).toHaveBeenCalledTimes(1);
|
|
expect(result.output).toBe('Test output');
|
|
});
|
|
|
|
test('OllamaProvider callApi', async () => {
|
|
const mockResponse = {
|
|
text: jest.fn()
|
|
.mockResolvedValue(`{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.898068Z","response":"Gre","done":false}
|
|
{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.929199Z","response":"at","done":false}
|
|
{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.959989Z","response":" question","done":false}
|
|
{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.992117Z","response":"!","done":false}
|
|
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.023658Z","response":" The","done":false}
|
|
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.0551Z","response":" sky","done":false}
|
|
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.086103Z","response":" appears","done":false}
|
|
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.117166Z","response":" blue","done":false}
|
|
{"model":"llama2:13b","created_at":"2023-08-08T21:50:41.695299Z","done":true,"context":[1,29871,1,13,9314],"total_duration":10411943458,"load_duration":458333,"sample_count":217,"sample_duration":154566000,"prompt_eval_count":11,"prompt_eval_duration":3334582000,"eval_count":216,"eval_duration":6905134000}`),
|
|
};
|
|
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
|
|
|
|
const provider = new OllamaProvider('llama');
|
|
const result = await provider.callApi('Test prompt');
|
|
|
|
expect(fetch).toHaveBeenCalledTimes(1);
|
|
expect(result.output).toBe('Great question! The sky appears blue');
|
|
});
|
|
|
|
test('loadApiProvider with openai:chat', async () => {
|
|
const provider = await loadApiProvider('openai:chat');
|
|
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
|
|
});
|
|
|
|
test('loadApiProvider with openai:completion', async () => {
|
|
const provider = await loadApiProvider('openai:completion');
|
|
expect(provider).toBeInstanceOf(OpenAiCompletionProvider);
|
|
});
|
|
|
|
test('loadApiProvider with openai:chat:modelName', async () => {
|
|
const provider = await loadApiProvider('openai:chat:gpt-3.5-turbo');
|
|
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
|
|
});
|
|
|
|
test('loadApiProvider with openai:completion:modelName', async () => {
|
|
const provider = await loadApiProvider('openai:completion:text-davinci-003');
|
|
expect(provider).toBeInstanceOf(OpenAiCompletionProvider);
|
|
});
|
|
|
|
test('loadApiProvider with azureopenai:completion:modelName', async () => {
|
|
const provider = await loadApiProvider('azureopenai:completion:text-davinci-003');
|
|
expect(provider).toBeInstanceOf(AzureOpenAiCompletionProvider);
|
|
});
|
|
|
|
test('loadApiProvider with azureopenai:chat:modelName', async () => {
|
|
const provider = await loadApiProvider('azureopenai:chat:gpt-3.5-turbo');
|
|
expect(provider).toBeInstanceOf(AzureOpenAiChatCompletionProvider);
|
|
});
|
|
|
|
test('loadApiProvider with anthropic:completion', async () => {
|
|
const provider = await loadApiProvider('anthropic:completion');
|
|
expect(provider).toBeInstanceOf(AnthropicCompletionProvider);
|
|
});
|
|
|
|
test('loadApiProvider with anthropic:completion:modelName', async () => {
|
|
const provider = await loadApiProvider('anthropic:completion:claude-1');
|
|
expect(provider).toBeInstanceOf(AnthropicCompletionProvider);
|
|
});
|
|
|
|
test('loadApiProvider with llama:modelName', async () => {
|
|
const provider = await loadApiProvider('llama');
|
|
expect(provider).toBeInstanceOf(LlamaProvider);
|
|
});
|
|
|
|
test('loadApiProvider with RawProviderConfig', async () => {
|
|
const rawProviderConfig = {
|
|
'openai:chat': {
|
|
id: 'test',
|
|
config: { foo: 'bar' },
|
|
},
|
|
};
|
|
const provider = await loadApiProvider('openai:chat', rawProviderConfig['openai:chat']);
|
|
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
|
|
});
|
|
|
|
test('loadApiProviders with ProviderFunction', async () => {
|
|
const providerFunction: ProviderFunction = async (prompt: string) => {
|
|
return {
|
|
output: `Output for ${prompt}`,
|
|
tokenUsage: { total: 10, prompt: 5, completion: 5 },
|
|
};
|
|
};
|
|
const providers = await loadApiProviders(providerFunction);
|
|
expect(providers).toHaveLength(1);
|
|
expect(providers[0].id()).toBe('custom-function');
|
|
const response = await providers[0].callApi('Test prompt');
|
|
expect(response.output).toBe('Output for Test prompt');
|
|
expect(response.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
|
|
});
|
|
|
|
test('loadApiProviders with RawProviderConfig[]', async () => {
|
|
const rawProviderConfigs: RawProviderConfig[] = [
|
|
{
|
|
'openai:chat:abc123': {
|
|
config: { foo: 'bar' },
|
|
},
|
|
},
|
|
{
|
|
'openai:completion:def456': {
|
|
config: { foo: 'bar' },
|
|
},
|
|
},
|
|
{
|
|
'anthropic:completion:ghi789': {
|
|
config: { foo: 'bar' },
|
|
},
|
|
},
|
|
];
|
|
const providers = await loadApiProviders(rawProviderConfigs);
|
|
expect(providers).toHaveLength(3);
|
|
expect(providers[0]).toBeInstanceOf(OpenAiChatCompletionProvider);
|
|
expect(providers[1]).toBeInstanceOf(OpenAiCompletionProvider);
|
|
expect(providers[2]).toBeInstanceOf(AnthropicCompletionProvider);
|
|
});
|
|
});
|