small bugfixes

This commit is contained in:
Kyle Corbitt
2023-07-07 12:22:27 -07:00
parent 70a1448d73
commit 46344d8fc4
10 changed files with 391 additions and 276 deletions

2
pnpm-lock.yaml generated
View File

@@ -1,4 +1,4 @@
lockfileVersion: '6.0'
lockfileVersion: '6.1'
settings:
autoInstallPeers: true

View File

@@ -37,7 +37,7 @@ await prisma.promptVariant.createMany({
label: "Prompt Variant 1",
sortIndex: 0,
config: {
model: "gpt-3.5-turbo",
model: "gpt-3.5-turbo-0613",
messages: [{ role: "user", content: "What is the capital of {{country}}?" }],
temperature: 0,
},
@@ -47,7 +47,7 @@ await prisma.promptVariant.createMany({
label: "Prompt Variant 2",
sortIndex: 1,
config: {
model: "gpt-3.5-turbo",
model: "gpt-3.5-turbo-0613",
messages: [
{
role: "user",

File diff suppressed because one or more lines are too long

View File

@@ -6,9 +6,12 @@ import chroma from "chroma-js";
import { BsCurrencyDollar } from "react-icons/bs";
export default function VariantStats(props: { variant: PromptVariant }) {
const { evalResults, overallCost } = api.promptVariants.stats.useQuery({
variantId: props.variant.id,
}).data ?? { evalResults: [] };
const { data } = api.promptVariants.stats.useQuery(
{
variantId: props.variant.id,
},
{ initialData: { evalResults: [], overallCost: 0, scenarioCount: 0, outputCount: 0 } }
);
const [passColor, neutralColor, failColor] = useToken("colors", [
"green.500",
@@ -18,12 +21,19 @@ export default function VariantStats(props: { variant: PromptVariant }) {
const scale = chroma.scale([failColor, neutralColor, passColor]).domain([0, 0.5, 1]);
if (!(evalResults.length > 0) && !overallCost) return null;
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
if (!(data.evalResults.length > 0) && !data.overallCost) return null;
return (
<HStack justifyContent="space-between" alignItems="center" mx="2">
<HStack px={cellPadding.x} py={cellPadding.y} fontSize="sm">
{evalResults.map((result) => {
<HStack justifyContent="space-between" alignItems="center" mx="2" fontSize="xs">
{showNumFinished && (
<Text>
{data.outputCount} / {data.scenarioCount}
</Text>
)}
<HStack px={cellPadding.x} py={cellPadding.y}>
{data.evalResults.map((result) => {
const passedFrac = result.passCount / (result.passCount + result.failCount);
return (
<HStack key={result.id}>
@@ -35,10 +45,10 @@ export default function VariantStats(props: { variant: PromptVariant }) {
);
})}
</HStack>
{overallCost && (
<HStack spacing={0} align="center" color="gray.500" fontSize="xs" my="2">
{data.overallCost && (
<HStack spacing={0} align="center" color="gray.500" my="2">
<Icon as={BsCurrencyDollar} />
<Text mr={1}>{overallCost.toFixed(3)}</Text>
<Text mr={1}>{data.overallCost.toFixed(3)}</Text>
</HStack>
)}
</HStack>

View File

@@ -43,12 +43,11 @@ export const experimentsRouter = createTRPCRouter({
label: "Prompt Variant 1",
sortIndex: 0,
config: {
model: "gpt-3.5-turbo",
stream: true,
model: "gpt-3.5-turbo-0613",
messages: [
{
role: "system",
content: "count to three in {{input}}...",
content: "Return 'Ready to go!'",
},
],
},
@@ -57,13 +56,7 @@ export const experimentsRouter = createTRPCRouter({
prisma.testScenario.create({
data: {
experimentId: exp.id,
variableValues: { input: "Spanish" },
},
}),
prisma.templateVariable.create({
data: {
experimentId: exp.id,
label: "input",
variableValues: {},
},
}),
]);

View File

@@ -71,7 +71,12 @@ export const modelOutputsRouter = createTRPCRouter({
completionTokens: existingResponse.completionTokens ?? undefined,
};
} else {
modelResponse = await getCompletion(filledTemplate, input.channel);
try {
modelResponse = await getCompletion(filledTemplate, input.channel);
} catch (e) {
console.error(e);
throw e;
}
}
const modelOutput = await prisma.modelOutput.upsert({
@@ -79,7 +84,7 @@ export const modelOutputsRouter = createTRPCRouter({
promptVariantId_testScenarioId: {
promptVariantId: input.variantId,
testScenarioId: input.scenarioId,
}
},
},
create: {
promptVariantId: input.variantId,

View File

@@ -34,6 +34,19 @@ export const promptVariantsRouter = createTRPCRouter({
include: { evaluation: true },
});
const scenarioCount = await prisma.testScenario.count({
where: {
experimentId: variant.experimentId,
visible: true,
},
});
const outputCount = await prisma.modelOutput.count({
where: {
promptVariantId: input.variantId,
testScenario: { visible: true },
},
});
const overallTokens = await prisma.modelOutput.aggregate({
where: {
promptVariantId: input.variantId,
@@ -53,7 +66,7 @@ export const promptVariantsRouter = createTRPCRouter({
const overallCost = overallPromptCost + overallCompletionCost;
return { evalResults, overallCost };
return { evalResults, overallCost, scenarioCount, outputCount };
}),
create: publicProcedure

View File

@@ -12,7 +12,7 @@ export const evaluateOutput = (
if (!message) return false;
const stringifiedMessage = JSON.stringify(message);
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap);

View File

@@ -13,7 +13,7 @@ export const reevaluateVariant = async (variantId: string) => {
});
const modelOutputs = await prisma.modelOutput.findMany({
where: { promptVariantId: variantId },
where: { promptVariantId: variantId, statusCode: { notIn: [429] } },
include: { testScenario: true },
});
@@ -56,7 +56,11 @@ export const reevaluateEvaluation = async (evaluation: Evaluation) => {
});
const modelOutputs = await prisma.modelOutput.findMany({
where: { promptVariantId: { in: variants.map((v) => v.id) }, testScenario: { visible: true } },
where: {
promptVariantId: { in: variants.map((v) => v.id) },
testScenario: { visible: true },
statusCode: { notIn: [429] },
},
include: { testScenario: true },
});

View File

@@ -10,19 +10,20 @@ export default function useSocket(channel?: string) {
const [message, setMessage] = useState<ChatCompletion | null>(null);
useEffect(() => {
if (!channel) return;
console.log("connecting to channel", channel);
// Create websocket connection
socketRef.current = io(url);
socketRef.current.on("connect", () => {
// Join the specific room
if (channel) {
socketRef.current?.emit("join", channel);
socketRef.current?.emit("join", channel);
// Listen for 'message' events
socketRef.current?.on("message", (message: ChatCompletion) => {
setMessage(message);
});
}
// Listen for 'message' events
socketRef.current?.on("message", (message: ChatCompletion) => {
setMessage(message);
});
});
// Unsubscribe and disconnect on cleanup
@@ -32,6 +33,7 @@ export default function useSocket(channel?: string) {
socketRef.current.off("message");
}
socketRef.current.disconnect();
socketRef.current = undefined;
}
setMessage(null);
};