Add support for switching to Llama models (#80)

* Add support for switching to Llama models

* Fix prettier
This commit is contained in:
arcticfly
2023-07-21 20:10:59 -07:00
committed by GitHub
parent 4ea30a3ba3
commit 6fb7a82d72
19 changed files with 293 additions and 285 deletions

View File

@@ -1,18 +1,18 @@
import { type PromptVariant } from "@prisma/client";
import { type SupportedModel } from "../types";
import ivm from "isolated-vm";
import dedent from "dedent";
import { openai } from "./openai";
import { getApiShapeForModel } from "./getTypesForModel";
import { isObject } from "lodash-es";
import { type CompletionCreateParams } from "openai/resources/chat/completions";
import formatPromptConstructor from "~/utils/formatPromptConstructor";
import { type SupportedProvider, type Model } from "~/modelProviders/types";
import modelProviders from "~/modelProviders/modelProviders";
const isolate = new ivm.Isolate({ memoryLimit: 128 });
export async function deriveNewConstructFn(
originalVariant: PromptVariant | null,
newModel?: SupportedModel,
newModel?: Model,
instructions?: string,
) {
if (originalVariant && !newModel && !instructions) {
@@ -36,10 +36,11 @@ export async function deriveNewConstructFn(
const NUM_RETRIES = 5;
const requestUpdatedPromptFunction = async (
originalVariant: PromptVariant,
newModel?: SupportedModel,
newModel?: Model,
instructions?: string,
) => {
const originalModel = originalVariant.model as SupportedModel;
const originalModelProvider = modelProviders[originalVariant.modelProvider as SupportedProvider];
const originalModel = originalModelProvider.models[originalVariant.model] as Model;
let newContructionFn = "";
for (let i = 0; i < NUM_RETRIES; i++) {
try {
@@ -47,7 +48,7 @@ const requestUpdatedPromptFunction = async (
{
role: "system",
content: `Your job is to update prompt constructor functions. Here is the api shape for the current model:\n---\n${JSON.stringify(
getApiShapeForModel(originalModel),
originalModelProvider.inputSchema,
null,
2,
)}\n\nDo not add any assistant messages.`,
@@ -60,8 +61,20 @@ const requestUpdatedPromptFunction = async (
if (newModel) {
messages.push({
role: "user",
content: `Return the prompt constructor function for ${newModel} given the existing prompt constructor function for ${originalModel}`,
content: `Return the prompt constructor function for ${newModel.name} given the existing prompt constructor function for ${originalModel.name}`,
});
if (newModel.provider !== originalModel.provider) {
messages.push({
role: "user",
content: `The old provider was ${originalModel.provider}. The new provider is ${
newModel.provider
}. Here is the schema for the new model:\n---\n${JSON.stringify(
modelProviders[newModel.provider].inputSchema,
null,
2,
)}`,
});
}
}
if (instructions) {
messages.push({

View File

@@ -1,6 +0,0 @@
import { type SupportedModel } from "../types";
export const getApiShapeForModel = (model: SupportedModel) => {
// if (model in OpenAIChatModel) return openAIChatApiShape;
return "";
};