added websockets

This commit is contained in:
William Guss
2024-08-05 15:14:03 -07:00
parent 9d29c0d3b6
commit 576afa20ee
12 changed files with 183 additions and 65 deletions

View File

@@ -8,22 +8,41 @@ import Traces from './pages/Traces';
import { ThemeProvider } from './contexts/ThemeContext';
import './styles/globals.css';
import './styles/sourceCode.css';
import { useWebSocketConnection } from './hooks/useBackend';
import { Toaster, toast } from 'react-hot-toast';
const WebSocketConnectionProvider = ({children}) => {
const { isConnected } = useWebSocketConnection();
React.useEffect(() => {
if (isConnected) {
toast.success('Store connected', {
duration: 1000,
});
} else {
toast('Connecting to store...', {
icon: '🔄',
duration: 500,
});
}
}, [isConnected]);
return (
<>
{children}
<Toaster position="top-right" />
</>
);
};
// Create a client
const queryClient = new QueryClient({
defaultOptions: {
queries: {
refetchOnWindowFocus: false, // default: true
retry: false, // default: 3
staleTime: 5 * 60 * 1000, // 5 minutes
},
},
});
const queryClient = new QueryClient();
function App() {
return (
<QueryClientProvider client={queryClient}>
<ThemeProvider>
<WebSocketConnectionProvider>
<Router>
<div className="flex min-h-screen max-h-screen bg-gray-900 text-gray-100">
<Sidebar />
@@ -38,6 +57,7 @@ function App() {
</div>
</div>
</Router>
</WebSocketConnectionProvider>
</ThemeProvider>
</QueryClientProvider>
);

View File

@@ -9,13 +9,24 @@ const TableRow = ({ item, schema, level = 0, onRowClick, columnWidths, updateWid
const hasChildren = item.children && item.children.length > 0;
const isExpanded = expandedRows[item.id];
const isSelected = isItemSelected(item);
const [isNew, setIsNew] = useState(true);
const customRowClassName = rowClassName ? rowClassName(item) : '';
useEffect(() => {
if (isNew) {
const timer = setTimeout(() => setIsNew(false), 200);
return () => clearTimeout(timer);
}
}, [isNew]);
return (
<React.Fragment>
<tr
className={`border-b border-gray-800 hover:bg-gray-800/30 cursor-pointer transition-colors duration-500 ease-in-out ${isSelected ? 'bg-blue-900/30' : ''} ${customRowClassName}`}
className={`border-b border-gray-800 hover:bg-gray-800/30 cursor-pointer transition-all duration-500 ease-in-out
${isSelected ? 'bg-blue-900/30' : ''}
${customRowClassName}
${isNew ? 'animate-fade-in bg-green-900/30' : ''}`}
onClick={() => {
if (onRowClick) onRowClick(item);
}}

View File

@@ -11,7 +11,6 @@ const InvocationsTable = ({ invocations, currentPage, setCurrentPage, pageSize,
const navigate = useNavigate();
const onClickLMP = useCallback(({lmp, id : invocationId}) => {
navigate(`/lmp/${lmp.name}/${lmp.lmp_id}?i=${invocationId}`);
}, [navigate]);

View File

@@ -1,8 +1,45 @@
import { useQuery, useQueries } from '@tanstack/react-query';
import { useQuery, useQueryClient, useQueries } from '@tanstack/react-query';
import axios from 'axios';
import { useEffect, useState } from 'react';
const API_BASE_URL = "http://localhost:8080";
const WS_URL = "ws://localhost:8080/ws";
export const useWebSocketConnection = () => {
const queryClient = useQueryClient();
const [isConnected, setIsConnected] = useState(false);
useEffect(() => {
const socket = new WebSocket(WS_URL);
socket.onopen = () => {
console.log('WebSocket connected');
setIsConnected(true);
};
socket.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.entity === 'database_updated') {
// Invalidate relevant queries
queryClient.invalidateQueries({queryKey: ['traces']});
queryClient.invalidateQueries({queryKey: ['latestLMPs']});
queryClient.invalidateQueries({queryKey: ['invocations']}) ;
queryClient.invalidateQueries({queryKey: ['lmpDetails']});
console.log('Database updated, invalidating queries');
}
};
socket.onclose = () => {
console.log('WebSocket disconnected');
setIsConnected(false);
};
return () => {
console.log('WebSocket connection closed');
socket.close();
};
}, [queryClient]);
return { isConnected };
};
export const useLMPs = (name, id) => {
return useQuery({
@@ -21,7 +58,7 @@ export const useLMPs = (name, id) => {
});
};
export const useInvocations = (name, id, page = 0, pageSize = 50) => {
export const useInvocationsFromLMP = (name, id, page = 0, pageSize = 50) => {
return useQuery({
queryKey: ['invocations', name, id, page, pageSize],
queryFn: async () => {
@@ -39,6 +76,18 @@ export const useInvocations = (name, id, page = 0, pageSize = 50) => {
});
};
export const useInvocation = (id) => {
return useQuery({
queryKey: ['invocation', id],
queryFn: async () => {
const response = await axios.get(`${API_BASE_URL}/api/invocation/${id}`);
return response.data;
},
enabled: !!id,
});
}
export const useMultipleLMPs = (usesIds) => {
const multipleLMPs = useQueries({
queries: (usesIds || []).map(use => ({
@@ -55,26 +104,17 @@ export const useMultipleLMPs = (usesIds) => {
return { isLoading, data };
};
export const useLatestLMPs = (page = 0, pageSize = 100) => {
return useQuery({
queryKey: ['allLMPs', page, pageSize],
queryKey: ['latestLMPs', page, pageSize],
queryFn: async () => {
const skip = page * pageSize;
const response = await axios.get(`${API_BASE_URL}/api/latest/lmps?skip=${skip}&limit=${pageSize}`);
const lmps = response.data;
return lmps;
return response.data;
}
});
};
export const useTraces = (lmps) => {
return useQuery({
queryKey: ['traces', lmps],
@@ -103,4 +143,4 @@ export const useTraces = (lmps) => {
},
enabled: !!lmps && lmps.length > 0,
});
};
};

View File

@@ -1,6 +1,6 @@
import React, { useState, useEffect, useMemo } from "react";
import { useParams, useSearchParams, useNavigate, Link } from "react-router-dom";
import { useLMPs, useInvocations, useMultipleLMPs } from "../hooks/useBackend";
import { useLMPs, useInvocationsFromLMP, useMultipleLMPs, useInvocation } from "../hooks/useBackend";
import InvocationsTable from "../components/invocations/InvocationsTable";
import DependencyGraphPane from "../components/DependencyGraphPane";
import LMPSourceView from "../components/source/LMPSourceView";
@@ -35,6 +35,7 @@ function LMP() {
const requestedInvocationId = searchParams.get("i");
const [currentPage, setCurrentPage] = useState(0);
const pageSize = 50;
// TODO: Could be expensive if you have a funct on of versions.
const { data: versionHistory, isLoading: isLoadingLMP } = useLMPs(name);
@@ -47,7 +48,7 @@ function LMP() {
}
}, [versionHistory, id]);
const { data: invocations } = useInvocations(name, id);
const { data: invocations } = useInvocationsFromLMP(name, id, currentPage, pageSize);
const { data: uses } = useMultipleLMPs(lmp?.uses);
@@ -65,9 +66,14 @@ function LMP() {
: null;
}, [versionHistory, lmp]);
const requestedInvocation = useMemo(() => invocations?.find(
(invocation) => invocation.id === requestedInvocationId
), [invocations, requestedInvocationId]);
const {data: requestedInvocationQueryData} = useInvocation(requestedInvocationId);
const requestedInvocation = useMemo(() => {
if (!requestedInvocationQueryData)
return invocations?.find(i => i.id === requestedInvocationId);
else
return requestedInvocationQueryData;
}, [requestedInvocationQueryData, invocations, requestedInvocationId]);
useEffect(() => {
setSelectedTrace(requestedInvocation);
@@ -233,6 +239,7 @@ function LMP() {
<InvocationsTable
invocations={invocations}
currentPage={currentPage}
pageSize={pageSize}
setCurrentPage={setCurrentPage}
producingLmp={lmp}
onSelectTrace={(trace) => {

View File

@@ -3,7 +3,7 @@ import { FiCopy, FiZap, FiEdit2, FiFilter, FiClock, FiColumns, FiPause, FiPlay }
import InvocationsTable from '../components/invocations/InvocationsTable';
import InvocationsLayout from '../components/invocations/InvocationsLayout';
import { useNavigate, useLocation } from 'react-router-dom';
import { useInvocations } from '../hooks/useBackend';
import { useInvocationsFromLMP } from '../hooks/useBackend';
const Traces = () => {
const [selectedTrace, setSelectedTrace] = useState(null);
@@ -11,21 +11,13 @@ const Traces = () => {
const navigate = useNavigate();
const location = useLocation();
// TODO Unify invocation search behaviour with the LMP page.
const [currentPage, setCurrentPage] = useState(0);
const pageSize = 10;
const pageSize = 50;
const { data: invocations, refetch , isLoading } = useInvocations(null, null, currentPage, pageSize);
const { data: invocations , isLoading } = useInvocationsFromLMP(null, null, currentPage, pageSize);
useEffect(() => {
let intervalId;
if (isPolling) {
intervalId = setInterval(refetch, 200); // Poll every 200ms
}
return () => {
if (intervalId) clearInterval(intervalId);
};
}, [isPolling, refetch]);
useEffect(() => {
const searchParams = new URLSearchParams(location.search);

Binary file not shown.

View File

View File

@@ -141,6 +141,7 @@ class SQLStore(ell.store.Store):
))
if filters:
print(f"Filters: {filters}")
for key, value in filters.items():
query = query.where(getattr(SerializedLMP, key) == value)

View File

@@ -1,10 +1,12 @@
import asyncio
import os
import uvicorn
from argparse import ArgumentParser
from ell.studio.data_server import create_app
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from watchfiles import run_process
from watchfiles import awatch
def main():
parser = ArgumentParser(description="ELL Studio Data Server")
@@ -26,8 +28,23 @@ def main():
async def serve_react_app(full_path: str):
return FileResponse(os.path.join(static_dir, "index.html"))
# In production mode, run without auto-reloading
uvicorn.run(app, host=args.host, port=args.port)
db_path = os.path.join(args.storage_dir, "ell.db")
async def db_watcher():
async for changes in awatch(db_path):
print(f"Database changed: {changes}")
await app.notify_clients("database_updated")
# Start the database watcher
loop = asyncio.new_event_loop()
config = uvicorn.Config(app=app, port=args.port, loop=loop)
server = uvicorn.Server(config)
loop.create_task(server.serve())
loop.create_task(db_watcher())
loop.run_forever()
if __name__ == "__main__":
main()

View File

@@ -2,14 +2,33 @@ from datetime import datetime
from typing import Optional, Dict, Any, List
from ell.stores.sql import SQLiteStore
from ell import __version__
from fastapi import FastAPI, Query, HTTPException, Depends
from fastapi import FastAPI, Query, HTTPException, Depends, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
import os
import logging
import asyncio
import json
logger = logging.getLogger(__name__)
class ConnectionManager:
def __init__(self):
self.active_connections = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def broadcast(self, message: str):
for connection in self.active_connections:
print(f"Broadcasting message to {connection} {message}")
await connection.send_text(message)
def create_app(storage_dir: Optional[str] = None):
storage_path = storage_dir or os.environ.get("ELL_STORAGE_DIR") or os.getcwd()
assert storage_path, "ELL_STORAGE_DIR must be set"
@@ -26,13 +45,18 @@ def create_app(storage_dir: Optional[str] = None):
allow_headers=["*"],
)
@app.get("/api/lmps")
def get_lmps(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100)
):
lmps = serializer.get_lmps(skip=skip, limit=limit)
return lmps
manager = ConnectionManager()
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
# Handle incoming WebSocket messages if needed
except WebSocketDisconnect:
manager.disconnect(websocket)
@app.get("/api/latest/lmps")
def get_latest_lmps(
@@ -78,7 +102,7 @@ def create_app(storage_dir: Optional[str] = None):
def get_invocation(
invocation_id: str,
):
invocation = serializer.get_invocations(id=invocation_id)[0]
invocation = serializer.get_invocations(lmp_filters=dict(), filters={"id": invocation_id})[0]
return invocation
@app.get("/api/invocations")
@@ -107,14 +131,6 @@ def create_app(storage_dir: Optional[str] = None):
)
return invocations
@app.post("/api/invocations/search")
def search_invocations(
q: str = Query(...),
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100)
):
invocations = serializer.search_invocations(q, skip=skip, limit=limit)
return invocations
@app.get("/api/traces")
def get_consumption_graph(
@@ -129,5 +145,11 @@ def create_app(storage_dir: Optional[str] = None):
traces = serializer.get_all_traces_leading_to(invocation_id)
return traces
return app
async def notify_clients(entity: str, id: Optional[str] = None):
message = json.dumps({"entity": entity, "id": id})
await manager.broadcast(message)
# Add this method to the app object
app.notify_clients = notify_clients
return app

View File

@@ -8,6 +8,15 @@ module.exports = {
850: '#22272e',
},
},
keyframes: {
highlight: {
'0%': { backgroundColor: 'rgba(59, 130, 246, 0.5)' },
'100%': { backgroundColor: 'rgba(59, 130, 246, 0)' },
}
},
animation: {
highlight: 'highlight 1s ease-in-out',
}
},
},
plugins: [