Add ollama provider (#102)

This commit is contained in:
Ian Webster
2023-08-08 16:30:55 -07:00
committed by GitHub
parent a603eeea6a
commit 4e6b580f7b
9 changed files with 139 additions and 23 deletions

View File

@@ -42,10 +42,11 @@ export function getCache() {
return cacheInstance;
}
export async function fetchJsonWithCache(
export async function fetchWithCache(
url: RequestInfo,
options: RequestInit = {},
timeout: number,
format: 'json' | 'text' = 'json',
): Promise<{ data: any; cached: boolean }> {
if (!enabled) {
const resp = await fetchWithRetries(url, options, timeout);
@@ -75,7 +76,7 @@ export async function fetchJsonWithCache(
// Fetch the actual data and store it in the cache
const response = await fetchWithRetries(url, options, timeout);
try {
const data = await response.json();
const data = format === 'json' ? await response.json() : await response.text();
if (response.ok) {
logger.debug(`Storing ${url} response in cache: ${JSON.stringify(data)}`);
await cache.set(cacheKey, JSON.stringify(data));

View File

@@ -5,6 +5,7 @@ import { AnthropicCompletionProvider } from './providers/anthropic';
import { ReplicateProvider } from './providers/replicate';
import { LocalAiCompletionProvider, LocalAiChatProvider } from './providers/localai';
import { LlamaProvider } from './providers/llama';
import { OllamaProvider } from './providers/ollama';
import { ScriptCompletionProvider } from './providers/scriptCompletion';
import {
AzureOpenAiChatCompletionProvider,
@@ -148,6 +149,9 @@ export async function loadApiProvider(
if (providerPath === 'llama' || providerPath.startsWith('llama:')) {
const modelName = providerPath.split(':')[1];
return new LlamaProvider(modelName, context?.config);
} else if (providerPath.startsWith('ollama:')) {
const modelName = providerPath.split(':')[1];
return new OllamaProvider(modelName);
} else if (providerPath?.startsWith('localai:')) {
const options = providerPath.split(':');
const modelType = options[1];

View File

@@ -1,5 +1,5 @@
import logger from '../logger';
import { fetchJsonWithCache } from '../cache';
import { fetchWithCache } from '../cache';
import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
import type { ApiProvider, ProviderEmbeddingResponse, ProviderResponse } from '../types.js';
@@ -61,7 +61,7 @@ export class AzureOpenAiEmbeddingProvider extends AzureOpenAiGenericProvider {
let data,
cached = false;
try {
({ data, cached } = (await fetchJsonWithCache(
({ data, cached } = (await fetchWithCache(
`https://${this.apiHost}/openai/deployments/${this.deploymentName}/embeddings?api-version=2023-07-01-preview`,
{
method: 'POST',
@@ -171,7 +171,7 @@ export class AzureOpenAiCompletionProvider extends AzureOpenAiGenericProvider {
let data,
cached = false;
try {
({ data, cached } = (await fetchJsonWithCache(
({ data, cached } = (await fetchWithCache(
`https://${this.apiHost}/openai/deployments/${this.deploymentName}/completions?api-version=2023-07-01-preview`,
{
method: 'POST',
@@ -258,7 +258,7 @@ export class AzureOpenAiChatCompletionProvider extends AzureOpenAiGenericProvide
let data,
cached = false;
try {
({ data, cached } = (await fetchJsonWithCache(
({ data, cached } = (await fetchWithCache(
`https://${this.apiHost}/openai/deployments/${this.deploymentName}/chat/completions?api-version=2023-07-01-preview`,
{
method: 'POST',

View File

@@ -1,4 +1,4 @@
import { fetchJsonWithCache } from '../cache';
import { fetchWithCache } from '../cache';
import { REQUEST_TIMEOUT_MS } from './shared';
import type { ApiProvider, ProviderResponse } from '../types.js';
@@ -65,7 +65,7 @@ export class LlamaProvider implements ApiProvider {
let response;
try {
response = await fetchJsonWithCache(
response = await fetchWithCache(
`${process.env.LLAMA_BASE_URL || 'http://localhost:8080'}/completion`,
{
method: 'POST',

View File

@@ -1,5 +1,5 @@
import logger from '../logger';
import { fetchJsonWithCache } from '../cache';
import { fetchWithCache } from '../cache';
import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
import type { ApiProvider, ProviderResponse } from '../types.js';
@@ -40,7 +40,7 @@ export class LocalAiChatProvider extends LocalAiGenericProvider {
let data,
cached = false;
try {
({ data, cached } = (await fetchJsonWithCache(
({ data, cached } = (await fetchWithCache(
`${this.apiBaseUrl}/chat/completions`,
{
method: 'POST',
@@ -81,7 +81,7 @@ export class LocalAiCompletionProvider extends LocalAiGenericProvider {
let data,
cached = false;
try {
({ data, cached } = (await fetchJsonWithCache(
({ data, cached } = (await fetchWithCache(
`${this.apiBaseUrl}/completions`,
{
method: 'POST',

88
src/providers/ollama.ts Normal file
View File

@@ -0,0 +1,88 @@
import logger from '../logger';
import { fetchWithCache } from '../cache';
import type { ApiProvider, ProviderResponse } from '../types.js';
import { REQUEST_TIMEOUT_MS } from './shared';
interface OllamaJsonL {
model: string;
created_at: string;
response?: string;
done: boolean;
context?: number[];
total_duration?: number;
load_duration?: number;
sample_count?: number;
sample_duration?: number;
prompt_eval_count?: number;
prompt_eval_duration?: number;
eval_count?: number;
eval_duration?: number;
}
export class OllamaProvider implements ApiProvider {
modelName: string;
constructor(modelName: string) {
this.modelName = modelName;
}
id(): string {
return `ollama:${this.modelName}`;
}
toString(): string {
return `[Ollama Provider ${this.modelName}]`;
}
async callApi(prompt: string): Promise<ProviderResponse> {
const params = {
model: this.modelName,
prompt,
};
logger.debug(`Calling Ollama API: ${JSON.stringify(params)}`);
let response;
try {
response = await fetchWithCache(
`${process.env.OLLAMA_BASE_URL || 'http://localhost:11434'}/api/generate`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(params),
},
REQUEST_TIMEOUT_MS,
'text',
);
} catch (err) {
return {
error: `API call error: ${String(err)}`,
};
}
logger.debug(`\tOllama API response: ${response.data}`);
try {
const output = response.data
.split('\n')
.map((line: string) => {
const parsed = JSON.parse(line) as OllamaJsonL;
if (parsed.response) {
return parsed.response;
}
return null;
})
.filter((s: string | null) => s !== null)
.join('');
return {
output,
};
} catch (err) {
return {
error: `API response error: ${String(err)}: ${JSON.stringify(response.data)}`,
};
}
}
}

View File

@@ -1,5 +1,5 @@
import logger from '../logger';
import { fetchJsonWithCache } from '../cache';
import { fetchWithCache } from '../cache';
import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
import type { ApiProvider, ProviderEmbeddingResponse, ProviderResponse } from '../types.js';
@@ -61,7 +61,7 @@ export class OpenAiEmbeddingProvider extends OpenAiGenericProvider {
let data,
cached = false;
try {
({ data, cached } = (await fetchJsonWithCache(
({ data, cached } = (await fetchWithCache(
`https://${this.apiHost}/v1/embeddings`,
{
method: 'POST',
@@ -177,7 +177,7 @@ export class OpenAiCompletionProvider extends OpenAiGenericProvider {
let data,
cached = false;
try {
({ data, cached } = (await fetchJsonWithCache(
({ data, cached } = (await fetchWithCache(
`https://${this.apiHost}/v1/completions`,
{
method: 'POST',
@@ -275,7 +275,7 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
let data,
cached = false;
try {
({ data, cached } = (await fetchJsonWithCache(
({ data, cached } = (await fetchWithCache(
`https://${this.apiHost}/v1/chat/completions`,
{
method: 'POST',

View File

@@ -1,10 +1,10 @@
import { fetchJsonWithCache, disableCache, enableCache } from '../src/cache.js';
import { fetchWithCache, disableCache, enableCache } from '../src/cache.js';
import fetch, { Response } from 'node-fetch';
jest.mock('node-fetch');
const mockedFetch = fetch as jest.MockedFunction<typeof fetch>;
describe('fetchJsonWithCache', () => {
describe('fetchWithCache', () => {
afterEach(() => {
mockedFetch.mockReset();
});
@@ -20,7 +20,7 @@ describe('fetchJsonWithCache', () => {
json: () => Promise.resolve(response),
} as Response);
const result = await fetchJsonWithCache(url, {}, 1000);
const result = await fetchWithCache(url, {}, 1000);
expect(mockedFetch).toHaveBeenCalledTimes(1);
expect(result).toEqual({ cached: false, data: response });
@@ -37,7 +37,7 @@ describe('fetchJsonWithCache', () => {
json: () => Promise.resolve(response),
} as Response);
const result = await fetchJsonWithCache(url, {}, 1000);
const result = await fetchWithCache(url, {}, 1000);
expect(mockedFetch).toHaveBeenCalledTimes(1);
expect(result).toEqual({ cached: false, data: response });
@@ -52,7 +52,7 @@ describe('fetchJsonWithCache', () => {
json: () => Promise.resolve(response),
} as Response);
const result = await fetchJsonWithCache(url, {}, 1000);
const result = await fetchWithCache(url, {}, 1000);
expect(mockedFetch).toHaveBeenCalledTimes(0);
expect(result).toEqual({ cached: true, data: response });
@@ -68,7 +68,7 @@ describe('fetchJsonWithCache', () => {
json: () => Promise.resolve(response),
} as Response);
const result = await fetchJsonWithCache(url, {}, 1000);
const result = await fetchWithCache(url, {}, 1000);
expect(mockedFetch).toHaveBeenCalledTimes(1);
expect(result).toEqual({ cached: false, data: response });
@@ -87,7 +87,7 @@ describe('fetchJsonWithCache', () => {
json: () => Promise.resolve(response),
} as Response);
const result = await fetchJsonWithCache(url, {}, 1000);
const result = await fetchWithCache(url, {}, 1000);
expect(mockedFetch).toHaveBeenCalledTimes(1);
expect(result).toEqual({ cached: false, data: response });

View File

@@ -1,5 +1,5 @@
import fetch from 'node-fetch';
import { fetchJsonWithCache } from '../src/cache';
import { fetchWithCache } from '../src/cache';
import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from '../src/providers/openai';
import { AnthropicCompletionProvider } from '../src/providers/anthropic';
@@ -12,6 +12,7 @@ import {
AzureOpenAiChatCompletionProvider,
AzureOpenAiCompletionProvider,
} from '../src/providers/azureopenai';
import { OllamaProvider } from '../src/providers/ollama';
jest.mock('node-fetch', () => jest.fn());
@@ -167,6 +168,28 @@ describe('providers', () => {
expect(result.output).toBe('Test output');
});
test('OllamaProvider callApi', async () => {
const mockResponse = {
text: jest.fn()
.mockResolvedValue(`{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.898068Z","response":"Gre","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.929199Z","response":"at","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.959989Z","response":" question","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:34.992117Z","response":"!","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.023658Z","response":" The","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.0551Z","response":" sky","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.086103Z","response":" appears","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:35.117166Z","response":" blue","done":false}
{"model":"llama2:13b","created_at":"2023-08-08T21:50:41.695299Z","done":true,"context":[1,29871,1,13,9314],"total_duration":10411943458,"load_duration":458333,"sample_count":217,"sample_duration":154566000,"prompt_eval_count":11,"prompt_eval_duration":3334582000,"eval_count":216,"eval_duration":6905134000}`),
};
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
const provider = new OllamaProvider('llama');
const result = await provider.callApi('Test prompt');
expect(fetch).toHaveBeenCalledTimes(1);
expect(result.output).toBe('Great question! The sky appears blue');
});
test('loadApiProvider with openai:chat', async () => {
const provider = await loadApiProvider('openai:chat');
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);