Merge pull request #109 from google-gemini/fixes

Fixes app state and returns
This commit is contained in:
Philipp Schmid
2025-06-18 14:39:15 +02:00
committed by GitHub
5 changed files with 31 additions and 19 deletions

View File

@@ -16,14 +16,14 @@ class Configuration(BaseModel):
) )
reflection_model: str = Field( reflection_model: str = Field(
default="gemini-2.5-flash-preview-04-17", default="gemini-2.5-flash",
metadata={ metadata={
"description": "The name of the language model to use for the agent's reflection." "description": "The name of the language model to use for the agent's reflection."
}, },
) )
answer_model: str = Field( answer_model: str = Field(
default="gemini-2.5-pro-preview-05-06", default="gemini-2.5-pro",
metadata={ metadata={
"description": "The name of the language model to use for the agent's answer." "description": "The name of the language model to use for the agent's answer."
}, },

View File

@@ -78,7 +78,7 @@ def generate_query(state: OverallState, config: RunnableConfig) -> QueryGenerati
) )
# Generate the search queries # Generate the search queries
result = structured_llm.invoke(formatted_prompt) result = structured_llm.invoke(formatted_prompt)
return {"query_list": result.query} return {"search_query": result.query}
def continue_to_web_research(state: QueryGenerationState): def continue_to_web_research(state: QueryGenerationState):
@@ -88,7 +88,7 @@ def continue_to_web_research(state: QueryGenerationState):
""" """
return [ return [
Send("web_research", {"search_query": search_query, "id": int(idx)}) Send("web_research", {"search_query": search_query, "id": int(idx)})
for idx, search_query in enumerate(state["query_list"]) for idx, search_query in enumerate(state["search_query"])
] ]
@@ -153,7 +153,7 @@ def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
configurable = Configuration.from_runnable_config(config) configurable = Configuration.from_runnable_config(config)
# Increment the research loop count and get the reasoning model # Increment the research loop count and get the reasoning model
state["research_loop_count"] = state.get("research_loop_count", 0) + 1 state["research_loop_count"] = state.get("research_loop_count", 0) + 1
reasoning_model = state.get("reasoning_model") or configurable.reasoning_model reasoning_model = state.get("reasoning_model", configurable.reflection_model)
# Format the prompt # Format the prompt
current_date = get_current_date() current_date = get_current_date()
@@ -231,7 +231,7 @@ def finalize_answer(state: OverallState, config: RunnableConfig):
Dictionary with state update, including running_summary key containing the formatted final summary with sources Dictionary with state update, including running_summary key containing the formatted final summary with sources
""" """
configurable = Configuration.from_runnable_config(config) configurable = Configuration.from_runnable_config(config)
reasoning_model = state.get("reasoning_model") or configurable.reasoning_model reasoning_model = state.get("reasoning_model") or configurable.answer_model
# Format the prompt # Format the prompt
current_date = get_current_date() current_date = get_current_date()

View File

@@ -37,7 +37,7 @@ class Query(TypedDict):
class QueryGenerationState(TypedDict): class QueryGenerationState(TypedDict):
query_list: list[Query] search_query: list[Query]
class WebSearchState(TypedDict): class WebSearchState(TypedDict):

View File

