Merge pull request #356 from transformerlab/add/model-architecture-visualization

Add Model Architecture Visualization
This commit is contained in:
Deep Gandhi
2025-04-08 12:12:43 -07:00
committed by GitHub
7 changed files with 886 additions and 19 deletions

11
package-lock.json generated
View File

@@ -5,7 +5,7 @@
"requires": true,
"packages": {
"": {
"version": "0.11.0",
"version": "0.12.0",
"hasInstallScript": true,
"license": "MIT",
"dependencies": {
@@ -60,6 +60,7 @@
"serve-favicon": "^2.5.0",
"swr": "^2.3.2",
"tail": "^2.2.6",
"three": "^0.175.0",
"tree-kill": "^1.2.2",
"use-debounce": "^9.0.4",
"validator": "^13.7.0"
@@ -23191,6 +23192,12 @@
"integrity": "sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw==",
"dev": true
},
"node_modules/three": {
"version": "0.175.0",
"resolved": "https://registry.npmjs.org/three/-/three-0.175.0.tgz",
"integrity": "sha512-nNE3pnTHxXN/Phw768u0Grr7W4+rumGg/H6PgeseNJojkJtmeHJfZWi41Gp2mpXl1pg1pf1zjwR4McM1jTqkpg==",
"license": "MIT"
},
"node_modules/thunky": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/thunky/-/thunky-1.1.0.tgz",
@@ -25281,4 +25288,4 @@
}
}
}
}
}

View File

@@ -143,6 +143,7 @@
"serve-favicon": "^2.5.0",
"swr": "^2.3.2",
"tail": "^2.2.6",
"three": "^0.175.0",
"tree-kill": "^1.2.2",
"use-debounce": "^9.0.4",
"validator": "^13.7.0"
@@ -185,10 +186,10 @@
"eslint-plugin-import": "^2.27.5",
"eslint-plugin-jest": "^27.2.1",
"eslint-plugin-jsx-a11y": "^6.7.1",
"eslint-plugin-prettier": "^5.2.3",
"eslint-plugin-promise": "^6.1.1",
"eslint-plugin-react": "^7.32.2",
"eslint-plugin-react-hooks": "^4.6.0",
"eslint-plugin-prettier": "^5.2.3",
"file-loader": "^6.2.0",
"html-webpack-plugin": "^5.6.3",
"identity-obj-proxy": "^3.0.0",
@@ -304,4 +305,4 @@
],
"logLevel": "quiet"
}
}
}

View File

@@ -6,7 +6,7 @@
"packages": {
"": {
"name": "transformerlab",
"version": "0.11.0",
"version": "0.12.0",
"hasInstallScript": true,
"license": "AGPL-3.0",
"dependencies": {
@@ -147,4 +147,4 @@
"integrity": "sha512-KXXFFdAbFXY4geFIwoyNK+f5Z1b7swfXABfL7HXCmoIWMKU3dmS26672A4EeQtDzLKy7SXmfBu51JolvEKwtGA=="
}
}
}
}

View File

