moving cli/server/shared to app, working on long tasks and more robust builds

This commit is contained in:
Dane Schneider
2024-01-22 15:33:26 -08:00
parent a320c4aac9
commit f5f50d7c2c
161 changed files with 504 additions and 9586 deletions

View File

@@ -874,7 +874,11 @@ func (a *Api) RewindPlan(planId string, req shared.RewindPlanRequest) (*shared.R
}
func (a *Api) SignIn(req shared.SignInRequest, customHost string) (*shared.SessionResponse, *shared.ApiError) {
serverUrl := customHost + "/accounts/sign_in"
host := customHost
if host == "" {
host = cloudApiHost
}
serverUrl := host + "/accounts/sign_in"
reqBytes, err := json.Marshal(req)
if err != nil {
return nil, &shared.ApiError{Type: shared.ApiErrorTypeOther, Msg: fmt.Sprintf("error marshalling request: %v", err)}
@@ -902,7 +906,11 @@ func (a *Api) SignIn(req shared.SignInRequest, customHost string) (*shared.Sessi
}
func (a *Api) CreateAccount(req shared.CreateAccountRequest, customHost string) (*shared.SessionResponse, *shared.ApiError) {
serverUrl := customHost + "/accounts"
host := customHost
if host == "" {
host = cloudApiHost
}
serverUrl := host + "/accounts"
reqBytes, err := json.Marshal(req)
if err != nil {
return nil, &shared.ApiError{Type: shared.ApiErrorTypeOther, Msg: fmt.Sprintf("error marshalling request: %v", err)}
@@ -1206,7 +1214,11 @@ func (a *Api) DeleteInvite(inviteId string) *shared.ApiError {
}
func (a *Api) CreateEmailVerification(email, customHost, userId string) (*shared.CreateEmailVerificationResponse, *shared.ApiError) {
serverUrl := customHost + "/accounts/email_verifications"
host := customHost
if host == "" {
host = cloudApiHost
}
serverUrl := host + "/accounts/email_verifications"
req := shared.CreateEmailVerificationRequest{Email: email, UserId: userId}
reqBytes, err := json.Marshal(req)
if err != nil {

View File

@@ -170,7 +170,9 @@ func verifyEmail(email, host string) (bool, string, error) {
return false, "", fmt.Errorf("error creating email verification: %v", apiErr.Msg)
}
pin, err := term.GetUserStringInput("You'll now receive a 6 character pin by email. Please enter it here:")
fmt.Println("✉️ You'll now receive a 6 character pin by email")
pin, err := term.GetUserPasswordInput("Please enter it here:")
if err != nil {
return false, "", fmt.Errorf("error prompting pin: %v", err)

View File

@@ -83,11 +83,12 @@ func createOrg() (*shared.Org, error) {
}
func promptAutoAddUsersIfValid(email string) (bool, error) {
userDomain := strings.Split(Current.Email, "@")[1]
userDomain := strings.Split(email, "@")[1]
var autoAddDomainUsers bool
var err error
if !shared.IsEmailServiceDomain(userDomain) {
autoAddDomainUsers, err = term.ConfirmYesNo(fmt.Sprintf("Do you want to allow any user with an email ending in @%s to auto-join this org?", userDomain))
fmt.Println("With domain auto-join, you can allow any user with an email ending in @", userDomain, "to auto-join this org")
autoAddDomainUsers, err = term.ConfirmYesNo(fmt.Sprintf("Enable auto-join for %s?", userDomain))
if err != nil {
return false, err

View File

@@ -65,10 +65,11 @@ func (m *changesUIModel) right() {
func (m *changesUIModel) up() {
if m.selectedReplacementIndex > 0 {
// log.Println("up")
m.selectedReplacementIndex--
m.setSelectionInfo()
m.updateMainView(true)
}
m.setSelectionInfo()
m.updateMainView(true)
}
func (m *changesUIModel) down() {
@@ -84,10 +85,11 @@ func (m *changesUIModel) down() {
}
if m.selectedReplacementIndex < max {
// log.Println("down")
m.selectedReplacementIndex++
m.setSelectionInfo()
m.updateMainView(true)
}
m.setSelectionInfo()
m.updateMainView(true)
}
func (m *changesUIModel) scrollUp() {
@@ -148,27 +150,41 @@ func (m *changesUIModel) windowResized(w, h int) {
}
func (m *changesUIModel) updateMainView(scrollReplacement bool) {
// log.Println("updateMainView")
// var updateMsg types.ChangesUIViewportsUpdate
if m.selectedFullFile() {
context := m.currentPlan.ContextsByPath[m.selectionInfo.currentPath]
var originalFile string
if context != nil {
originalFile = context.Body
}
updatedFile := m.currentPlan.CurrentPlanFiles.Files[m.selectionInfo.currentPath]
wrapWidth := m.fileViewport.Width - 2
replacements := m.selectionInfo.currentReplacements
fileSegments := []string{}
replacementSegments := map[int]bool{}
lastReplacementIdx := 0
for _, rep := range replacements {
idx := strings.Index(updatedFile, rep.New)
if idx == -1 || idx < lastReplacementIdx {
continue
}
fileSegments = append(fileSegments, updatedFile[lastReplacementIdx:idx])
fileSegments = append(fileSegments, rep.New)
replacementSegments[len(fileSegments)-1] = true
lastReplacementIdx = idx + len(rep.New)
if context == nil {
// the file is new, so all lines are new and should be highlighted
fileSegments = append(fileSegments, updatedFile)
replacementSegments[0] = true
} else {
lastFoundIdx := 0
updatedLines := strings.Split(updatedFile, "\n")
for i, line := range updatedLines {
fileSegments = append(fileSegments, line+"\n")
originalIdx := strings.Index(originalFile, line)
if originalIdx == -1 || originalIdx < lastFoundIdx {
replacementSegments[i] = true
} else {
lastFoundIdx = originalIdx + len(line)
replacementSegments[i] = false
}
}
}
fileSegments = append(fileSegments, updatedFile[lastReplacementIdx:])
for i, segment := range fileSegments {
wrapped := wrap.String(segment, wrapWidth)
@@ -184,18 +200,19 @@ func (m *changesUIModel) updateMainView(scrollReplacement bool) {
}
m.fileViewport.SetContent(strings.Join(fileSegments, ""))
m.updateViewportSizes()
} else {
oldRes := m.getReplacementOldDisplay()
m.changeOldViewport.SetContent(oldRes.oldDisplay)
// log.Println("set old content")
newContent, newContentDisplay := m.getReplacementNewDisplay(oldRes.prependContent, oldRes.appendContent)
m.changeNewViewport.SetContent(newContentDisplay)
m.updateViewportSizes()
// log.Println("set new content")
if scrollReplacement {
m.scrollReplacementIntoView(oldRes.old, newContent, oldRes.numLinesPrepended)
}
}
m.updateViewportSizes()
}

View File

@@ -7,6 +7,9 @@ import (
)
func (m changesUIModel) renderMainView() string {
// log.Println()
// log.Println("renderMainView")
mainViewHeader := m.renderMainViewHeader()
if m.selectedFullFile() {
@@ -31,9 +34,9 @@ func (m changesUIModel) renderMainView() string {
oldViews := []string{oldView}
newViews := []string{newView}
if m.oldScrollable() && m.selectedViewport == 0 {
if m.oldScrollable() && (m.selectedViewport == 0 || !m.newScrollable()) {
oldViews = append(oldViews, m.renderScrollFooter())
} else if m.newScrollable() {
} else if m.newScrollable() && (m.selectedViewport == 1 || !m.oldScrollable()) {
newViews = append(newViews, m.renderScrollFooter())
}

View File

@@ -78,6 +78,6 @@ func (m changesUIModel) renderPathTabs() string {
}
style := lipgloss.NewStyle().BorderStyle(lipgloss.NormalBorder()).BorderBottom(true).BorderForeground(lipgloss.Color(borderColor)).Width(m.width)
return style.Render(strings.Join(resRows, "\n"))
tabs := strings.Join(resRows, "\n")
return style.Render(tabs)
}

View File

@@ -26,15 +26,18 @@ func (m *changesUIModel) setSelectionInfo() {
var pathReplacements []*shared.Replacement
for i, res := range results {
for j, rep := range res.Replacements {
pathReplacements = append(pathReplacements, rep)
for _, res := range results {
pathReplacements = append(pathReplacements, res.Replacements...)
}
flatIndex := i*len(res.Replacements) + j
if flatIndex == m.selectedReplacementIndex {
i := 0
for _, res := range results {
for _, rep := range res.Replacements {
if i == m.selectedReplacementIndex {
currentRes = res
currentRep = rep
}
i++
}
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/charmbracelet/lipgloss"
"github.com/fatih/color"
"github.com/plandex/plandex/shared"
)
func (m changesUIModel) renderSidebar() string {
@@ -22,48 +23,51 @@ func (m changesUIModel) renderSidebar() string {
anyApplied := false
anyReplacements := false
var replacements []*shared.Replacement
for _, result := range results {
replacements = append(replacements, result.Replacements...)
}
// Change entries
for i, result := range results {
for j, rep := range result.Replacements {
anyReplacements = true
flatIndex := i*len(result.Replacements) + j
selected := currentRep != nil && rep.Id == currentRep.Id
s := ""
for i, rep := range replacements {
anyReplacements = true
selected := currentRep != nil && rep.Id == currentRep.Id
s := ""
fgColor := color.FgHiGreen
bgColor := color.BgGreen
if rep.Failed {
fgColor = color.FgHiRed
bgColor = color.BgRed
anyFailed = true
} else if rep.RejectedAt != nil {
fgColor = color.FgWhite
bgColor = color.BgBlack
}
var icon string
if rep.RejectedAt != nil {
icon = "👎"
} else if rep.Failed {
icon = "🚫"
} else {
icon = "📝"
}
if !rep.Failed && rep.RejectedAt == nil {
anyApplied = true
}
if selected {
s += color.New(color.Bold, bgColor, color.FgHiWhite).Sprintf(" > %s %d ", icon, flatIndex+1)
} else {
s += color.New(fgColor).Sprintf(" - %s %d ", icon, flatIndex+1)
}
s += "\n"
sb.WriteString(s)
fgColor := color.FgHiGreen
bgColor := color.BgGreen
if rep.Failed {
fgColor = color.FgHiRed
bgColor = color.BgRed
anyFailed = true
} else if rep.RejectedAt != nil {
fgColor = color.FgWhite
bgColor = color.BgBlack
}
var icon string
if rep.RejectedAt != nil {
icon = "👎"
} else if rep.Failed {
icon = "🚫"
} else {
icon = "📝"
}
if !rep.Failed && rep.RejectedAt == nil {
anyApplied = true
}
if selected {
s += color.New(color.Bold, bgColor, color.FgHiWhite).Sprintf(" > %s %d ", icon, i+1)
} else {
s += color.New(fgColor).Sprintf(" - %s %d ", icon, i+1)
}
s += "\n"
sb.WriteString(s)
}
if !anyReplacements {

View File

@@ -1,6 +1,7 @@
package changes_tui
import (
"plandex/types"
"time"
bubbleKey "github.com/charmbracelet/bubbles/key"
@@ -19,6 +20,12 @@ func (m changesUIModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case tea.WindowSizeMsg:
m.windowResized(msg.Width, msg.Height)
case types.ChangesUIViewportsUpdate:
m.updateViewportSizes()
if msg.ScrollReplacement != nil {
m.scrollReplacementIntoView(msg.ScrollReplacement.OldContent, msg.ScrollReplacement.NewContent, msg.ScrollReplacement.NumLinesPrepended)
}
case tea.KeyMsg:
switch {

View File

@@ -25,11 +25,13 @@ func (m changesUIModel) View() string {
layout := lipgloss.JoinHorizontal(lipgloss.Top, sidebar, mainView)
return lipgloss.JoinVertical(lipgloss.Left,
view := lipgloss.JoinVertical(lipgloss.Left,
tabs,
layout,
help,
)
return view
}
func (m changesUIModel) getMainViewDims() (int, int) {
@@ -62,6 +64,10 @@ func (m *changesUIModel) initViewports() {
}
func (m *changesUIModel) updateViewportSizes() {
// log.Println()
// log.Println()
// log.Println("updateViewportSizes")
mainViewWidth, mainViewHeight := m.getMainViewDims()
if m.selectedFullFile() {
@@ -76,27 +82,48 @@ func (m *changesUIModel) updateViewportSizes() {
m.fileViewport.Height = fileViewHeight
} else {
// log.Println("mainViewHeight", mainViewHeight)
mainViewHeight := mainViewHeight
oldViewHeight := mainViewHeight
newViewHeight := mainViewHeight
// set widths and reset heights
// log.Println("resetting widths and heights")
m.resetViewportDims()
// log.Println("oldScrollable", m.oldScrollable())
// log.Println("newScrollable", m.newScrollable())
// log.Println("selectedViewport", m.selectedViewport)
if m.oldScrollable() && (m.selectedViewport == 0 || !m.newScrollable()) {
footerHeight := lipgloss.Height(m.renderScrollFooter())
oldViewHeight -= footerHeight
}
newViewHeight := mainViewHeight
if m.newScrollable() && (m.selectedViewport == 1 || !m.oldScrollable()) {
} else if m.newScrollable() && (m.selectedViewport == 1 || !m.oldScrollable()) {
footerHeight := lipgloss.Height(m.renderScrollFooter())
newViewHeight -= footerHeight
}
m.changeOldViewport.Width = mainViewWidth / 2
// log.Println("oldViewHeight", oldViewHeight)
// log.Println("newViewHeight", newViewHeight)
// set updated heights
m.changeOldViewport.Height = oldViewHeight
m.changeNewViewport.Width = mainViewWidth / 2
m.changeNewViewport.Height = newViewHeight
// log.Println("updated heights")
}
}
func (m *changesUIModel) resetViewportDims() {
mainViewWidth, mainViewHeight := m.getMainViewDims()
m.fileViewport.Width = mainViewWidth
m.fileViewport.Height = mainViewHeight
m.changeOldViewport.Width = mainViewWidth / 2
m.changeOldViewport.Height = mainViewHeight
m.changeNewViewport.Width = mainViewWidth / 2
m.changeNewViewport.Height = mainViewHeight
}
func (m changesUIModel) renderHelp() string {
help := " "
@@ -114,10 +141,18 @@ func (m changesUIModel) renderHelp() string {
}
func (m changesUIModel) oldScrollable() bool {
// log.Println("oldScrollable")
// log.Println("TotalLineCount", m.changeOldViewport.TotalLineCount())
// log.Println("VisibleLineCount", m.changeOldViewport.VisibleLineCount())
return m.changeOldViewport.TotalLineCount() > m.changeOldViewport.VisibleLineCount()
}
func (m changesUIModel) newScrollable() bool {
// log.Println("newScrollable")
// log.Println("TotalLineCount", m.changeNewViewport.TotalLineCount())
// log.Println("VisibleLineCount", m.changeNewViewport.VisibleLineCount())
return m.changeNewViewport.TotalLineCount() > m.changeNewViewport.VisibleLineCount()
}

View File

@@ -37,33 +37,44 @@ func next(cmd *cobra.Command, args []string) {
lib.MustResolveProject()
lib.MustCheckOutdatedContextWithOutput()
apiErr := api.Client.TellPlan(lib.CurrentPlanId, shared.TellPlanRequest{
Prompt: continuePrompt,
ConnectStream: !continueBg,
}, lib.OnStreamPlan)
if apiErr != nil {
if apiErr.Type == shared.ApiErrorTypeTrialMessagesExceeded {
fmt.Fprintf(os.Stderr, "🚨 You've reached the free trial limit of %d messages per plan\n", apiErr.TrialMessagesExceededError.MaxMessages)
var fn func() bool
fn = func() bool {
apiErr := api.Client.TellPlan(lib.CurrentPlanId, shared.TellPlanRequest{
Prompt: continuePrompt,
ConnectStream: !continueBg,
}, lib.OnStreamPlan)
if apiErr != nil {
if apiErr.Type == shared.ApiErrorTypeTrialMessagesExceeded {
fmt.Fprintf(os.Stderr, "🚨 You've reached the free trial limit of %d messages per plan\n", apiErr.TrialMessagesExceededError.MaxMessages)
res, err := term.ConfirmYesNo("Upgrade now?")
res, err := term.ConfirmYesNo("Upgrade now?")
if err != nil {
fmt.Fprintln(os.Stderr, "Error prompting upgrade trial:", err)
return
}
if res {
err := auth.ConvertTrial()
if err != nil {
fmt.Fprintln(os.Stderr, "Error converting trial:", err)
return
fmt.Fprintln(os.Stderr, "Error prompting upgrade trial:", err)
return false
}
if res {
err := auth.ConvertTrial()
if err != nil {
fmt.Fprintln(os.Stderr, "Error converting trial:", err)
return false
}
// retry action after converting trial
return fn()
}
return false
}
return
fmt.Fprintln(os.Stderr, "Prompt error:", apiErr.Msg)
return false
}
return true
}
fmt.Fprintln(os.Stderr, "Prompt error:", apiErr.Msg)
shouldContinue := fn()
if !shouldContinue {
return
}

View File

@@ -56,19 +56,19 @@ func convo(cmd *cobra.Command, args []string) {
}
// format as above but start with day of week
formattedTs := msg.CreatedAt.Local().Format("Mon Jan 2, 2006 | 3:04:05pm MST")
formattedTs := msg.CreatedAt.Local().Format("Mon Jan 2, 2006 | 3:04pm MST")
// if it's today then use 'Today' instead of the date
if msg.CreatedAt.Day() == time.Now().Day() {
formattedTs = msg.CreatedAt.Local().Format("Today | 3:04:05pm MST")
formattedTs = msg.CreatedAt.Local().Format("Today | 3:04pm MST")
}
// if it's yesterday then use 'Yesterday' instead of the date
if msg.CreatedAt.Day() == time.Now().AddDate(0, 0, -1).Day() {
formattedTs = msg.CreatedAt.Local().Format("Yesterday | 3:04:05pm MST")
formattedTs = msg.CreatedAt.Local().Format("Yesterday | 3:04pm MST")
}
header := fmt.Sprintf("#### %d | %s | %s | %d 🪙", i+1,
header := fmt.Sprintf("#### %d | %s | %s | %d 🪙 | ", i+1,
author, formattedTs, msg.Tokens)
convMarkdown = append(convMarkdown, header, msg.Message, "")
totalTokens += msg.Tokens
@@ -85,5 +85,5 @@ func convo(cmd *cobra.Command, args []string) {
term.GetDivisionLine() +
color.New(color.Bold, color.FgCyan).Sprint(" Conversation size →") + fmt.Sprintf(" %d 🪙", totalTokens) + "\n\n"
term.PageOutputReverse(output)
term.PageOutput(output)
}

View File

@@ -26,14 +26,9 @@ var rewindCmd = &cobra.Command{
Run: rewind,
}
var sha string
func init() {
// Add rewind command
RootCmd.AddCommand(rewindCmd)
// Add sha flag
rewindCmd.Flags().StringVar(&sha, "sha", "", "Specify a commit sha to rewind to")
}
func rewind(cmd *cobra.Command, args []string) {
@@ -45,11 +40,7 @@ func rewind(cmd *cobra.Command, args []string) {
return
}
// Check if either steps or sha is provided and not both
stepsOrSha := ""
if len(args) > 1 {
stepsOrSha = args[1]
}
stepsOrSha := args[0]
logsRes, err := api.Client.ListLogs(lib.CurrentPlanId)
@@ -63,13 +54,15 @@ func rewind(cmd *cobra.Command, args []string) {
log.Println("shas:", logsRes.Shas)
if steps, err := strconv.Atoi(stepsOrSha); err == nil && steps > 0 {
log.Println("steps:", steps)
// log.Println("steps:", steps)
// Rewind by the specified number of steps
targetSha = logsRes.Shas[steps]
} else if sha := stepsOrSha; sha != "" {
// log.Println("sha provided:", sha)
// Rewind to the specified Sha
targetSha = sha
} else if stepsOrSha == "" {
// log.Println("No steps or sha provided, rewinding by 1 step")
// Rewind by 1 step
targetSha = logsRes.Shas[1]
} else {
@@ -77,7 +70,7 @@ func rewind(cmd *cobra.Command, args []string) {
os.Exit(1)
}
log.Println("Rewinding to", targetSha)
// log.Println("Rewinding to", targetSha)
// Rewind to the target sha
rwRes, err := api.Client.RewindPlan(lib.CurrentPlanId, shared.RewindPlanRequest{Sha: targetSha})

View File

@@ -135,33 +135,43 @@ func tell(cmd *cobra.Command, args []string) {
}()
}
apiErr := api.Client.TellPlan(lib.CurrentPlanId, shared.TellPlanRequest{
Prompt: prompt,
ConnectStream: !tellBg,
}, lib.OnStreamPlan)
if apiErr != nil {
if apiErr.Type == shared.ApiErrorTypeTrialMessagesExceeded {
fmt.Fprintf(os.Stderr, "🚨 You've reached the free trial limit of %d messages per plan\n", apiErr.TrialMessagesExceededError.MaxMessages)
var fn func() bool
fn = func() bool {
apiErr := api.Client.TellPlan(lib.CurrentPlanId, shared.TellPlanRequest{
Prompt: prompt,
ConnectStream: !tellBg,
}, lib.OnStreamPlan)
if apiErr != nil {
if apiErr.Type == shared.ApiErrorTypeTrialMessagesExceeded {
fmt.Fprintf(os.Stderr, "\n🚨 You've reached the free trial limit of %d messages per plan\n", apiErr.TrialMessagesExceededError.MaxMessages)
res, err := term.ConfirmYesNo("Upgrade now?")
res, err := term.ConfirmYesNo("Upgrade now?")
if err != nil {
fmt.Fprintln(os.Stderr, "Error prompting upgrade trial:", err)
return
}
if res {
err := auth.ConvertTrial()
if err != nil {
fmt.Fprintln(os.Stderr, "Error converting trial:", err)
return
fmt.Fprintln(os.Stderr, "Error prompting upgrade trial:", err)
return false
}
if res {
err := auth.ConvertTrial()
if err != nil {
fmt.Fprintln(os.Stderr, "Error converting trial:", err)
return false
}
// retry action after converting trial
return fn()
}
return false
}
return
fmt.Fprintln(os.Stderr, "Prompt error:", apiErr.Msg)
return false
}
return true
}
fmt.Fprintln(os.Stderr, "Prompt error:", apiErr.Msg)
shouldContinue := fn()
if !shouldContinue {
return
}

View File

@@ -5,6 +5,7 @@ go 1.21.3
require (
github.com/atotto/clipboard v0.1.4
github.com/charmbracelet/lipgloss v0.9.1
github.com/davecgh/go-spew v1.1.1
github.com/fatih/color v1.16.0
github.com/looplab/fsm v1.0.1
github.com/muesli/reflow v0.3.0
@@ -62,4 +63,4 @@ require (
replace github.com/plandex/plandex/shared => ../shared
replace github.com/plandex-ai/survey/v2 => ../../survey
replace github.com/plandex-ai/survey/v2 => ../../../survey

View File

@@ -11,6 +11,7 @@ import (
"golang.org/x/term"
"github.com/cqroot/prompt"
"github.com/cqroot/prompt/input"
)
var CmdDesc = map[string][2]string{
@@ -115,6 +116,10 @@ func GetUserStringInput(msg string) (string, error) {
return prompt.New().Ask(msg).Input("")
}
func GetUserPasswordInput(msg string) (string, error) {
return prompt.New().Ask(msg).Input("", input.WithEchoMode(input.EchoPassword))
}
func GetUserKeyInput() (rune, error) {
if err := keyboard.Open(); err != nil {
return 0, fmt.Errorf("failed to open keyboard: %s", err)

View File

@@ -54,3 +54,13 @@ type StreamTUIUpdate struct {
Processing bool
PlanTokenCount *shared.PlanTokenCount
}
type ChangesUIScrollReplacement struct {
OldContent string
NewContent string
NumLinesPrepended int
}
type ChangesUIViewportsUpdate struct {
ScrollReplacement *ChangesUIScrollReplacement
}

View File

@@ -1,58 +0,0 @@
package model
import (
"context"
"encoding/json"
"fmt"
"plandex-server/model/prompts"
"github.com/sashabaranov/go-openai"
)
type PlanFinishedParams struct {
Conversation []openai.ChatCompletionMessage
}
type PlanFinishedRes struct {
Reasoning string `json:"reasoning"`
Finished bool `json:"finished"`
}
func PlanFinished(params PlanFinishedParams) (PlanFinishedRes, error) {
messages := []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: prompts.SysFinished,
},
}
messages = append(messages, params.Conversation...)
resp, err := Client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: PlanSummaryModel,
Messages: messages,
},
)
if err != nil {
fmt.Println("PlanFinished err:", err)
return PlanFinishedRes{}, err
}
if len(resp.Choices) == 0 {
return PlanFinishedRes{}, fmt.Errorf("no response from GPT")
}
content := resp.Choices[0].Message.Content
// Here, we assume that if the AI assistant says "finished", the plan is finished.
// You might want to adjust this according to your needs.
isFinished := content == "finished"
return PlanFinishedRes{
Reasoning: content,
Finished: isFinished,
}, nil
}

View File

@@ -16,6 +16,7 @@ var BuilderModel = mediumModel
var ShortSummaryModel = weakModel
var NameModel = weakModel
var CommitMsgModel = weakModel
var PlanExecStatusModel = mediumModel
func init() {
if os.Getenv("PLANNER_MODEL") != "" {

View File

@@ -33,14 +33,14 @@ func genPlanDescription(planId string, ctx context.Context) (*db.ConvoMessageDes
},
)
var descStrRes string
var desc shared.ConvoMessageDescription
if err != nil {
fmt.Printf("Error during plan description model call: %v\n", err)
return nil, err
}
var descStrRes string
var desc shared.ConvoMessageDescription
for _, choice := range descResp.Choices {
if choice.FinishReason == "function_call" &&
choice.Message.FunctionCall != nil &&

View File

@@ -0,0 +1,153 @@
package plan
import (
"context"
"encoding/json"
"fmt"
"plandex-server/model"
"plandex-server/model/prompts"
"github.com/sashabaranov/go-openai"
)
type PlanExecStatus struct {
NeedsInput bool `json:"needs_input"`
Finished bool `json:"finished"`
}
func ExecStatus(conversation []openai.ChatCompletionMessage, ctx context.Context) (*PlanExecStatus, error) {
var res PlanExecStatus
errCh := make(chan error, 2)
go func() {
needsInput, err := ExecStatusNeedsInput(conversation, ctx)
if err != nil {
errCh <- err
}
res.NeedsInput = needsInput
errCh <- nil
}()
go func() {
finished, err := ExecStatusIsFinished(conversation, ctx)
if err != nil {
errCh <- err
}
res.Finished = finished
errCh <- nil
}()
for i := 0; i < 2; i++ {
err := <-errCh
if err != nil {
return nil, err
}
}
return &res, nil
}
func ExecStatusIsFinished(conversation []openai.ChatCompletionMessage, ctx context.Context) (bool, error) {
messages := []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: prompts.GetExecStatusIsFinishedPrompt(conversation),
},
}
resp, err := model.Client.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: model.PlanExecStatusModel,
Functions: []openai.FunctionDefinition{prompts.PlanIsFinishedFn},
Messages: messages,
ResponseFormat: &openai.ChatCompletionResponseFormat{Type: "json_object"},
},
)
if err != nil {
fmt.Printf("Error during plan exec status check model call: %v\n", err)
return false, err
}
var strRes string
var res PlanExecStatus
for _, choice := range resp.Choices {
if choice.FinishReason == "function_call" &&
choice.Message.FunctionCall != nil &&
choice.Message.FunctionCall.Name == "planIsFinished" {
fnCall := choice.Message.FunctionCall
strRes = fnCall.Arguments
}
}
if strRes == "" {
fmt.Println("no planIsFinished function call found in response")
return false, err
}
byteRes := []byte(strRes)
err = json.Unmarshal(byteRes, &res)
if err != nil {
fmt.Printf("Error unmarshalling plan exec status response: %v\n", err)
return false, err
}
return res.Finished, nil
}
func ExecStatusNeedsInput(conversation []openai.ChatCompletionMessage, ctx context.Context) (bool, error) {
messages := []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: prompts.GetExecStatusNeedsInputPrompt(&conversation[len(conversation)-1]),
},
}
messages = append(messages, conversation...)
resp, err := model.Client.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: model.PlanExecStatusModel,
Functions: []openai.FunctionDefinition{prompts.PlanNeedsInputFn},
Messages: messages,
ResponseFormat: &openai.ChatCompletionResponseFormat{Type: "json_object"},
},
)
if err != nil {
fmt.Printf("Error during plan exec status check model call: %v\n", err)
return false, err
}
var strRes string
var res PlanExecStatus
for _, choice := range resp.Choices {
if choice.FinishReason == "function_call" &&
choice.Message.FunctionCall != nil &&
choice.Message.FunctionCall.Name == "planNeedsInput" {
fnCall := choice.Message.FunctionCall
strRes = fnCall.Arguments
}
}
if strRes == "" {
fmt.Println("no planNeedsInput function call found in response")
return false, err
}
byteRes := []byte(strRes)
err = json.Unmarshal(byteRes, &res)
if err != nil {
fmt.Printf("Error unmarshalling plan exec status response: %v\n", err)
return false, err
}
return res.NeedsInput, nil
}

View File

@@ -242,9 +242,6 @@ func execTellPlan(plan *db.Plan, auth *types.ServerAuth, req *shared.TellPlanReq
// token limit exceeded after adding conversation
// get summary for as much as the conversation as necessary to stay under the token limit
for _, s := range summaries {
log.Printf("summary: ")
spew.Dump(s)
timestamp := s.LatestConvoMessageCreatedAt.UnixNano() / int64(time.Millisecond)
tokens, ok := tokensUpToTimestamp[timestamp]

View File

@@ -55,7 +55,7 @@ const SysCreate = Identity + ` A plan is a set of files with an attached context
**Don't ask the user to take an action that you are able to do.** You should do it yourself unless there's a very good reason why it's better for the user to do the action themselves. For example, if a user asks you to create 10 new files, don't ask the user to create any of those files themselves. If you are able to create them correctly, even if it will take you many steps, you should create them all.
At the end of a plan, you can suggest additional iterations to make the plan better. You can also ask the user to load more files into context or give you more information if it would help you make a better plan.
At the end of your response, you can suggest additional iterations to make the plan better. You can also ask the user to load more files into context or give you more information if it would help you make a better plan. **If all tasks and subtasks have been completed, explicitly say "All tasks have been completed."**
Be aware that since the plan started, the context may have been updated. It may have been updated by the user implementing your suggestions, by the user implementing their own work, or by the user adding more files or information to context. Be sure to consider the current state of the context when continuing with the plan, and whether the plan needs to be updated to reflect the latest context. For example, if you are working on a plan that has been broken up into subtasks, and you've reached the point of implementing a particular subtask, first consider whether the subtask is still necessary looking at the files in context. If it has already been implemented or is no longer necessary, say so, revise the plan if needed, and move on. Otherwise, implement the subtask.
` +

View File

@@ -0,0 +1,60 @@
package prompts
import (
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
)
const SysExecStatusIsFinished = `You are an AI assistant that determines the execution status of a coding AI's plan for a programming task. Analyze the AI's latest message to determine whether the plan is finished.
The plan is finished if all the plan's tasks and subtasks have been completed. When a plan is finished, the coding AI will say something like "All tasks have been completed." If the response is a list of tasks, then the plan is not finished. If the response is a list of tasks and a message saying that all tasks have been completed, then the plan is finished.
Return a JSON object with the 'finished' key set to true or false. Only call the 'planIsFinished' function in your response. Don't call any other function.`
func GetExecStatusIsFinishedPrompt(conversation []openai.ChatCompletionMessage) string {
s := ""
for _, m := range conversation {
s += m.Role + ":\n"
s += m.Content + "\n"
}
return SysExecStatusIsFinished + "\n\nConversation:\n" + s
}
var PlanIsFinishedFn = openai.FunctionDefinition{
Name: "planIsFinished",
Parameters: &jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"finished": {
Type: jsonschema.Boolean,
Description: "Whether the plan is finished.",
},
},
Required: []string{"finished"},
},
}
const SysExecStatusNeedsInput = `You are an AI assistant that determines the execution status of a coding AI's plan for a programming task. Analyze the AI's latest message to determine whether the plan needs more input. The plan needs more input if the coding AI requires the user to add more context, provide information, or answer questions the AI has asked.
When the coding AI needs more input, it will say something like "I need more information or context to make a plan for this task."
If the coding AI says or implies that additional information would be helpful or useful, but that information isn't *required* to continue the plan, then the plan *does not* need more input. It only needs more input if the AI says or implies that more information is necessary and required to continue. Return a JSON object with the 'needs_input' key set to true or false. Only call the 'planNeedsInput' function in your response. Don't call any other function.`
func GetExecStatusNeedsInputPrompt(message *openai.ChatCompletionMessage) string {
return SysExecStatusNeedsInput + "\nLatest message from coding AI:\n" + message.Content
}
var PlanNeedsInputFn = openai.FunctionDefinition{
Name: "planNeedsInput",
Parameters: &jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"needs_input": {
Type: jsonschema.Boolean,
Description: "Whether the plan needs more input. If ambiguous or unclear, assume the plan does not need more input.",
},
},
Required: []string{"needs_input"},
},
}

View File

@@ -1,7 +0,0 @@
package prompts
const SysFinished = "You are an AI assistant that determines whether a software development plan has been finished or not. Analyze the conversation and decide whether all tasks have been completed."
func GetFinishedPrompt(conversation string) string {
return SysFinished + "\n\nConversation:\n" + conversation
}

View File

@@ -7,39 +7,43 @@ import (
"github.com/google/uuid"
)
type ActivePlan struct {
Prompt string
StreamCh chan string
StreamDoneCh chan error
Ctx context.Context
CancelFn context.CancelFunc
Contexts []*db.Context
ContextsByPath map[string]*db.Context
Content string
NumTokens int
PromptMessageNum int
type ActiveBuild struct {
AssistantMessageId string
Files []string
BuildBuffers map[string]string
BuiltFiles map[string]bool
subscriptions map[string]chan string
Error error
ErrorReason string
}
type ActivePlan struct {
Prompt string
StreamCh chan string
StreamDoneCh chan error
Ctx context.Context
CancelFn context.CancelFunc
Contexts []*db.Context
ContextsByPath map[string]*db.Context
Content string
NumTokens int
PromptMessageNum int
BuildQueuesByPath map[string]*[]ActiveBuild
subscriptions map[string]chan string
}
func NewActivePlan(prompt string) *ActivePlan {
ctx, cancel := context.WithCancel(context.Background())
active := ActivePlan{
Prompt: prompt,
StreamCh: make(chan string),
StreamDoneCh: make(chan error),
Ctx: ctx,
CancelFn: cancel,
Files: []string{},
BuiltFiles: map[string]bool{},
Contexts: []*db.Context{},
ContextsByPath: map[string]*db.Context{},
BuildBuffers: map[string]string{},
subscriptions: map[string]chan string{},
Prompt: prompt,
StreamCh: make(chan string),
StreamDoneCh: make(chan error),
Ctx: ctx,
CancelFn: cancel,
BuildQueuesByPath: map[string]*[]ActiveBuild{},
Contexts: []*db.Context{},
ContextsByPath: map[string]*db.Context{},
subscriptions: map[string]chan string{},
}
go func() {
@@ -58,8 +62,8 @@ func NewActivePlan(prompt string) *ActivePlan {
return &active
}
func (ap *ActivePlan) BuildFinished() bool {
return len(ap.Files) == len(ap.BuiltFiles)
func (b *ActiveBuild) BuildFinished() bool {
return len(b.Files) == len(b.BuiltFiles)
}
func (ap *ActivePlan) Subscribe() (string, chan string) {

View File

@@ -2,7 +2,6 @@ package shared
import (
"fmt"
"log"
"sort"
"strings"
"time"
@@ -141,7 +140,7 @@ func (planState *CurrentPlanState) GetFilesBeforeReplacement(
_, hasPath := planRes.FileResultsByPath[contextPart.FilePath]
log.Printf("hasPath: %v\n", hasPath)
// log.Printf("hasPath: %v\n", hasPath)
if hasPath {
files[contextPart.FilePath] = contextPart.Body
@@ -152,8 +151,8 @@ func (planState *CurrentPlanState) GetFilesBeforeReplacement(
for path, planResults := range planRes.FileResultsByPath {
updated := files[path]
log.Printf("path: %s\n", path)
log.Printf("updated: %s\n", updated)
// log.Printf("path: %s\n", path)
// log.Printf("updated: %s\n", updated)
PlanResLoop:
for _, planRes := range planResults {
@@ -161,8 +160,6 @@ func (planState *CurrentPlanState) GetFilesBeforeReplacement(
continue
}
log.Println("planRes:", planRes)
if len(planRes.Replacements) == 0 {
if updated != "" {
return nil, fmt.Errorf("plan updates out of order: %s", path)

View File

@@ -12,7 +12,6 @@ const EVENT_DESCRIBE = "describe"
const EVENT_BUILD = "build"
const EVENT_FINISH = "finish"
const EVENT_ABORT = "abort"
const EVENT_REVISE = "revise"
const EVENT_CANCEL = "cancel"
const EVENT_ERROR = "error"
@@ -34,7 +33,6 @@ func NewPlanStreamState() *fsm.FSM {
{Name: EVENT_FINISH, Src: []string{STATE_DESCRIBING, STATE_BUILDING}, Dst: STATE_FINISHED},
{Name: EVENT_ABORT, Src: []string{STATE_REPLYING, STATE_DESCRIBING, STATE_BUILDING},
Dst: STATE_ABORTED},
{Name: EVENT_REVISE, Src: []string{STATE_ABORTED}, Dst: STATE_REPLYING},
{Name: EVENT_CANCEL, Src: []string{STATE_ABORTED}, Dst: STATE_CANCELED},
{Name: EVENT_ERROR, Src: []string{STATE_REPLYING, STATE_DESCRIBING, STATE_BUILDING}},
},

2
dev.sh
View File

@@ -8,6 +8,8 @@ terminate() {
trap terminate SIGTERM SIGINT
cd app
(cd cli && ./dev.sh)
reflex -r '^(cli|shared)/.*\.(go|mod|sum)$' -- sh -c 'cd cli && ./dev.sh' &

View File

@@ -1,135 +0,0 @@
package db
import (
"crypto/sha256"
"database/sql"
"encoding/hex"
"fmt"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
)
const tokenExpirationDays = 90 // (free trial tokens don't expire)
func CreateAuthToken(userId string, isTrial bool, tx *sql.Tx) (token, id string, err error) {
uid := uuid.New()
bytes := uid[:]
hashBytes := sha256.Sum256(bytes)
hash := hex.EncodeToString(hashBytes[:])
err = tx.QueryRow("INSERT INTO auth_tokens (user_id, token_hash, is_trial) VALUES ($1, $2, $3) RETURNING id", userId, hash, isTrial).Scan(&id)
if err != nil {
return "", "", fmt.Errorf("error creating auth token: %v", err)
}
return uid.String(), id, nil
}
func ValidateAuthToken(token string) (*AuthToken, error) {
uid, err := uuid.Parse(token)
if err != nil {
return nil, errors.New("invalid token")
}
bytes := uid[:]
hashBytes := sha256.Sum256(bytes)
tokenHash := hex.EncodeToString(hashBytes[:])
var authToken AuthToken
// free trial tokens don't expire
err = Conn.Get(&authToken, "SELECT * FROM auth_tokens WHERE token_hash = $1 AND (created_at > $2 OR is_trial = TRUE) AND deleted_at IS NULL", tokenHash, time.Now().AddDate(0, 0, -tokenExpirationDays))
if err != nil {
if err == sql.ErrNoRows {
return nil, errors.New("invalid token")
}
return nil, fmt.Errorf("error validating token: %v", err)
}
return &authToken, nil
}
func CreateEmailVerification(email string, userId, pinHash string) error {
_, err := Conn.Exec("INSERT INTO email_verifications (email, pin_hash, user_id) VALUES ($1, $2, $3)", email, pinHash, userId)
if err != nil {
return fmt.Errorf("error creating email verification: %v", err)
}
return nil
}
// email verifications expire in 5 minutes
const emailVerificationExpirationMinutes = 5
func ValidateEmailVerification(email, pin string) (id string, err error) {
pinHashBytes := sha256.Sum256([]byte(pin))
pinHash := hex.EncodeToString(pinHashBytes[:])
var authTokenId string
query := `SELECT id, auth_token_id
FROM email_verifications
WHERE pin_hash = $1
AND email = $2
AND created_at > $3`
err = Conn.QueryRow(query, pinHash, email, time.Now().Add(-emailVerificationExpirationMinutes*time.Minute)).Scan(&id, &authTokenId)
if err != nil {
if err == sql.ErrNoRows {
return "", errors.New("invalid pin")
}
return "", fmt.Errorf("error validating email verification: %v", err)
}
if authTokenId != "" {
return "", errors.New("pin already verified")
}
return id, nil
}
func GetUserPermissions(userId, orgId string) ([]string, error) {
var permissions []string
query := `
SELECT p.name, p.resource_id
FROM permissions p
JOIN org_roles_permissions orp ON p.id = orp.permission_id
JOIN orgs_users ou ON orp.org_role_id = ou.org_role_id
WHERE ou.user_id = $1 AND ou.org_id = $2
`
rows, err := Conn.Query(query, userId, orgId)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var permission string
var resourceId sql.NullString
if err := rows.Scan(&permission, &resourceId); err != nil {
return nil, err
}
toAdd := permission
if resourceId.Valid {
toAdd = toAdd + "|" + resourceId.String
}
permissions = append(permissions, toAdd)
}
// Check for errors from iterating over rows.
if err = rows.Err(); err != nil {
return nil, err
}
return permissions, nil
}

View File

@@ -1,42 +0,0 @@
package db
import (
"fmt"
"time"
)
func StorePlanBuild(build *PlanBuild) error {
query := `INSERT INTO plan_builds (org_id, plan_id, convo_message_id) VALUES (:org_id, :plan_id, :convo_message_id) RETURNING id, created_at, updated_at`
row, err := Conn.NamedQuery(query, build)
if err != nil {
return fmt.Errorf("error storing plan build: %v", err)
}
defer row.Close()
if row.Next() {
var createdAt, updatedAt time.Time
var id string
if err := row.Scan(&id, &createdAt, &updatedAt); err != nil {
return fmt.Errorf("error storing plan build: %v", err)
}
build.Id = id
build.CreatedAt = createdAt
build.UpdatedAt = updatedAt
}
return nil
}
func SetBuildError(build *PlanBuild) error {
_, err := Conn.Exec("UPDATE plan_builds SET error = $1, error_path = $2 WHERE id = $3", build.Error, build.ErrorPath, build.Id)
if err != nil {
return fmt.Errorf("error setting build error: %v", err)
}
return nil
}

View File

@@ -1,152 +0,0 @@
package db
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"time"
"github.com/google/uuid"
)
func GetPlanContexts(orgId, planId string, includeBody bool) ([]*Context, error) {
var contexts []*Context
contextDir := getPlanContextDir(orgId, planId)
// get all context files
files, err := os.ReadDir(contextDir)
if err != nil {
return nil, fmt.Errorf("error reading context dir: %v", err)
}
errCh := make(chan error, len(files)/2)
contextCh := make(chan *Context, len(files)/2)
// read each context file
for _, file := range files {
if strings.HasSuffix(file.Name(), ".meta") {
go func(file os.DirEntry) {
context, err := GetContext(orgId, planId, strings.TrimSuffix(file.Name(), ".meta"), includeBody)
if err != nil {
errCh <- fmt.Errorf("error reading context file: %v", err)
return
}
contextCh <- context
}(file)
}
}
for i := 0; i < len(files)/2; i++ {
select {
case err := <-errCh:
return nil, fmt.Errorf("error reading context files: %v", err)
case context := <-contextCh:
contexts = append(contexts, context)
}
}
// sort contexts by CreatedAt
sort.Slice(contexts, func(i, j int) bool {
return contexts[i].CreatedAt.Before(contexts[j].CreatedAt)
})
return contexts, nil
}
func GetContext(orgId, planId, contextId string, includeBody bool) (*Context, error) {
contextDir := getPlanContextDir(orgId, planId)
// read the meta file
metaPath := filepath.Join(contextDir, contextId+".meta")
metaBytes, err := os.ReadFile(metaPath)
if err != nil {
return nil, fmt.Errorf("error reading context meta file: %v", err)
}
var context Context
err = json.Unmarshal(metaBytes, &context)
if err != nil {
return nil, fmt.Errorf("error unmarshalling context meta file: %v", err)
}
if includeBody {
// read the body file
bodyPath := filepath.Join(contextDir, strings.TrimSuffix(contextId, ".meta")+".body")
bodyBytes, err := os.ReadFile(bodyPath)
if err != nil {
return nil, fmt.Errorf("error reading context body file: %v", err)
}
context.Body = string(bodyBytes)
}
return &context, nil
}
func ContextRemove(contexts []*Context) error {
// remove files
numFiles := len(contexts) * 2
errCh := make(chan error, numFiles)
for _, context := range contexts {
contextDir := getPlanContextDir(context.OrgId, context.PlanId)
for _, ext := range []string{".meta", ".body"} {
go func(context *Context, dir, ext string) {
errCh <- os.Remove(filepath.Join(dir, context.Id+ext))
}(context, contextDir, ext)
}
}
for i := 0; i < numFiles; i++ {
err := <-errCh
if err != nil {
return fmt.Errorf("error removing context file: %v", err)
}
}
return nil
}
func StoreContext(context *Context) error {
contextDir := getPlanContextDir(context.OrgId, context.PlanId)
ts := time.Now().UTC()
if context.Id == "" {
context.Id = uuid.New().String()
context.CreatedAt = ts
}
context.UpdatedAt = ts
metaFilename := context.Id + ".meta"
metaPath := filepath.Join(contextDir, metaFilename)
bodyFilename := context.Id + ".body"
bodyPath := filepath.Join(contextDir, bodyFilename)
body := []byte(context.Body)
context.Body = ""
// Convert the ModelContextPart to JSON
data, err := json.MarshalIndent(context, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal context context: %v", err)
}
// Write the body to the file
if err = os.WriteFile(bodyPath, body, 0644); err != nil {
return fmt.Errorf("failed to write context body to file %s: %v", bodyPath, err)
}
// Write the meta data to the file
if err = os.WriteFile(metaPath, data, 0644); err != nil {
return fmt.Errorf("failed to write context meta to file %s: %v", metaPath, err)
}
return nil
}

View File

@@ -1,149 +0,0 @@
package db
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"time"
"github.com/google/uuid"
"github.com/sashabaranov/go-openai"
)
func GetPlanConvo(orgId, planId string) ([]*ConvoMessage, error) {
var convo []*ConvoMessage
convoDir := getPlanConversationDir(orgId, planId)
files, err := os.ReadDir(convoDir)
if err != nil {
return nil, fmt.Errorf("error reading convo dir: %v", err)
}
errCh := make(chan error, len(files))
convoCh := make(chan *ConvoMessage, len(files))
for _, file := range files {
go func(file os.DirEntry) {
bytes, err := os.ReadFile(filepath.Join(convoDir, file.Name()))
if err != nil {
errCh <- fmt.Errorf("error reading convo file: %v", err)
return
}
var convoMessage ConvoMessage
err = json.Unmarshal(bytes, &convoMessage)
if err != nil {
errCh <- fmt.Errorf("error unmarshalling convo file: %v", err)
return
}
convoCh <- &convoMessage
}(file)
}
for i := 0; i < len(files); i++ {
select {
case err := <-errCh:
return nil, fmt.Errorf("error reading convo files: %v", err)
case convoMessage := <-convoCh:
convo = append(convo, convoMessage)
}
}
sort.Slice(convo, func(i, j int) bool {
return convo[i].CreatedAt.Before(convo[j].CreatedAt)
})
return convo, nil
}
func StoreConvoMessage(message *ConvoMessage, commit bool) (string, error) {
convoDir := getPlanConversationDir(message.OrgId, message.PlanId)
id := uuid.New().String()
ts := time.Now().UTC()
message.Id = id
message.CreatedAt = ts
bytes, err := json.Marshal(message)
if err != nil {
return "", fmt.Errorf("error marshalling convo message: %v", err)
}
err = os.WriteFile(filepath.Join(convoDir, message.Id+".json"), bytes, os.ModePerm)
if err != nil {
return "", fmt.Errorf("error writing convo message: %v", err)
}
err = AddPlanConvoMessage(message.PlanId, message.Tokens)
if err != nil {
return "", fmt.Errorf("error adding convo tokens: %v", err)
}
var desc string
if message.Role == openai.ChatMessageRoleUser {
desc = "💬 User prompt"
// TODO: add user name
} else {
desc = "🤖 Plandex reply"
if message.Stopped {
desc += " | 🛑 stopped early"
}
}
msg := fmt.Sprintf("Message #%d | %s | %d 🪙", message.Num, desc, message.Tokens)
if commit {
err = GitAddAndCommit(message.OrgId, message.PlanId, msg)
if err != nil {
return "", fmt.Errorf("error committing convo message: %v", err)
}
}
return msg, nil
}
func GetPlanSummaries(planId string) ([]*ConvoSummary, error) {
var summaries []*ConvoSummary
err := Conn.Select(&summaries, "SELECT * FROM convo_summaries WHERE plan_id = $1 ORDER BY created_at", planId)
if err != nil {
return nil, fmt.Errorf("error getting plan summaries: %v", err)
}
return summaries, nil
}
func StoreSummary(summary *ConvoSummary) error {
query := "INSERT INTO convo_summaries (org_id, plan_id, latest_convo_message_id, latest_convo_message_created_at, summary, tokens, num_messages, created_at) VALUES (:org_id, :plan_id, :latest_convo_message_id, :latest_convo_message_created_at, :summary, :tokens, :num_messages, :created_at) RETURNING id, created_at"
row, err := Conn.NamedQuery(query, summary)
if err != nil {
return fmt.Errorf("error storing summary: %v", err)
}
defer row.Close()
if row.Next() {
var createdAt time.Time
var id string
if err := row.Scan(&id, &createdAt); err != nil {
return fmt.Errorf("error storing summary: %v", err)
}
summary.Id = id
summary.CreatedAt = createdAt
}
return nil
}

View File

@@ -1,85 +0,0 @@
package db
import (
"errors"
"fmt"
"log"
"os"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
)
var Conn *sqlx.DB
func Connect() error {
var err error
if os.Getenv("DATABASE_URL") == "" {
return errors.New("DATABASE_URL not set")
}
Conn, err = sqlx.Connect("postgres", os.Getenv("DATABASE_URL"))
if err != nil {
return err
}
return nil
}
func MigrationsUp() error {
if Conn == nil {
return errors.New("db not initialized")
}
driver, err := postgres.WithInstance(Conn.DB, &postgres.Config{})
if err != nil {
return fmt.Errorf("error creating postgres driver: %v", err)
}
m, err := migrate.NewWithDatabaseInstance(
"file://migrations",
"postgres", driver)
if err != nil {
return fmt.Errorf("error creating migration instance: %v", err)
}
// Uncomment below to reset migration state to a specific version after a failure
// if err := m.Force(2024011700); err != nil {
// return fmt.Errorf("error forcing migration version: %v", err)
// }
// Uncomment below to run down migrations in development (resets database)
// if os.Getenv("GOENV") == "development" {
// err = m.Down()
// if err != nil {
// if err == migrate.ErrNoChange {
// log.Println("no migrations to run down")
// } else {
// return fmt.Errorf("error running down migrations: %v", err)
// }
// }
// log.Println("ran down migrations - database was reset")
// }
err = m.Up()
if err != nil {
if err == migrate.ErrNoChange {
log.Println("migration state is up to date")
} else {
return fmt.Errorf("error running migrations: %v", err)
}
}
if err == nil {
log.Println("ran migrations successfully")
}
return nil
}

View File

@@ -1,78 +0,0 @@
package db
import (
"fmt"
"os"
"path/filepath"
)
var BaseDir string
func init() {
home, err := os.UserHomeDir()
if err != nil {
panic(fmt.Errorf("error getting user home dir: %v", err))
}
BaseDir = filepath.Join(home, "plandex-server")
}
func InitPlan(orgId, planId string) error {
dir := getPlanDir(orgId, planId)
err := os.MkdirAll(dir, os.ModePerm)
if err != nil {
return fmt.Errorf("error creating plan dir: %v", err)
}
for _, subdirFn := range [](func(orgId, planId string) string){
getPlanContextDir,
getPlanConversationDir,
getPlanResultsDir,
getPlanDescriptionsDir} {
err = os.MkdirAll(subdirFn(orgId, planId), os.ModePerm)
if err != nil {
return fmt.Errorf("error creating plan subdir: %v", err)
}
}
err = InitGitRepo(orgId, planId)
if err != nil {
return fmt.Errorf("error initializing git repo: %v", err)
}
return nil
}
func DeletePlanDir(orgId, planId string) error {
dir := getPlanDir(orgId, planId)
err := os.RemoveAll(dir)
if err != nil {
return fmt.Errorf("error deleting plan dir: %v", err)
}
return nil
}
func getPlanDir(orgId, planId string) string {
return filepath.Join(BaseDir, "orgs", orgId, "plans", planId)
}
func getPlanContextDir(orgId, planId string) string {
return filepath.Join(getPlanDir(orgId, planId), "context")
}
func getPlanConversationDir(orgId, planId string) string {
return filepath.Join(getPlanDir(orgId, planId), "conversation")
}
func getPlanResultsDir(orgId, planId string) string {
return filepath.Join(getPlanDir(orgId, planId), "results")
}
func getPlanDescriptionsDir(orgId, planId string) string {
return filepath.Join(getPlanDir(orgId, planId), "descriptions")
}

View File

@@ -1,204 +0,0 @@
package db
import (
"bytes"
"fmt"
"os/exec"
"strconv"
"strings"
"time"
"github.com/fatih/color"
)
func init() {
// ensure git is available
cmd := exec.Command("git", "--version")
if err := cmd.Run(); err != nil {
panic(fmt.Errorf("error running git --version: %v", err))
}
}
func InitGitRepo(orgId, planId string) error {
dir := getPlanDir(orgId, planId)
res, err := exec.Command("git", "init", dir).CombinedOutput()
if err != nil {
return fmt.Errorf("error initializing git repository for dir: %s, err: %v, output: %s", dir, err, string(res))
}
return nil
}
func GitAddAndCommit(orgId, planId, message string) error {
dir := getPlanDir(orgId, planId)
err := gitAdd(dir, ".")
if err != nil {
return fmt.Errorf("error adding files to git repository for dir: %s, err: %v", dir, err)
}
err = gitCommit(dir, message)
if err != nil {
return fmt.Errorf("error committing files to git repository for dir: %s, err: %v", dir, err)
}
return nil
}
func GitRewindToSha(orgId, planId, sha string) error {
dir := getPlanDir(orgId, planId)
err := gitRewindToSha(dir, sha)
if err != nil {
return fmt.Errorf("error rewinding git repository for dir: %s, err: %v", dir, err)
}
return nil
}
func GetGitCommitHistory(orgId, planId string) (body string, shas []string, err error) {
dir := getPlanDir(orgId, planId)
body, shas, err = getGitCommitHistory(dir)
if err != nil {
return "", nil, fmt.Errorf("error getting git history for dir: %s, err: %v", dir, err)
}
return body, shas, nil
}
func GetLatestCommit(orgId, planId string) (sha, body string, err error) {
dir := getPlanDir(orgId, planId)
sha, body, err = getLatestCommit(dir)
if err != nil {
return "", "", fmt.Errorf("error getting latest commit for dir: %s, err: %v", dir, err)
}
return sha, body, nil
}
func gitRewindToSha(repoDir, sha string) error {
res, err := exec.Command("git", "-C", repoDir, "reset", "--hard",
sha).CombinedOutput()
if err != nil {
return fmt.Errorf("error executing git reset for dir: %s, sha: %s, err: %v, output: %s", repoDir, sha, err, string(res))
}
return nil
}
func getLatestCommit(dir string) (sha, body string, err error) {
var out bytes.Buffer
cmd := exec.Command("git", "log", "--pretty=%h@@|@@%at@@|@@%B@>>>@")
cmd.Dir = dir
cmd.Stdout = &out
err = cmd.Run()
if err != nil {
return "", "", fmt.Errorf("error getting git history for dir: %s, err: %v",
dir, err)
}
// Process the log output to get it in the desired format.
history := processGitHistoryOutput(strings.TrimSpace(out.String()))
first := history[0]
sha = first[0]
body = first[1]
return sha, body, nil
}
func getGitCommitHistory(dir string) (body string, shas []string, err error) {
var out bytes.Buffer
cmd := exec.Command("git", "log", "--pretty=%h@@|@@%at@@|@@%B@>>>@")
cmd.Dir = dir
cmd.Stdout = &out
err = cmd.Run()
if err != nil {
return "", nil, fmt.Errorf("error getting git history for dir: %s, err: %v",
dir, err)
}
// Process the log output to get it in the desired format.
history := processGitHistoryOutput(strings.TrimSpace(out.String()))
var output []string
for _, el := range history {
shas = append(shas, el[0])
output = append(output, el[1])
}
return strings.Join(output, "\n\n"), shas, nil
}
// processGitHistoryOutput processes the raw output from the git log command and returns a formatted string.
func processGitHistoryOutput(raw string) [][2]string {
var history [][2]string
entries := strings.Split(raw, "@>>>@") // Split entries using the custom separator.
for _, entry := range entries {
// First clean up any leading/trailing whitespace or newlines from each entry.
entry = strings.TrimSpace(entry)
// Now split the cleaned entry into its parts.
parts := strings.Split(entry, "@@|@@")
if len(parts) == 3 {
sha := parts[0]
timestampStr := parts[1]
message := strings.TrimSpace(parts[2]) // Trim whitespace from message as well.
// Extract and format timestamp.
timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
if err != nil {
continue // Skip entries with invalid timestamps.
}
dt := time.Unix(timestamp, 0).Local()
formattedTs := dt.Format("Mon Jan 2, 2006 | 3:04:05pm MST")
if dt.Day() == time.Now().Day() {
formattedTs = dt.Format("Today | 3:04:05pm MST")
} else if dt.Day() == time.Now().AddDate(0, 0, -1).Day() {
formattedTs = dt.Format("Yesterday | 3:04:05pm MST")
}
// Prepare the header with colors.
headerColor := color.New(color.FgCyan, color.Bold)
dateColor := color.New(color.FgCyan)
// Combine sha, formatted timestamp, and message header into one string.
header := fmt.Sprintf("%s | %s", headerColor.Sprintf("📝 Update %s", sha), dateColor.Sprintf("%s", formattedTs))
// Combine header and message with a newline only if the message is not empty.
fullEntry := header
if message != "" {
fullEntry += "\n" + message
}
history = append(history, [2]string{sha, fullEntry})
}
}
return history
}
func gitAdd(repoDir, path string) error {
res, err := exec.Command("git", "-C", repoDir, "add", path).CombinedOutput()
if err != nil {
return fmt.Errorf("error adding files to git repository for dir: %s, err: %v, output: %s", repoDir, err, string(res))
}
return nil
}
func gitCommit(repoDir, commitMsg string) error {
res, err := exec.Command("git", "-C", repoDir, "commit", "-m", commitMsg).CombinedOutput()
if err != nil {
return fmt.Errorf("error committing files to git repository for dir: %s, err: %v, output: %s", repoDir, err, string(res))
}
return nil
}

View File

@@ -1,149 +0,0 @@
package db
import (
"database/sql"
"fmt"
"log"
"strings"
)
func CreateInvite(invite *Invite) error {
_, err := Conn.NamedExec(`INSERT INTO invites (id, org_id, email, name, inviter_id) VALUES (:id, :org_id, :email, :name, :inviter_id)`, invite)
if err != nil {
return fmt.Errorf("error creating invite: %v", err)
}
return nil
}
func GetInvite(id string) (*Invite, error) {
var invite Invite
err := Conn.Get(&invite, "SELECT * FROM invites WHERE id = $1", id)
if err != nil {
return nil, fmt.Errorf("error getting invite: %v", err)
}
return &invite, nil
}
func GetInviteForOrgUser(orgId, userId string) (*Invite, error) {
var invite Invite
err := Conn.Get(&invite, "SELECT * FROM invites WHERE org_id = $1 AND invitee_id = $2", orgId, userId)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("error getting invite: %v", err)
}
return &invite, nil
}
func ListPendingInvites(orgId string) ([]*Invite, error) {
var invites []*Invite
err := Conn.Select(&invites, "SELECT * FROM invites WHERE org_id = $1 AND accepted_at IS NULL", orgId)
if err != nil {
return nil, fmt.Errorf("error getting invites for org: %v", err)
}
return invites, nil
}
func ListAllInvites(orgId string) ([]*Invite, error) {
var invites []*Invite
err := Conn.Select(&invites, "SELECT * FROM invites WHERE org_id = $1", orgId)
if err != nil {
return nil, fmt.Errorf("error getting invites for org: %v", err)
}
return invites, nil
}
func ListAcceptedInvites(orgId string) ([]*Invite, error) {
var invites []*Invite
err := Conn.Select(&invites, "SELECT * FROM invites WHERE org_id = $1 AND accepted_at IS NOT NULL", orgId)
if err != nil {
return nil, fmt.Errorf("error getting invites for org: %v", err)
}
return invites, nil
}
func GetPendingInvitesForEmail(email string) ([]*Invite, error) {
email = strings.ToLower(email)
var invites []*Invite
err := Conn.Select(&invites, "SELECT * FROM invites WHERE email = $1 AND accepted_at IS NULL", email)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("error getting invites and org names for email: %v", err)
}
return invites, nil
}
func DeleteInvite(id string, tx *sql.Tx) error {
query := "DELETE FROM invites WHERE id = $1"
var err error
if tx == nil {
_, err = tx.Exec(query, id)
} else {
_, err = Conn.Exec(query, id)
}
if err != nil {
return fmt.Errorf("error deleting invite: %v", err)
}
return nil
}
func AcceptInvite(invite *Invite, inviteeId string) error {
// start a transaction
tx, err := Conn.Begin()
if err != nil {
return fmt.Errorf("error starting transaction: %v", err)
}
// Ensure that rollback is attempted in case of failure
defer func() {
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
log.Printf("transaction rollback error: %v\n", rbErr)
} else {
log.Println("transaction rolled back")
}
}
}()
_, err = tx.Exec(`UPDATE invites SET accepted_at = NOW(), invitee_id = $1 WHERE id = $2`, inviteeId, invite.Id)
if err != nil {
return fmt.Errorf("error accepting invite: %v", err)
}
// create org user
err = CreateOrgUser(invite.OrgId, invite.InviteeId, invite.OrgRoleId, tx)
if err != nil {
return fmt.Errorf("error creating org user: %v", err)
}
// commit transaction
err = tx.Commit()
if err != nil {
return fmt.Errorf("error committing transaction: %v", err)
}
invite.InviteeId = inviteeId
return nil
}

Some files were not shown because too many files have changed in this diff Show More