@@ -4,6 +4,7 @@ import { useState, useEffect, useRef, useCallback } from "react";
import { ProcessedEvent } from "@/components/ActivityTimeline"; import { ProcessedEvent } from "@/components/ActivityTimeline";
import { WelcomeScreen } from "@/components/WelcomeScreen"; import { WelcomeScreen } from "@/components/WelcomeScreen";
import { ChatMessagesView } from "@/components/ChatMessagesView"; import { ChatMessagesView } from "@/components/ChatMessagesView";
import { Button } from "@/components/ui/button";
export default function App() { export default function App() {
const [processedEventsTimeline, setProcessedEventsTimeline] = useState< const [processedEventsTimeline, setProcessedEventsTimeline] = useState<
@@ -14,7 +15,7 @@ export default function App() {
>({}); >({});
const scrollAreaRef = useRef<HTMLDivElement>(null); const scrollAreaRef = useRef<HTMLDivElement>(null);
const hasFinalizeEventOccurredRef = useRef(false); const hasFinalizeEventOccurredRef = useRef(false);
const [error, setError] = useState<string | null>(null);
const thread = useStream<{ const thread = useStream<{
messages: Message[]; messages: Message[];
initial_search_query_count: number; initial_search_query_count: number;
@@ -26,15 +27,12 @@ export default function App() {
: "http://localhost:8123", : "http://localhost:8123",
assistantId: "agent", assistantId: "agent",
messagesKey: "messages", messagesKey: "messages",
onFinish: (event: any) => {
console.log(event);
},
onUpdateEvent: (event: any) => { onUpdateEvent: (event: any) => {
let processedEvent: ProcessedEvent | null = null; let processedEvent: ProcessedEvent | null = null;
if (event.generate_query) { if (event.generate_query) {
processedEvent = { processedEvent = {
title: "Generating Search Queries", title: "Generating Search Queries",
data: event.generate_query.query_list.join(", "), data: event.generate_query?.search_query?.join(", ") || "",
}; };
} else if (event.web_research) { } else if (event.web_research) {
const sources = event.web_research.sources_gathered || []; const sources = event.web_research.sources_gathered || [];
@@ -52,11 +50,7 @@ export default function App() {
} else if (event.reflection) { } else if (event.reflection) {
processedEvent = { processedEvent = {
title: "Reflection", title: "Reflection",
data: event.reflection.is_sufficient data: "Analysing Web Research Results",
? "Search successful, generating final answer."
: `Need more information, searching for ${event.reflection.follow_up_queries?.join(
", "
) || "additional information"}`,
}; };
} else if (event.finalize_answer) { } else if (event.finalize_answer) {
processedEvent = { processedEvent = {
@@ -72,6 +66,9 @@ export default function App() {
]); ]);
} }
}, },
onError: (error: any) => {
setError(error.message);
},
}); });
useEffect(() => { useEffect(() => {
@@ -166,6 +163,20 @@ export default function App() {
isLoading={thread.isLoading} isLoading={thread.isLoading}
onCancel={handleCancel} onCancel={handleCancel}
/> />
) : error ? (
<div className="flex flex-col items-center justify-center h-full">
<div className="flex flex-col items-center justify-center gap-4">
<h1 className="text-2xl text-red-400 font-bold">Error</h1>
<p className="text-red-400">{JSON.stringify(error)}</p>
<Button
variant="destructive"
onClick={() => window.location.reload()}
>
Retry
</Button>
</div>
</div>
) : ( ) : (
<ChatMessagesView <ChatMessagesView
messages={thread.messages} messages={thread.messages}

View File

@@ -203,7 +203,9 @@ const AiMessageBubble: React.FC<AiMessageBubbleProps> = ({
</ReactMarkdown> </ReactMarkdown>
<Button <Button
variant="default" variant="default"
className="cursor-pointer bg-neutral-700 border-neutral-600 text-neutral-300 self-end" className={`cursor-pointer bg-neutral-700 border-neutral-600 text-neutral-300 self-end ${
message.content.length > 0 ? "visible" : "hidden"
}`}
onClick={() => onClick={() =>
handleCopy( handleCopy(
typeof message.content === "string" typeof message.content === "string"
@@ -250,7 +252,6 @@ export function ChatMessagesView({
console.error("Failed to copy text: ", err); console.error("Failed to copy text: ", err);
} }
}; };
return ( return (
<div className="flex flex-col h-full overflow-hidden"> <div className="flex flex-col h-full overflow-hidden">
<ScrollArea className="flex-1 min-h-0" ref={scrollAreaRef}> <ScrollArea className="flex-1 min-h-0" ref={scrollAreaRef}>