mirror of
https://github.com/promptfoo/promptfoo.git
synced 2023-08-15 01:10:51 +03:00
Add support for llama.cpp server (#94)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -4,3 +4,4 @@ scratch
|
||||
.DS_Store
|
||||
node_modules
|
||||
dist
|
||||
.aider*
|
||||
|
||||
@@ -4,6 +4,7 @@ import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from './provid
|
||||
import { AnthropicCompletionProvider } from './providers/anthropic';
|
||||
import { ReplicateProvider } from './providers/replicate';
|
||||
import { LocalAiCompletionProvider, LocalAiChatProvider } from './providers/localai';
|
||||
import { LlamaProvider } from './providers/llama';
|
||||
import { ScriptCompletionProvider } from './providers/scriptCompletion';
|
||||
import {
|
||||
AzureOpenAiChatCompletionProvider,
|
||||
@@ -133,7 +134,10 @@ export async function loadApiProvider(
|
||||
return new ReplicateProvider(modelName, undefined, context?.config);
|
||||
}
|
||||
|
||||
if (providerPath?.startsWith('localai:')) {
|
||||
if (providerPath === 'llama' || providerPath.startsWith('llama:')) {
|
||||
const modelName = providerPath.split(':')[1];
|
||||
return new LlamaProvider(modelName, context?.config);
|
||||
} else if (providerPath?.startsWith('localai:')) {
|
||||
const options = providerPath.split(':');
|
||||
const modelType = options[1];
|
||||
const modelName = options[2];
|
||||
|
||||
97
src/providers/llama.ts
Normal file
97
src/providers/llama.ts
Normal file
@@ -0,0 +1,97 @@
|
||||
import { fetchJsonWithCache } from '../cache';
|
||||
import { REQUEST_TIMEOUT_MS } from './shared';
|
||||
|
||||
import type { ApiProvider, ProviderResponse } from '../types.js';
|
||||
|
||||
interface LlamaCompletionOptions {
|
||||
n_predict?: number;
|
||||
temperature?: number;
|
||||
top_k?: number;
|
||||
top_p?: number;
|
||||
n_keep?: number;
|
||||
stop?: string[];
|
||||
repeat_penalty?: number;
|
||||
repeat_last_n?: number;
|
||||
penalize_nl?: boolean;
|
||||
presence_penalty?: number;
|
||||
frequency_penalty?: number;
|
||||
mirostat?: boolean;
|
||||
mirostat_tau?: number;
|
||||
mirostat_eta?: number;
|
||||
seed?: number;
|
||||
ignore_eos?: boolean;
|
||||
logit_bias?: Record<string, number>;
|
||||
}
|
||||
|
||||
export class LlamaProvider implements ApiProvider {
|
||||
modelName: string;
|
||||
apiBaseUrl: string;
|
||||
options?: LlamaCompletionOptions;
|
||||
|
||||
constructor(modelName: string, options?: LlamaCompletionOptions) {
|
||||
this.modelName = modelName;
|
||||
this.apiBaseUrl = 'http://localhost:8080';
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
id(): string {
|
||||
return `llama:${this.modelName}`;
|
||||
}
|
||||
|
||||
toString(): string {
|
||||
return `[Llama Provider ${this.modelName}]`;
|
||||
}
|
||||
|
||||
async callApi(prompt: string, options?: LlamaCompletionOptions): Promise<ProviderResponse> {
|
||||
options = Object.assign({}, this.options, options);
|
||||
const body = {
|
||||
prompt,
|
||||
n_predict: options?.n_predict || 512,
|
||||
temperature: options?.temperature,
|
||||
top_k: options?.top_k,
|
||||
top_p: options?.top_p,
|
||||
n_keep: options?.n_keep,
|
||||
stop: options?.stop,
|
||||
repeat_penalty: options?.repeat_penalty,
|
||||
repeat_last_n: options?.repeat_last_n,
|
||||
penalize_nl: options?.penalize_nl,
|
||||
presence_penalty: options?.presence_penalty,
|
||||
frequency_penalty: options?.frequency_penalty,
|
||||
mirostat: options?.mirostat,
|
||||
mirostat_tau: options?.mirostat_tau,
|
||||
mirostat_eta: options?.mirostat_eta,
|
||||
seed: options?.seed,
|
||||
ignore_eos: options?.ignore_eos,
|
||||
logit_bias: options?.logit_bias,
|
||||
};
|
||||
|
||||
let response;
|
||||
try {
|
||||
response = await fetchJsonWithCache(
|
||||
`${this.apiBaseUrl}/completion`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
},
|
||||
REQUEST_TIMEOUT_MS,
|
||||
);
|
||||
} catch (err) {
|
||||
return {
|
||||
error: `API call error: ${String(err)}`,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
return {
|
||||
output: response.data.content,
|
||||
};
|
||||
} catch (err) {
|
||||
return {
|
||||
error: `API response error: ${String(err)}: ${JSON.stringify(response.data)}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
import fetch from 'node-fetch';
|
||||
import { fetchJsonWithCache } from '../src/cache';
|
||||
|
||||
import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from '../src/providers/openai';
|
||||
import { AnthropicCompletionProvider } from '../src/providers/anthropic';
|
||||
import { LlamaProvider } from '../src/providers/llama';
|
||||
|
||||
import { disableCache, enableCache } from '../src/cache.js';
|
||||
import { loadApiProvider, loadApiProviders } from '../src/providers.js';
|
||||
@@ -150,6 +152,21 @@ describe('providers', () => {
|
||||
expect(result.tokenUsage).toEqual({});
|
||||
});
|
||||
|
||||
test('LlamaProvider callApi', async () => {
|
||||
const mockResponse = {
|
||||
json: jest.fn().mockResolvedValue({
|
||||
content: 'Test output',
|
||||
}),
|
||||
};
|
||||
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
|
||||
|
||||
const provider = new LlamaProvider('llama.cpp');
|
||||
const result = await provider.callApi('Test prompt');
|
||||
|
||||
expect(fetch).toHaveBeenCalledTimes(1);
|
||||
expect(result.output).toBe('Test output');
|
||||
});
|
||||
|
||||
test('loadApiProvider with openai:chat', async () => {
|
||||
const provider = await loadApiProvider('openai:chat');
|
||||
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
|
||||
@@ -190,6 +207,11 @@ describe('providers', () => {
|
||||
expect(provider).toBeInstanceOf(AnthropicCompletionProvider);
|
||||
});
|
||||
|
||||
test('loadApiProvider with llama:modelName', async () => {
|
||||
const provider = await loadApiProvider('llama');
|
||||
expect(provider).toBeInstanceOf(LlamaProvider);
|
||||
});
|
||||
|
||||
test('loadApiProvider with RawProviderConfig', async () => {
|
||||
const rawProviderConfig = {
|
||||
'openai:chat': {
|
||||
|
||||
Reference in New Issue
Block a user