@@ -21,6 +21,7 @@ import {
Trash2Icon,
Undo2Icon,
XCircleIcon,
LayersIcon,
} from 'lucide-react';
import useSWR from 'swr';
@@ -204,6 +205,27 @@ export default function CurrentFoundationInfo({
});
};
const handleModelVisualizationClick = async () => {
try {
// Check if the local model server is running by checking worker health
const response = await fetch(
`${chatAPI.INFERENCE_SERVER_URL()}server/worker_healthz`,
);
const data = await response.json();
if (response.status === 200 && Array.isArray(data) && data.length > 0) {
// Model server is running, navigate to visualization page
navigate('/experiment/model_architecture_visualization');
} else {
// Server responded but workers aren't ready
alert('Please Run the model before visualizing its architecture');
}
} catch (error) {
console.error('Failed to check model server status:', error);
alert('Please Run the model before visualizing its architecture');
}
};
return (
<Sheet
sx={{
@@ -220,6 +242,17 @@ export default function CurrentFoundationInfo({
setFoundation={setFoundation}
/>
<Box sx={{ display: 'flex', justifyContent: 'flex-end', mt: 1, mb: 2 }}>
<Button
variant="outlined"
color="primary"
startDecorator={<LayersIcon size={18} />}
onClick={handleModelVisualizationClick}
>
Visualize Model Architecture
</Button>
</Box>
<Sheet sx={{ overflow: 'auto' }}>
<Box sx={{ mt: 3 }}>
<Typography level="title-lg" marginBottom={1}>

View File

@@ -37,6 +37,7 @@ import {
import Batched from './Batched/Batched';
import VisualizeLogProbs from './VisualizeLogProbs';
import VisualizeGeneration from './VisualizeGeneration';
import ModelLayerVisualization from './ModelLayerVisualization';
const fetcher = (url) => fetch(url).then((res) => res.json());
@@ -63,7 +64,7 @@ export default function Chat({
const { data: defaultPromptConfigForModel } = useSWR(
chatAPI.TEMPLATE_FOR_MODEL_URL(experimentInfo?.config?.foundation),
fetcher
fetcher,
);
const parsedPromptData = experimentInfo?.config?.prompt_template;
@@ -194,8 +195,8 @@ export default function Chat({
chatAPI.Endpoints.Experiment.UpdateConfig(
experimentInfo?.id,
'generationParams',
JSON.stringify(generationParameters)
)
JSON.stringify(generationParameters),
),
).then(() => {
experimentInfoMutate();
});
@@ -307,7 +308,7 @@ export default function Chat({
try {
generationParameters.stop_str = JSON.parse(
generationParameters?.stop_str
generationParameters?.stop_str,
);
} catch (e) {
console.log('Error parsing stop strings as JSON');
@@ -324,7 +325,7 @@ export default function Chat({
generationParameters?.frequencyPenalty,
systemMessage,
generationParameters?.stop_str,
image
image,
);
clearTimeout(timeoutId);
@@ -471,7 +472,7 @@ export default function Chat({
try {
generationParameters.stop_str = JSON.parse(
generationParameters?.stop_str
generationParameters?.stop_str,
);
} catch (e) {
console.log('Error parsing stop strings as JSON');
@@ -488,7 +489,7 @@ export default function Chat({
generationParameters?.frequencyPenalty,
systemMessage,
generationParameters?.stop_str,
image
image,
);
// The model may make repeated tool calls but don't let it get stuck in a loop
@@ -537,7 +538,7 @@ export default function Chat({
tool_responses.push(func_response.message);
} else {
tool_responses.push(
'There was an unknown error calling the tool.'
'There was an unknown error calling the tool.',
);
}
}
@@ -567,7 +568,7 @@ export default function Chat({
generationParameters?.frequencyPenalty,
systemMessage,
generationParameters?.stop_str,
image
image,
);
}
}
@@ -647,7 +648,7 @@ export default function Chat({
mutate: conversationsMutate,
} = useSWR(
chatAPI.Endpoints.Experiment.GetConversations(experimentInfo?.id),
fetcher
fetcher,
);
const sendCompletionToLLM = async (element, targetElement) => {
@@ -667,7 +668,7 @@ export default function Chat({
try {
generationParameters.stop_str = JSON.parse(
generationParameters?.stop_str
generationParameters?.stop_str,
);
} catch (e) {
console.log('Error parsing stop strings as JSON');
@@ -682,7 +683,7 @@ export default function Chat({
generationParameters?.topP,
false,
generationParameters?.stop_str,
targetElement
targetElement,
);
setIsThinking(false);
@@ -744,7 +745,7 @@ export default function Chat({
value={mode}
onChange={(
event: React.SyntheticEvent | null,
newValue: string | null
newValue: string | null,
) => setMode(newValue)}
variant="soft"
size="md"
@@ -766,6 +767,7 @@ export default function Chat({
<Option value="chat">Chat</Option>
<Option value="completion">Completion</Option>
<Option value="visualize_model">Model Activations</Option>
<Option value="model_layers">Model Architecture</Option>
<Option value="rag">Query Docs (RAG)</Option>
<Option value="tools">Tool Calling</Option>
<Option value="template">Templated Prompt</Option>
@@ -853,6 +855,23 @@ export default function Chat({
experimentInfoMutate={experimentInfoMutate}
/>
)}
{mode === 'model_layers' && (
<ModelLayerVisualization
tokenCount={tokenCount}
stopStreaming={stopStreaming}
generationParameters={generationParameters}
setGenerationParameters={setGenerationParameters}
defaultPromptConfigForModel={defaultPromptConfigForModel}
conversations={conversations}
conversationsIsLoading={conversationsIsLoading}
conversationsMutate={conversationsMutate}
setChats={setChats}
setConversationId={setConversationId}
conversationId={conversationId}
experimentInfo={experimentInfo}
experimentInfoMutate={experimentInfoMutate}
/>
)}
{mode === 'tools' && (
<ChatPage
key={conversationId}

View File

@@ -0,0 +1,795 @@
import {
Alert,
Box,
Button,
CircularProgress,
FormControl,
IconButton,
Sheet,
Slider,
Stack,
Typography,
Select,
Option,
} from '@mui/joy';
import {
SendIcon,
StopCircle,
RotateCcw,
ZoomIn,
ZoomOut,
ConstructionIcon,
} from 'lucide-react';
import { useEffect, useRef, useState } from 'react';
import * as chatAPI from '../../../lib/transformerlab-api-sdk';
import ChatSettingsOnLeftHandSide from './ChatSettingsOnLeftHandSide';
import * as THREE from 'three';
import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls';
export default function ModelLayerVisualization({
tokenCount,
stopStreaming,
generationParameters,
setGenerationParameters,
defaultPromptConfigForModel,
conversations,
conversationsIsLoading,
conversationsMutate,
setChats,
setConversationId,
conversationId,
experimentInfo,
experimentInfoMutate,
currentModel,
currentAdaptor,
}) {
const [isLoading, setIsLoading] = useState(false);
const [error, setError] = useState(null);
const [modelLayers, setModelLayers] = useState([]);
const [elevation, setElevation] = useState(30);
const [azimuth, setAzimuth] = useState(45);
const [zoom, setZoom] = useState(1.0);
const [selectedLayer, setSelectedLayer] = useState(null);
// Canvas ref for Three.js
const canvasRef = useRef(null);
const rendererRef = useRef(null);
const sceneRef = useRef(null);
const cameraRef = useRef(null);
const controlsRef = useRef(null);
const layerCanvasRef = useRef(null);
// Add these near the top of your component
const [hoveredLayer, setHoveredLayer] = useState(null);
const raycasterRef = useRef(new THREE.Raycaster());
const mouseRef = useRef(new THREE.Vector2());
const layerMeshesRef = useRef([]);
let hoveredLayerSavedBeforeNextFrame = null;
// Add this function to your component
const updateHoveredLayer = () => {
if (
!raycasterRef.current ||
!mouseRef.current ||
!cameraRef.current ||
!sceneRef.current
)
return;
// Update the raycaster with the current mouse position
raycasterRef.current.setFromCamera(mouseRef.current, cameraRef.current);
// Calculate intersections with layer objects
const intersects = raycasterRef.current.intersectObjects(
layerMeshesRef.current,
);
if (intersects.length > 0) {
// Found a hovered layer
const hoveredMesh = intersects[0].object;
hoveredLayerSavedBeforeNextFrame = hoveredMesh;
setHoveredLayer({
name: hoveredMesh.userData.name,
paramCount: hoveredMesh.userData.paramCount,
type: hoveredMesh.userData.type,
index: hoveredMesh.userData.index,
});
} else {
// No layer being hovered
hoveredLayerSavedBeforeNextFrame = null;
setHoveredLayer(null);
}
};
const handleClick = () => {
if (hoveredLayerSavedBeforeNextFrame !== null) {
setSelectedLayer(hoveredLayerSavedBeforeNextFrame);
} else {
setSelectedLayer(null);
}
};
// Get current model
if (!currentModel) {
currentModel = experimentInfo?.config?.foundation;
}
// const currentModel = experimentInfo?.config?.foundation;
// console.log('FOUNDATION', experimentInfo?.config);
// Fetch model layer data
const fetchModelArchitecture = async () => {
if (!currentModel) return;
setIsLoading(true);
setError(null);
console.log('Fetching model architecture for:', currentModel);
try {
const url = `${chatAPI.INFERENCE_SERVER_URL()}v1/model_architecture`;
console.log(
'REQUEST BODY',
JSON.stringify({
model: currentModel,
adaptor: currentAdaptor || '',
}),
);
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: currentModel,
adaptor: currentAdaptor || '',
}),
});
if (!response.ok) {
throw new Error(
`Server responded with ${response.status}: ${response.statusText}`,
);
}
const data = await response.json();
console.log('DATA', data);
setModelLayers(data.layers || []);
} catch (err) {
setError(`Failed to fetch model architecture: ${err.message}`);
} finally {
setIsLoading(false);
}
};
// Initialize 3D visualization
const initVisualization = () => {
if (!canvasRef.current || modelLayers.length === 0) return;
// Clean up existing renderer
if (rendererRef.current) {
rendererRef.current.dispose();
if (
canvasRef.current instanceof HTMLElement &&
rendererRef.current?.domElement instanceof HTMLElement &&
canvasRef.current.contains(rendererRef.current.domElement)
) {
canvasRef.current.removeChild(rendererRef.current.domElement);
}
}
// Create scene
const scene = new THREE.Scene();
scene.background = new THREE.Color(0xf0f0f0);
sceneRef.current = scene;
// Create camera
const canvas = canvasRef.current;
const camera = new THREE.PerspectiveCamera(
75,
canvas.clientWidth / canvas.clientHeight,
0.1,
1000,
);
// Position camera to view horizontal layout from above
camera.position.set(0, 5, 10);
cameraRef.current = camera;
// Create renderer
const renderer = new THREE.WebGLRenderer({ antialias: true });
renderer.setSize(canvas.clientWidth, canvas.clientHeight);
renderer.shadowMap.enabled = true; // Enable shadow mapping
renderer.shadowMap.type = THREE.PCFSoftShadowMap; // Use soft shadows
canvasRef.current.appendChild(renderer.domElement);
rendererRef.current = renderer;
// Add orbit controls
const controls = new OrbitControls(camera, renderer.domElement);
controls.enableDamping = true;
controls.dampingFactor = 0.05;
controls.rotateSpeed = 0.5;
controls.maxDistance = 200; // Set a reasonable max zoom out distance
controls.update();
controlsRef.current = controls;
// Setup raycaster for hover detection
const raycaster = new THREE.Raycaster();
const mouse = new THREE.Vector2();
raycasterRef.current = raycaster;
mouseRef.current = mouse;
// Find min/max parameter sizes for scaling
const paramSizes = modelLayers.map((layer) => layer.param_count);
const maxParamSize = Math.max(...paramSizes);
const minParamSize = Math.min(...paramSizes);
const minWidth = 0.5;
const maxWidth = 5.0; // Length based on param count
const thickness = 0.2; // Thickness of the rectangle
const spacing = 0.05; // Minimal spacing between components for stacked look
// Create rectangles for each layer - vertical stack layout
let yOffset = 0;
// Calculate colors for layers based on their type
const uniqueTypes = [
...new Set(
modelLayers.map((layer) => layer.name.split('.').slice(-2)[0]),
),
];
// Color map for layer types
const colorMap = {};
uniqueTypes.forEach((type, index) => {
colorMap[type] = `hsl(${(index / uniqueTypes.length) * 360}, 70%, 60%)`;
});
// Create vertical stack of layers
modelLayers.forEach((layer, index) => {
// Calculate width based on parameter count (logarithmic scale)
const paramCount = layer.param_count;
const width =
minWidth +
((Math.log(paramCount) - Math.log(minParamSize)) /
(Math.log(maxParamSize) - Math.log(minParamSize))) *
(maxWidth - minWidth);
// Create box geometry and material
const height = 0.25; // Fixed height for each layer in the stack
const geometry = new THREE.BoxGeometry(width, height, width);
// Get color based on layer type
const layerType = layer.name.split('.').slice(-2)[0];
const color =
colorMap[layerType] ||
`hsl(${(index / modelLayers.length) * 360}, 70%, 60%)`;
const material = new THREE.MeshLambertMaterial({
color: color,
transparent: true,
opacity: 0.9,
});
// Create mesh and position it in the stack
const box = new THREE.Mesh(geometry, material);
// box.castShadow = true; // Enable shadow casting for the object
// box.receiveShadow = true; // Enable shadow receiving for the object
// Position vertically stacked from bottom to top
box.position.set(0, yOffset + height / 2, 0);
box.userData = {
name: layer.name,
paramCount: layer.param_count,
type: layerType,
index: index,
};
scene.add(box);
layerMeshesRef.current.push(box);
yOffset += height + spacing; // Move to next position vertically
});
// Add input arrow below the first layer
if (modelLayers.length > 0) {
const inputArrowGroup = new THREE.Group();
// Shaft of the arrow
const inputShaftGeometry = new THREE.CylinderGeometry(
0.05,
0.05,
0.5,
16,
);
const inputShaftMaterial = new THREE.MeshBasicMaterial({
color: 0xff0000,
});
const inputShaft = new THREE.Mesh(inputShaftGeometry, inputShaftMaterial);
inputShaft.position.y = -0.25; // Center the shaft vertically
inputArrowGroup.add(inputShaft);
// Arrowhead
const inputHeadGeometry = new THREE.ConeGeometry(0.1, 0.2, 16);
const inputHeadMaterial = new THREE.MeshBasicMaterial({
color: 0xff0000,
});
const inputHead = new THREE.Mesh(inputHeadGeometry, inputHeadMaterial);
inputHead.position.y = -0.6; // Position the arrowhead below the shaft
inputHead.rotation.x = Math.PI; // Flip the cone to point upward
inputArrowGroup.add(inputHead);
inputArrowGroup.rotation.x = Math.PI; // Point the arrow upward
inputArrowGroup.position.set(0, -0.75, 0); // Position below the first layer
scene.add(inputArrowGroup);
}
// Add output arrow after the last layer
if (modelLayers.length > 0) {
const outputArrowGroup = new THREE.Group();
// Shaft of the arrow
const outputShaftGeometry = new THREE.CylinderGeometry(
0.05,
0.05,
0.5,
16,
);
const outputShaftMaterial = new THREE.MeshBasicMaterial({
color: 0xff0000,
});
const outputShaft = new THREE.Mesh(
outputShaftGeometry,
outputShaftMaterial,
);
outputShaft.position.y = 0.25; // Center the shaft vertically
outputArrowGroup.add(outputShaft);
// Arrowhead
const outputHeadGeometry = new THREE.ConeGeometry(0.1, 0.2, 16);
const outputHeadMaterial = new THREE.MeshBasicMaterial({
color: 0xff0000,
});
const outputHead = new THREE.Mesh(outputHeadGeometry, outputHeadMaterial);
outputHead.position.y = 0.6; // Position the arrowhead above the shaft
outputArrowGroup.add(outputHead);
outputArrowGroup.position.set(0, yOffset + 0.5, 0); // Position above the last layer
scene.add(outputArrowGroup);
}
// Add lights
const ambientLight = new THREE.AmbientLight(0x404040, 30); // Increased intensity to brighten up the objects
scene.add(ambientLight);
const directionalLight = new THREE.DirectionalLight(0xffffff, 1.2); // Slightly increased intensity
directionalLight.position.set(1, yOffset + 5, 4);
// directionalLight.castShadow = true; // Enable shadow casting for the light
directionalLight.shadow.mapSize.width = 1024; // Shadow map resolution
directionalLight.shadow.mapSize.height = 1024;
directionalLight.shadow.camera.near = 0.5;
directionalLight.shadow.camera.far = yOffset * 2;
// scene.add(directionalLight);
// Add a visible representation for the directional light
const lightHelperGeometry = new THREE.SphereGeometry(3, 16, 16); // Small sphere
const lightHelperMaterial = new THREE.MeshBasicMaterial({
color: 0xff0000,
});
const lightHelper = new THREE.Mesh(
lightHelperGeometry,
lightHelperMaterial,
);
lightHelper.position.copy(directionalLight.position); // Match the light's position
// scene.add(lightHelper);
// Add ground plane to receive shadows
const groundGeometry = new THREE.PlaneGeometry(50, 50);
const groundMaterial = new THREE.ShadowMaterial({ opacity: 0.5 });
const ground = new THREE.Mesh(groundGeometry, groundMaterial);
ground.rotation.x = -Math.PI / 2; // Rotate to lie flat
ground.position.y = 0; // Position at the base of the stack
ground.receiveShadow = true; // Enable shadow receiving
scene.add(ground);
// Adjust camera position for better viewing of vertical stack
const totalHeight = yOffset;
camera.position.set(8, totalHeight / 2, 8);
camera.lookAt(new THREE.Vector3(0, totalHeight / 2, 0)); // Look at the middle of the stack
controls.target.set(0, totalHeight / 2, 0);
controls.update();
// Add hover detection
const handleMouseMove = (event) => {
// Calculate mouse position in normalized device coordinates
const rect = renderer.domElement.getBoundingClientRect();
mouse.x = ((event.clientX - rect.left) / rect.width) * 2 - 1;
mouse.y = -((event.clientY - rect.top) / rect.height) * 2 + 1;
// Update the hoveredLayer state
updateHoveredLayer();
};
// Add hover detection events
renderer.domElement.addEventListener('mousemove', handleMouseMove);
renderer.domElement.addEventListener('click', handleClick);
// Animation function with hover detection
const animate = () => {
requestAnimationFrame(animate);
if (controlsRef.current) controlsRef.current.update();
if (rendererRef.current && sceneRef.current && cameraRef.current) {
rendererRef.current.render(sceneRef.current, cameraRef.current);
}
};
// Start animation
animate();
// Handle window resize
const handleResize = () => {
if (!canvasRef.current || !cameraRef.current || !rendererRef.current)
return;
const canvas = canvasRef.current;
cameraRef.current.aspect = canvas.clientWidth / canvas.clientHeight;
cameraRef.current.updateProjectionMatrix();
rendererRef.current.setSize(canvas.clientWidth, canvas.clientHeight);
};
window.addEventListener('resize', handleResize);
// Cleanup function
return () => {
window.removeEventListener('resize', handleResize);
renderer.domElement.removeEventListener('mousemove', handleMouseMove);
renderer.domElement.removeEventListener('click', handleClick);
layerMeshesRef.current = [];
};
};
// Update camera view based on elevation and azimuth
const updateCameraView = () => {
if (!controlsRef.current || !cameraRef.current) return;
// Convert degrees to radians
const elevRad = (elevation * Math.PI) / 180;
const azimRad = (azimuth * Math.PI) / 180;
// Calculate new camera position
const distance = 10 / zoom;
const x = distance * Math.sin(azimRad) * Math.cos(elevRad);
const y = distance * Math.sin(elevRad);
const z = distance * Math.cos(azimRad) * Math.cos(elevRad);
// Update camera position
cameraRef.current.position.set(x, y, z);
cameraRef.current.lookAt(0, 0, 0);
controlsRef.current.update();
};
// Effect to fetch data when model changes
useEffect(() => {
if (currentModel) {
fetchModelArchitecture();
}
}, [currentModel]);
// Effect to initialize visualization when data is loaded
useEffect(() => {
if (modelLayers.length > 0) {
initVisualization();
}
}, [modelLayers]);
// Effect to update camera when controls change
useEffect(() => {
updateCameraView();
}, [elevation, azimuth, zoom]);
const renderSelectedLayer = () => {
if (!layerCanvasRef.current || !selectedLayer) return;
// Clear the canvas
while (layerCanvasRef.current.firstChild) {
layerCanvasRef.current.removeChild(layerCanvasRef.current.firstChild);
}
// Create a new scene
const scene = new THREE.Scene();
scene.background = new THREE.Color(0xf0f0f0);
// Create a camera
const canvas = layerCanvasRef.current;
const camera = new THREE.PerspectiveCamera(
75,
canvas.clientWidth / canvas.clientHeight,
0.1,
1000,
);
camera.position.set(0, 2, 0.5);
// Create a renderer
const renderer = new THREE.WebGLRenderer({ antialias: true });
renderer.setSize(canvas.clientWidth, canvas.clientHeight);
layerCanvasRef.current.appendChild(renderer.domElement);
// Add orbit controls
const controls = new OrbitControls(camera, renderer.domElement);
controls.enableDamping = true;
controls.dampingFactor = 0.05;
// Create a 100x100 grid of boxes
const gridSize = 20;
const boxSize = 0.1;
const spacing = 0.0; // Space between boxes
const color = selectedLayer?.material?.color || 0x0077ff;
// change opacity to match the selected layer:
const opacity = selectedLayer?.material?.opacity || 0.5;
const transparent = selectedLayer?.material?.transparent || true;
const geometry = new THREE.BoxGeometry(boxSize, boxSize, boxSize);
const material = new THREE.MeshLambertMaterial({
color,
opacity,
transparent,
});
for (let i = 0; i < gridSize; i++) {
for (let j = 0; j < gridSize; j++) {
const box = new THREE.Mesh(geometry, material);
box.position.set(
i * (boxSize + spacing) - (gridSize * (boxSize + spacing)) / 2,
0,
j * (boxSize + spacing) - (gridSize * (boxSize + spacing)) / 2,
);
scene.add(box);
}
}
// Point the camera at the center of the grid
camera.lookAt(0, 0, 0);
controls.update();
// Add lights
const ambientLight = new THREE.AmbientLight(0x404040, 2);
scene.add(ambientLight);
const directionalLight = new THREE.DirectionalLight(0xffffff, 1);
directionalLight.position.set(5, 10, 7.5);
scene.add(directionalLight);
// Render the scene
const animate = () => {
requestAnimationFrame(animate);
controls.update();
renderer.render(scene, camera);
};
animate();
};
// Effect to render the selected layer when it changes
useEffect(() => {
renderSelectedLayer();
}, [selectedLayer]);
// Handle manual refresh
const handleRefresh = () => {
fetchModelArchitecture();
};
// Handle zoom controls
const handleZoomIn = () => {
setZoom((prev) => Math.min(prev * 1.2, 5.0));
};
const handleZoomOut = () => {
setZoom((prev) => Math.max(prev / 1.2, 0.5));
};
return (
<Sheet
sx={{
display: 'flex',
flexDirection: 'row',
height: '100%',
width: '100%',
overflow: 'hidden',
gap: 2,
}}
>
{/* <ChatSettingsOnLeftHandSide
generationParameters={generationParameters}
setGenerationParameters={setGenerationParameters}
tokenCount={tokenCount}
defaultPromptConfigForModel={defaultPromptConfigForModel}
conversations={conversations}
conversationsIsLoading={conversationsIsLoading}
conversationsMutate={conversationsMutate}
setChats={setChats}
setConversationId={setConversationId}
conversationId={conversationId}
experimentInfo={experimentInfo}
experimentInfoMutate={experimentInfoMutate}
/> */}
<Sheet
sx={{
display: 'flex',
flexDirection: 'column',
flexGrow: 1,
height: '100%',
overflow: 'hidden',
}}
>
<Alert
color="neutral"
variant="outlined"
startDecorator={<ConstructionIcon />}
>
This feature is currently in development and works with the FastChat
and MLX Model Server only.
</Alert>
<Box
sx={{
p: 2,
display: 'flex',
justifyContent: 'space-between',
alignItems: 'center',
}}
>
<Typography level="h2">Model Layer Visualization</Typography>
<Stack direction="row" spacing={1}>
<IconButton
color="neutral"
onClick={handleZoomOut}
aria-label="Zoom out"
>
<ZoomOut />
</IconButton>
<IconButton
color="neutral"
onClick={handleZoomIn}
aria-label="Zoom in"
>
<ZoomIn />
</IconButton>
<Button
color="neutral"
startDecorator={
isLoading ? (
<CircularProgress thickness={2} size="sm" color="neutral" />
) : (
<RotateCcw size="20px" />
)
}
onClick={handleRefresh}
disabled={isLoading}
>
Refresh
</Button>
</Stack>
</Box>
{/* Add this after the canvas box in your return statement */}
{hoveredLayer && (
<Box
sx={{
position: 'absolute',
bottom: '16px',
right: '16px',
maxWidth: '300px',
bgcolor: 'rgba(255,255,255,0.9)',
p: 2,
borderRadius: 'md',
backdropFilter: 'blur(5px)',
boxShadow: 'sm',
zIndex: 1000,
}}
>
<Typography level="title-md" sx={{ mb: 1, color: 'primary.500' }}>
Layer Information
</Typography>
<Typography level="body-sm" sx={{ mb: 0.5 }}>
<strong>Name:</strong> {hoveredLayer.name}
</Typography>
<Typography level="body-sm" sx={{ mb: 0.5 }}>
<strong>Type:</strong> {hoveredLayer.type}
</Typography>
<Typography level="body-sm" sx={{ mb: 0.5 }}>
<strong>Parameters:</strong>{' '}
{hoveredLayer.paramCount.toLocaleString()}
</Typography>
<Typography level="body-sm">
<strong>Layer Index:</strong> {hoveredLayer.index + 1} of{' '}
{modelLayers.length}
</Typography>
</Box>
)}
{error && (
<Alert color="danger" sx={{ mx: 2, mb: 2 }}>
{error}
</Alert>
)}
<Box
sx={{
flexGrow: 1,
p: 2,
pt: 0,
gap: 2,
display: 'flex',
flexDirection: 'row',
position: 'relative',
}}
>
{isLoading ? (
<Box
sx={{
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
height: '100%',
}}
>
<CircularProgress size="lg" />
</Box>
) : (
<>
<Box
ref={canvasRef}
sx={{
flexGrow: 1,
height: '100%',
bgcolor: 'background.level1',
borderRadius: 'md',
overflow: 'hidden',
}}
/>
<Box sx={{ width: '300px' }} id="detailed-layer">
<Box
ref={layerCanvasRef}
sx={{
width: '100%',
height: '300px',
bgcolor: 'background.level1',
borderRadius: 'md',
overflow: 'hidden',
}}
/>
{/* {JSON.stringify(selectedLayer)} */}
</Box>
</>
)}
{modelLayers.length > 0 && (
<Box
sx={{
position: 'absolute',
bottom: '16px',
left: '16px',
maxWidth: '300px',
bgcolor: 'rgba(255,255,255,0.8)',
p: 2,
borderRadius: 'md',
backdropFilter: 'blur(5px)',
}}
>
<Typography level="body-sm" sx={{ mb: 1 }}>
<strong>Model:</strong> {currentModel.split('/').pop()}
</Typography>
<Typography level="body-sm">
<strong>Total Layers:</strong> {modelLayers.length}
</Typography>
</Box>
)}
</Box>
</Sheet>
</Sheet>
);
}

View File

@@ -354,6 +354,18 @@ export default function MainAppPanel({
/>
}
/>
<Route
path="/experiment/model_architecture_visualization"
element={
<Interact
experimentInfo={experimentInfo}
experimentInfoMutate={experimentInfoMutate}
setRagEngine={setRagEngine}
mode={'model_layers'}
setMode={setSelectedInteractSubpage}
/>
}
/>
<Route
path="/experiment/embeddings"
element={<Embeddings experimentInfo={experimentInfo} />}