small bugfixes
This commit is contained in:
2
pnpm-lock.yaml
generated
2
pnpm-lock.yaml
generated
@@ -1,4 +1,4 @@
|
|||||||
lockfileVersion: '6.0'
|
lockfileVersion: '6.1'
|
||||||
|
|
||||||
settings:
|
settings:
|
||||||
autoInstallPeers: true
|
autoInstallPeers: true
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ await prisma.promptVariant.createMany({
|
|||||||
label: "Prompt Variant 1",
|
label: "Prompt Variant 1",
|
||||||
sortIndex: 0,
|
sortIndex: 0,
|
||||||
config: {
|
config: {
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo-0613",
|
||||||
messages: [{ role: "user", content: "What is the capital of {{country}}?" }],
|
messages: [{ role: "user", content: "What is the capital of {{country}}?" }],
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
},
|
},
|
||||||
@@ -47,7 +47,7 @@ await prisma.promptVariant.createMany({
|
|||||||
label: "Prompt Variant 2",
|
label: "Prompt Variant 2",
|
||||||
sortIndex: 1,
|
sortIndex: 1,
|
||||||
config: {
|
config: {
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo-0613",
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: "user",
|
role: "user",
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -6,9 +6,12 @@ import chroma from "chroma-js";
|
|||||||
import { BsCurrencyDollar } from "react-icons/bs";
|
import { BsCurrencyDollar } from "react-icons/bs";
|
||||||
|
|
||||||
export default function VariantStats(props: { variant: PromptVariant }) {
|
export default function VariantStats(props: { variant: PromptVariant }) {
|
||||||
const { evalResults, overallCost } = api.promptVariants.stats.useQuery({
|
const { data } = api.promptVariants.stats.useQuery(
|
||||||
|
{
|
||||||
variantId: props.variant.id,
|
variantId: props.variant.id,
|
||||||
}).data ?? { evalResults: [] };
|
},
|
||||||
|
{ initialData: { evalResults: [], overallCost: 0, scenarioCount: 0, outputCount: 0 } }
|
||||||
|
);
|
||||||
|
|
||||||
const [passColor, neutralColor, failColor] = useToken("colors", [
|
const [passColor, neutralColor, failColor] = useToken("colors", [
|
||||||
"green.500",
|
"green.500",
|
||||||
@@ -18,12 +21,19 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
|
|
||||||
const scale = chroma.scale([failColor, neutralColor, passColor]).domain([0, 0.5, 1]);
|
const scale = chroma.scale([failColor, neutralColor, passColor]).domain([0, 0.5, 1]);
|
||||||
|
|
||||||
if (!(evalResults.length > 0) && !overallCost) return null;
|
const showNumFinished = data.scenarioCount > 0 && data.scenarioCount !== data.outputCount;
|
||||||
|
|
||||||
|
if (!(data.evalResults.length > 0) && !data.overallCost) return null;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<HStack justifyContent="space-between" alignItems="center" mx="2">
|
<HStack justifyContent="space-between" alignItems="center" mx="2" fontSize="xs">
|
||||||
<HStack px={cellPadding.x} py={cellPadding.y} fontSize="sm">
|
{showNumFinished && (
|
||||||
{evalResults.map((result) => {
|
<Text>
|
||||||
|
{data.outputCount} / {data.scenarioCount}
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
<HStack px={cellPadding.x} py={cellPadding.y}>
|
||||||
|
{data.evalResults.map((result) => {
|
||||||
const passedFrac = result.passCount / (result.passCount + result.failCount);
|
const passedFrac = result.passCount / (result.passCount + result.failCount);
|
||||||
return (
|
return (
|
||||||
<HStack key={result.id}>
|
<HStack key={result.id}>
|
||||||
@@ -35,10 +45,10 @@ export default function VariantStats(props: { variant: PromptVariant }) {
|
|||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
</HStack>
|
</HStack>
|
||||||
{overallCost && (
|
{data.overallCost && (
|
||||||
<HStack spacing={0} align="center" color="gray.500" fontSize="xs" my="2">
|
<HStack spacing={0} align="center" color="gray.500" my="2">
|
||||||
<Icon as={BsCurrencyDollar} />
|
<Icon as={BsCurrencyDollar} />
|
||||||
<Text mr={1}>{overallCost.toFixed(3)}</Text>
|
<Text mr={1}>{data.overallCost.toFixed(3)}</Text>
|
||||||
</HStack>
|
</HStack>
|
||||||
)}
|
)}
|
||||||
</HStack>
|
</HStack>
|
||||||
|
|||||||
@@ -43,12 +43,11 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
label: "Prompt Variant 1",
|
label: "Prompt Variant 1",
|
||||||
sortIndex: 0,
|
sortIndex: 0,
|
||||||
config: {
|
config: {
|
||||||
model: "gpt-3.5-turbo",
|
model: "gpt-3.5-turbo-0613",
|
||||||
stream: true,
|
|
||||||
messages: [
|
messages: [
|
||||||
{
|
{
|
||||||
role: "system",
|
role: "system",
|
||||||
content: "count to three in {{input}}...",
|
content: "Return 'Ready to go!'",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
@@ -57,13 +56,7 @@ export const experimentsRouter = createTRPCRouter({
|
|||||||
prisma.testScenario.create({
|
prisma.testScenario.create({
|
||||||
data: {
|
data: {
|
||||||
experimentId: exp.id,
|
experimentId: exp.id,
|
||||||
variableValues: { input: "Spanish" },
|
variableValues: {},
|
||||||
},
|
|
||||||
}),
|
|
||||||
prisma.templateVariable.create({
|
|
||||||
data: {
|
|
||||||
experimentId: exp.id,
|
|
||||||
label: "input",
|
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
]);
|
]);
|
||||||
|
|||||||
@@ -71,7 +71,12 @@ export const modelOutputsRouter = createTRPCRouter({
|
|||||||
completionTokens: existingResponse.completionTokens ?? undefined,
|
completionTokens: existingResponse.completionTokens ?? undefined,
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
|
try {
|
||||||
modelResponse = await getCompletion(filledTemplate, input.channel);
|
modelResponse = await getCompletion(filledTemplate, input.channel);
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelOutput = await prisma.modelOutput.upsert({
|
const modelOutput = await prisma.modelOutput.upsert({
|
||||||
@@ -79,7 +84,7 @@ export const modelOutputsRouter = createTRPCRouter({
|
|||||||
promptVariantId_testScenarioId: {
|
promptVariantId_testScenarioId: {
|
||||||
promptVariantId: input.variantId,
|
promptVariantId: input.variantId,
|
||||||
testScenarioId: input.scenarioId,
|
testScenarioId: input.scenarioId,
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
create: {
|
create: {
|
||||||
promptVariantId: input.variantId,
|
promptVariantId: input.variantId,
|
||||||
|
|||||||
@@ -34,6 +34,19 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
include: { evaluation: true },
|
include: { evaluation: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const scenarioCount = await prisma.testScenario.count({
|
||||||
|
where: {
|
||||||
|
experimentId: variant.experimentId,
|
||||||
|
visible: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
const outputCount = await prisma.modelOutput.count({
|
||||||
|
where: {
|
||||||
|
promptVariantId: input.variantId,
|
||||||
|
testScenario: { visible: true },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
const overallTokens = await prisma.modelOutput.aggregate({
|
const overallTokens = await prisma.modelOutput.aggregate({
|
||||||
where: {
|
where: {
|
||||||
promptVariantId: input.variantId,
|
promptVariantId: input.variantId,
|
||||||
@@ -53,7 +66,7 @@ export const promptVariantsRouter = createTRPCRouter({
|
|||||||
|
|
||||||
const overallCost = overallPromptCost + overallCompletionCost;
|
const overallCost = overallPromptCost + overallCompletionCost;
|
||||||
|
|
||||||
return { evalResults, overallCost };
|
return { evalResults, overallCost, scenarioCount, outputCount };
|
||||||
}),
|
}),
|
||||||
|
|
||||||
create: publicProcedure
|
create: publicProcedure
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ export const evaluateOutput = (
|
|||||||
|
|
||||||
if (!message) return false;
|
if (!message) return false;
|
||||||
|
|
||||||
const stringifiedMessage = JSON.stringify(message);
|
const stringifiedMessage = message.content ?? JSON.stringify(message.function_call);
|
||||||
|
|
||||||
const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap);
|
const matchRegex = fillTemplate(evaluation.matchString, scenario.variableValues as VariableMap);
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ export const reevaluateVariant = async (variantId: string) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const modelOutputs = await prisma.modelOutput.findMany({
|
const modelOutputs = await prisma.modelOutput.findMany({
|
||||||
where: { promptVariantId: variantId },
|
where: { promptVariantId: variantId, statusCode: { notIn: [429] } },
|
||||||
include: { testScenario: true },
|
include: { testScenario: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -56,7 +56,11 @@ export const reevaluateEvaluation = async (evaluation: Evaluation) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const modelOutputs = await prisma.modelOutput.findMany({
|
const modelOutputs = await prisma.modelOutput.findMany({
|
||||||
where: { promptVariantId: { in: variants.map((v) => v.id) }, testScenario: { visible: true } },
|
where: {
|
||||||
|
promptVariantId: { in: variants.map((v) => v.id) },
|
||||||
|
testScenario: { visible: true },
|
||||||
|
statusCode: { notIn: [429] },
|
||||||
|
},
|
||||||
include: { testScenario: true },
|
include: { testScenario: true },
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -10,19 +10,20 @@ export default function useSocket(channel?: string) {
|
|||||||
const [message, setMessage] = useState<ChatCompletion | null>(null);
|
const [message, setMessage] = useState<ChatCompletion | null>(null);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
if (!channel) return;
|
||||||
|
|
||||||
|
console.log("connecting to channel", channel);
|
||||||
// Create websocket connection
|
// Create websocket connection
|
||||||
socketRef.current = io(url);
|
socketRef.current = io(url);
|
||||||
|
|
||||||
socketRef.current.on("connect", () => {
|
socketRef.current.on("connect", () => {
|
||||||
// Join the specific room
|
// Join the specific room
|
||||||
if (channel) {
|
|
||||||
socketRef.current?.emit("join", channel);
|
socketRef.current?.emit("join", channel);
|
||||||
|
|
||||||
// Listen for 'message' events
|
// Listen for 'message' events
|
||||||
socketRef.current?.on("message", (message: ChatCompletion) => {
|
socketRef.current?.on("message", (message: ChatCompletion) => {
|
||||||
setMessage(message);
|
setMessage(message);
|
||||||
});
|
});
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Unsubscribe and disconnect on cleanup
|
// Unsubscribe and disconnect on cleanup
|
||||||
@@ -32,6 +33,7 @@ export default function useSocket(channel?: string) {
|
|||||||
socketRef.current.off("message");
|
socketRef.current.off("message");
|
||||||
}
|
}
|
||||||
socketRef.current.disconnect();
|
socketRef.current.disconnect();
|
||||||
|
socketRef.current = undefined;
|
||||||
}
|
}
|
||||||
setMessage(null);
|
setMessage(null);
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user