Add Azure OpenAI Provider (#66)

* update local command

azure providers

other naming updates

Add azure providers

* unit tests
This commit is contained in:
Paul
2023-07-13 17:47:17 -07:00
committed by GitHub
parent 51c254fe3a
commit d15dbc9464
4 changed files with 372 additions and 2 deletions

View File

@@ -2,7 +2,7 @@
"name": "promptfoo",
"description": "LLM eval & testing toolkit",
"author": "Ian Webster",
"version": "0.17.4",
"version": "0.17.5",
"license": "MIT",
"type": "commonjs",
"main": "dist/src/index.js",
@@ -25,7 +25,7 @@
"promptfoo": "dist/src/main.js"
},
"scripts": {
"local": "ts-node --esm src/main.ts",
"local": "ts-node --esm --files src/main.ts",
"install:client": "cd src/web/client && npm install",
"build:clean": "rm -rf dist",
"build:client": "cd src/web/client && npm run build && cp -r dist/ ../../../dist/src/web/client",

View File

@@ -6,6 +6,7 @@ import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from './provid
import { AnthropicCompletionProvider } from './providers/anthropic';
import { LocalAiCompletionProvider, LocalAiChatProvider } from './providers/localai';
import { ScriptCompletionProvider } from './providers/scriptCompletion';
import { AzureOpenAiChatCompletionProvider, AzureOpenAiCompletionProvider } from './providers/azureopenai'
export async function loadApiProviders(
providerPaths: ProviderId | ProviderId[] | RawProviderConfig[],
@@ -68,6 +69,29 @@ export async function loadApiProvider(
`Unknown OpenAI model type: ${modelType}. Use one of the following providers: openai:chat:<model name>, openai:completion:<model name>`,
);
}
} else if (providerPath?.startsWith('azureopenai:')) {
// Load Azure OpenAI module
const options = providerPath.split(':');
const modelType = options[1];
const deploymentName = options[2];
if (modelType === 'chat') {
return new AzureOpenAiChatCompletionProvider(
deploymentName,
undefined,
context?.config,
);
} else if (modelType === 'completion') {
return new AzureOpenAiCompletionProvider(
deploymentName,
undefined,
context?.config,
);
} else {
throw new Error(
`Unknown Azure OpenAI model type: ${modelType}. Use one of the following providers: openai:chat:<model name>, openai:completion:<model name>`,
);
}
} else if (providerPath?.startsWith('anthropic:')) {
// Load Anthropic module
const options = providerPath.split(':');

View File

@@ -0,0 +1,275 @@
import logger from '../logger';
import { fetchJsonWithCache } from '../cache';
import { REQUEST_TIMEOUT_MS } from './shared';
import type { ApiProvider, ProviderEmbeddingResponse, ProviderResponse } from '../types.js';
interface AzureOpenAiCompletionOptions {
temperature?: number;
functions?: {
name: string;
description?: string;
parameters: any;
}[];
function_call?: 'none' | 'auto';
}
class AzureOpenAiGenericProvider implements ApiProvider {
deploymentName: string;
apiKey?: string;
apiHost: string;
constructor(deploymentName: string, apiKey?: string) {
this.deploymentName = deploymentName;
this.apiKey = apiKey || process.env.AZURE_OPENAI_API_KEY;
if (!process.env.AZURE_OPENAI_API_HOST) {
throw new Error('Azure OpenAI API host must be set');
}
this.apiHost = process.env.AZURE_OPENAI_API_HOST;
}
id(): string {
return `azureopenai:${this.deploymentName}`;
}
toString(): string {
return `[Azure OpenAI Provider ${this.deploymentName}]`;
}
// @ts-ignore: Prompt is not used in this implementation
async callApi(prompt: string, options?: AzureOpenAiCompletionOptions): Promise<ProviderResponse> {
throw new Error('Not implemented');
}
}
export class AzureOpenAiEmbeddingProvider extends AzureOpenAiGenericProvider {
async callEmbeddingApi(text: string): Promise<ProviderEmbeddingResponse> {
if (!this.apiKey) {
throw new Error('Azure OpenAI API key must be set for similarity comparison');
}
const body = {
input: text,
model: this.deploymentName,
};
let data,
cached = false;
try {
({ data, cached } = (await fetchJsonWithCache(
`https://${this.apiHost}/openai/deployments/${this.deploymentName}/embeddings?api-version=2023-07-01-preview`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
"api-key": this.apiKey,
},
body: JSON.stringify(body),
},
REQUEST_TIMEOUT_MS,
)) as unknown as any);
} catch (err) {
return {
error: `API call error: ${String(err)}`,
tokenUsage: {
total: 0,
prompt: 0,
completion: 0,
},
};
}
logger.debug(`\tAzure OpenAI API response (embeddings): ${JSON.stringify(data)}`);
try {
const embedding = data?.data?.[0]?.embedding;
if (!embedding) {
throw new Error('No embedding returned');
}
const ret = {
embedding,
tokenUsage: cached
? { cached: data.usage.total_tokens }
: {
total: data.usage.total_tokens,
prompt: data.usage.prompt_tokens,
completion: data.usage.completion_tokens,
},
};
return ret;
} catch (err) {
return {
error: `API response error: ${String(err)}: ${JSON.stringify(data)}`,
tokenUsage: {
total: data?.usage?.total_tokens,
prompt: data?.usage?.prompt_tokens,
completion: data?.usage?.completion_tokens,
},
};
}
}
}
export class AzureOpenAiCompletionProvider extends AzureOpenAiGenericProvider {
options: AzureOpenAiCompletionOptions;
constructor(deploymentName: string, apiKey?: string, context?: AzureOpenAiCompletionOptions) {
super(deploymentName, apiKey);
this.options = context || {};
}
async callApi(prompt: string, options?: AzureOpenAiCompletionOptions): Promise<ProviderResponse> {
if (!this.apiKey) {
throw new Error(
'Azure OpenAI API key is not set. Set AZURE_OPENAI_API_KEY environment variable or pass it as an argument to the constructor.',
);
}
let stop: string;
try {
stop = process.env.OPENAI_STOP
? JSON.parse(process.env.OPENAI_STOP)
: ['<|im_end|>', '<|endoftext|>'];
} catch (err) {
throw new Error(`OPENAI_STOP is not a valid JSON string: ${err}`);
}
const body = {
model: this.deploymentName,
prompt,
max_tokens: parseInt(process.env.OPENAI_MAX_TOKENS || '1024'),
temperature:
options?.temperature ??
this.options.temperature ??
parseFloat(process.env.OPENAI_TEMPERATURE || '0'),
stop,
};
logger.debug(`Calling Azure OpenAI API: ${JSON.stringify(body)}`);
let data,
cached = false;
try {
({ data, cached } = (await fetchJsonWithCache(
`https://${this.apiHost}/openai/deployments/${this.deploymentName}/completions?api-version=2023-07-01-preview`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
"api-key": this.apiKey,
},
body: JSON.stringify(body),
},
REQUEST_TIMEOUT_MS,
)) as unknown as any);
} catch (err) {
return {
error: `API call error: ${String(err)}`,
};
}
logger.debug(`\tAzure OpenAI API response: ${JSON.stringify(data)}`);
try {
return {
output: data.choices[0].text,
tokenUsage: cached
? { cached: data.usage.total_tokens }
: {
total: data.usage.total_tokens,
prompt: data.usage.prompt_tokens,
completion: data.usage.completion_tokens,
},
};
} catch (err) {
return {
error: `API response error: ${String(err)}: ${JSON.stringify(data)}`,
};
}
}
}
export class AzureOpenAiChatCompletionProvider extends AzureOpenAiGenericProvider {
options: AzureOpenAiCompletionOptions;
constructor(deploymentName: string, apiKey?: string, context?: AzureOpenAiCompletionOptions) {
super(deploymentName, apiKey);
this.options = context || {};
}
async callApi(prompt: string, options?: AzureOpenAiCompletionOptions): Promise<ProviderResponse> {
if (!this.apiKey) {
throw new Error(
'Azure OpenAI API key is not set. Set AZURE_OPENAI_API_KEY environment variable or pass it as an argument to the constructor.',
);
}
let messages: { role: string; content: string; name?: string }[];
try {
messages = JSON.parse(prompt) as { role: string; content: string }[];
} catch (err) {
const trimmedPrompt = prompt.trim();
if (
process.env.PROMPTFOO_REQUIRE_JSON_PROMPTS ||
trimmedPrompt.startsWith('{') ||
trimmedPrompt.startsWith('[')
) {
throw new Error(
`Azure OpenAI Chat Completion prompt is not a valid JSON string: ${err}\n\n${prompt}`,
);
}
messages = [{ role: 'user', content: prompt }];
}
const body = {
model: this.deploymentName,
messages: messages,
max_tokens: parseInt(process.env.OPENAI_MAX_TOKENS || '1024'),
temperature:
options?.temperature ??
this.options.temperature ??
parseFloat(process.env.OPENAI_TEMPERATURE || '0'),
functions: options?.functions || this.options.functions || undefined,
function_call: options?.function_call || this.options.function_call || undefined,
};
logger.debug(`Calling Azure OpenAI API: ${JSON.stringify(body)}`);
let data,
cached = false;
try {
({ data, cached } = (await fetchJsonWithCache(
`https://${this.apiHost}/openai/deployments/${this.deploymentName}/chat/completions?api-version=2023-07-01-preview`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
"api-key": this.apiKey,
},
body: JSON.stringify(body),
},
REQUEST_TIMEOUT_MS,
)) as unknown as any);
} catch (err) {
return {
error: `API call error: ${String(err)}`,
};
}
logger.debug(`\tAzure OpenAI API response: ${JSON.stringify(data)}`);
try {
const message = data.choices[0].message;
const output =
message.content === null ? JSON.stringify(message.function_call) : message.content;
return {
output,
tokenUsage: cached
? { cached: data.usage.total_tokens }
: {
total: data.usage.total_tokens,
prompt: data.usage.prompt_tokens,
completion: data.usage.completion_tokens,
},
};
} catch (err) {
return {
error: `API response error: ${String(err)}: ${JSON.stringify(data)}`,
};
}
}
}

