From 15ddea3d8eb7dba8442597d7078e5aeabcd7c93a Mon Sep 17 00:00:00 2001 From: dane madsen Date: Sat, 11 Nov 2023 13:16:14 +1000 Subject: [PATCH] work on model settings --- assets/default_parameters.json | 3 +- core/core.cpp | 6 +- core/core.h | 2 +- lib/core/bindings.dart | 2 +- lib/core/local_generation.dart | 2 +- lib/core/remote_generation.dart | 2 - lib/pages/home_page.dart | 6 +- lib/pages/model_page.dart | 102 +++++++++++++++++++++++------- lib/utilities/memory_manager.dart | 12 ++-- lib/utilities/model.dart | 26 ++++---- 10 files changed, 107 insertions(+), 56 deletions(-) diff --git a/assets/default_parameters.json b/assets/default_parameters.json index 25e822f..46a45a2 100644 --- a/assets/default_parameters.json +++ b/assets/default_parameters.json @@ -3,8 +3,7 @@ "instruct":true, "interactive":true, "random_seed":true, - "model_name":"", - "model_path":"", + "path":"", "seed":-1, "n_ctx":512, "n_batch":8, diff --git a/core/core.cpp b/core/core.cpp index 52bbeab..3980401 100644 --- a/core/core.cpp +++ b/core/core.cpp @@ -72,7 +72,7 @@ int core_init(struct maid_params *mparams) { params.sparams.mirostat_eta = (*mparams).mirostat_eta ? (*mparams).mirostat_eta : 0.10f; params.sparams.penalize_nl = (*mparams).penalize_nl != 0; - params.model = (*mparams).model_path; + params.model = (*mparams).path; params.prompt = (*mparams).preprompt; params.input_prefix = (*mparams).input_prefix; params.input_suffix = (*mparams).input_suffix; @@ -84,10 +84,10 @@ int core_init(struct maid_params *mparams) { std::tie(model, ctx) = llama_init_from_gpt_params(params); if (model == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, (*mparams).model_path); + fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, (*mparams).path); return 1; } else if (ctx == NULL) { - fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, (*mparams).model_path); + fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, (*mparams).path); llama_free_model(model); return 1; } diff --git a/core/core.h b/core/core.h index a91eead..d683b28 100644 --- a/core/core.h +++ b/core/core.h @@ -16,7 +16,7 @@ struct maid_params { unsigned char interactive; unsigned char memory_f16; - char *model_path; + char *path; char *preprompt; char *input_prefix; // string to prefix user inputs with char *input_suffix; // string to suffix user inputs with diff --git a/lib/core/bindings.dart b/lib/core/bindings.dart index baa000c..65820c6 100644 --- a/lib/core/bindings.dart +++ b/lib/core/bindings.dart @@ -77,7 +77,7 @@ final class maid_params extends ffi.Struct { @ffi.UnsignedChar() external int memory_f16; - external ffi.Pointer model_path; + external ffi.Pointer path; external ffi.Pointer preprompt; diff --git a/lib/core/local_generation.dart b/lib/core/local_generation.dart index bc2001f..3f428a4 100644 --- a/lib/core/local_generation.dart +++ b/lib/core/local_generation.dart @@ -81,7 +81,7 @@ class LocalGeneration { _hasStarted = true; final params = calloc(); - params.ref.model_path = model.parameters["model_path"].toString().toNativeUtf8().cast(); + params.ref.path = model.parameters["path"].toString().toNativeUtf8().cast(); params.ref.preprompt = character.getPrePrompt().toNativeUtf8().cast(); params.ref.input_prefix = character.userAliasController.text.trim().toNativeUtf8().cast(); params.ref.input_suffix = character.responseAliasController.text.trim().toNativeUtf8().cast(); diff --git a/lib/core/remote_generation.dart b/lib/core/remote_generation.dart index 681d53e..4217126 100644 --- a/lib/core/remote_generation.dart +++ b/lib/core/remote_generation.dart @@ -45,8 +45,6 @@ class RemoteGeneration { } }); - print(_messages); - try { var request = http.Request("POST", url) ..headers.addAll(headers) diff --git a/lib/pages/home_page.dart b/lib/pages/home_page.dart index a1e3a2b..9e13ac2 100644 --- a/lib/pages/home_page.dart +++ b/lib/pages/home_page.dart @@ -92,7 +92,7 @@ class MaidHomePageState extends State { ); MessageManager.add(UniqueKey()); - if (MemoryManager.checkFileExists(model.parameters["model_path"])) { + if (MemoryManager.checkFileExists(model.parameters["path"])) { GenerationManager.prompt(promptController.text.trim()); setState(() { model.busy = true; @@ -290,7 +290,7 @@ class MaidHomePageState extends State { enableInteractiveSelection: true, onSubmitted: (value) { if (!model.busy) { - if (model.parameters["model_path"] + if (model.parameters["path"] .toString() .isEmpty) { _missingModelDialog(); @@ -311,7 +311,7 @@ class MaidHomePageState extends State { IconButton( onPressed: () { if (!model.busy) { - if (model.parameters["model_path"] + if (model.parameters["path"] .toString() .isEmpty) { _missingModelDialog(); diff --git a/lib/pages/model_page.dart b/lib/pages/model_page.dart index cb15b40..dcf4593 100644 --- a/lib/pages/model_page.dart +++ b/lib/pages/model_page.dart @@ -50,7 +50,7 @@ class _ModelPageState extends State { children: [ const SizedBox(height: 10.0), Text( - model.name, + model.preset, textAlign: TextAlign.center, style: Theme.of(context).textTheme.titleLarge, ), @@ -66,7 +66,7 @@ class _ModelPageState extends State { () async { MemoryManager.save(); model = Model(); - model.name= "New Preset"; + model.preset= "New Preset"; setState(() {}); } ); @@ -88,9 +88,9 @@ class _ModelPageState extends State { child: TextField( cursorColor: Theme.of(context).colorScheme.secondary, decoration: const InputDecoration( - labelText: "Name", + labelText: "Preset", ), - controller: TextEditingController(text: model.name), + controller: TextEditingController(text: model.preset), onSubmitted: (value) { if (MemoryManager.getModels().contains(value)) { MemoryManager.setModel(value); @@ -111,25 +111,91 @@ class _ModelPageState extends State { endIndent: 10, color: Theme.of(context).colorScheme.primary, ), + Text( + "Remote Model", + textAlign: TextAlign.center, + style: Theme.of(context).textTheme.titleSmall, + ), + const SizedBox(height: 20.0), + ListTile( + title: Row( + children: [ + const Expanded( + child: Text("Remote Model"), + ), + Expanded( + flex: 2, + child: TextField( + cursorColor: Theme.of(context).colorScheme.secondary, + decoration: const InputDecoration( + labelText: "Model", + ), + controller: TextEditingController(text: model.parameters["remote_model"]), + onSubmitted: (value) { + setState(() { + model.parameters["remote_model"] = value; + }); + }, + ), + ), + ], + ), + ), + const SizedBox(height: 8.0), + ListTile( + title: Row( + children: [ + const Expanded( + child: Text("Remote Tag"), + ), + Expanded( + flex: 2, + child: TextField( + cursorColor: Theme.of(context).colorScheme.secondary, + decoration: const InputDecoration( + labelText: "Tag", + ), + controller: TextEditingController(text: model.parameters["remote_tag"]), + onSubmitted: (value) { + setState(() { + model.parameters["remote_tag"] = value; + }); + }, + ), + ), + ], + ), + ), + const SizedBox(height: 20.0), + Divider( + height: 20, + indent: 10, + endIndent: 10, + color: Theme.of(context).colorScheme.primary, + ), + Text( + "Local Model", + textAlign: TextAlign.center, + style: Theme.of(context).textTheme.titleSmall, + ), + const SizedBox(height: 20.0), if (model.local) ..._localOptions(), DoubleButtonRow( - leftText: "Load Model", + leftText: "Load GGUF", leftOnPressed: () async { await storageOperationDialog(context, model.loadModelFile); - if (model.parameters["model_path"] != null) model.local = true; + if (model.parameters["path"] != null) model.local = true; setState(() {}); }, - rightText: "Unload Model", + rightText: "Unload GGUF", rightOnPressed: () { - model.parameters["model_path"] = null; - model.parameters["model_name"] = null; + model.parameters["path"] = null; model.local = false; setState(() {}); } ), - //if (model.remote) - // _remoteOptions(), + const SizedBox(height: 20.0), Divider( height: 20, indent: 10, @@ -444,19 +510,7 @@ class _ModelPageState extends State { Padding( padding: const EdgeInsets.all(8.0), child: Text( - "Model Path: ${model.parameters["model_path"]}", - ), - ), - const SizedBox(height: 15.0), - ]; - } - - List _remoteOptions() { - return [ - Padding( - padding: const EdgeInsets.all(8.0), - child: Text( - "Model Name: ${model.parameters["model_name"]}", + "Model Path: ${model.parameters["path"]}", ), ), const SizedBox(height: 15.0), diff --git a/lib/utilities/memory_manager.dart b/lib/utilities/memory_manager.dart index a40f8a7..e895a98 100644 --- a/lib/utilities/memory_manager.dart +++ b/lib/utilities/memory_manager.dart @@ -41,14 +41,14 @@ class MemoryManager { prefs.setBool("remote", GenerationManager.remote); - _models[model.name] = model.toMap(); - Logger.log("Model Saved: ${model.name}"); + _models[model.preset] = model.toMap(); + Logger.log("Model Saved: ${model.preset}"); _characters[character.name] = character.toMap(); Logger.log("Character Saved: ${character.name}"); prefs.setString("models", json.encode(_models)); prefs.setString("characters", json.encode(_characters)); - prefs.setString("current_model", model.name); + prefs.setString("current_model", model.preset); prefs.setString("current_character", character.name); LocalGeneration.instance.cleanup(); @@ -66,9 +66,9 @@ class MemoryManager { } static void updateModel(String newName) { - String oldName = model.name; + String oldName = model.preset; Logger.log("Updating model $oldName ====> $newName"); - model.name = newName; + model.preset = newName; _models.remove(oldName); save(); } @@ -108,7 +108,7 @@ class MemoryManager { static void setModel(String modelName) { save(); model = Model.fromMap(_models[modelName] ?? {}); - Logger.log("Model Set: ${model.name}"); + Logger.log("Model Set: ${model.preset}"); save(); } diff --git a/lib/utilities/model.dart b/lib/utilities/model.dart index 14ea5c8..dfeee69 100644 --- a/lib/utilities/model.dart +++ b/lib/utilities/model.dart @@ -12,11 +12,10 @@ import 'package:maid/utilities/memory_manager.dart'; Model model = Model(); class Model { - String name = "Default"; + String preset = "Default"; Map parameters = {}; - bool local = true; - bool remote = false; + bool local = false; bool busy = false; Model() { @@ -27,7 +26,7 @@ class Model { if (inputJson.isEmpty) { resetAll(); } else { - name = inputJson["name"] ?? "Default"; + preset = inputJson["preset"] ?? "Default"; parameters = inputJson; Logger.log("Model created with name: ${inputJson["name"]}"); } @@ -37,8 +36,9 @@ class Model { Map jsonModel = {}; jsonModel = parameters; - jsonModel["name"] = name; - Logger.log("Model JSON created with name: $name"); + jsonModel["preset"] = preset; + jsonModel["local"] = local; + Logger.log("Model JSON created with name: $preset"); return jsonModel; } @@ -49,18 +49,19 @@ class Model { await rootBundle.loadString('assets/default_parameters.json'); parameters = json.decode(jsonString); + local = false; MemoryManager.save(); } Future exportModelParameters(BuildContext context) async { try { - parameters["name"] = name; - parameters["remote"] = remote; + parameters["preset"] = preset; + parameters["local"] = local; String jsonString = json.encode(parameters); - File? file = await FileManager.save(context, name); + File? file = await FileManager.save(context, preset); if (file == null) return "Error saving file"; @@ -88,8 +89,8 @@ class Model { resetAll(); return "Failed to decode parameters"; } else { - remote = parameters["remote"] ?? false; - name = parameters["name"] ?? "Default"; + local = parameters["local"] ?? false; + preset = parameters["preset"] ?? "Default"; } } catch (e) { resetAll(); @@ -107,8 +108,7 @@ class Model { Logger.log("Loading model from $file"); - parameters["model_path"] = file.path; - parameters["model_name"] = path.basename(file.path); + parameters["path"] = file.path; } catch (e) { return "Error: $e"; }