Add token usage tracking in query mode

This commit is contained in:
kardolus
2024-08-16 11:50:37 -04:00
parent 4338c0e07d
commit 70c808d6a2
5 changed files with 55 additions and 7 deletions

View File

@@ -205,6 +205,7 @@ Configuration variables:
| `auth_token_prefix` | The prefix to be added before the token in the `auth_header`. | 'Bearer ' | | `auth_token_prefix` | The prefix to be added before the token in the `auth_header`. | 'Bearer ' |
| `command_prompt` | The command prompt in interactive mode. Should be single-quoted. | '[%datetime] [Q%counter]' | | `command_prompt` | The command prompt in interactive mode. Should be single-quoted. | '[%datetime] [Q%counter]' |
| `auto_create_new_thread` | If set to `true`, a new thread with a unique identifier (e.g., `int_a1b2`) will be created for each interactive session. If `false`, the CLI will use the thread specified by the `thread` parameter. | `false` | | `auto_create_new_thread` | If set to `true`, a new thread with a unique identifier (e.g., `int_a1b2`) will be created for each interactive session. If `false`, the CLI will use the thread specified by the `thread` parameter. | `false` |
| `track_token_usage` | If set to true, displays the total token usage after each query in --query mode, helping you monitor API usage. | `false` |
#### Variables for interactive mode: #### Variables for interactive mode:
@@ -285,6 +286,7 @@ auth_header: api-key
auth_token_prefix: " " auth_token_prefix: " "
command_prompt: '[%datetime] [Q%counter]' command_prompt: '[%datetime] [Q%counter]'
auto_create_new_thread: false auto_create_new_thread: false
track_token_usage: false
``` ```
You can set the API key either in the `config.yaml` file as shown above or export it as an environment variable: You can set the API key either in the `config.yaml` file as shown above or export it as an environment variable:

View File

