mirror of
				https://github.com/kardolus/chatgpt-cli.git
				synced 2024-09-08 23:15:00 +03:00 
			
		
		
		
	Preserve comments in config.yaml
This commit is contained in:
		| @@ -6,6 +6,8 @@ import ( | ||||
| 	"gopkg.in/yaml.v3" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"reflect" | ||||
| 	"strconv" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -106,12 +108,37 @@ func (f *FileIO) ReadDefaults() types.Config { | ||||
| } | ||||
|  | ||||
| func (f *FileIO) Write(config types.Config) error { | ||||
| 	data, err := yaml.Marshal(config) | ||||
| 	rootNode, err := f.readNode() | ||||
|  | ||||
| 	// If readNode returns an error or there was a problem reading the rootNode, initialize a new rootNode. | ||||
| 	if err != nil || rootNode.Kind == 0 { | ||||
| 		rootNode = yaml.Node{Kind: yaml.DocumentNode} | ||||
| 		rootNode.Content = append(rootNode.Content, &yaml.Node{Kind: yaml.MappingNode}) | ||||
| 	} | ||||
|  | ||||
| 	updateNodeFromConfig(&rootNode, config) | ||||
|  | ||||
| 	modifiedContent, err := yaml.Marshal(&rootNode) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return os.WriteFile(f.configFilePath, data, 0644) | ||||
| 	return os.WriteFile(f.configFilePath, modifiedContent, 0644) | ||||
| } | ||||
|  | ||||
| func (f *FileIO) readNode() (yaml.Node, error) { | ||||
| 	var rootNode yaml.Node | ||||
|  | ||||
| 	content, err := os.ReadFile(f.configFilePath) | ||||
| 	if err != nil { | ||||
| 		return rootNode, err | ||||
| 	} | ||||
|  | ||||
| 	if err := yaml.Unmarshal(content, &rootNode); err != nil { | ||||
| 		return rootNode, err | ||||
| 	} | ||||
|  | ||||
| 	return rootNode, nil | ||||
| } | ||||
|  | ||||
| func getPath() (string, error) { | ||||
| @@ -150,3 +177,65 @@ func parseFile(fileName string) (types.Config, error) { | ||||
|  | ||||
| 	return result, nil | ||||
| } | ||||
|  | ||||
| // updateNodeFromConfig updates the specified yaml.Node with values from the Config struct. | ||||
| // It uses reflection to match struct fields with YAML tags, updating the node accordingly. | ||||
| func updateNodeFromConfig(node *yaml.Node, config types.Config) { | ||||
| 	t := reflect.TypeOf(config) | ||||
| 	v := reflect.ValueOf(config) | ||||
|  | ||||
| 	for i := 0; i < t.NumField(); i++ { | ||||
| 		field := t.Field(i) | ||||
| 		value := v.Field(i) | ||||
| 		yamlTag := field.Tag.Get("yaml") | ||||
|  | ||||
| 		if yamlTag == "" || yamlTag == "-" { | ||||
| 			continue // Skip fields without yaml tag or marked to be ignored | ||||
| 		} | ||||
|  | ||||
| 		// Convert value to string; adjust for different data types as needed | ||||
| 		var strValue string | ||||
| 		switch value.Kind() { | ||||
| 		case reflect.String: | ||||
| 			strValue = value.String() | ||||
| 		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||
| 			strValue = strconv.FormatInt(value.Int(), 10) | ||||
| 		case reflect.Float32, reflect.Float64: | ||||
| 			strValue = strconv.FormatFloat(value.Float(), 'f', -1, 64) | ||||
| 		case reflect.Bool: | ||||
| 			strValue = strconv.FormatBool(value.Bool()) | ||||
| 		default: | ||||
| 			continue // Skip unsupported types for simplicity | ||||
| 		} | ||||
|  | ||||
| 		setField(node, yamlTag, strValue) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // setField either updates an existing field or adds a new field to the YAML mapping node. | ||||
| // It assumes the root node is a DocumentNode containing a MappingNode. | ||||
| func setField(root *yaml.Node, key string, newValue string) { | ||||
| 	found := false | ||||
|  | ||||
| 	if root.Kind == yaml.DocumentNode { | ||||
| 		root = root.Content[0] // Move from document node to the actual mapping node. | ||||
| 	} | ||||
|  | ||||
| 	if root.Kind != yaml.MappingNode { | ||||
| 		return // If the root is not a mapping node, we can't do anything. | ||||
| 	} | ||||
|  | ||||
| 	for i := 0; i < len(root.Content); i += 2 { | ||||
| 		keyNode := root.Content[i] | ||||
| 		if keyNode.Value == key { | ||||
| 			valueNode := root.Content[i+1] | ||||
| 			valueNode.Value = newValue | ||||
| 			found = true | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if !found { // If the key wasn't found, add it. | ||||
| 		root.Content = append(root.Content, &yaml.Node{Kind: yaml.ScalarNode, Value: key}, &yaml.Node{Kind: yaml.ScalarNode, Value: newValue}) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -184,6 +184,40 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) { | ||||
|  | ||||
| 				Expect(readConfig).To(Equal(expectedConfig)) | ||||
| 			}) | ||||
| 			it("preserves comments when writing the config to the file", func() { | ||||
| 				configFile := filepath.Join(tmpDir, "config_with_comments.yaml") | ||||
| 				err := os.WriteFile(configFile, []byte(` | ||||
| # This is a model configuration | ||||
| model: gpt-3.5-turbo # Default model | ||||
| # Maximum number of tokens | ||||
| max_tokens: 100 | ||||
| `), 0644) | ||||
| 				Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 				configIO = config.New().WithConfigPath(configFile) | ||||
|  | ||||
| 				// Read the existing configuration, modify a value, and write it back | ||||
| 				readConfig, err := configIO.Read() | ||||
| 				Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 				// Modify the configuration | ||||
| 				readConfig.Model = "new-model" | ||||
| 				err = configIO.Write(readConfig) | ||||
| 				Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 				// Read the file content back | ||||
| 				content, err := os.ReadFile(configFile) | ||||
| 				Expect(err).NotTo(HaveOccurred()) | ||||
|  | ||||
| 				// Check if the comments are still present | ||||
| 				Expect(string(content)).To(ContainSubstring("# This is a model configuration")) | ||||
| 				Expect(string(content)).To(ContainSubstring("# Default model")) | ||||
| 				Expect(string(content)).To(ContainSubstring("# Maximum number of tokens")) | ||||
|  | ||||
| 				// Verify the configuration was updated | ||||
| 				Expect(string(content)).To(ContainSubstring("model: new-model")) | ||||
| 			}) | ||||
|  | ||||
| 		}) | ||||
|  | ||||
| 		it("lists all the threads", func() { | ||||
| @@ -624,6 +658,34 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) { | ||||
| 					Expect(output).To(ContainSubstring("* " + newModel + " (current)")) | ||||
| 				}) | ||||
|  | ||||
| 				it("has a configurable default context-window", func() { | ||||
| 					defaults := config.New().ReadDefaults() | ||||
|  | ||||
| 					// Initial check for default context-window | ||||
| 					output := runCommand("--config") | ||||
| 					Expect(output).To(ContainSubstring(strconv.Itoa(defaults.ContextWindow))) | ||||
|  | ||||
| 					// Update and verify context-window | ||||
| 					newContextWindow := "100000" | ||||
| 					runCommand("--set-context-window", newContextWindow) | ||||
| 					Expect(configFile).To(BeAnExistingFile()) | ||||
| 					checkConfigFileContent(newContextWindow) | ||||
|  | ||||
| 					// Verify update through --config | ||||
| 					output = runCommand("--config") | ||||
| 					Expect(output).To(ContainSubstring(newContextWindow)) | ||||
|  | ||||
| 					// Environment variable takes precedence | ||||
| 					envContext := "123" | ||||
| 					modelEnvKey := strings.Replace(apiKeyEnvVar, "API_KEY", "CONTEXT_WINDOW", 1) | ||||
| 					Expect(os.Setenv(modelEnvKey, envContext)).To(Succeed()) | ||||
|  | ||||
| 					// Verify environment variable override | ||||
| 					output = runCommand("--config") | ||||
| 					Expect(output).To(ContainSubstring(envContext)) | ||||
| 					Expect(os.Unsetenv(modelEnvKey)).To(Succeed()) | ||||
| 				}) | ||||
|  | ||||
| 				it("has a configurable default max-tokens", func() { | ||||
| 					defaults := config.New().ReadDefaults() | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Guillermo Kardolus
					Guillermo Kardolus