Add custom context parsing

This commit is contained in:
kardolus
2023-05-03 18:38:12 -04:00
parent dd791fbeb9
commit c5426d68af
4 changed files with 96 additions and 25 deletions

View File

@@ -13,6 +13,8 @@ environment, demonstrating its practicality and effectiveness.
history. history.
* Sliding window history management: Automatically trims conversation history while maintaining context to stay within * Sliding window history management: Automatically trims conversation history while maintaining context to stay within
token limits, enabling seamless and efficient conversations with the GPT model across CLI calls. token limits, enabling seamless and efficient conversations with the GPT model across CLI calls.
* Custom context from local files: Easily provide custom context through piping, allowing the GPT model to reference
specific data during the conversation.
* Viper integration for robust configuration management. * Viper integration for robust configuration management.
## Development ## Development
@@ -69,10 +71,6 @@ For more options, see:
./bin/chatgpt --help ./bin/chatgpt --help
``` ```
## Coming Soon
Enable piping custom context for seamless interaction with the ChatGPT API.
## Useful Links ## Useful Links
* [ChatGPT API Documentation](https://platform.openai.com/docs/introduction/overview) * [ChatGPT API Documentation](https://platform.openai.com/docs/introduction/overview)

View File

@@ -23,9 +23,9 @@ const (
) )
type Client struct { type Client struct {
History []types.Message
caller http.Caller caller http.Caller
readWriter history.Store readWriter history.Store
history []types.Message
capacity int capacity int
} }
@@ -52,7 +52,8 @@ func NewDefault(caller http.Caller, rw history.Store) *Client {
// call using the Post method. If the response is not empty, it decodes the // call using the Post method. If the response is not empty, it decodes the
// response JSON and returns the content of the first choice. // response JSON and returns the content of the first choice.
func (c *Client) Query(input string) (string, error) { func (c *Client) Query(input string) (string, error) {
c.initHistory(input) c.initHistory()
c.addQuery(input)
body, err := c.createBody(false) body, err := c.createBody(false)
if err != nil { if err != nil {
@@ -88,7 +89,8 @@ func (c *Client) Query(input string) (string, error) {
// input and then makes an API call using the Post method. The actual // input and then makes an API call using the Post method. The actual
// processing of the streamed response is done in the Post method. // processing of the streamed response is done in the Post method.
func (c *Client) Stream(input string) error { func (c *Client) Stream(input string) error {
c.initHistory(input) c.initHistory()
c.addQuery(input)
body, err := c.createBody(true) body, err := c.createBody(true)
if err != nil { if err != nil {
@@ -105,36 +107,55 @@ func (c *Client) Stream(input string) error {
return nil return nil
} }
// ProvideContext adds custom context to the client's history by converting the
// provided string into a series of messages. This allows the ChatGPT API to have
// prior knowledge of the provided context when generating responses.
//
// The context string should contain the text you want to provide as context,
// and the method will split it into messages, preserving punctuation and special
// characters.
func (c *Client) ProvideContext(context string) {
c.initHistory()
messages := createMessagesFromString(context)
c.History = append(c.History, messages...)
}
func (c *Client) createBody(stream bool) ([]byte, error) { func (c *Client) createBody(stream bool) ([]byte, error) {
body := types.Request{ body := types.Request{
Model: GPTModel, Model: GPTModel,
Messages: c.history, Messages: c.History,
Stream: stream, Stream: stream,
} }
return json.Marshal(body) return json.Marshal(body)
} }
func (c *Client) initHistory(query string) { func (c *Client) initHistory() {
if len(c.History) != 0 {
return
}
c.History, _ = c.readWriter.Read()
if len(c.History) == 0 {
c.History = []types.Message{{
Role: SystemRole,
Content: AssistantContent,
}}
}
}
func (c *Client) addQuery(query string) {
message := types.Message{ message := types.Message{
Role: UserRole, Role: UserRole,
Content: query, Content: query,
} }
c.history, _ = c.readWriter.Read() c.History = append(c.History, message)
if len(c.history) == 0 {
c.history = []types.Message{{
Role: SystemRole,
Content: AssistantContent,
}}
}
c.history = append(c.history, message)
c.truncateHistory() c.truncateHistory()
} }
func (c *Client) truncateHistory() { func (c *Client) truncateHistory() {
tokens, rolling := countTokens(c.history) tokens, rolling := countTokens(c.History)
effectiveTokenSize := calculateEffectiveTokenSize(c.capacity, MaxTokenBufferPercentage) effectiveTokenSize := calculateEffectiveTokenSize(c.capacity, MaxTokenBufferPercentage)
if tokens <= effectiveTokenSize { if tokens <= effectiveTokenSize {
@@ -153,15 +174,15 @@ func (c *Client) truncateHistory() {
} }
} }
c.history = append(c.history[:1], c.history[index+1:]...) c.History = append(c.History[:1], c.History[index+1:]...)
} }
func (c *Client) updateHistory(response string) { func (c *Client) updateHistory(response string) {
c.history = append(c.history, types.Message{ c.History = append(c.History, types.Message{
Role: AssistantRole, Role: AssistantRole,
Content: response, Content: response,
}) })
_ = c.readWriter.Write(c.history) _ = c.readWriter.Write(c.History)
} }
func calculateEffectiveTokenSize(maxTokenSize int, bufferPercentage int) int { func calculateEffectiveTokenSize(maxTokenSize int, bufferPercentage int) int {
@@ -192,3 +213,25 @@ func countTokens(messages []types.Message) (int, []int) {
return result, rolling return result, rolling
} }
func createMessagesFromString(input string) []types.Message {
words := strings.Fields(input)
var messages []types.Message
for i := 0; i < len(words); i += 100 {
end := i + 100
if end > len(words) {
end = len(words)
}
content := strings.Join(words[i:end], " ")
message := types.Message{
Role: UserRole,
Content: content,
}
messages = append(messages, message)
}
return messages
}

View File

@@ -278,6 +278,23 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
}) })
}) })
}) })
when("ProvideContext()", func() {
it("updates the history with the provided context", func() {
context := "This is a story about a dog named Kya. Kya loves to play fetch and swim in the lake."
mockStore.EXPECT().Read().Return(nil, nil).Times(1)
subject.ProvideContext(context)
Expect(len(subject.History)).To(Equal(2)) // The system message and the provided context
systemMessage := subject.History[0]
Expect(systemMessage.Role).To(Equal(client.SystemRole))
Expect(systemMessage.Content).To(Equal("You are a helpful assistant."))
contextMessage := subject.History[1]
Expect(contextMessage.Role).To(Equal(client.UserRole))
Expect(contextMessage.Content).To(Equal(context))
})
})
} }
func createBody(messages []types.Message, stream bool) ([]byte, error) { func createBody(messages []types.Message, stream bool) ([]byte, error) {

View File

@@ -8,6 +8,7 @@ import (
"github.com/kardolus/chatgpt-cli/http" "github.com/kardolus/chatgpt-cli/http"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"io/ioutil"
"os" "os"
"strings" "strings"
) )
@@ -20,9 +21,11 @@ var clearHistory bool
func main() { func main() {
var rootCmd = &cobra.Command{ var rootCmd = &cobra.Command{
Use: "chatgpt", Use: "chatgpt",
Short: "ChatGPT Proof of Concept", Short: "ChatGPT CLI Tool",
Long: "A Proof of Concept for building ChatGPT clients.", Long: "A powerful ChatGPT client that enables seamless interactions with the GPT model. " +
RunE: run, "Provides multiple modes and context management features, including the ability to " +
"pipe custom context into the conversation.",
RunE: run,
} }
rootCmd.PersistentFlags().BoolVarP(&queryMode, "query", "q", false, "Use query mode instead of stream mode") rootCmd.PersistentFlags().BoolVarP(&queryMode, "query", "q", false, "Use query mode instead of stream mode")
@@ -60,6 +63,16 @@ func run(cmd *cobra.Command, args []string) error {
} }
client := client.NewDefault(http.New().WithSecret(secret), history.NewDefault()) client := client.NewDefault(http.New().WithSecret(secret), history.NewDefault())
// Check if there is input from the pipe (stdin)
stat, _ := os.Stdin.Stat()
if (stat.Mode() & os.ModeCharDevice) == 0 {
pipeContent, err := ioutil.ReadAll(os.Stdin)
if err != nil {
return fmt.Errorf("failed to read from pipe: %w", err)
}
client.ProvideContext(string(pipeContent))
}
if queryMode { if queryMode {
result, err := client.Query(strings.Join(args, " ")) result, err := client.Query(strings.Join(args, " "))
if err != nil { if err != nil {