View File

@@ -6,6 +6,7 @@ import { AnthropicCompletionProvider } from '../src/providers/anthropic';
import { disableCache, enableCache } from '../src/cache.js';
import { loadApiProvider, loadApiProviders } from '../src/providers.js';
import type { RawProviderConfig } from '../src/types';
import { AzureOpenAiChatCompletionProvider, AzureOpenAiCompletionProvider } from '../src/providers/azureopenai';
jest.mock('node-fetch', () => jest.fn());
@@ -75,6 +76,65 @@ describe('providers', () => {
enableCache();
});
test('AzureOpenAiCompletionProvider callApi', async () => {
const mockResponse = {
json: jest.fn().mockResolvedValue({
choices: [{ text: 'Test output' }],
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
}),
};
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
const provider = new AzureOpenAiCompletionProvider('text-davinci-003', 'test-api-key');
const result = await provider.callApi('Test prompt');
expect(fetch).toHaveBeenCalledTimes(1);
expect(result.output).toBe('Test output');
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
});
test('AzureOpenAiChatCompletionProvider callApi', async () => {
const mockResponse = {
json: jest.fn().mockResolvedValue({
choices: [{ message: { content: 'Test output' } }],
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
}),
};
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
const provider = new AzureOpenAiChatCompletionProvider('gpt-3.5-turbo', 'test-api-key');
const result = await provider.callApi(
JSON.stringify([{ role: 'user', content: 'Test prompt' }]),
);
expect(fetch).toHaveBeenCalledTimes(1);
expect(result.output).toBe('Test output');
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
});
test('AzureOpenAiChatCompletionProvider callApi with cache disabled', async () => {
disableCache();
const mockResponse = {
json: jest.fn().mockResolvedValue({
choices: [{ message: { content: 'Test output' } }],
usage: { total_tokens: 10, prompt_tokens: 5, completion_tokens: 5 },
}),
};
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
const provider = new AzureOpenAiChatCompletionProvider('gpt-3.5-turbo', 'test-api-key');
const result = await provider.callApi(
JSON.stringify([{ role: 'user', content: 'Test prompt' }]),
);
expect(fetch).toHaveBeenCalledTimes(1);
expect(result.output).toBe('Test output');
expect(result.tokenUsage).toEqual({ total: 10, prompt: 5, completion: 5 });
enableCache();
});
test('AnthropicCompletionProvider callApi', async () => {
const provider = new AnthropicCompletionProvider('claude-1', 'test-api-key');
provider.anthropic.completions.create = jest.fn().mockResolvedValue({
@@ -107,6 +167,17 @@ describe('providers', () => {
expect(provider).toBeInstanceOf(OpenAiCompletionProvider);
});
test('loadApiProvider with azureopenai:completion:modelName', async () => {
const provider = await loadApiProvider('azureopenai:completion:text-davinci-003');
expect(provider).toBeInstanceOf(AzureOpenAiCompletionProvider);
});
test('loadApiProvider with azureopenai:chat:modelName', async () => {
const provider = await loadApiProvider('azureopenai:chat:gpt-3.5-turbo');
expect(provider).toBeInstanceOf(AzureOpenAiChatCompletionProvider);
});
test('loadApiProvider with anthropic:completion', async () => {
const provider = await loadApiProvider('anthropic:completion');
expect(provider).toBeInstanceOf(AnthropicCompletionProvider);