Preserve comments in config.yaml

This commit is contained in:
Guillermo Kardolus
2024-04-01 11:57:36 -04:00
parent 18c42277bc
commit 12b56868d7
2 changed files with 153 additions and 2 deletions

View File

@@ -6,6 +6,8 @@ import (
"gopkg.in/yaml.v3"
"os"
"path/filepath"
"reflect"
"strconv"
)
const (
@@ -106,12 +108,37 @@ func (f *FileIO) ReadDefaults() types.Config {
}
func (f *FileIO) Write(config types.Config) error {
data, err := yaml.Marshal(config)
rootNode, err := f.readNode()
// If readNode returns an error or there was a problem reading the rootNode, initialize a new rootNode.
if err != nil || rootNode.Kind == 0 {
rootNode = yaml.Node{Kind: yaml.DocumentNode}
rootNode.Content = append(rootNode.Content, &yaml.Node{Kind: yaml.MappingNode})
}
updateNodeFromConfig(&rootNode, config)
modifiedContent, err := yaml.Marshal(&rootNode)
if err != nil {
return err
}
return os.WriteFile(f.configFilePath, data, 0644)
return os.WriteFile(f.configFilePath, modifiedContent, 0644)
}
func (f *FileIO) readNode() (yaml.Node, error) {
var rootNode yaml.Node
content, err := os.ReadFile(f.configFilePath)
if err != nil {
return rootNode, err
}
if err := yaml.Unmarshal(content, &rootNode); err != nil {
return rootNode, err
}
return rootNode, nil
}
func getPath() (string, error) {
@@ -150,3 +177,65 @@ func parseFile(fileName string) (types.Config, error) {
return result, nil
}
// updateNodeFromConfig updates the specified yaml.Node with values from the Config struct.
// It uses reflection to match struct fields with YAML tags, updating the node accordingly.
func updateNodeFromConfig(node *yaml.Node, config types.Config) {
t := reflect.TypeOf(config)
v := reflect.ValueOf(config)
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
value := v.Field(i)
yamlTag := field.Tag.Get("yaml")
if yamlTag == "" || yamlTag == "-" {
continue // Skip fields without yaml tag or marked to be ignored
}
// Convert value to string; adjust for different data types as needed
var strValue string
switch value.Kind() {
case reflect.String:
strValue = value.String()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
strValue = strconv.FormatInt(value.Int(), 10)
case reflect.Float32, reflect.Float64:
strValue = strconv.FormatFloat(value.Float(), 'f', -1, 64)
case reflect.Bool:
strValue = strconv.FormatBool(value.Bool())
default:
continue // Skip unsupported types for simplicity
}
setField(node, yamlTag, strValue)
}
}
// setField either updates an existing field or adds a new field to the YAML mapping node.
// It assumes the root node is a DocumentNode containing a MappingNode.
func setField(root *yaml.Node, key string, newValue string) {
found := false
if root.Kind == yaml.DocumentNode {
root = root.Content[0] // Move from document node to the actual mapping node.
}
if root.Kind != yaml.MappingNode {
return // If the root is not a mapping node, we can't do anything.
}
for i := 0; i < len(root.Content); i += 2 {
keyNode := root.Content[i]
if keyNode.Value == key {
valueNode := root.Content[i+1]
valueNode.Value = newValue
found = true
break
}
}
if !found { // If the key wasn't found, add it.
root.Content = append(root.Content, &yaml.Node{Kind: yaml.ScalarNode, Value: key}, &yaml.Node{Kind: yaml.ScalarNode, Value: newValue})
}
}

View File

@@ -184,6 +184,40 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
Expect(readConfig).To(Equal(expectedConfig))
})
it("preserves comments when writing the config to the file", func() {
configFile := filepath.Join(tmpDir, "config_with_comments.yaml")
err := os.WriteFile(configFile, []byte(`
# This is a model configuration
model: gpt-3.5-turbo # Default model
# Maximum number of tokens
max_tokens: 100
`), 0644)
Expect(err).NotTo(HaveOccurred())
configIO = config.New().WithConfigPath(configFile)
// Read the existing configuration, modify a value, and write it back
readConfig, err := configIO.Read()
Expect(err).NotTo(HaveOccurred())
// Modify the configuration
readConfig.Model = "new-model"
err = configIO.Write(readConfig)
Expect(err).NotTo(HaveOccurred())
// Read the file content back
content, err := os.ReadFile(configFile)
Expect(err).NotTo(HaveOccurred())
// Check if the comments are still present
Expect(string(content)).To(ContainSubstring("# This is a model configuration"))
Expect(string(content)).To(ContainSubstring("# Default model"))
Expect(string(content)).To(ContainSubstring("# Maximum number of tokens"))
// Verify the configuration was updated
Expect(string(content)).To(ContainSubstring("model: new-model"))
})
})
it("lists all the threads", func() {
@@ -624,6 +658,34 @@ func testIntegration(t *testing.T, when spec.G, it spec.S) {
Expect(output).To(ContainSubstring("* " + newModel + " (current)"))
})
it("has a configurable default context-window", func() {
defaults := config.New().ReadDefaults()
// Initial check for default context-window
output := runCommand("--config")
Expect(output).To(ContainSubstring(strconv.Itoa(defaults.ContextWindow)))
// Update and verify context-window
newContextWindow := "100000"
runCommand("--set-context-window", newContextWindow)
Expect(configFile).To(BeAnExistingFile())
checkConfigFileContent(newContextWindow)
// Verify update through --config
output = runCommand("--config")
Expect(output).To(ContainSubstring(newContextWindow))
// Environment variable takes precedence
envContext := "123"
modelEnvKey := strings.Replace(apiKeyEnvVar, "API_KEY", "CONTEXT_WINDOW", 1)
Expect(os.Setenv(modelEnvKey, envContext)).To(Succeed())
// Verify environment variable override
output = runCommand("--config")
Expect(output).To(ContainSubstring(envContext))
Expect(os.Unsetenv(modelEnvKey)).To(Succeed())
})
it("has a configurable default max-tokens", func() {
defaults := config.New().ReadDefaults()