use log scale for distribution

This commit is contained in:
ali asaria
2025-04-11 21:26:30 -04:00
parent 483b21dfab
commit 6fc7da7f16

View File

@@ -26,7 +26,8 @@ import ChatSettingsOnLeftHandSide from './ChatSettingsOnLeftHandSide';
import * as THREE from 'three';
import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls';
import useSWR from 'swr';
import { ResponsiveBar } from '@nivo/bar';
import { Bar, ResponsiveBar } from '@nivo/bar';
import { ResponsiveLine } from '@nivo/line';
// write a fetcher that uses POST:
const fetcher = (url: string, body: Record<string, unknown>) =>
@@ -69,24 +70,27 @@ function SingleLayerHistogram({
const histogramData = data.histogram.map((value: number, index: number) => ({
bin: `${data.bin_edges[index].toFixed(2)} - ${data.bin_edges[index + 1].toFixed(2)}`,
count: value,
count: Math.log10(value + 1), // Convert value to logarithmic scale
}));
return (
<Box
sx={{
width: '100%',
height: '150px',
borderRadius: 'md',
overflow: 'hidden',
display: 'flex',
flexDirection: 'column',
}}
>
<Typography level="title-md" sx={{ mb: 1 }}>
Layer Weights Distribution:
Layer Weights Distribution (log scale):
</Typography>
<ResponsiveBar
<Bar
data={histogramData}
keys={['count']}
width={300}
height={150}
indexBy="bin"
margin={{ top: 0, right: 0, bottom: 0, left: 0 }}
padding={0.0}
@@ -712,7 +716,6 @@ export default function ModelLayerVisualization({
sx={{
display: 'flex',
flexDirection: 'row',
height: '100%',
width: '100%',
overflow: 'hidden',
gap: 2,
@@ -742,22 +745,13 @@ export default function ModelLayerVisualization({
overflow: 'hidden',
}}
>
<Alert
color="neutral"
variant="outlined"
startDecorator={<ConstructionIcon />}
>
This feature is currently in development and works with the FastChat
and MLX Model Server only.
</Alert>
<Box
sx={{
display: 'flex',
justifyContent: 'space-between',
alignItems: 'center',
overflow: 'hidden',
marginY: 1,
marginBottom: 1,
}}
>
<Typography level="h2">Model Layer Visualization</Typography>
@@ -821,7 +815,6 @@ export default function ModelLayerVisualization({
display: 'flex',
justifyContent: 'center',
alignItems: 'center',
height: '100%',
}}
>
<CircularProgress size="lg" />
@@ -837,7 +830,7 @@ export default function ModelLayerVisualization({
position: 'absolute',
top: '1rem',
left: '1rem',
maxWidth: '300px',
width: '300px',
bgcolor: 'rgba(255,255,255,0.9)',
p: 2,
borderRadius: 'md',
@@ -878,7 +871,14 @@ export default function ModelLayerVisualization({
overflow: 'hidden',
}}
/>
<Box sx={{ width: '300px' }} id="detailed-layer">
<Box
sx={{
width: '300px',
display: 'flex',
flexDirection: 'column',
}}
id="detailed-layer"
>
<Typography level="title-md" sx={{ mb: 1 }}>
Layer Details
</Typography>
@@ -886,10 +886,11 @@ export default function ModelLayerVisualization({
ref={layerCanvasRef}
sx={{
width: '100%',
height: '300px',
height: '200px',
bgcolor: 'background.level1',
borderRadius: 'md',
overflow: 'hidden',
display: 'flex',
}}
/>
{selectedLayer && (
@@ -898,21 +899,30 @@ export default function ModelLayerVisualization({
modelName={currentModel}
layerName={selectedLayer?.userData?.original_name}
/>
<Typography
level="body-md"
sx={{ mb: 0.5, color: 'primary.500' }}
>
<Typography level="body-md">
Name: {selectedLayer?.userData?.original_name}
<br />
Type: {selectedLayer?.userData?.type}
<br />
Parameters: {selectedLayer?.userData?.paramCount}
<br />
index: {selectedLayer?.userData?.index}
<br />
Shape: {selectedLayer?.userData?.shape}
<br />
</Typography>
<Box sx={{ display: 'flex', flexDirection: 'row', gap: 2 }}>
<Box sx={{ flex: 1 }}>
<Typography level="body-md">
<strong>Type:</strong> {selectedLayer?.userData?.type}
</Typography>
<Typography level="body-md">
<strong>Parameters:</strong>{' '}
{selectedLayer?.userData?.paramCount}
</Typography>
</Box>
<Box sx={{ flex: 1 }}>
<Typography level="body-md">
<strong>Index:</strong>{' '}
{selectedLayer?.userData?.index}
</Typography>
<Typography level="body-md">
<strong>Shape:</strong>{' '}
{selectedLayer?.userData?.shape}
</Typography>
</Box>
</Box>
</>
)}
</Box>