small bugfixes

This commit is contained in:
Kyle Corbitt
2023-07-07 12:22:27 -07:00
parent 70a1448d73
commit 46344d8fc4
10 changed files with 391 additions and 276 deletions

2
pnpm-lock.yaml generated
View File

@@ -1,4 +1,4 @@
lockfileVersion: '6.0' lockfileVersion: '6.1'
settings: settings:
autoInstallPeers: true autoInstallPeers: true

View File

@@ -37,7 +37,7 @@ await prisma.promptVariant.createMany({
label: "Prompt Variant 1", label: "Prompt Variant 1",
sortIndex: 0, sortIndex: 0,
config: { config: {
model: "gpt-3.5-turbo", model: "gpt-3.5-turbo-0613",
messages: [{ role: "user", content: "What is the capital of {{country}}?" }], messages: [{ role: "user", content: "What is the capital of {{country}}?" }],
temperature: 0, temperature: 0,
}, },
@@ -47,7 +47,7 @@ await prisma.promptVariant.createMany({
label: "Prompt Variant 2", label: "Prompt Variant 2",
sortIndex: 1, sortIndex: 1,
config: { config: {
model: "gpt-3.5-turbo", model: "gpt-3.5-turbo-0613",
messages: [ messages: [
{ {
role: "user", role: "user",

File diff suppressed because one or more lines are too long

View File

@@ -6,9 +6,12 @@ import chroma from "chroma-js";
import { BsCurrencyDollar } from "react-icons/bs"; import { BsCurrencyDollar } from "react-icons/bs";
export default function VariantStats(props: { variant: PromptVariant }) { export default function VariantStats(props: { variant: PromptVariant }) {
const { evalResults, overallCost } = api.promptVariants.stats.useQuery({ const { data } = api.promptVariants.stats.useQuery(
variantId: props.variant.id, {
}).data ?? { evalResults: [] }; variantId: props.variant.id,
},
{ initialData: { evalResults: [], overallCost: 0, scenarioCount: 0, outputCount: 0 } }
);
const [passColor, neutralColor, failColor] = useToken("colors", [ const [passColor, neutralColor, failColor] = useToken("colors", [
"green.500", "green.500",
@@ -18,12 +21,19 @@ export default function VariantStats(props: { variant: PromptVariant }) {
const scale = chroma.scale([failColor, neutralColor, passColor]).domain([0, 0.5, 1]); const scale = chroma.scale([failColor, neutralColor, passColor]).domain([0, 0.5, 1]);
if (!(evalResults.length > 0) && !overallCost) return null; const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
if (!(data.evalResults.length > 0) && !data.overallCost) return null;
return ( return (
<HStack justifyContent="space-between" alignItems="center" mx="2"> <HStack justifyContent="space-between" alignItems="center" mx="2" fontSize="xs">
<HStack px={cellPadding.x} py={cellPadding.y} fontSize="sm"> {showNumFinished && (
{evalResults.map((result) => { <Text>
{data.outputCount} / {data.scenarioCount}
</Text>
)}
<HStack px={cellPadding.x} py={cellPadding.y}>
{data.evalResults.map((result) => {
const passedFrac = result.passCount / (result.passCount + result.failCount); const passedFrac = result.passCount / (result.passCount + result.failCount);
return ( return (
<HStack key={result.id}> <HStack key={result.id}>
@@ -35,10 +45,10 @@ export default function VariantStats(props: { variant: PromptVariant }) {
); );
})} })}
</HStack> </HStack>
{overallCost && ( {data.overallCost && (
<HStack spacing={0} align="center" color="gray.500" fontSize="xs" my="2"> <HStack spacing={0} align="center" color="gray.500" my="2">
<Icon as={BsCurrencyDollar} /> <Icon as={BsCurrencyDollar} />
<Text mr={1}>{overallCost.toFixed(3)}</Text> <Text mr={1}>{data.overallCost.toFixed(3)}</Text>
</HStack> </HStack>
)} )}
</HStack> </HStack>

View File

@@ -43,12 +43,11 @@ export const experimentsRouter = createTRPCRouter({
label: "Prompt Variant 1", label: "Prompt Variant 1",
sortIndex: 0, sortIndex: 0,
config: { config: {
model: "gpt-3.5-turbo", model: "gpt-3.5-turbo-0613",
stream: true,
messages: [ messages: [
{ {
role: "system", role: "system",
content: "count to three in {{input}}...", content: "Return 'Ready to go!'",
}, },
], ],
}, },
@@ -57,13 +56,7 @@ export const experimentsRouter = createTRPCRouter({
prisma.testScenario.create({ prisma.testScenario.create({
data: { data: {
experimentId: exp.id, experimentId: exp.id,
variableValues: { input: "Spanish" }, variableValues: {},
},
}),
prisma.templateVariable.create({
data: {
experimentId: exp.id,
label: "input",
}, },
}), }),
]); ]);

View File

@@ -71,7 +71,12 @@ export const modelOutputsRouter = createTRPCRouter({
completionTokens: existingResponse.completionTokens ?? undefined, completionTokens: existingResponse.completionTokens ?? undefined,
}; };
} else { } else {
modelResponse = await getCompletion(filledTemplate, input.channel); try {
modelResponse = await getCompletion(filledTemplate, input.channel);
} catch (e) {
console.error(e);
throw e;
}
} }
const modelOutput = await prisma.modelOutput.upsert({ const modelOutput = await prisma.modelOutput.upsert({
@@ -79,7 +84,7 @@ export const modelOutputsRouter = createTRPCRouter({
promptVariantId_testScenarioId: { promptVariantId_testScenarioId: {
promptVariantId: input.variantId, promptVariantId: input.variantId,
testScenarioId: input.scenarioId, testScenarioId: input.scenarioId,
} },
}, },
create: { create: {
promptVariantId: input.variantId, promptVariantId: input.variantId,

View File

@@ -34,6 +34,19 @@ export const promptVariantsRouter = createTRPCRouter({
include: { evaluation: true }, include: { evaluation: true },
}); });
const scenarioCount = await prisma.testScenario.count({
where: {
experimentId: variant.experimentId,
visible: true,
},
});
const outputCount = await prisma.modelOutput.count({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
},
});
const overallTokens = await prisma.modelOutput.aggregate({ const overallTokens = await prisma.modelOutput.aggregate({
where: { where: {
promptVariantId: input.variantId, promptVariantId: input.variantId,
@@ -53,7 +66,7 @@ export const promptVariantsRouter = createTRPCRouter({
const overallCost = overallPromptCost + overallCompletionCost; const overallCost = overallPromptCost + overallCompletionCost;
return { evalResults, overallCost }; return { evalResults, overallCost, scenarioCount, outputCount };
}), }),
create: publicProcedure create: publicProcedure

View File

@@ -12,7 +12,7 @@ export const evaluateOutput = (
if (!message) return false; if (!message) return false;
const stringifiedMessage = JSON.stringify(message); const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap); const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap);

View File

@@ -13,7 +13,7 @@ export const reevaluateVariant = async (variantId: string) => {
}); });
const modelOutputs = await prisma.modelOutput.findMany({ const modelOutputs = await prisma.modelOutput.findMany({
where: { promptVariantId: variantId }, where: { promptVariantId: variantId, statusCode: { notIn: [429] } },
include: { testScenario: true }, include: { testScenario: true },
}); });
@@ -56,7 +56,11 @@ export const reevaluateEvaluation = async (evaluation: Evaluation) => {
}); });
const modelOutputs = await prisma.modelOutput.findMany({ const modelOutputs = await prisma.modelOutput.findMany({
where: { promptVariantId: { in: variants.map((v) => v.id) }, testScenario: { visible: true } }, where: {
promptVariantId: { in: variants.map((v) => v.id) },
testScenario: { visible: true },
statusCode: { notIn: [429] },
},
include: { testScenario: true }, include: { testScenario: true },
}); });

View File

@@ -10,19 +10,20 @@ export default function useSocket(channel?: string) {
const [message, setMessage] = useState<ChatCompletion | null>(null); const [message, setMessage] = useState<ChatCompletion | null>(null);
useEffect(() => { useEffect(() => {
if (!channel) return;
console.log("connecting to channel", channel);
// Create websocket connection // Create websocket connection
socketRef.current = io(url); socketRef.current = io(url);
socketRef.current.on("connect", () => { socketRef.current.on("connect", () => {
// Join the specific room // Join the specific room
if (channel) { socketRef.current?.emit("join", channel);
socketRef.current?.emit("join", channel);
// Listen for 'message' events // Listen for 'message' events
socketRef.current?.on("message", (message: ChatCompletion) => { socketRef.current?.on("message", (message: ChatCompletion) => {
setMessage(message); setMessage(message);
}); });
}
}); });
// Unsubscribe and disconnect on cleanup // Unsubscribe and disconnect on cleanup
@@ -32,6 +33,7 @@ export default function useSocket(channel?: string) {
socketRef.current.off("message"); socketRef.current.off("message");
} }
socketRef.current.disconnect(); socketRef.current.disconnect();
socketRef.current = undefined;
} }
setMessage(null); setMessage(null);
}; };