mirror of
https://github.com/MadcowD/ell.git
synced 2024-09-22 16:14:36 +03:00
added websockets
This commit is contained in:
@@ -8,22 +8,41 @@ import Traces from './pages/Traces';
|
||||
import { ThemeProvider } from './contexts/ThemeContext';
|
||||
import './styles/globals.css';
|
||||
import './styles/sourceCode.css';
|
||||
import { useWebSocketConnection } from './hooks/useBackend';
|
||||
import { Toaster, toast } from 'react-hot-toast';
|
||||
|
||||
const WebSocketConnectionProvider = ({children}) => {
|
||||
const { isConnected } = useWebSocketConnection();
|
||||
|
||||
React.useEffect(() => {
|
||||
if (isConnected) {
|
||||
toast.success('Store connected', {
|
||||
duration: 1000,
|
||||
});
|
||||
} else {
|
||||
toast('Connecting to store...', {
|
||||
icon: '🔄',
|
||||
duration: 500,
|
||||
});
|
||||
}
|
||||
}, [isConnected]);
|
||||
|
||||
return (
|
||||
<>
|
||||
{children}
|
||||
<Toaster position="top-right" />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
// Create a client
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
refetchOnWindowFocus: false, // default: true
|
||||
retry: false, // default: 3
|
||||
staleTime: 5 * 60 * 1000, // 5 minutes
|
||||
},
|
||||
},
|
||||
});
|
||||
const queryClient = new QueryClient();
|
||||
|
||||
function App() {
|
||||
return (
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<ThemeProvider>
|
||||
<WebSocketConnectionProvider>
|
||||
<Router>
|
||||
<div className="flex min-h-screen max-h-screen bg-gray-900 text-gray-100">
|
||||
<Sidebar />
|
||||
@@ -38,6 +57,7 @@ function App() {
|
||||
</div>
|
||||
</div>
|
||||
</Router>
|
||||
</WebSocketConnectionProvider>
|
||||
</ThemeProvider>
|
||||
</QueryClientProvider>
|
||||
);
|
||||
|
||||
@@ -9,13 +9,24 @@ const TableRow = ({ item, schema, level = 0, onRowClick, columnWidths, updateWid
|
||||
const hasChildren = item.children && item.children.length > 0;
|
||||
const isExpanded = expandedRows[item.id];
|
||||
const isSelected = isItemSelected(item);
|
||||
const [isNew, setIsNew] = useState(true);
|
||||
|
||||
const customRowClassName = rowClassName ? rowClassName(item) : '';
|
||||
|
||||
useEffect(() => {
|
||||
if (isNew) {
|
||||
const timer = setTimeout(() => setIsNew(false), 200);
|
||||
return () => clearTimeout(timer);
|
||||
}
|
||||
}, [isNew]);
|
||||
|
||||
return (
|
||||
<React.Fragment>
|
||||
<tr
|
||||
className={`border-b border-gray-800 hover:bg-gray-800/30 cursor-pointer transition-colors duration-500 ease-in-out ${isSelected ? 'bg-blue-900/30' : ''} ${customRowClassName}`}
|
||||
className={`border-b border-gray-800 hover:bg-gray-800/30 cursor-pointer transition-all duration-500 ease-in-out
|
||||
${isSelected ? 'bg-blue-900/30' : ''}
|
||||
${customRowClassName}
|
||||
${isNew ? 'animate-fade-in bg-green-900/30' : ''}`}
|
||||
onClick={() => {
|
||||
if (onRowClick) onRowClick(item);
|
||||
}}
|
||||
|
||||
@@ -11,7 +11,6 @@ const InvocationsTable = ({ invocations, currentPage, setCurrentPage, pageSize,
|
||||
const navigate = useNavigate();
|
||||
|
||||
|
||||
|
||||
const onClickLMP = useCallback(({lmp, id : invocationId}) => {
|
||||
navigate(`/lmp/${lmp.name}/${lmp.lmp_id}?i=${invocationId}`);
|
||||
}, [navigate]);
|
||||
|
||||
@@ -1,8 +1,45 @@
|
||||
import { useQuery, useQueries } from '@tanstack/react-query';
|
||||
import { useQuery, useQueryClient, useQueries } from '@tanstack/react-query';
|
||||
import axios from 'axios';
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
const API_BASE_URL = "http://localhost:8080";
|
||||
const WS_URL = "ws://localhost:8080/ws";
|
||||
|
||||
export const useWebSocketConnection = () => {
|
||||
const queryClient = useQueryClient();
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
useEffect(() => {
|
||||
const socket = new WebSocket(WS_URL);
|
||||
|
||||
socket.onopen = () => {
|
||||
console.log('WebSocket connected');
|
||||
setIsConnected(true);
|
||||
};
|
||||
|
||||
socket.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.entity === 'database_updated') {
|
||||
// Invalidate relevant queries
|
||||
queryClient.invalidateQueries({queryKey: ['traces']});
|
||||
queryClient.invalidateQueries({queryKey: ['latestLMPs']});
|
||||
queryClient.invalidateQueries({queryKey: ['invocations']}) ;
|
||||
queryClient.invalidateQueries({queryKey: ['lmpDetails']});
|
||||
console.log('Database updated, invalidating queries');
|
||||
}
|
||||
};
|
||||
|
||||
socket.onclose = () => {
|
||||
console.log('WebSocket disconnected');
|
||||
setIsConnected(false);
|
||||
};
|
||||
|
||||
return () => {
|
||||
console.log('WebSocket connection closed');
|
||||
socket.close();
|
||||
};
|
||||
}, [queryClient]);
|
||||
return { isConnected };
|
||||
};
|
||||
|
||||
export const useLMPs = (name, id) => {
|
||||
return useQuery({
|
||||
@@ -21,7 +58,7 @@ export const useLMPs = (name, id) => {
|
||||
});
|
||||
};
|
||||
|
||||
export const useInvocations = (name, id, page = 0, pageSize = 50) => {
|
||||
export const useInvocationsFromLMP = (name, id, page = 0, pageSize = 50) => {
|
||||
return useQuery({
|
||||
queryKey: ['invocations', name, id, page, pageSize],
|
||||
queryFn: async () => {
|
||||
@@ -39,6 +76,18 @@ export const useInvocations = (name, id, page = 0, pageSize = 50) => {
|
||||
});
|
||||
};
|
||||
|
||||
export const useInvocation = (id) => {
|
||||
return useQuery({
|
||||
queryKey: ['invocation', id],
|
||||
queryFn: async () => {
|
||||
const response = await axios.get(`${API_BASE_URL}/api/invocation/${id}`);
|
||||
return response.data;
|
||||
},
|
||||
enabled: !!id,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
export const useMultipleLMPs = (usesIds) => {
|
||||
const multipleLMPs = useQueries({
|
||||
queries: (usesIds || []).map(use => ({
|
||||
@@ -55,26 +104,17 @@ export const useMultipleLMPs = (usesIds) => {
|
||||
return { isLoading, data };
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
export const useLatestLMPs = (page = 0, pageSize = 100) => {
|
||||
return useQuery({
|
||||
queryKey: ['allLMPs', page, pageSize],
|
||||
queryKey: ['latestLMPs', page, pageSize],
|
||||
queryFn: async () => {
|
||||
const skip = page * pageSize;
|
||||
const response = await axios.get(`${API_BASE_URL}/api/latest/lmps?skip=${skip}&limit=${pageSize}`);
|
||||
const lmps = response.data;
|
||||
|
||||
return lmps;
|
||||
return response.data;
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
export const useTraces = (lmps) => {
|
||||
return useQuery({
|
||||
queryKey: ['traces', lmps],
|
||||
@@ -103,4 +143,4 @@ export const useTraces = (lmps) => {
|
||||
},
|
||||
enabled: !!lmps && lmps.length > 0,
|
||||
});
|
||||
};
|
||||
};
|
||||
@@ -1,6 +1,6 @@
|
||||
import React, { useState, useEffect, useMemo } from "react";
|
||||
import { useParams, useSearchParams, useNavigate, Link } from "react-router-dom";
|
||||
import { useLMPs, useInvocations, useMultipleLMPs } from "../hooks/useBackend";
|
||||
import { useLMPs, useInvocationsFromLMP, useMultipleLMPs, useInvocation } from "../hooks/useBackend";
|
||||
import InvocationsTable from "../components/invocations/InvocationsTable";
|
||||
import DependencyGraphPane from "../components/DependencyGraphPane";
|
||||
import LMPSourceView from "../components/source/LMPSourceView";
|
||||
@@ -35,6 +35,7 @@ function LMP() {
|
||||
const requestedInvocationId = searchParams.get("i");
|
||||
|
||||
const [currentPage, setCurrentPage] = useState(0);
|
||||
const pageSize = 50;
|
||||
|
||||
// TODO: Could be expensive if you have a funct on of versions.
|
||||
const { data: versionHistory, isLoading: isLoadingLMP } = useLMPs(name);
|
||||
@@ -47,7 +48,7 @@ function LMP() {
|
||||
}
|
||||
}, [versionHistory, id]);
|
||||
|
||||
const { data: invocations } = useInvocations(name, id);
|
||||
const { data: invocations } = useInvocationsFromLMP(name, id, currentPage, pageSize);
|
||||
const { data: uses } = useMultipleLMPs(lmp?.uses);
|
||||
|
||||
|
||||
@@ -65,9 +66,14 @@ function LMP() {
|
||||
: null;
|
||||
}, [versionHistory, lmp]);
|
||||
|
||||
const requestedInvocation = useMemo(() => invocations?.find(
|
||||
(invocation) => invocation.id === requestedInvocationId
|
||||
), [invocations, requestedInvocationId]);
|
||||
const {data: requestedInvocationQueryData} = useInvocation(requestedInvocationId);
|
||||
const requestedInvocation = useMemo(() => {
|
||||
if (!requestedInvocationQueryData)
|
||||
return invocations?.find(i => i.id === requestedInvocationId);
|
||||
else
|
||||
return requestedInvocationQueryData;
|
||||
|
||||
}, [requestedInvocationQueryData, invocations, requestedInvocationId]);
|
||||
|
||||
useEffect(() => {
|
||||
setSelectedTrace(requestedInvocation);
|
||||
@@ -233,6 +239,7 @@ function LMP() {
|
||||
<InvocationsTable
|
||||
invocations={invocations}
|
||||
currentPage={currentPage}
|
||||
pageSize={pageSize}
|
||||
setCurrentPage={setCurrentPage}
|
||||
producingLmp={lmp}
|
||||
onSelectTrace={(trace) => {
|
||||
|
||||
@@ -3,7 +3,7 @@ import { FiCopy, FiZap, FiEdit2, FiFilter, FiClock, FiColumns, FiPause, FiPlay }
|
||||
import InvocationsTable from '../components/invocations/InvocationsTable';
|
||||
import InvocationsLayout from '../components/invocations/InvocationsLayout';
|
||||
import { useNavigate, useLocation } from 'react-router-dom';
|
||||
import { useInvocations } from '../hooks/useBackend';
|
||||
import { useInvocationsFromLMP } from '../hooks/useBackend';
|
||||
|
||||
const Traces = () => {
|
||||
const [selectedTrace, setSelectedTrace] = useState(null);
|
||||
@@ -11,21 +11,13 @@ const Traces = () => {
|
||||
const navigate = useNavigate();
|
||||
const location = useLocation();
|
||||
|
||||
|
||||
// TODO Unify invocation search behaviour with the LMP page.
|
||||
const [currentPage, setCurrentPage] = useState(0);
|
||||
const pageSize = 10;
|
||||
const pageSize = 50;
|
||||
|
||||
const { data: invocations, refetch , isLoading } = useInvocations(null, null, currentPage, pageSize);
|
||||
const { data: invocations , isLoading } = useInvocationsFromLMP(null, null, currentPage, pageSize);
|
||||
|
||||
useEffect(() => {
|
||||
let intervalId;
|
||||
if (isPolling) {
|
||||
intervalId = setInterval(refetch, 200); // Poll every 200ms
|
||||
}
|
||||
|
||||
return () => {
|
||||
if (intervalId) clearInterval(intervalId);
|
||||
};
|
||||
}, [isPolling, refetch]);
|
||||
|
||||
useEffect(() => {
|
||||
const searchParams = new URLSearchParams(location.search);
|
||||
|
||||
BIN
examples/sqlite_example/ell.db-shm
Normal file
BIN
examples/sqlite_example/ell.db-shm
Normal file
Binary file not shown.
0
examples/sqlite_example/ell.db-wal
Normal file
0
examples/sqlite_example/ell.db-wal
Normal file
@@ -141,6 +141,7 @@ class SQLStore(ell.store.Store):
|
||||
))
|
||||
|
||||
if filters:
|
||||
print(f"Filters: {filters}")
|
||||
for key, value in filters.items():
|
||||
query = query.where(getattr(SerializedLMP, key) == value)
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uvicorn
|
||||
from argparse import ArgumentParser
|
||||
from ell.studio.data_server import create_app
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
from watchfiles import run_process
|
||||
from watchfiles import awatch
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser(description="ELL Studio Data Server")
|
||||
@@ -26,8 +28,23 @@ def main():
|
||||
async def serve_react_app(full_path: str):
|
||||
return FileResponse(os.path.join(static_dir, "index.html"))
|
||||
|
||||
# In production mode, run without auto-reloading
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
db_path = os.path.join(args.storage_dir, "ell.db")
|
||||
|
||||
async def db_watcher():
|
||||
async for changes in awatch(db_path):
|
||||
print(f"Database changed: {changes}")
|
||||
await app.notify_clients("database_updated")
|
||||
|
||||
# Start the database watcher
|
||||
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
config = uvicorn.Config(app=app, port=args.port, loop=loop)
|
||||
server = uvicorn.Server(config)
|
||||
loop.create_task(server.serve())
|
||||
loop.create_task(db_watcher())
|
||||
loop.run_forever()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -2,14 +2,33 @@ from datetime import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from ell.stores.sql import SQLiteStore
|
||||
from ell import __version__
|
||||
from fastapi import FastAPI, Query, HTTPException, Depends
|
||||
from fastapi import FastAPI, Query, HTTPException, Depends, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections = []
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
self.active_connections.remove(websocket)
|
||||
|
||||
async def broadcast(self, message: str):
|
||||
for connection in self.active_connections:
|
||||
print(f"Broadcasting message to {connection} {message}")
|
||||
await connection.send_text(message)
|
||||
|
||||
|
||||
def create_app(storage_dir: Optional[str] = None):
|
||||
storage_path = storage_dir or os.environ.get("ELL_STORAGE_DIR") or os.getcwd()
|
||||
assert storage_path, "ELL_STORAGE_DIR must be set"
|
||||
@@ -26,13 +45,18 @@ def create_app(storage_dir: Optional[str] = None):
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/api/lmps")
|
||||
def get_lmps(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=100)
|
||||
):
|
||||
lmps = serializer.get_lmps(skip=skip, limit=limit)
|
||||
return lmps
|
||||
manager = ConnectionManager()
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
# Handle incoming WebSocket messages if needed
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
|
||||
|
||||
@app.get("/api/latest/lmps")
|
||||
def get_latest_lmps(
|
||||
@@ -78,7 +102,7 @@ def create_app(storage_dir: Optional[str] = None):
|
||||
def get_invocation(
|
||||
invocation_id: str,
|
||||
):
|
||||
invocation = serializer.get_invocations(id=invocation_id)[0]
|
||||
invocation = serializer.get_invocations(lmp_filters=dict(), filters={"id": invocation_id})[0]
|
||||
return invocation
|
||||
|
||||
@app.get("/api/invocations")
|
||||
@@ -107,14 +131,6 @@ def create_app(storage_dir: Optional[str] = None):
|
||||
)
|
||||
return invocations
|
||||
|
||||
@app.post("/api/invocations/search")
|
||||
def search_invocations(
|
||||
q: str = Query(...),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=100)
|
||||
):
|
||||
invocations = serializer.search_invocations(q, skip=skip, limit=limit)
|
||||
return invocations
|
||||
|
||||
@app.get("/api/traces")
|
||||
def get_consumption_graph(
|
||||
@@ -129,5 +145,11 @@ def create_app(storage_dir: Optional[str] = None):
|
||||
traces = serializer.get_all_traces_leading_to(invocation_id)
|
||||
return traces
|
||||
|
||||
return app
|
||||
async def notify_clients(entity: str, id: Optional[str] = None):
|
||||
message = json.dumps({"entity": entity, "id": id})
|
||||
await manager.broadcast(message)
|
||||
|
||||
# Add this method to the app object
|
||||
app.notify_clients = notify_clients
|
||||
|
||||
return app
|
||||
@@ -8,6 +8,15 @@ module.exports = {
|
||||
850: '#22272e',
|
||||
},
|
||||
},
|
||||
keyframes: {
|
||||
highlight: {
|
||||
'0%': { backgroundColor: 'rgba(59, 130, 246, 0.5)' },
|
||||
'100%': { backgroundColor: 'rgba(59, 130, 246, 0)' },
|
||||
}
|
||||
},
|
||||
animation: {
|
||||
highlight: 'highlight 1s ease-in-out',
|
||||
}
|
||||
},
|
||||
},
|
||||
plugins: [
|
||||
|
||||
Reference in New Issue
Block a user