mirror of
https://github.com/promptfoo/promptfoo.git
synced 2023-08-15 01:10:51 +03:00
Add Azure OpenAI Provider (#66)
* update local command azure providers other naming updates Add azure providers * unit tests
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
"name": "promptfoo",
|
||||
"description": "LLM eval & testing toolkit",
|
||||
"author": "Ian Webster",
|
||||
"version": "0.17.4",
|
||||
"version": "0.17.5",
|
||||
"license": "MIT",
|
||||
"type": "commonjs",
|
||||
"main": "dist/src/index.js",
|
||||
@@ -25,7 +25,7 @@
|
||||
"promptfoo": "dist/src/main.js"
|
||||
},
|
||||
"scripts": {
|
||||
"local": "ts-node --esm src/main.ts",
|
||||
"local": "ts-node --esm --files src/main.ts",
|
||||
"install:client": "cd src/web/client && npm install",
|
||||
"build:clean": "rm -rf dist",
|
||||
"build:client": "cd src/web/client && npm run build && cp -r dist/ ../../../dist/src/web/client",
|
||||
|
||||
@@ -6,6 +6,7 @@ import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from './provid
|
||||
import { AnthropicCompletionProvider } from './providers/anthropic';
|
||||
import { LocalAiCompletionProvider, LocalAiChatProvider } from './providers/localai';
|
||||
import { ScriptCompletionProvider } from './providers/scriptCompletion';
|
||||
import { AzureOpenAiChatCompletionProvider, AzureOpenAiCompletionProvider } from './providers/azureopenai'
|
||||
|
||||
export async function loadApiProviders(
|
||||
providerPaths: ProviderId | ProviderId[] | RawProviderConfig[],
|
||||
@@ -68,6 +69,29 @@ export async function loadApiProvider(
|
||||
`Unknown OpenAI model type: ${modelType}. Use one of the following providers: openai:chat:<model name>, openai:completion:<model name>`,
|
||||
);
|
||||
}
|
||||
} else if (providerPath?.startsWith('azureopenai:')) {
|
||||
// Load Azure OpenAI module
|
||||
const options = providerPath.split(':');
|
||||
const modelType = options[1];
|
||||
const deploymentName = options[2];
|
||||
|
||||
if (modelType === 'chat') {
|
||||
return new AzureOpenAiChatCompletionProvider(
|
||||
deploymentName,
|
||||
undefined,
|
||||
context?.config,
|
||||
);
|
||||
} else if (modelType === 'completion') {
|
||||
return new AzureOpenAiCompletionProvider(
|
||||
deploymentName,
|
||||
undefined,
|
||||
context?.config,
|
||||
);
|
||||
} else {
|
||||
throw new Error(
|
||||
`Unknown Azure OpenAI model type: ${modelType}. Use one of the following providers: openai:chat:<model name>, openai:completion:<model name>`,
|
||||
);
|
||||
}
|
||||
} else if (providerPath?.startsWith('anthropic:')) {
|
||||
// Load Anthropic module
|
||||
const options = providerPath.split(':');
|
||||
|
||||
275
src/providers/azureopenai.ts
Normal file
275
src/providers/azureopenai.ts
Normal file
@@ -0,0 +1,275 @@
|
||||
import logger from '../logger';
|
||||
import { fetchJsonWithCache } from '../cache';
|
||||
import { REQUEST_TIMEOUT_MS } from './shared';
|
||||
|
||||
import type { ApiProvider, ProviderEmbeddingResponse, ProviderResponse } from '../types.js';
|
||||
|
||||
interface AzureOpenAiCompletionOptions {
|
||||
temperature?: number;
|
||||
functions?: {
|
||||
name: string;
|
||||
description?: string;
|
||||
parameters: any;
|
||||
}[];
|
||||
function_call?: 'none' | 'auto';
|
||||
}
|
||||
|
||||
class AzureOpenAiGenericProvider implements ApiProvider {
|
||||
deploymentName: string;
|
||||
apiKey?: string;
|
||||
apiHost: string;
|
||||
|
||||
constructor(deploymentName: string, apiKey?: string) {
|
||||
this.deploymentName = deploymentName;
|
||||
|
||||
this.apiKey = apiKey || process.env.AZURE_OPENAI_API_KEY;
|
||||
|
||||
if (!process.env.AZURE_OPENAI_API_HOST) {
|
||||
throw new Error('Azure OpenAI API host must be set');
|
||||
}
|
||||
|
||||
this.apiHost = process.env.AZURE_OPENAI_API_HOST;
|
||||
}
|
||||
|
||||
id(): string {
|
||||
return `azureopenai:${this.deploymentName}`;
|
||||
}
|
||||
|
||||
toString(): string {
|
||||
return `[Azure OpenAI Provider ${this.deploymentName}]`;
|
||||
}
|
||||
|
||||
// @ts-ignore: Prompt is not used in this implementation
|
||||
async callApi(prompt: string, options?: AzureOpenAiCompletionOptions): Promise<ProviderResponse> {
|
||||
throw new Error('Not implemented');
|
||||
}
|
||||
}
|
||||
|
||||
export class AzureOpenAiEmbeddingProvider extends AzureOpenAiGenericProvider {
|
||||
async callEmbeddingApi(text: string): Promise<ProviderEmbeddingResponse> {
|
||||
if (!this.apiKey) {
|
||||
throw new Error('Azure OpenAI API key must be set for similarity comparison');
|
||||
}
|
||||
|
||||
const body = {
|
||||
input: text,
|
||||
model: this.deploymentName,
|
||||
};
|
||||
let data,
|
||||
cached = false;
|
||||
try {
|
||||
({ data, cached } = (await fetchJsonWithCache(
|
||||
`https://${this.apiHost}/openai/deployments/${this.deploymentName}/embeddings?api-version=2023-07-01-preview`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
"api-key": this.apiKey,
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
},
|
||||
REQUEST_TIMEOUT_MS,
|
||||
)) as unknown as any);
|
||||
} catch (err) {
|
||||
return {
|
||||
error: `API call error: ${String(err)}`,
|
||||
tokenUsage: {
|
||||
total: 0,
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
},
|
||||
};
|
||||
}
|
||||
logger.debug(`\tAzure OpenAI API response (embeddings): ${JSON.stringify(data)}`);
|
||||
|
||||
try {
|
||||
const embedding = data?.data?.[0]?.embedding;
|
||||
if (!embedding) {
|
||||
throw new Error('No embedding returned');
|
||||
}
|
||||
const ret = {
|
||||
embedding,
|
||||
tokenUsage: cached
|
||||
? { cached: data.usage.total_tokens }
|
||||
: {
|
||||
total: data.usage.total_tokens,
|
||||
prompt: data.usage.prompt_tokens,
|
||||
completion: data.usage.completion_tokens,
|
||||
},
|
||||
};
|
||||
return ret;
|
||||
} catch (err) {
|
||||
return {
|
||||
error: `API response error: ${String(err)}: ${JSON.stringify(data)}`,
|
||||
tokenUsage: {
|
||||
total: data?.usage?.total_tokens,
|
||||
prompt: data?.usage?.prompt_tokens,
|
||||
completion: data?.usage?.completion_tokens,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class AzureOpenAiCompletionProvider extends AzureOpenAiGenericProvider {
|
||||
options: AzureOpenAiCompletionOptions;
|
||||
|
||||
constructor(deploymentName: string, apiKey?: string, context?: AzureOpenAiCompletionOptions) {
|
||||
super(deploymentName, apiKey);
|
||||
this.options = context || {};
|
||||
}
|
||||
|
||||
async callApi(prompt: string, options?: AzureOpenAiCompletionOptions): Promise<ProviderResponse> {
|
||||
if (!this.apiKey) {
|
||||
throw new Error(
|
||||
'Azure OpenAI API key is not set. Set AZURE_OPENAI_API_KEY environment variable or pass it as an argument to the constructor.',
|
||||
);
|
||||
}
|
||||
|
||||
let stop: string;
|
||||
try {
|
||||
stop = process.env.OPENAI_STOP
|
||||
? JSON.parse(process.env.OPENAI_STOP)
|
||||
: ['<|im_end|>', '<|endoftext|>'];
|
||||
} catch (err) {
|
||||
throw new Error(`OPENAI_STOP is not a valid JSON string: ${err}`);
|
||||
}
|
||||
const body = {
|
||||
model: this.deploymentName,
|
||||
prompt,
|
||||
max_tokens: parseInt(process.env.OPENAI_MAX_TOKENS || '1024'),
|
||||
temperature:
|
||||
options?.temperature ??
|
||||
this.options.temperature ??
|
||||
parseFloat(process.env.OPENAI_TEMPERATURE || '0'),
|
||||
stop,
|
||||
};
|
||||
logger.debug(`Calling Azure OpenAI API: ${JSON.stringify(body)}`);
|
||||
let data,
|
||||
cached = false;
|
||||
try {
|
||||
({ data, cached } = (await fetchJsonWithCache(
|
||||
`https://${this.apiHost}/openai/deployments/${this.deploymentName}/completions?api-version=2023-07-01-preview`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
"api-key": this.apiKey,
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
},
|
||||
REQUEST_TIMEOUT_MS,
|
||||
)) as unknown as any);
|
||||
} catch (err) {
|
||||
return {
|
||||
error: `API call error: ${String(err)}`,
|
||||
};
|
||||
}
|
||||
logger.debug(`\tAzure OpenAI API response: ${JSON.stringify(data)}`);
|
||||
try {
|
||||
return {
|
||||
output: data.choices[0].text,
|
||||
tokenUsage: cached
|
||||
? { cached: data.usage.total_tokens }
|
||||
: {
|
||||
total: data.usage.total_tokens,
|
||||
prompt: data.usage.prompt_tokens,
|
||||
completion: data.usage.completion_tokens,
|
||||
},
|
||||
};
|
||||
} catch (err) {
|
||||
return {
|
||||
error: `API response error: ${String(err)}: ${JSON.stringify(data)}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class AzureOpenAiChatCompletionProvider extends AzureOpenAiGenericProvider {
|
||||
options: AzureOpenAiCompletionOptions;
|
||||
|
||||
constructor(deploymentName: string, apiKey?: string, context?: AzureOpenAiCompletionOptions) {
|
||||
super(deploymentName, apiKey);
|
||||
this.options = context || {};
|
||||
}
|
||||
|
||||
async callApi(prompt: string, options?: AzureOpenAiCompletionOptions): Promise<ProviderResponse> {
|
||||
if (!this.apiKey) {
|
||||
throw new Error(
|
||||
'Azure OpenAI API key is not set. Set AZURE_OPENAI_API_KEY environment variable or pass it as an argument to the constructor.',
|
||||
);
|
||||
}
|
||||
|
||||
let messages: { role: string; content: string; name?: string }[];
|
||||
try {
|
||||
messages = JSON.parse(prompt) as { role: string; content: string }[];
|
||||
} catch (err) {
|
||||
const trimmedPrompt = prompt.trim();
|
||||
if (
|
||||
process.env.PROMPTFOO_REQUIRE_JSON_PROMPTS ||
|
||||
trimmedPrompt.startsWith('{') ||
|
||||
trimmedPrompt.startsWith('[')
|
||||
) {
|
||||
throw new Error(
|
||||
`Azure OpenAI Chat Completion prompt is not a valid JSON string: ${err}\n\n${prompt}`,
|
||||
);
|
||||
}
|
||||
messages = [{ role: 'user', content: prompt }];
|
||||
}
|
||||
|
||||
const body = {
|
||||
model: this.deploymentName,
|
||||
messages: messages,
|
||||
max_tokens: parseInt(process.env.OPENAI_MAX_TOKENS || '1024'),
|
||||
temperature:
|
||||
options?.temperature ??
|
||||
this.options.temperature ??
|
||||
parseFloat(process.env.OPENAI_TEMPERATURE || '0'),
|
||||
functions: options?.functions || this.options.functions || undefined,
|
||||
function_call: options?.function_call || this.options.function_call || undefined,
|
||||
};
|
||||
logger.debug(`Calling Azure OpenAI API: ${JSON.stringify(body)}`);
|
||||
|
||||
let data,
|
||||
cached = false;
|
||||
try {
|
||||
({ data, cached } = (await fetchJsonWithCache(
|
||||
`https://${this.apiHost}/openai/deployments/${this.deploymentName}/chat/completions?api-version=2023-07-01-preview`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
"api-key": this.apiKey,
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
},
|
||||
REQUEST_TIMEOUT_MS,
|
||||
)) as unknown as any);
|
||||
} catch (err) {
|
||||
return {
|
||||
error: `API call error: ${String(err)}`,
|
||||
};
|
||||
}
|
||||
|
||||
logger.debug(`\tAzure OpenAI API response: ${JSON.stringify(data)}`);
|
||||
try {
|
||||
const message = data.choices[0].message;
|
||||
const output =
|
||||
message.content === null ? JSON.stringify(message.function_call) : message.content;
|
||||
return {
|
||||
output,
|
||||
tokenUsage: cached
|
||||
? { cached: data.usage.total_tokens }
|
||||
: {
|
||||
total: data.usage.total_tokens,
|
||||
prompt: data.usage.prompt_tokens,
|
||||
completion: data.usage.completion_tokens,
|
||||
},
|
||||
};
|
||||
} catch (err) {
|
||||
return {
|
||||
error: `API response error: ${String(err)}: ${JSON.stringify(data)}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import { AnthropicCompletionProvider } from '../src/providers/anthropic';
|
||||
import { disableCache, enableCache } from '../src/cache.js';
|
||||
import { loadApiProvider, loadApiProviders } from '../src/providers.js';
|
||||
import type { RawProviderConfig } from '../src/types';
|
||||
import { AzureOpenAiChatCompletionProvider, AzureOpenAiCompletionProvider } from '../src/providers/azureopenai';
|
||||
|
||||
jest.mock('node-fetch', () => jest.fn());
|
||||
|
||||
@@ -75,6 +76,65 @@ describe('providers', () => {
|
||||
enableCache();
|
||||
});
|
||||
|
||||
test('AzureOpenAiCompletionProvider callApi', async () => {
|
||||
const mockResponse = {
|
||||
json: jest.fn().mockResolvedValue({
|
||||
choices: [{ text: 'Test output' }],
|
||||
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
|
||||
}),
|
||||
};
|
||||
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
|
||||
|
||||
const provider = new AzureOpenAiCompletionProvider('text-davinci-003', 'test-api-key');
|
||||
const result = await provider.callApi('Test prompt');
|
||||
|
||||
expect(fetch).toHaveBeenCalledTimes(1);
|
||||
expect(result.output).toBe('Test output');
|
||||
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
|
||||
});
|
||||
|
||||
test('AzureOpenAiChatCompletionProvider callApi', async () => {
|
||||
const mockResponse = {
|
||||
json: jest.fn().mockResolvedValue({
|
||||
choices: [{ message: { content: 'Test output' } }],
|
||||
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
|
||||
}),
|
||||
};
|
||||
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
|
||||
|
||||
const provider = new AzureOpenAiChatCompletionProvider('gpt-3.5-turbo', 'test-api-key');
|
||||
const result = await provider.callApi(
|
||||
JSON.stringify([{ role: 'user', content: 'Test prompt' }]),
|
||||
);
|
||||
|
||||
expect(fetch).toHaveBeenCalledTimes(1);
|
||||
expect(result.output).toBe('Test output');
|
||||
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
|
||||
});
|
||||
|
||||
test('AzureOpenAiChatCompletionProvider callApi with cache disabled', async () => {
|
||||
disableCache();
|
||||
|
||||
const mockResponse = {
|
||||
json: jest.fn().mockResolvedValue({
|
||||
choices: [{ message: { content: 'Test output' } }],
|
||||
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
|
||||
}),
|
||||
};
|
||||
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
|
||||
|
||||
const provider = new AzureOpenAiChatCompletionProvider('gpt-3.5-turbo', 'test-api-key');
|
||||
const result = await provider.callApi(
|
||||
JSON.stringify([{ role: 'user', content: 'Test prompt' }]),
|
||||
);
|
||||
|
||||
expect(fetch).toHaveBeenCalledTimes(1);
|
||||
expect(result.output).toBe('Test output');
|
||||
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
|
||||
|
||||
enableCache();
|
||||
});
|
||||
|
||||
test('AnthropicCompletionProvider callApi', async () => {
|
||||
const provider = new AnthropicCompletionProvider('claude-1', 'test-api-key');
|
||||
provider.anthropic.completions.create = jest.fn().mockResolvedValue({
|
||||
@@ -107,6 +167,17 @@ describe('providers', () => {
|
||||
expect(provider).toBeInstanceOf(OpenAiCompletionProvider);
|
||||
});
|
||||
|
||||
test('loadApiProvider with azureopenai:completion:modelName', async () => {
|
||||
const provider = await loadApiProvider('azureopenai:completion:text-davinci-003');
|
||||
expect(provider).toBeInstanceOf(AzureOpenAiCompletionProvider);
|
||||
});
|
||||
|
||||
test('loadApiProvider with azureopenai:chat:modelName', async () => {
|
||||
const provider = await loadApiProvider('azureopenai:chat:gpt-3.5-turbo');
|
||||
expect(provider).toBeInstanceOf(AzureOpenAiChatCompletionProvider);
|
||||
});
|
||||
|
||||
|
||||
test('loadApiProvider with anthropic:completion', async () => {
|
||||
const provider = await loadApiProvider('anthropic:completion');
|
||||
expect(provider).toBeInstanceOf(AnthropicCompletionProvider);
|
||||
|
||||
Reference in New Issue
Block a user