mirror of
https://github.com/transformerlab/transformerlab-app.git
synced 2025-04-14 07:48:20 +03:00
Merge pull request #359 from transformerlab/add/model-architecture-visualization
Merging remaining parts for visualization
This commit is contained in:
@@ -277,11 +277,14 @@ export default function ModelLayerVisualization({
|
||||
// Position vertically stacked from bottom to top
|
||||
box.position.set(0, yOffset + height / 2, 0);
|
||||
|
||||
// console.log('Layer:', layer);
|
||||
|
||||
box.userData = {
|
||||
name: layer.name,
|
||||
paramCount: layer.param_count,
|
||||
type: layerType,
|
||||
index: index,
|
||||
shape: layer?.shape,
|
||||
};
|
||||
|
||||
scene.add(box);
|
||||
@@ -517,8 +520,20 @@ export default function ModelLayerVisualization({
|
||||
controls.enableDamping = true;
|
||||
controls.dampingFactor = 0.05;
|
||||
|
||||
// Create a 100x100 grid of boxes
|
||||
const gridSize = 20;
|
||||
// Get the grid size from the shape of layer in the userData?.shape:
|
||||
const shape = selectedLayer?.userData?.shape;
|
||||
// break the shape text which looks like (576, 192) to width and length.
|
||||
// If the shape looks like (415,) then set the width to 415 and length to 1.
|
||||
const shapeArray = shape.replace(/[()]/g, '').split(',').map(Number);
|
||||
let width = shapeArray[0] || 1;
|
||||
let length = shapeArray[1] || 1;
|
||||
width = Math.ceil(Math.sqrt(width));
|
||||
length = Math.ceil(Math.sqrt(length));
|
||||
// clamp width and length to max 40:
|
||||
const maxSize = 40;
|
||||
width = Math.min(width, maxSize);
|
||||
length = Math.min(length, maxSize);
|
||||
|
||||
const boxSize = 0.1;
|
||||
const spacing = 0.0; // Space between boxes
|
||||
const color = selectedLayer?.material?.color || 0x0077ff;
|
||||
@@ -533,13 +548,13 @@ export default function ModelLayerVisualization({
|
||||
transparent,
|
||||
});
|
||||
|
||||
for (let i = 0; i < gridSize; i++) {
|
||||
for (let j = 0; j < gridSize; j++) {
|
||||
for (let i = 0; i < width; i++) {
|
||||
for (let j = 0; j < length; j++) {
|
||||
const box = new THREE.Mesh(geometry, material);
|
||||
box.position.set(
|
||||
i * (boxSize + spacing) - (gridSize * (boxSize + spacing)) / 2,
|
||||
i * (boxSize + spacing) - (width * (boxSize + spacing)) / 2,
|
||||
0,
|
||||
j * (boxSize + spacing) - (gridSize * (boxSize + spacing)) / 2,
|
||||
j * (boxSize + spacing) - (length * (boxSize + spacing)) / 2,
|
||||
);
|
||||
scene.add(box);
|
||||
}
|
||||
@@ -762,7 +777,13 @@ export default function ModelLayerVisualization({
|
||||
overflow: 'hidden',
|
||||
}}
|
||||
/>
|
||||
{/* {JSON.stringify(selectedLayer)} */}
|
||||
<Typography
|
||||
level="title-md"
|
||||
sx={{ mt: 1, mb: 0.5, color: 'primary.500' }}
|
||||
>
|
||||
{selectedLayer?.userData?.name}
|
||||
<pre>{JSON.stringify(selectedLayer?.userData, null, 2)}</pre>
|
||||
</Typography>
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
Reference in New Issue
Block a user