mirror of
https://github.com/transformerlab/transformerlab-app.git
synced 2025-04-14 07:48:20 +03:00
Add v1 of model architecture visualization
This commit is contained in:
2681
package-lock.json
generated
2681
package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -105,6 +105,7 @@
|
||||
"@rjsf/utils": "^5.24.3",
|
||||
"@rjsf/validator-ajv8": "^5.24.3",
|
||||
"@segment/analytics-next": "^1.77.0",
|
||||
"@types/three": "^0.175.0",
|
||||
"@uppy/core": "^3.8.0",
|
||||
"@uppy/dashboard": "^3.5.1",
|
||||
"@uppy/drag-drop": "^3.0.3",
|
||||
@@ -127,9 +128,11 @@
|
||||
"electron-ssh2": "^0.1.2",
|
||||
"electron-store": "^8.1.0",
|
||||
"electron-updater": "^6.3.0-alpha.6",
|
||||
"install": "^0.13.0",
|
||||
"lucide-react": "^0.477.0",
|
||||
"monaco-themes": "^0.4.4",
|
||||
"morgan": "~1.10.0",
|
||||
"npm": "^11.2.0",
|
||||
"react": "^18.2.0",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-dropzone": "^14.3.8",
|
||||
@@ -143,6 +146,7 @@
|
||||
"serve-favicon": "^2.5.0",
|
||||
"swr": "^2.3.2",
|
||||
"tail": "^2.2.6",
|
||||
"three": "^0.175.0",
|
||||
"tree-kill": "^1.2.2",
|
||||
"use-debounce": "^9.0.4",
|
||||
"validator": "^13.7.0"
|
||||
@@ -185,10 +189,10 @@
|
||||
"eslint-plugin-import": "^2.27.5",
|
||||
"eslint-plugin-jest": "^27.2.1",
|
||||
"eslint-plugin-jsx-a11y": "^6.7.1",
|
||||
"eslint-plugin-prettier": "^5.2.3",
|
||||
"eslint-plugin-promise": "^6.1.1",
|
||||
"eslint-plugin-react": "^7.32.2",
|
||||
"eslint-plugin-react-hooks": "^4.6.0",
|
||||
"eslint-plugin-prettier": "^5.2.3",
|
||||
"file-loader": "^6.2.0",
|
||||
"html-webpack-plugin": "^5.6.3",
|
||||
"identity-obj-proxy": "^3.0.0",
|
||||
@@ -304,4 +308,4 @@
|
||||
],
|
||||
"logLevel": "quiet"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ import {
|
||||
import Batched from './Batched/Batched';
|
||||
import VisualizeLogProbs from './VisualizeLogProbs';
|
||||
import VisualizeGeneration from './VisualizeGeneration';
|
||||
import ModelLayerVisualization from './ModelLayerVisualization';
|
||||
|
||||
const fetcher = (url) => fetch(url).then((res) => res.json());
|
||||
|
||||
@@ -63,7 +64,7 @@ export default function Chat({
|
||||
|
||||
const { data: defaultPromptConfigForModel } = useSWR(
|
||||
chatAPI.TEMPLATE_FOR_MODEL_URL(experimentInfo?.config?.foundation),
|
||||
fetcher
|
||||
fetcher,
|
||||
);
|
||||
|
||||
const parsedPromptData = experimentInfo?.config?.prompt_template;
|
||||
@@ -194,8 +195,8 @@ export default function Chat({
|
||||
chatAPI.Endpoints.Experiment.UpdateConfig(
|
||||
experimentInfo?.id,
|
||||
'generationParams',
|
||||
JSON.stringify(generationParameters)
|
||||
)
|
||||
JSON.stringify(generationParameters),
|
||||
),
|
||||
).then(() => {
|
||||
experimentInfoMutate();
|
||||
});
|
||||
@@ -307,7 +308,7 @@ export default function Chat({
|
||||
|
||||
try {
|
||||
generationParameters.stop_str = JSON.parse(
|
||||
generationParameters?.stop_str
|
||||
generationParameters?.stop_str,
|
||||
);
|
||||
} catch (e) {
|
||||
console.log('Error parsing stop strings as JSON');
|
||||
@@ -324,7 +325,7 @@ export default function Chat({
|
||||
generationParameters?.frequencyPenalty,
|
||||
systemMessage,
|
||||
generationParameters?.stop_str,
|
||||
image
|
||||
image,
|
||||
);
|
||||
|
||||
clearTimeout(timeoutId);
|
||||
@@ -471,7 +472,7 @@ export default function Chat({
|
||||
|
||||
try {
|
||||
generationParameters.stop_str = JSON.parse(
|
||||
generationParameters?.stop_str
|
||||
generationParameters?.stop_str,
|
||||
);
|
||||
} catch (e) {
|
||||
console.log('Error parsing stop strings as JSON');
|
||||
@@ -488,7 +489,7 @@ export default function Chat({
|
||||
generationParameters?.frequencyPenalty,
|
||||
systemMessage,
|
||||
generationParameters?.stop_str,
|
||||
image
|
||||
image,
|
||||
);
|
||||
|
||||
// The model may make repeated tool calls but don't let it get stuck in a loop
|
||||
@@ -537,7 +538,7 @@ export default function Chat({
|
||||
tool_responses.push(func_response.message);
|
||||
} else {
|
||||
tool_responses.push(
|
||||
'There was an unknown error calling the tool.'
|
||||
'There was an unknown error calling the tool.',
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -567,7 +568,7 @@ export default function Chat({
|
||||
generationParameters?.frequencyPenalty,
|
||||
systemMessage,
|
||||
generationParameters?.stop_str,
|
||||
image
|
||||
image,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -647,7 +648,7 @@ export default function Chat({
|
||||
mutate: conversationsMutate,
|
||||
} = useSWR(
|
||||
chatAPI.Endpoints.Experiment.GetConversations(experimentInfo?.id),
|
||||
fetcher
|
||||
fetcher,
|
||||
);
|
||||
|
||||
const sendCompletionToLLM = async (element, targetElement) => {
|
||||
@@ -667,7 +668,7 @@ export default function Chat({
|
||||
|
||||
try {
|
||||
generationParameters.stop_str = JSON.parse(
|
||||
generationParameters?.stop_str
|
||||
generationParameters?.stop_str,
|
||||
);
|
||||
} catch (e) {
|
||||
console.log('Error parsing stop strings as JSON');
|
||||
@@ -682,7 +683,7 @@ export default function Chat({
|
||||
generationParameters?.topP,
|
||||
false,
|
||||
generationParameters?.stop_str,
|
||||
targetElement
|
||||
targetElement,
|
||||
);
|
||||
setIsThinking(false);
|
||||
|
||||
@@ -744,7 +745,7 @@ export default function Chat({
|
||||
value={mode}
|
||||
onChange={(
|
||||
event: React.SyntheticEvent | null,
|
||||
newValue: string | null
|
||||
newValue: string | null,
|
||||
) => setMode(newValue)}
|
||||
variant="soft"
|
||||
size="md"
|
||||
@@ -766,6 +767,7 @@ export default function Chat({
|
||||
<Option value="chat">Chat</Option>
|
||||
<Option value="completion">Completion</Option>
|
||||
<Option value="visualize_model">Model Activations</Option>
|
||||
<Option value="model_layers">Model Architecture</Option>
|
||||
<Option value="rag">Query Docs (RAG)</Option>
|
||||
<Option value="tools">Tool Calling</Option>
|
||||
<Option value="template">Templated Prompt</Option>
|
||||
@@ -853,6 +855,23 @@ export default function Chat({
|
||||
experimentInfoMutate={experimentInfoMutate}
|
||||
/>
|
||||
)}
|
||||
{mode === 'model_layers' && (
|
||||
<ModelLayerVisualization
|
||||
tokenCount={tokenCount}
|
||||
stopStreaming={stopStreaming}
|
||||
generationParameters={generationParameters}
|
||||
setGenerationParameters={setGenerationParameters}
|
||||
defaultPromptConfigForModel={defaultPromptConfigForModel}
|
||||
conversations={conversations}
|
||||
conversationsIsLoading={conversationsIsLoading}
|
||||
conversationsMutate={conversationsMutate}
|
||||
setChats={setChats}
|
||||
setConversationId={setConversationId}
|
||||
conversationId={conversationId}
|
||||
experimentInfo={experimentInfo}
|
||||
experimentInfoMutate={experimentInfoMutate}
|
||||
/>
|
||||
)}
|
||||
{mode === 'tools' && (
|
||||
<ChatPage
|
||||
key={conversationId}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
19
src/renderer/components/Experiment/Interact/test
Normal file
19
src/renderer/components/Experiment/Interact/test
Normal file
@@ -0,0 +1,19 @@
|
||||
def generate_architecture(model_name):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
state_dict = model.state_dict()
|
||||
cube_list = []
|
||||
unique_layers = sorted(set(clean_layer_name(layer) for layer in state_dict.keys()))
|
||||
max_param_size = max(v.numel() for v in state_dict.values())
|
||||
min_param_size = min(v.numel() for v in state_dict.values())
|
||||
min_size = 0.5
|
||||
max_size = 2.0
|
||||
for layer, params in state_dict.items():
|
||||
param_size = params.numel()
|
||||
size = float(min_size + ((np.log(param_size) - np.log(min_param_size)) / (np.log(max_param_size) - np.log(min_param_size))) * (max_size - min_size))
|
||||
clean_name = clean_layer_name(layer)
|
||||
cube_list.append({
|
||||
'name': clean_name,
|
||||
'size': size,
|
||||
'param_count': param_size,
|
||||
})
|
||||
return cube_list
|
||||
Reference in New Issue
Block a user