Merge pull request #359 from transformerlab/add/model-architecture-visualization

Merging remaining parts for visualization
This commit is contained in:
ali asaria
2025-04-09 13:59:56 -04:00
committed by GitHub

View File

@@ -277,11 +277,14 @@ export default function ModelLayerVisualization({
// Position vertically stacked from bottom to top
box.position.set(0, yOffset + height / 2, 0);
// console.log('Layer:', layer);
box.userData = {
name: layer.name,
paramCount: layer.param_count,
type: layerType,
index: index,
shape: layer?.shape,
};
scene.add(box);
@@ -517,8 +520,20 @@ export default function ModelLayerVisualization({
controls.enableDamping = true;
controls.dampingFactor = 0.05;
// Create a 100x100 grid of boxes
const gridSize = 20;
// Get the grid size from the shape of layer in the userData?.shape:
const shape = selectedLayer?.userData?.shape;
// break the shape text which looks like (576, 192) to width and length.
// If the shape looks like (415,) then set the width to 415 and length to 1.
const shapeArray = shape.replace(/[()]/g, '').split(',').map(Number);
let width = shapeArray[0] || 1;
let length = shapeArray[1] || 1;
width = Math.ceil(Math.sqrt(width));
length = Math.ceil(Math.sqrt(length));
// clamp width and length to max 40:
const maxSize = 40;
width = Math.min(width, maxSize);
length = Math.min(length, maxSize);
const boxSize = 0.1;
const spacing = 0.0; // Space between boxes
const color = selectedLayer?.material?.color || 0x0077ff;
@@ -533,13 +548,13 @@ export default function ModelLayerVisualization({
transparent,
});
for (let i = 0; i < gridSize; i++) {
for (let j = 0; j < gridSize; j++) {
for (let i = 0; i < width; i++) {
for (let j = 0; j < length; j++) {
const box = new THREE.Mesh(geometry, material);
box.position.set(
i * (boxSize + spacing) - (gridSize * (boxSize + spacing)) / 2,
i * (boxSize + spacing) - (width * (boxSize + spacing)) / 2,
0,
j * (boxSize + spacing) - (gridSize * (boxSize + spacing)) / 2,
j * (boxSize + spacing) - (length * (boxSize + spacing)) / 2,
);
scene.add(box);
}
@@ -762,7 +777,13 @@ export default function ModelLayerVisualization({
overflow: 'hidden',
}}
/>
{/* {JSON.stringify(selectedLayer)} */}
<Typography
level="title-md"
sx={{ mt: 1, mb: 0.5, color: 'primary.500' }}
>
{selectedLayer?.userData?.name}
<pre>{JSON.stringify(selectedLayer?.userData, null, 2)}</pre>
</Typography>
</Box>
</>
)}