@@ -211,17 +211,16 @@ func run(cmd *cobra.Command, args []string) error {
} }
defer rl.Close() defer rl.Close()
prompt := func(counter int) string { prompt := func(counter, usage int) string {
cm := configmanager.New(config.New()) return config.FormatPrompt(client.Config.CommandPrompt, counter, usage, time.Now())
return config.FormatPrompt(cm.Config.CommandPrompt, counter, 0, time.Now())
} }
qNum, usage := 1, 0 qNum, usage := 1, 0
for { for {
if queryMode { if queryMode {
rl.SetPrompt(prompt(usage)) rl.SetPrompt(prompt(qNum, usage))
} else { } else {
rl.SetPrompt(prompt(qNum)) rl.SetPrompt(prompt(qNum, usage))
} }
line, err := rl.Readline() line, err := rl.Readline()
@@ -245,6 +244,7 @@ func run(cmd *cobra.Command, args []string) error {
} else { } else {
fmt.Printf("%s\n\n", result) fmt.Printf("%s\n\n", result)
usage += qUsage usage += qUsage
qNum++
} }
} else { } else {
if err := client.Stream(line); err != nil { if err := client.Stream(line); err != nil {
@@ -260,11 +260,15 @@ func run(cmd *cobra.Command, args []string) error {
return errors.New("you must specify your query") return errors.New("you must specify your query")
} }
if queryMode { if queryMode {
result, _, err := client.Query(strings.Join(args, " ")) result, usage, err := client.Query(strings.Join(args, " "))
if err != nil { if err != nil {
return err return err
} }
fmt.Println(result) fmt.Println(result)
if client.Config.TrackTokenUsage {
fmt.Printf("\n[Token Usage: %d]\n", usage)
}
} else { } else {
if err := client.Stream(strings.Join(args, " ")); err != nil { if err := client.Stream(strings.Join(args, " ")); err != nil {
return err return err

View File

@@ -37,6 +37,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
defaultAuthTokenPrefix = "default-auth-token-prefix" defaultAuthTokenPrefix = "default-auth-token-prefix"
defaultOmitHistory = false defaultOmitHistory = false
defaultAutoCreateNewThread = false defaultAutoCreateNewThread = false
defaultTrackTokenUsage = false
defaultTemperature = 1.1 defaultTemperature = 1.1
defaultTopP = 2.2 defaultTopP = 2.2
defaultFrequencyPenalty = 3.3 defaultFrequencyPenalty = 3.3
@@ -76,6 +77,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
PresencePenalty: defaultPresencePenalty, PresencePenalty: defaultPresencePenalty,
CommandPrompt: defaultCommandPrompt, CommandPrompt: defaultCommandPrompt,
AutoCreateNewThread: defaultAutoCreateNewThread, AutoCreateNewThread: defaultAutoCreateNewThread,
TrackTokenUsage: defaultTrackTokenUsage,
} }
envPrefix = strings.ToUpper(defaultConfig.Name) + "_" envPrefix = strings.ToUpper(defaultConfig.Name) + "_"
@@ -111,6 +113,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.PresencePenalty).To(Equal(defaultPresencePenalty)) Expect(subject.Config.PresencePenalty).To(Equal(defaultPresencePenalty))
Expect(subject.Config.CommandPrompt).To(Equal(defaultCommandPrompt)) Expect(subject.Config.CommandPrompt).To(Equal(defaultCommandPrompt))
Expect(subject.Config.AutoCreateNewThread).To(Equal(defaultAutoCreateNewThread)) Expect(subject.Config.AutoCreateNewThread).To(Equal(defaultAutoCreateNewThread))
Expect(subject.Config.TrackTokenUsage).To(Equal(defaultTrackTokenUsage))
}) })
it("should prioritize user-provided config over defaults", func() { it("should prioritize user-provided config over defaults", func() {
@@ -132,6 +135,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
PresencePenalty: 5.5, PresencePenalty: 5.5,
CommandPrompt: "user-command-prompt", CommandPrompt: "user-command-prompt",
AutoCreateNewThread: true, AutoCreateNewThread: true,
TrackTokenUsage: true,
} }
mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1) mockConfigStore.EXPECT().ReadDefaults().Return(defaultConfig).Times(1)
@@ -149,6 +153,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.AuthTokenPrefix).To(Equal("user-auth-token-prefix")) Expect(subject.Config.AuthTokenPrefix).To(Equal("user-auth-token-prefix"))
Expect(subject.Config.OmitHistory).To(BeTrue()) Expect(subject.Config.OmitHistory).To(BeTrue())
Expect(subject.Config.AutoCreateNewThread).To(BeTrue()) Expect(subject.Config.AutoCreateNewThread).To(BeTrue())
Expect(subject.Config.TrackTokenUsage).To(BeTrue())
Expect(subject.Config.Role).To(Equal("user-role")) Expect(subject.Config.Role).To(Equal("user-role"))
Expect(subject.Config.Thread).To(Equal("user-thread")) Expect(subject.Config.Thread).To(Equal("user-thread"))
Expect(subject.Config.Temperature).To(Equal(2.5)) Expect(subject.Config.Temperature).To(Equal(2.5))
@@ -170,6 +175,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
os.Setenv(envPrefix+"AUTH_TOKEN_PREFIX", "env-auth-token-prefix") os.Setenv(envPrefix+"AUTH_TOKEN_PREFIX", "env-auth-token-prefix")
os.Setenv(envPrefix+"OMIT_HISTORY", "true") os.Setenv(envPrefix+"OMIT_HISTORY", "true")
os.Setenv(envPrefix+"AUTO_CREATE_NEW_THREAD", "true") os.Setenv(envPrefix+"AUTO_CREATE_NEW_THREAD", "true")
os.Setenv(envPrefix+"TRACK_TOKEN_USAGE", "true")
os.Setenv(envPrefix+"ROLE", "env-role") os.Setenv(envPrefix+"ROLE", "env-role")
os.Setenv(envPrefix+"THREAD", "env-thread") os.Setenv(envPrefix+"THREAD", "env-thread")
os.Setenv(envPrefix+"TEMPERATURE", "2.2") os.Setenv(envPrefix+"TEMPERATURE", "2.2")
@@ -194,6 +200,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.AuthTokenPrefix).To(Equal("env-auth-token-prefix")) Expect(subject.Config.AuthTokenPrefix).To(Equal("env-auth-token-prefix"))
Expect(subject.Config.OmitHistory).To(BeTrue()) Expect(subject.Config.OmitHistory).To(BeTrue())
Expect(subject.Config.AutoCreateNewThread).To(BeTrue()) Expect(subject.Config.AutoCreateNewThread).To(BeTrue())
Expect(subject.Config.TrackTokenUsage).To(BeTrue())
Expect(subject.Config.Role).To(Equal("env-role")) Expect(subject.Config.Role).To(Equal("env-role"))
Expect(subject.Config.Thread).To(Equal("env-thread")) Expect(subject.Config.Thread).To(Equal("env-thread"))
Expect(subject.Config.Temperature).To(Equal(2.2)) Expect(subject.Config.Temperature).To(Equal(2.2))
@@ -215,6 +222,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
os.Setenv(envPrefix+"AUTH_TOKEN_PREFIX", "env-auth-token-prefix") os.Setenv(envPrefix+"AUTH_TOKEN_PREFIX", "env-auth-token-prefix")
os.Setenv(envPrefix+"OMIT_HISTORY", "true") os.Setenv(envPrefix+"OMIT_HISTORY", "true")
os.Setenv(envPrefix+"AUTO_CREATE_NEW_THREAD", "true") os.Setenv(envPrefix+"AUTO_CREATE_NEW_THREAD", "true")
os.Setenv(envPrefix+"TRACK_TOKEN_USAGE", "true")
os.Setenv(envPrefix+"ROLE", "env-role") os.Setenv(envPrefix+"ROLE", "env-role")
os.Setenv(envPrefix+"THREAD", "env-thread") os.Setenv(envPrefix+"THREAD", "env-thread")
os.Setenv(envPrefix+"TEMPERATURE", "2.2") os.Setenv(envPrefix+"TEMPERATURE", "2.2")
@@ -235,6 +243,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
AuthTokenPrefix: "user-auth-token-prefix", AuthTokenPrefix: "user-auth-token-prefix",
OmitHistory: false, OmitHistory: false,
AutoCreateNewThread: false, AutoCreateNewThread: false,
TrackTokenUsage: false,
Role: "user-role", Role: "user-role",
Thread: "user-thread", Thread: "user-thread",
Temperature: 1.5, Temperature: 1.5,
@@ -260,6 +269,7 @@ func testConfig(t *testing.T, when spec.G, it spec.S) {
Expect(subject.Config.AuthTokenPrefix).To(Equal("env-auth-token-prefix")) Expect(subject.Config.AuthTokenPrefix).To(Equal("env-auth-token-prefix"))
Expect(subject.Config.OmitHistory).To(BeTrue()) Expect(subject.Config.OmitHistory).To(BeTrue())
Expect(subject.Config.AutoCreateNewThread).To(BeTrue()) Expect(subject.Config.AutoCreateNewThread).To(BeTrue())
Expect(subject.Config.TrackTokenUsage).To(BeTrue())
Expect(subject.Config.Role).To(Equal("env-role")) Expect(subject.Config.Role).To(Equal("env-role"))
Expect(subject.Config.Thread).To(Equal("env-thread")) Expect(subject.Config.Thread).To(Equal("env-thread"))
Expect(subject.Config.Temperature).To(Equal(2.2)) Expect(subject.Config.Temperature).To(Equal(2.2))
@@ -398,7 +408,28 @@ func setValue(config *types.Config, fieldName string, value interface{}) {
} }
func unsetEnvironmentVariables(envPrefix string) { func unsetEnvironmentVariables(envPrefix string) {
variables := []string{"API_KEY", "MODEL", "MAX_TOKENS", "CONTEXT_WINDOW", "URL", "COMPLETIONS_PATH", "MODELS_PATH", "AUTH_HEADER", "AUTH_TOKEN_PREFIX", "OMIT_HISTORY", "ROLE", "THREAD", "TEMPERATURE", "TOP_P", "FREQUENCY_PENALTY", "PRESENCE_PENALTY", "COMMAND_PROMPT", "AUTO_CREATE_NEW_THREAD"} variables := []string{
"API_KEY",
"MODEL",
"MAX_TOKENS",
"CONTEXT_WINDOW",
"URL",
"COMPLETIONS_PATH",
"MODELS_PATH",
"AUTH_HEADER",
"AUTH_TOKEN_PREFIX",
"OMIT_HISTORY",
"ROLE",
"THREAD",
"TEMPERATURE",
"TOP_P",
"FREQUENCY_PENALTY",
"PRESENCE_PENALTY",
"COMMAND_PROMPT",
"AUTO_CREATE_NEW_THREAD",
"TRACK_TOKEN_USAGE",
}
for _, variable := range variables { for _, variable := range variables {
Expect(os.Unsetenv(envPrefix + variable)).To(Succeed()) Expect(os.Unsetenv(envPrefix + variable)).To(Succeed())
} }

View File

@@ -471,10 +471,20 @@ max_tokens: 100
}) })
it("should return the expected result for the --query flag", func() { it("should return the expected result for the --query flag", func() {
Expect(os.Setenv("OPENAI_TRACK_TOKEN_USAGE", "false")).To(Succeed())
output := runCommand("--query", "some-query") output := runCommand("--query", "some-query")
expectedResponse := `I don't have personal opinions about bars, but here are some popular bars in Red Hook, Brooklyn:` expectedResponse := `I don't have personal opinions about bars, but here are some popular bars in Red Hook, Brooklyn:`
Expect(output).To(ContainSubstring(expectedResponse)) Expect(output).To(ContainSubstring(expectedResponse))
Expect(output).NotTo(ContainSubstring("Token Usage:"))
})
it("should display token usage after a query when configured to do so", func() {
Expect(os.Setenv("OPENAI_TRACK_TOKEN_USAGE", "true")).To(Succeed())
output := runCommand("--query", "tell me a 5 line joke")
Expect(output).To(ContainSubstring("Token Usage:"))
}) })
it("should assemble http errors as expected", func() { it("should assemble http errors as expected", func() {

View File

@@ -20,4 +20,5 @@ type Config struct {
AuthTokenPrefix string `yaml:"auth_token_prefix"` AuthTokenPrefix string `yaml:"auth_token_prefix"`
CommandPrompt string `yaml:"command_prompt"` CommandPrompt string `yaml:"command_prompt"`
AutoCreateNewThread bool `yaml:"auto_create_new_thread"` AutoCreateNewThread bool `yaml:"auto_create_new_thread"`
TrackTokenUsage bool `yaml:"track_token_usage"`
} }