From 70c808d6a232d71039e4ea71f0d58071da8e2089 Mon Sep 17 00:00:00 2001 From: kardolus Date: Fri, 16 Aug 2024 11:50:37 -0400 Subject: [PATCH] Add token usage tracking in query mode --- README.md | 2 ++ cmd/chatgpt/main.go | 16 ++++++++------ configmanager/configmanager_test.go | 33 ++++++++++++++++++++++++++++- integration/integration_test.go | 10 +++++++++ types/config.go | 1 + 5 files changed, 55 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 8f6e71f..c2e6032 100644 --- a/README.md +++ b/README.md @@ -205,6 +205,7 @@ Configuration variables: | `auth_token_prefix` | The prefix to be added before the token in the `auth_header`. | 'Bearer ' | | `command_prompt` | The command prompt in interactive mode. Should be single-quoted. | '[%datetime] [Q%counter]' | | `auto_create_new_thread` | If set to `true`, a new thread with a unique identifier (e.g., `int_a1b2`) will be created for each interactive session. If `false`, the CLI will use the thread specified by the `thread` parameter. | `false` | +| `track_token_usage` | If set to true, displays the total token usage after each query in --query mode, helping you monitor API usage. | `false` | #### Variables for interactive mode: @@ -285,6 +286,7 @@ auth_header: api-key auth_token_prefix: " " command_prompt: '[%datetime] [Q%counter]' auto_create_new_thread: false +track_token_usage: false ``` You can set the API key either in the `config.yaml` file as shown above or export it as an environment variable: diff --git a/cmd/chatgpt/main.go b/cmd/chatgpt/main.go index ebd4a11..652f034 100644 --- a/cmd/chatgpt/main.go +++ b/cmd/chatgpt/main.go @@ -211,17 +211,16 @@ func run(cmd *cobra.Command, args []string) error { } defer rl.Close() - prompt := func(counter int) string { - cm := configmanager.New(config.New()) - return config.FormatPrompt(cm.Config.CommandPrompt, counter, 0, time.Now()) + prompt := func(counter, usage int) string { + return config.FormatPrompt(client.Config.CommandPrompt, counter, usage, time.Now()) } qNum, usage := 1, 0 for { if queryMode { - rl.SetPrompt(prompt(usage)) + rl.SetPrompt(prompt(qNum, usage)) } else { - rl.SetPrompt(prompt(qNum)) + rl.SetPrompt(prompt(qNum, usage)) } line, err := rl.Readline() @@ -245,6 +244,7 @@ func run(cmd *cobra.Command, args []string) error { } else { fmt.Printf("%s\n\n", result) usage += qUsage + qNum++ } } else { if err := client.Stream(line); err != nil { @@ -260,11 +260,15 @@ func run(cmd *cobra.Command, args []string) error { return errors.New("you must specify your query") } if queryMode { - result, _, err := client.Query(strings.Join(args, " ")) + result, usage, err := client.Query(strings.Join(args, " ")) if err != nil { return err } fmt.Println(result) + + if client.Config.TrackTokenUsage { + fmt.Printf("\n[Token Usage: %d]\n", usage) + } } else { if err := client.Stream(strings.Join(args, " ")); err != nil { return err diff --git a/configmanager/configmanager_test.go b/configmanager/configmanager_test.go index f4191b3..0d68dde 100644 --- a/configmanager/configmanager_test.go +++ b/configmanager/configmanager_test.go @@ -37,6 +37,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { defaultAuthTokenPrefix = "default-auth-token-prefix" defaultOmitHistory = false defaultAutoCreateNewThread = false + defaultTrackTokenUsage = false defaultTemperature = 1.1 defaultTopP = 2.2 defaultFrequencyPenalty = 3.3 @@ -76,6 +77,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { PresencePenalty: defaultPresencePenalty, CommandPrompt: defaultCommandPrompt, AutoCreateNewThread: defaultAutoCreateNewThread, + TrackTokenUsage: defaultTrackTokenUsage, } envPrefix = strings.ToUpper(defaultConfig.Name) + "_" @@ -111,6 +113,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.PresencePenalty).To(Equal(defaultPresencePenalty)) Expect(subject.Config.CommandPrompt).To(Equal(defaultCommandPrompt)) Expect(subject.Config.AutoCreateNewThread).To(Equal(defaultAutoCreateNewThread)) + Expect(subject.Config.TrackTokenUsage).To(Equal(defaultTrackTokenUsage)) }) it("should prioritize user-provided config over defaults", func() { @@ -132,6 +135,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { PresencePenalty: 5.5, CommandPrompt: "user-command-prompt", AutoCreateNewThread: true, + TrackTokenUsage: true, } mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1) @@ -149,6 +153,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.AuthTokenPrefix).To(Equal("user-auth-token-prefix")) Expect(subject.Config.OmitHistory).To(BeTrue()) Expect(subject.Config.AutoCreateNewThread).To(BeTrue()) + Expect(subject.Config.TrackTokenUsage).To(BeTrue()) Expect(subject.Config.Role).To(Equal("user-role")) Expect(subject.Config.Thread).To(Equal("user-thread")) Expect(subject.Config.Temperature).To(Equal(2.5)) @@ -170,6 +175,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { os.Setenv(envPrefix+"AUTH_TOKEN_PREFIX", "env-auth-token-prefix") os.Setenv(envPrefix+"OMIT_HISTORY", "true") os.Setenv(envPrefix+"AUTO_CREATE_NEW_THREAD", "true") + os.Setenv(envPrefix+"TRACK_TOKEN_USAGE", "true") os.Setenv(envPrefix+"ROLE", "env-role") os.Setenv(envPrefix+"THREAD", "env-thread") os.Setenv(envPrefix+"TEMPERATURE", "2.2") @@ -194,6 +200,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.AuthTokenPrefix).To(Equal("env-auth-token-prefix")) Expect(subject.Config.OmitHistory).To(BeTrue()) Expect(subject.Config.AutoCreateNewThread).To(BeTrue()) + Expect(subject.Config.TrackTokenUsage).To(BeTrue()) Expect(subject.Config.Role).To(Equal("env-role")) Expect(subject.Config.Thread).To(Equal("env-thread")) Expect(subject.Config.Temperature).To(Equal(2.2)) @@ -215,6 +222,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { os.Setenv(envPrefix+"AUTH_TOKEN_PREFIX", "env-auth-token-prefix") os.Setenv(envPrefix+"OMIT_HISTORY", "true") os.Setenv(envPrefix+"AUTO_CREATE_NEW_THREAD", "true") + os.Setenv(envPrefix+"TRACK_TOKEN_USAGE", "true") os.Setenv(envPrefix+"ROLE", "env-role") os.Setenv(envPrefix+"THREAD", "env-thread") os.Setenv(envPrefix+"TEMPERATURE", "2.2") @@ -235,6 +243,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { AuthTokenPrefix: "user-auth-token-prefix", OmitHistory: false, AutoCreateNewThread: false, + TrackTokenUsage: false, Role: "user-role", Thread: "user-thread", Temperature: 1.5, @@ -260,6 +269,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) { Expect(subject.Config.AuthTokenPrefix).To(Equal("env-auth-token-prefix")) Expect(subject.Config.OmitHistory).To(BeTrue()) Expect(subject.Config.AutoCreateNewThread).To(BeTrue()) + Expect(subject.Config.TrackTokenUsage).To(BeTrue()) Expect(subject.Config.Role).To(Equal("env-role")) Expect(subject.Config.Thread).To(Equal("env-thread")) Expect(subject.Config.Temperature).To(Equal(2.2)) @@ -398,7 +408,28 @@ func setValue(config *types.Config, fieldName string, value interface{}) { } func unsetEnvironmentVariables(envPrefix string) { - variables := []string{"API_KEY", "MODEL", "MAX_TOKENS", "CONTEXT_WINDOW", "URL", "COMPLETIONS_PATH", "MODELS_PATH", "AUTH_HEADER", "AUTH_TOKEN_PREFIX", "OMIT_HISTORY", "ROLE", "THREAD", "TEMPERATURE", "TOP_P", "FREQUENCY_PENALTY", "PRESENCE_PENALTY", "COMMAND_PROMPT", "AUTO_CREATE_NEW_THREAD"} + variables := []string{ + "API_KEY", + "MODEL", + "MAX_TOKENS", + "CONTEXT_WINDOW", + "URL", + "COMPLETIONS_PATH", + "MODELS_PATH", + "AUTH_HEADER", + "AUTH_TOKEN_PREFIX", + "OMIT_HISTORY", + "ROLE", + "THREAD", + "TEMPERATURE", + "TOP_P", + "FREQUENCY_PENALTY", + "PRESENCE_PENALTY", + "COMMAND_PROMPT", + "AUTO_CREATE_NEW_THREAD", + "TRACK_TOKEN_USAGE", + } + for _, variable := range variables { Expect(os.Unsetenv(envPrefix + variable)).To(Succeed()) } diff --git a/integration/integration_test.go b/integration/integration_test.go index b69ee01..78b808d 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -471,10 +471,20 @@ max_tokens: 100 }) it("should return the expected result for the --query flag", func() { + Expect(os.Setenv("OPENAI_TRACK_TOKEN_USAGE", "false")).To(Succeed()) + output := runCommand("--query", "some-query") expectedResponse := `I don't have personal opinions about bars, but here are some popular bars in Red Hook, Brooklyn:` Expect(output).To(ContainSubstring(expectedResponse)) + Expect(output).NotTo(ContainSubstring("Token Usage:")) + }) + + it("should display token usage after a query when configured to do so", func() { + Expect(os.Setenv("OPENAI_TRACK_TOKEN_USAGE", "true")).To(Succeed()) + + output := runCommand("--query", "tell me a 5 line joke") + Expect(output).To(ContainSubstring("Token Usage:")) }) it("should assemble http errors as expected", func() { diff --git a/types/config.go b/types/config.go index 4600fe6..430bd15 100644 --- a/types/config.go +++ b/types/config.go @@ -20,4 +20,5 @@ type Config struct { AuthTokenPrefix string `yaml:"auth_token_prefix"` CommandPrompt string `yaml:"command_prompt"` AutoCreateNewThread bool `yaml:"auto_create_new_thread"` + TrackTokenUsage bool `yaml:"track_token_usage"` }