mirror of
https://github.com/promptfoo/promptfoo.git
synced 2023-08-15 01:10:51 +03:00
Add ollama provider (#102)
This commit is contained in:
@@ -42,10 +42,11 @@ export function getCache() {
|
||||
return cacheInstance;
|
||||
}
|
||||
|
||||
export async function fetchJsonWithCache(
|
||||
export async function fetchWithCache(
|
||||
url: RequestInfo,
|
||||
options: RequestInit = {},
|
||||
timeout: number,
|
||||
format: 'json' | 'text' = 'json',
|
||||
): Promise<{ data: any; cached: boolean }> {
|
||||
if (!enabled) {
|
||||
const resp = await fetchWithRetries(url, options, timeout);
|
||||
@@ -75,7 +76,7 @@ export async function fetchJsonWithCache(
|
||||
// Fetch the actual data and store it in the cache
|
||||
const response = await fetchWithRetries(url, options, timeout);
|
||||
try {
|
||||
const data = await response.json();
|
||||
const data = format === 'json' ? await response.json() : await response.text();
|
||||
if (response.ok) {
|
||||
logger.debug(`Storing ${url} response in cache: ${JSON.stringify(data)}`);
|
||||
await cache.set(cacheKey, JSON.stringify(data));
|
||||
|
||||
@@ -5,6 +5,7 @@ import { AnthropicCompletionProvider } from './providers/anthropic';
|
||||
import { ReplicateProvider } from './providers/replicate';
|
||||
import { LocalAiCompletionProvider, LocalAiChatProvider } from './providers/localai';
|
||||
import { LlamaProvider } from './providers/llama';
|
||||
import { OllamaProvider } from './providers/ollama';
|
||||
import { ScriptCompletionProvider } from './providers/scriptCompletion';
|
||||
import {
|
||||
AzureOpenAiChatCompletionProvider,
|
||||
@@ -148,6 +149,9 @@ export async function loadApiProvider(
|
||||
if (providerPath === 'llama' || providerPath.startsWith('llama:')) {
|
||||
const modelName = providerPath.split(':')[1];
|
||||
return new LlamaProvider(modelName, context?.config);
|
||||
} else if (providerPath.startsWith('ollama:')) {
|
||||
const modelName = providerPath.split(':')[1];
|
||||
return new OllamaProvider(modelName);
|
||||
} else if (providerPath?.startsWith('localai:')) {
|
||||
const options = providerPath.split(':');
|
||||
const modelType = options[1];
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logger from '../logger';
|
||||
import { fetchJsonWithCache } from '../cache';
|
||||
import { fetchWithCache } from '../cache';
|
||||
import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
|
||||
|
||||
import type { ApiProvider, ProviderEmbeddingResponse, ProviderResponse } from '../types.js';
|
||||
@@ -61,7 +61,7 @@ export class AzureOpenAiEmbeddingProvider extends AzureOpenAiGenericProvider {
|
||||
let data,
|
||||
cached = false;
|
||||
try {
|
||||
({ data, cached } = (await fetchJsonWithCache(
|
||||
({ data, cached } = (await fetchWithCache(
|
||||
`https://${this.apiHost}/openai/deployments/${this.deploymentName}/embeddings?api-version=2023-07-01-preview`,
|
||||
{
|
||||
method: 'POST',
|
||||
@@ -171,7 +171,7 @@ export class AzureOpenAiCompletionProvider extends AzureOpenAiGenericProvider {
|
||||
let data,
|
||||
cached = false;
|
||||
try {
|
||||
({ data, cached } = (await fetchJsonWithCache(
|
||||
({ data, cached } = (await fetchWithCache(
|
||||
`https://${this.apiHost}/openai/deployments/${this.deploymentName}/completions?api-version=2023-07-01-preview`,
|
||||
{
|
||||
method: 'POST',
|
||||
@@ -258,7 +258,7 @@ export class AzureOpenAiChatCompletionProvider extends AzureOpenAiGenericProvide
|
||||
let data,
|
||||
cached = false;
|
||||
try {
|
||||
({ data, cached } = (await fetchJsonWithCache(
|
||||
({ data, cached } = (await fetchWithCache(
|
||||
`https://${this.apiHost}/openai/deployments/${this.deploymentName}/chat/completions?api-version=2023-07-01-preview`,
|
||||
{
|
||||
method: 'POST',
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { fetchJsonWithCache } from '../cache';
|
||||
import { fetchWithCache } from '../cache';
|
||||
import { REQUEST_TIMEOUT_MS } from './shared';
|
||||
|
||||
import type { ApiProvider, ProviderResponse } from '../types.js';
|
||||
@@ -65,7 +65,7 @@ export class LlamaProvider implements ApiProvider {
|
||||
|
||||
let response;
|
||||
try {
|
||||
response = await fetchJsonWithCache(
|
||||
response = await fetchWithCache(
|
||||
`${process.env.LLAMA_BASE_URL || 'http://localhost:8080'}/completion`,
|
||||
{
|
||||
method: 'POST',
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logger from '../logger';
|
||||
import { fetchJsonWithCache } from '../cache';
|
||||
import { fetchWithCache } from '../cache';
|
||||
import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
|
||||
|
||||
import type { ApiProvider, ProviderResponse } from '../types.js';
|
||||
@@ -40,7 +40,7 @@ export class LocalAiChatProvider extends LocalAiGenericProvider {
|
||||
let data,
|
||||
cached = false;
|
||||
try {
|
||||
({ data, cached } = (await fetchJsonWithCache(
|
||||
({ data, cached } = (await fetchWithCache(
|
||||
`${this.apiBaseUrl}/chat/completions`,
|
||||
{
|
||||
method: 'POST',
|
||||
@@ -81,7 +81,7 @@ export class LocalAiCompletionProvider extends LocalAiGenericProvider {
|
||||
let data,
|
||||
cached = false;
|
||||
try {
|
||||
({ data, cached } = (await fetchJsonWithCache(
|
||||
({ data, cached } = (await fetchWithCache(
|
||||
`${this.apiBaseUrl}/completions`,
|
||||
{
|
||||
method: 'POST',
|
||||
|
||||
88
src/providers/ollama.ts
Normal file
88
src/providers/ollama.ts
Normal file
@@ -0,0 +1,88 @@
|
||||
import logger from '../logger';
|
||||
import { fetchWithCache } from '../cache';
|
||||
|
||||
import type { ApiProvider, ProviderResponse } from '../types.js';
|
||||
import { REQUEST_TIMEOUT_MS } from './shared';
|
||||
|
||||
interface OllamaJsonL {
|
||||
model: string;
|
||||
created_at: string;
|
||||
response?: string;
|
||||
done: boolean;
|
||||
context?: number[];
|
||||
total_duration?: number;
|
||||
load_duration?: number;
|
||||
sample_count?: number;
|
||||
sample_duration?: number;
|
||||
prompt_eval_count?: number;
|
||||
prompt_eval_duration?: number;
|
||||
eval_count?: number;
|
||||
eval_duration?: number;
|
||||
}
|
||||
|
||||
export class OllamaProvider implements ApiProvider {
|
||||
modelName: string;
|
||||
|
||||
constructor(modelName: string) {
|
||||
this.modelName = modelName;
|
||||
}
|
||||
|
||||
id(): string {
|
||||
return `ollama:${this.modelName}`;
|
||||
}
|
||||
|
||||
toString(): string {
|
||||
return `[Ollama Provider ${this.modelName}]`;
|
||||
}
|
||||
|
||||
async callApi(prompt: string): Promise<ProviderResponse> {
|
||||
const params = {
|
||||
model: this.modelName,
|
||||
prompt,
|
||||
};
|
||||
|
||||
logger.debug(`Calling Ollama API: ${JSON.stringify(params)}`);
|
||||
let response;
|
||||
try {
|
||||
response = await fetchWithCache(
|
||||
`${process.env.OLLAMA_BASE_URL || 'http://localhost:11434'}/api/generate`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(params),
|
||||
},
|
||||
REQUEST_TIMEOUT_MS,
|
||||
'text',
|
||||
);
|
||||
} catch (err) {
|
||||
return {
|
||||
error: `API call error: ${String(err)}`,
|
||||
};
|
||||
}
|
||||
logger.debug(`\tOllama API response: ${response.data}`);
|
||||
|
||||
try {
|
||||
const output = response.data
|
||||
.split('\n')
|
||||
.map((line: string) => {
|
||||
const parsed = JSON.parse(line) as OllamaJsonL;
|
||||
if (parsed.response) {
|
||||
return parsed.response;
|
||||
}
|
||||
return null;
|
||||
})
|
||||
.filter((s: string | null) => s !== null)
|
||||
.join('');
|
||||
|
||||
return {
|
||||
output,
|
||||
};
|
||||
} catch (err) {
|
||||
return {
|
||||
error: `API response error: ${String(err)}: ${JSON.stringify(response.data)}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import logger from '../logger';
|
||||
import { fetchJsonWithCache } from '../cache';
|
||||
import { fetchWithCache } from '../cache';
|
||||
import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
|
||||
|
||||
import type { ApiProvider, ProviderEmbeddingResponse, ProviderResponse } from '../types.js';
|
||||
@@ -61,7 +61,7 @@ export class OpenAiEmbeddingProvider extends OpenAiGenericProvider {
|
||||
let data,
|
||||
cached = false;
|
||||
try {
|
||||
({ data, cached } = (await fetchJsonWithCache(
|
||||
({ data, cached } = (await fetchWithCache(
|
||||
`https://${this.apiHost}/v1/embeddings`,
|
||||
{
|
||||
method: 'POST',
|
||||
@@ -177,7 +177,7 @@ export class OpenAiCompletionProvider extends OpenAiGenericProvider {
|
||||
let data,
|
||||
cached = false;
|
||||
try {
|
||||
({ data, cached } = (await fetchJsonWithCache(
|
||||
({ data, cached } = (await fetchWithCache(
|
||||
`https://${this.apiHost}/v1/completions`,
|
||||
{
|
||||
method: 'POST',
|
||||
@@ -275,7 +275,7 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
|
||||
let data,
|
||||
cached = false;
|
||||
try {
|
||||
({ data, cached } = (await fetchJsonWithCache(
|
||||
({ data, cached } = (await fetchWithCache(
|
||||
`https://${this.apiHost}/v1/chat/completions`,
|
||||
{
|
||||
method: 'POST',
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { fetchJsonWithCache, disableCache, enableCache } from '../src/cache.js';
|
||||
import { fetchWithCache, disableCache, enableCache } from '../src/cache.js';
|
||||
import fetch, { Response } from 'node-fetch';
|
||||
|
||||
jest.mock('node-fetch');
|
||||
const mockedFetch = fetch as jest.MockedFunction<typeof fetch>;
|
||||
|
||||
describe('fetchJsonWithCache', () => {
|
||||
describe('fetchWithCache', () => {
|
||||
afterEach(() => {
|
||||
mockedFetch.mockReset();
|
||||
});
|
||||
@@ -20,7 +20,7 @@ describe('fetchJsonWithCache', () => {
|
||||
json: () => Promise.resolve(response),
|
||||
} as Response);
|
||||
|
||||
const result = await fetchJsonWithCache(url, {}, 1000);
|
||||
const result = await fetchWithCache(url, {}, 1000);
|
||||
|
||||
expect(mockedFetch).toHaveBeenCalledTimes(1);
|
||||
expect(result).toEqual({ cached: false, data: response });
|
||||
@@ -37,7 +37,7 @@ describe('fetchJsonWithCache', () => {
|
||||
json: () => Promise.resolve(response),
|
||||
} as Response);
|
||||
|
||||
const result = await fetchJsonWithCache(url, {}, 1000);
|
||||
const result = await fetchWithCache(url, {}, 1000);
|
||||
|
||||
expect(mockedFetch).toHaveBeenCalledTimes(1);
|
||||
expect(result).toEqual({ cached: false, data: response });
|
||||
@@ -52,7 +52,7 @@ describe('fetchJsonWithCache', () => {
|
||||
json: () => Promise.resolve(response),
|
||||
} as Response);
|
||||
|
||||
const result = await fetchJsonWithCache(url, {}, 1000);
|
||||
const result = await fetchWithCache(url, {}, 1000);
|
||||
|
||||
expect(mockedFetch).toHaveBeenCalledTimes(0);
|
||||
expect(result).toEqual({ cached: true, data: response });
|
||||
@@ -68,7 +68,7 @@ describe('fetchJsonWithCache', () => {
|
||||
json: () => Promise.resolve(response),
|
||||
} as Response);
|
||||
|
||||
const result = await fetchJsonWithCache(url, {}, 1000);
|
||||
const result = await fetchWithCache(url, {}, 1000);
|
||||
|
||||
expect(mockedFetch).toHaveBeenCalledTimes(1);
|
||||
expect(result).toEqual({ cached: false, data: response });
|
||||
@@ -87,7 +87,7 @@ describe('fetchJsonWithCache', () => {
|
||||
json: () => Promise.resolve(response),
|
||||
} as Response);
|
||||
|
||||
const result = await fetchJsonWithCache(url, {}, 1000);
|
||||
const result = await fetchWithCache(url, {}, 1000);
|
||||
|
||||
expect(mockedFetch).toHaveBeenCalledTimes(1);
|
||||
expect(result).toEqual({ cached: false, data: response });
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import fetch from 'node-fetch';
|
||||
import { fetchJsonWithCache } from '../src/cache';
|
||||
import { fetchWithCache } from '../src/cache';
|
||||
|
||||
import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from '../src/providers/openai';
|
||||
import { AnthropicCompletionProvider } from '../src/providers/anthropic';
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
AzureOpenAiChatCompletionProvider,
|
||||
AzureOpenAiCompletionProvider,
|
||||
} from '../src/providers/azureopenai';
|
||||
import { OllamaProvider } from '../src/providers/ollama';
|
||||
|
||||
jest.mock('node-fetch', () => jest.fn());
|
||||
|
||||
@@ -167,6 +168,28 @@ describe('providers', () => {
|
||||
expect(result.output).toBe('Test output');
|
||||
});
|
||||
|
||||
test('OllamaProvider callApi', async () => {
|
||||
const mockResponse = {
|
||||
text: jest.fn()
|
||||
.mockResolvedValue(`{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.898068Z","response":"Gre","done":false}
|
||||
{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.929199Z","response":"at","done":false}
|
||||
{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.959989Z","response":" question","done":false}
|
||||
{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.992117Z","response":"!","done":false}
|
||||
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.023658Z","response":" The","done":false}
|
||||
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.0551Z","response":" sky","done":false}
|
||||
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.086103Z","response":" appears","done":false}
|
||||
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.117166Z","response":" blue","done":false}
|
||||
{"model":"llama2:13b","created_at":"2023-08-08T21:50:41.695299Z","done":true,"context":[1,29871,1,13,9314],"total_duration":10411943458,"load_duration":458333,"sample_count":217,"sample_duration":154566000,"prompt_eval_count":11,"prompt_eval_duration":3334582000,"eval_count":216,"eval_duration":6905134000}`),
|
||||
};
|
||||
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
|
||||
|
||||
const provider = new OllamaProvider('llama');
|
||||
const result = await provider.callApi('Test prompt');
|
||||
|
||||
expect(fetch).toHaveBeenCalledTimes(1);
|
||||
expect(result.output).toBe('Great question! The sky appears blue');
|
||||
});
|
||||
|
||||
test('loadApiProvider with openai:chat', async () => {
|
||||
const provider = await loadApiProvider('openai:chat');
|
||||
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
|
||||
|
||||
Reference in New Issue
Block a user