mirror of
https://github.com/charmbracelet/crush.git
synced 2025-08-02 05:20:46 +03:00
Merge branch 'main' into tools-bg
This commit is contained in:
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@@ -2,7 +2,7 @@ name: build
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
build-go:
|
||||
build:
|
||||
uses: charmbracelet/meta/.github/workflows/build.yml@main
|
||||
with:
|
||||
go-version: ""
|
||||
|
||||
1
.github/workflows/release.yml
vendored
1
.github/workflows/release.yml
vendored
@@ -20,3 +20,4 @@ jobs:
|
||||
fury_token: ${{ secrets.FURY_TOKEN }}
|
||||
nfpm_gpg_key: ${{ secrets.NFPM_GPG_KEY }}
|
||||
nfpm_passphrase: ${{ secrets.NFPM_PASSPHRASE }}
|
||||
npm_token: ${{ secrets.NPM_TOKEN }}
|
||||
|
||||
@@ -45,7 +45,7 @@ builds:
|
||||
goos:
|
||||
- linux
|
||||
- darwin
|
||||
# - windows
|
||||
- windows
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
@@ -111,8 +111,15 @@ homebrew_casks:
|
||||
owner: charmbracelet
|
||||
name: homebrew-tap
|
||||
|
||||
npms:
|
||||
- name: "@charmland/crush"
|
||||
repository: "git+https://github.com/charmbracelet/crush.git"
|
||||
bugs: https://github.com/charmbracelet/crush/issues
|
||||
access: public
|
||||
|
||||
nfpms:
|
||||
- formats:
|
||||
- apk
|
||||
- deb
|
||||
- rpm
|
||||
- archlinux
|
||||
|
||||
45
README.md
45
README.md
@@ -12,10 +12,12 @@ Crush is a tool for building software with AI.
|
||||
|
||||
## Installation
|
||||
|
||||
Crush has first class support for macOS, Linux, and Windows.
|
||||
|
||||
Nightly builds are available while Crush is in development.
|
||||
|
||||
- [Packages](https://github.com/charmbracelet/crush/releases/tag/nightly) are available in Debian and RPM formats
|
||||
- [Binaries](https://github.com/charmbracelet/crush/releases/tag/nightly) are available for Linux and macOS
|
||||
- [Packages](https://github.com/charmbracelet/crush/releases/tag/nightly) are available in Debian, RPM, APK, and PKG formats
|
||||
- [Binaries](https://github.com/charmbracelet/crush/releases/tag/nightly) are available for Linux, macOS and Windows
|
||||
|
||||
You can also just install it with go:
|
||||
|
||||
@@ -28,7 +30,7 @@ go install
|
||||
<details>
|
||||
<summary>Not a developer? Here’s a quick how-to.</summary>
|
||||
|
||||
Download the latest [nightly release](https://github.com/charmbracelet/crush/releases) for your system. The [macOS ARM64](https://github.com/charmbracelet/crush/releases/download/nightly/crush_0.1.0-nightly_Darwin_arm64.tar.gz) is most likely what you want.
|
||||
Download the latest [nightly release](https://github.com/charmbracelet/crush/releases) for your system. The [macOS ARM64 one](https://github.com/charmbracelet/crush/releases/download/nightly/crush_0.1.0-nightly_Darwin_arm64.tar.gz) is most likely what you want.
|
||||
|
||||
Next, open a terminal and run the following commands:
|
||||
|
||||
@@ -36,17 +38,15 @@ Next, open a terminal and run the following commands:
|
||||
cd ~/Downloads
|
||||
tar -xvzf crush_0.1.0-nightly_Darwin_arm64.tar.gz -C crush
|
||||
sudo mv ./crush/crush /usr/local/bin/crush
|
||||
rm -rf crush
|
||||
rm -rf ./crush
|
||||
```
|
||||
|
||||
Then, run Crush by typing `crush`.
|
||||
|
||||
***
|
||||
---
|
||||
|
||||
</details>
|
||||
|
||||
Note that Crush doesn't support Windows yet, however Windows support is planned and in progress.
|
||||
|
||||
## Getting Started
|
||||
|
||||
The quickest way to get started to grab an API key for your preferred
|
||||
@@ -108,7 +108,7 @@ Crush supports Model Context Protocol (MCP) servers through three transport type
|
||||
"mcp": {
|
||||
"filesystem": {
|
||||
"type": "stdio",
|
||||
"command": "node",
|
||||
"command": "node",
|
||||
"args": ["/path/to/mcp-server.js"],
|
||||
"env": {
|
||||
"NODE_ENV": "production"
|
||||
@@ -143,7 +143,7 @@ crush -d
|
||||
# View last 1000 lines
|
||||
crush logs
|
||||
|
||||
# Follow logs in real-time
|
||||
# Follow logs in real-time
|
||||
crush logs -f
|
||||
|
||||
# Show last 500 lines
|
||||
@@ -161,6 +161,31 @@ Add to your `crush.json` config file:
|
||||
}
|
||||
```
|
||||
|
||||
### Configurable Default Permissions
|
||||
|
||||
Crush includes a permission system to control which tools can be executed without prompting. You can configure allowed tools in your `crush.json` config file:
|
||||
|
||||
```json
|
||||
{
|
||||
"permissions": {
|
||||
"allowed_tools": [
|
||||
"view",
|
||||
"ls",
|
||||
"grep",
|
||||
"edit:write",
|
||||
"mcp_context7_get-library-doc"
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The `allowed_tools` array accepts:
|
||||
|
||||
- Tool names (e.g., `"view"`) - allows all actions for that tool
|
||||
- Tool:action combinations (e.g., `"edit:write"`) - allows only specific actions
|
||||
|
||||
You can also skip all permission prompts entirely by running Crush with the `--yolo` flag.
|
||||
|
||||
### OpenAI-Compatible APIs
|
||||
|
||||
Crush supports all OpenAI-compatible APIs. Here's an example configuration for Deepseek, which uses an OpenAI-compatible API. Don't forget to set `DEEPSEEK_API_KEY` in your environment.
|
||||
@@ -174,7 +199,7 @@ Crush supports all OpenAI-compatible APIs. Here's an example configuration for D
|
||||
"models": [
|
||||
{
|
||||
"id": "deepseek-chat",
|
||||
"model": "Deepseek V3",
|
||||
"name": "Deepseek V3",
|
||||
"cost_per_1m_in": 0.27,
|
||||
"cost_per_1m_out": 1.1,
|
||||
"cost_per_1m_in_cached": 0.07,
|
||||
|
||||
6
go.mod
6
go.mod
@@ -13,6 +13,7 @@ require (
|
||||
github.com/charlievieth/fastwalk v1.0.11
|
||||
github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250716191546-1e2ffbbcf5c5
|
||||
github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.4.0.20250717140350-bb75e8f6b6ac
|
||||
github.com/charmbracelet/catwalk v0.3.1
|
||||
github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674
|
||||
github.com/charmbracelet/glamour/v2 v2.0.0-20250516160903-6f1e2c8f9ebe
|
||||
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3.0.20250716211347-10c048e36112
|
||||
@@ -39,6 +40,7 @@ require (
|
||||
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/u-root/u-root v0.14.1-0.20250722142936-bf4e78a90dfc
|
||||
github.com/zeebo/xxh3 v1.0.2
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
@@ -70,7 +72,7 @@ require (
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.3.1 // indirect
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20250721205647-f6ac6eda5d42 // indirect
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20250723145313-809e6f5b43a1 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.14-0.20250516160309-24eee56f89fa // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20250611152503-f53cdd7e01ef
|
||||
github.com/charmbracelet/x/term v0.2.1
|
||||
@@ -108,7 +110,7 @@ require (
|
||||
github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
|
||||
github.com/sethvargo/go-retry v0.3.0 // indirect
|
||||
github.com/spf13/cast v1.7.1 // indirect
|
||||
github.com/spf13/pflag v1.0.6 // indirect
|
||||
github.com/spf13/pflag v1.0.7 // indirect
|
||||
github.com/tetratelabs/wazero v1.9.0 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
|
||||
11
go.sum
11
go.sum
@@ -72,6 +72,8 @@ github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250716191546-1e2ffbbcf5c5
|
||||
github.com/charmbracelet/bubbles/v2 v2.0.0-beta.1.0.20250716191546-1e2ffbbcf5c5/go.mod h1:6HamsBKWqEC/FVHuQMHgQL+knPyvHH55HwJDHl/adMw=
|
||||
github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.4.0.20250717140350-bb75e8f6b6ac h1:murtkvFYxZ/73vk4Z/tpE4biB+WDZcFmmBp8je/yV6M=
|
||||
github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.4.0.20250717140350-bb75e8f6b6ac/go.mod h1:m240IQxo1/eDQ7klblSzOCAUyc3LddHcV3Rc/YEGAgw=
|
||||
github.com/charmbracelet/catwalk v0.3.1 h1:MkGWspcMyE659zDkqS+9wsaCMTKRFEDBFY2A2sap6+U=
|
||||
github.com/charmbracelet/catwalk v0.3.1/go.mod h1:gUUCqqZ8bk4D7ZzGTu3I77k7cC2x4exRuJBN1H2u2pc=
|
||||
github.com/charmbracelet/colorprofile v0.3.1 h1:k8dTHMd7fgw4bnFd7jXTLZrSU/CQrKnL3m+AxCzDz40=
|
||||
github.com/charmbracelet/colorprofile v0.3.1/go.mod h1:/GkGusxNs8VB/RSOh3fu0TJmQ4ICMMPApIIVn0KszZ0=
|
||||
github.com/charmbracelet/fang v0.3.1-0.20250711140230-d5ebb8c1d674 h1:+Cz+VfxD5DO+JT1LlswXWhre0HYLj6l2HW8HVGfMuC0=
|
||||
@@ -82,8 +84,8 @@ github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3.0.20250716211347-10c048e36112
|
||||
github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.3.0.20250716211347-10c048e36112/go.mod h1:BXY7j7rZgAprFwzNcO698++5KTd6GKI6lU83Pr4o0r0=
|
||||
github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706 h1:WkwO6Ks3mSIGnGuSdKl9qDSyfbYK50z2wc2gGMggegE=
|
||||
github.com/charmbracelet/log/v2 v2.0.0-20250226163916-c379e29ff706/go.mod h1:mjJGp00cxcfvD5xdCa+bso251Jt4owrQvuimJtVmEmM=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20250721205647-f6ac6eda5d42 h1:Zqw2oP9Wo8VzMijVJbtIJcAaZviYyU07stvmCFCfn0Y=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20250721205647-f6ac6eda5d42/go.mod h1:XrrgNFfXLrFAyd9DUmrqVc3yQFVv8Uk+okj4PsNNzpc=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20250723145313-809e6f5b43a1 h1:tsw1mOuIEIKlmm614bXctvJ3aavaFhyPG+y+wrKtuKQ=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20250723145313-809e6f5b43a1/go.mod h1:XrrgNFfXLrFAyd9DUmrqVc3yQFVv8Uk+okj4PsNNzpc=
|
||||
github.com/charmbracelet/x/ansi v0.9.3 h1:BXt5DHS/MKF+LjuK4huWrC6NCvHtexww7dMayh6GXd0=
|
||||
github.com/charmbracelet/x/ansi v0.9.3/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.14-0.20250516160309-24eee56f89fa h1:lphz0Z3rsiOtMYiz8axkT24i9yFiueDhJbzyNUADmME=
|
||||
@@ -234,8 +236,9 @@ github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y=
|
||||
github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
|
||||
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
|
||||
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
|
||||
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/pflag v1.0.7 h1:vN6T9TfwStFPFM5XzjsvmzZkLuaLX+HS+0SeFLRgU6M=
|
||||
github.com/spf13/pflag v1.0.7/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c h1:km8GpoQut05eY3GiYWEedbTT0qnSxrCjsVbb7yKY1KE=
|
||||
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c/go.mod h1:cNQ3dwVJtS5Hmnjxy6AgTPd0Inb3pW05ftPSX7NZO7Q=
|
||||
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef h1:Ch6Q+AZUxDBCVqdkI8FSpFyZDtCVBc2VmejdNrm5rRQ=
|
||||
@@ -258,6 +261,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/u-root/u-root v0.14.1-0.20250722142936-bf4e78a90dfc h1:HjI/UCF4dRyzizePQrhGUSQvuU7z4tOqMqz6GRGlFCM=
|
||||
github.com/u-root/u-root v0.14.1-0.20250722142936-bf4e78a90dfc/go.mod h1:/0Qr7qJeDwWxoKku2xKQ4Szc+SwBE3g9VE8jNiamsmc=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
|
||||
25
internal/ansiext/ansi.go
Normal file
25
internal/ansiext/ansi.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package ansiext
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/x/ansi"
|
||||
)
|
||||
|
||||
// Escape replaces control characters with their Unicode Control Picture
|
||||
// representations to ensure they are displayed correctly in the UI.
|
||||
func Escape(content string) string {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(content))
|
||||
for _, r := range content {
|
||||
switch {
|
||||
case r >= 0 && r <= 0x1f: // Control characters 0x00-0x1F
|
||||
sb.WriteRune('\u2400' + r)
|
||||
case r == ansi.DEL:
|
||||
sb.WriteRune('\u2421')
|
||||
default:
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
@@ -59,13 +59,17 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
|
||||
sessions := session.NewService(q)
|
||||
messages := message.NewService(q)
|
||||
files := history.NewService(q, conn)
|
||||
skipPermissionsRequests := cfg.Options != nil && cfg.Options.SkipPermissionsRequests
|
||||
skipPermissionsRequests := cfg.Permissions != nil && cfg.Permissions.SkipRequests
|
||||
allowedTools := []string{}
|
||||
if cfg.Permissions != nil && cfg.Permissions.AllowedTools != nil {
|
||||
allowedTools = cfg.Permissions.AllowedTools
|
||||
}
|
||||
|
||||
app := &App{
|
||||
Sessions: sessions,
|
||||
Messages: messages,
|
||||
History: files,
|
||||
Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests),
|
||||
Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools),
|
||||
LSPClients: make(map[string]*lsp.Client),
|
||||
|
||||
globalCtx: ctx,
|
||||
@@ -157,16 +161,20 @@ func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) {
|
||||
slog.Info("Agent processing cancelled", "session_id", sess.ID)
|
||||
slog.Info("Non-interactive: agent processing cancelled", "session_id", sess.ID)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("agent processing failed: %w", result.Error)
|
||||
}
|
||||
|
||||
part := result.Message.Content().String()[readBts:]
|
||||
fmt.Println(part)
|
||||
msgContent := result.Message.Content().String()
|
||||
if len(msgContent) < readBts {
|
||||
slog.Error("Non-interactive: message content is shorter than read bytes", "message_length", len(msgContent), "read_bytes", readBts)
|
||||
return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(msgContent), readBts)
|
||||
}
|
||||
fmt.Println(msgContent[readBts:])
|
||||
|
||||
slog.Info("Non-interactive run completed", "session_id", sess.ID)
|
||||
slog.Info("Non-interactive: run completed", "session_id", sess.ID)
|
||||
return nil
|
||||
|
||||
case event := <-messageEvents:
|
||||
|
||||
@@ -73,7 +73,10 @@ to assist developers in writing, debugging, and understanding code directly from
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.Options.SkipPermissionsRequests = yolo
|
||||
if cfg.Permissions == nil {
|
||||
cfg.Permissions = &config.Permissions{}
|
||||
}
|
||||
cfg.Permissions.SkipRequests = yolo
|
||||
|
||||
ctx := cmd.Context()
|
||||
|
||||
@@ -85,14 +88,14 @@ to assist developers in writing, debugging, and understanding code directly from
|
||||
|
||||
app, err := app.New(ctx, conn, cfg)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("Failed to create app instance: %v", err))
|
||||
slog.Error("Failed to create app instance", "error", err)
|
||||
return err
|
||||
}
|
||||
defer app.Shutdown()
|
||||
|
||||
prompt, err = maybePrependStdin(prompt)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("Failed to read from stdin: %v", err))
|
||||
slog.Error("Failed to read from stdin", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -114,7 +117,7 @@ to assist developers in writing, debugging, and understanding code directly from
|
||||
go app.Subscribe(program)
|
||||
|
||||
if _, err := program.Run(); err != nil {
|
||||
slog.Error(fmt.Sprintf("TUI run error: %v", err))
|
||||
slog.Error("TUI run error", "error", err)
|
||||
return fmt.Errorf("TUI error: %v", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/csync"
|
||||
"github.com/charmbracelet/crush/internal/env"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/tidwall/sjson"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
@@ -71,7 +71,7 @@ type ProviderConfig struct {
|
||||
// The provider's API endpoint.
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
// The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
|
||||
Type provider.Type `json:"type,omitempty"`
|
||||
Type catwalk.Type `json:"type,omitempty"`
|
||||
// The provider's API key.
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
// Marks the provider as disabled.
|
||||
@@ -86,7 +86,7 @@ type ProviderConfig struct {
|
||||
ExtraParams map[string]string `json:"-"`
|
||||
|
||||
// The provider models
|
||||
Models []provider.Model `json:"models,omitempty"`
|
||||
Models []catwalk.Model `json:"models,omitempty"`
|
||||
}
|
||||
|
||||
type MCPType string
|
||||
@@ -121,14 +121,18 @@ type TUIOptions struct {
|
||||
// Here we can add themes later or any TUI related options
|
||||
}
|
||||
|
||||
type Permissions struct {
|
||||
AllowedTools []string `json:"allowed_tools,omitempty"` // Tools that don't require permission prompts
|
||||
SkipRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode)
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
ContextPaths []string `json:"context_paths,omitempty"`
|
||||
TUI *TUIOptions `json:"tui,omitempty"`
|
||||
Debug bool `json:"debug,omitempty"`
|
||||
DebugLSP bool `json:"debug_lsp,omitempty"`
|
||||
DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
|
||||
DataDirectory string `json:"data_directory,omitempty"` // Relative to the cwd
|
||||
SkipPermissionsRequests bool `json:"-"` // Automatically accept all permissions (YOLO mode)
|
||||
ContextPaths []string `json:"context_paths,omitempty"`
|
||||
TUI *TUIOptions `json:"tui,omitempty"`
|
||||
Debug bool `json:"debug,omitempty"`
|
||||
DebugLSP bool `json:"debug_lsp,omitempty"`
|
||||
DisableAutoSummarize bool `json:"disable_auto_summarize,omitempty"`
|
||||
DataDirectory string `json:"data_directory,omitempty"` // Relative to the cwd
|
||||
}
|
||||
|
||||
type MCPs map[string]MCPConfig
|
||||
@@ -245,14 +249,16 @@ type Config struct {
|
||||
|
||||
Options *Options `json:"options,omitempty"`
|
||||
|
||||
Permissions *Permissions `json:"permissions,omitempty"`
|
||||
|
||||
// Internal
|
||||
workingDir string `json:"-"`
|
||||
// TODO: most likely remove this concept when I come back to it
|
||||
Agents map[string]Agent `json:"-"`
|
||||
// TODO: find a better way to do this this should probably not be part of the config
|
||||
resolver VariableResolver
|
||||
dataConfigDir string `json:"-"`
|
||||
knownProviders []provider.Provider `json:"-"`
|
||||
dataConfigDir string `json:"-"`
|
||||
knownProviders []catwalk.Provider `json:"-"`
|
||||
}
|
||||
|
||||
func (c *Config) WorkingDir() string {
|
||||
@@ -274,7 +280,7 @@ func (c *Config) IsConfigured() bool {
|
||||
return len(c.EnabledProviders()) > 0
|
||||
}
|
||||
|
||||
func (c *Config) GetModel(provider, model string) *provider.Model {
|
||||
func (c *Config) GetModel(provider, model string) *catwalk.Model {
|
||||
if providerConfig, ok := c.Providers.Get(provider); ok {
|
||||
for _, m := range providerConfig.Models {
|
||||
if m.ID == model {
|
||||
@@ -296,7 +302,7 @@ func (c *Config) GetProviderForModel(modelType SelectedModelType) *ProviderConfi
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
|
||||
func (c *Config) GetModelByType(modelType SelectedModelType) *catwalk.Model {
|
||||
model, ok := c.Models[modelType]
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -304,7 +310,7 @@ func (c *Config) GetModelByType(modelType SelectedModelType) *provider.Model {
|
||||
return c.GetModel(model.Provider, model.Model)
|
||||
}
|
||||
|
||||
func (c *Config) LargeModel() *provider.Model {
|
||||
func (c *Config) LargeModel() *catwalk.Model {
|
||||
model, ok := c.Models[SelectedModelTypeLarge]
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -312,7 +318,7 @@ func (c *Config) LargeModel() *provider.Model {
|
||||
return c.GetModel(model.Provider, model.Model)
|
||||
}
|
||||
|
||||
func (c *Config) SmallModel() *provider.Model {
|
||||
func (c *Config) SmallModel() *catwalk.Model {
|
||||
model, ok := c.Models[SelectedModelTypeSmall]
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -378,7 +384,7 @@ func (c *Config) SetProviderAPIKey(providerID, apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var foundProvider *provider.Provider
|
||||
var foundProvider *catwalk.Provider
|
||||
for _, p := range c.knownProviders {
|
||||
if string(p.ID) == providerID {
|
||||
foundProvider = &p
|
||||
@@ -447,14 +453,14 @@ func (c *ProviderConfig) TestConnection(resolver VariableResolver) error {
|
||||
headers := make(map[string]string)
|
||||
apiKey, _ := resolver.ResolveValue(c.APIKey)
|
||||
switch c.Type {
|
||||
case provider.TypeOpenAI:
|
||||
case catwalk.TypeOpenAI:
|
||||
baseURL, _ := resolver.ResolveValue(c.BaseURL)
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
testURL = baseURL + "/models"
|
||||
headers["Authorization"] = "Bearer " + apiKey
|
||||
case provider.TypeAnthropic:
|
||||
case catwalk.TypeAnthropic:
|
||||
baseURL, _ := resolver.ResolveValue(c.BaseURL)
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.anthropic.com/v1"
|
||||
|
||||
@@ -12,12 +12,14 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/csync"
|
||||
"github.com/charmbracelet/crush/internal/env"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/log"
|
||||
)
|
||||
|
||||
const catwalkURL = "https://catwalk.charm.sh"
|
||||
|
||||
// LoadReader config via io.Reader.
|
||||
func LoadReader(fd io.Reader) (*Config, error) {
|
||||
data, err := io.ReadAll(fd)
|
||||
@@ -61,7 +63,7 @@ func Load(workingDir string, debug bool) (*Config, error) {
|
||||
cfg.Options.Debug,
|
||||
)
|
||||
|
||||
// Load known providers, this loads the config from fur
|
||||
// Load known providers, this loads the config from catwalk
|
||||
providers, err := Providers()
|
||||
if err != nil || len(providers) == 0 {
|
||||
return nil, fmt.Errorf("failed to load providers: %w", err)
|
||||
@@ -97,7 +99,7 @@ func (c *Config) removeUnresponsiveProviders() {
|
||||
slog.Info("Testing provider connections")
|
||||
defer slog.Info("Provider connection tests completed")
|
||||
for _, p := range c.Providers.Seq2() {
|
||||
if p.Type == provider.TypeOpenAI || p.Type == provider.TypeAnthropic {
|
||||
if p.Type == catwalk.TypeOpenAI || p.Type == catwalk.TypeAnthropic {
|
||||
wg.Add(1)
|
||||
go func(provider ProviderConfig) {
|
||||
defer wg.Done()
|
||||
@@ -122,7 +124,7 @@ func (c *Config) removeUnresponsiveProviders() {
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error {
|
||||
func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error {
|
||||
knownProviderNames := make(map[string]bool)
|
||||
for _, p := range knownProviders {
|
||||
knownProviderNames[string(p.ID)] = true
|
||||
@@ -141,7 +143,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
p.APIKey = config.APIKey
|
||||
}
|
||||
if len(config.Models) > 0 {
|
||||
models := []provider.Model{}
|
||||
models := []catwalk.Model{}
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, model := range config.Models {
|
||||
@@ -149,8 +151,8 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
continue
|
||||
}
|
||||
seen[model.ID] = true
|
||||
if model.Model == "" {
|
||||
model.Model = model.ID
|
||||
if model.Name == "" {
|
||||
model.Name = model.ID
|
||||
}
|
||||
models = append(models, model)
|
||||
}
|
||||
@@ -159,8 +161,8 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
continue
|
||||
}
|
||||
seen[model.ID] = true
|
||||
if model.Model == "" {
|
||||
model.Model = model.ID
|
||||
if model.Name == "" {
|
||||
model.Name = model.ID
|
||||
}
|
||||
models = append(models, model)
|
||||
}
|
||||
@@ -183,7 +185,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
|
||||
switch p.ID {
|
||||
// Handle specific providers that require additional configuration
|
||||
case provider.InferenceProviderVertexAI:
|
||||
case catwalk.InferenceProviderVertexAI:
|
||||
if !hasVertexCredentials(env) {
|
||||
if configExists {
|
||||
slog.Warn("Skipping Vertex AI provider due to missing credentials")
|
||||
@@ -193,7 +195,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
}
|
||||
prepared.ExtraParams["project"] = env.Get("GOOGLE_CLOUD_PROJECT")
|
||||
prepared.ExtraParams["location"] = env.Get("GOOGLE_CLOUD_LOCATION")
|
||||
case provider.InferenceProviderAzure:
|
||||
case catwalk.InferenceProviderAzure:
|
||||
endpoint, err := resolver.ResolveValue(p.APIEndpoint)
|
||||
if err != nil || endpoint == "" {
|
||||
if configExists {
|
||||
@@ -204,7 +206,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
}
|
||||
prepared.BaseURL = endpoint
|
||||
prepared.ExtraParams["apiVersion"] = env.Get("AZURE_OPENAI_API_VERSION")
|
||||
case provider.InferenceProviderBedrock:
|
||||
case catwalk.InferenceProviderBedrock:
|
||||
if !hasAWSCredentials(env) {
|
||||
if configExists {
|
||||
slog.Warn("Skipping Bedrock provider due to missing AWS credentials")
|
||||
@@ -244,7 +246,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
}
|
||||
// default to OpenAI if not set
|
||||
if providerConfig.Type == "" {
|
||||
providerConfig.Type = provider.TypeOpenAI
|
||||
providerConfig.Type = catwalk.TypeOpenAI
|
||||
}
|
||||
|
||||
if providerConfig.Disable {
|
||||
@@ -265,7 +267,7 @@ func (c *Config) configureProviders(env env.Env, resolver VariableResolver, know
|
||||
c.Providers.Del(id)
|
||||
continue
|
||||
}
|
||||
if providerConfig.Type != provider.TypeOpenAI {
|
||||
if providerConfig.Type != catwalk.TypeOpenAI {
|
||||
slog.Warn("Skipping custom provider because the provider type is not supported", "provider", id, "type", providerConfig.Type)
|
||||
c.Providers.Del(id)
|
||||
continue
|
||||
@@ -320,7 +322,7 @@ func (c *Config) setDefaults(workingDir string) {
|
||||
c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths)
|
||||
}
|
||||
|
||||
func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
|
||||
func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
|
||||
if len(knownProviders) == 0 && c.Providers.Len() == 0 {
|
||||
err = fmt.Errorf("no providers configured, please configure at least one provider")
|
||||
return
|
||||
@@ -389,7 +391,7 @@ func (c *Config) defaultModelSelection(knownProviders []provider.Provider) (larg
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Config) configureSelectedModels(knownProviders []provider.Provider) error {
|
||||
func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) error {
|
||||
defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to select default models: %w", err)
|
||||
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/csync"
|
||||
"github.com/charmbracelet/crush/internal/env"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -56,12 +56,12 @@ func TestConfig_setDefaults(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProviders(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$OPENAI_API_KEY",
|
||||
APIEndpoint: "https://api.openai.com/v1",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -83,12 +83,12 @@ func TestConfig_configureProviders(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersWithOverride(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$OPENAI_API_KEY",
|
||||
APIEndpoint: "https://api.openai.com/v1",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -100,10 +100,10 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
|
||||
cfg.Providers.Set("openai", ProviderConfig{
|
||||
APIKey: "xyz",
|
||||
BaseURL: "https://api.openai.com/v2",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "test-model",
|
||||
Model: "Updated",
|
||||
ID: "test-model",
|
||||
Name: "Updated",
|
||||
},
|
||||
{
|
||||
ID: "another-model",
|
||||
@@ -125,16 +125,16 @@ func TestConfig_configureProvidersWithOverride(t *testing.T) {
|
||||
assert.Equal(t, "xyz", pc.APIKey)
|
||||
assert.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
|
||||
assert.Len(t, pc.Models, 2)
|
||||
assert.Equal(t, "Updated", pc.Models[0].Model)
|
||||
assert.Equal(t, "Updated", pc.Models[0].Name)
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$OPENAI_API_KEY",
|
||||
APIEndpoint: "https://api.openai.com/v1",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -145,7 +145,7 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
|
||||
"custom": {
|
||||
APIKey: "xyz",
|
||||
BaseURL: "https://api.someendpoint.com/v2",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "test-model",
|
||||
},
|
||||
@@ -176,12 +176,12 @@ func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderBedrock,
|
||||
ID: catwalk.InferenceProviderBedrock,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
}},
|
||||
},
|
||||
@@ -205,12 +205,12 @@ func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderBedrock,
|
||||
ID: catwalk.InferenceProviderBedrock,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
}},
|
||||
},
|
||||
@@ -227,12 +227,12 @@ func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderBedrock,
|
||||
ID: catwalk.InferenceProviderBedrock,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "some-random-model",
|
||||
}},
|
||||
},
|
||||
@@ -250,12 +250,12 @@ func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderVertexAI,
|
||||
ID: catwalk.InferenceProviderVertexAI,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "gemini-pro",
|
||||
}},
|
||||
},
|
||||
@@ -282,12 +282,12 @@ func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderVertexAI,
|
||||
ID: catwalk.InferenceProviderVertexAI,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "gemini-pro",
|
||||
}},
|
||||
},
|
||||
@@ -308,12 +308,12 @@ func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderVertexAI,
|
||||
ID: catwalk.InferenceProviderVertexAI,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "gemini-pro",
|
||||
}},
|
||||
},
|
||||
@@ -333,12 +333,12 @@ func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersSetProviderID(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$OPENAI_API_KEY",
|
||||
APIEndpoint: "https://api.openai.com/v1",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -455,12 +455,12 @@ func TestConfig_IsConfigured(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$OPENAI_API_KEY",
|
||||
APIEndpoint: "https://api.openai.com/v1",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -494,7 +494,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
Providers: csync.NewMapFrom(map[string]ProviderConfig{
|
||||
"custom": {
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -507,7 +507,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
|
||||
env := env.NewFromMap(map[string]string{})
|
||||
resolver := NewEnvironmentVariableResolver(env)
|
||||
err := cfg.configureProviders(env, resolver, []provider.Provider{})
|
||||
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, cfg.Providers.Len(), 1)
|
||||
@@ -520,7 +520,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
Providers: csync.NewMapFrom(map[string]ProviderConfig{
|
||||
"custom": {
|
||||
APIKey: "test-key",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -530,7 +530,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
|
||||
env := env.NewFromMap(map[string]string{})
|
||||
resolver := NewEnvironmentVariableResolver(env)
|
||||
err := cfg.configureProviders(env, resolver, []provider.Provider{})
|
||||
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, cfg.Providers.Len(), 0)
|
||||
@@ -544,7 +544,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
"custom": {
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Models: []provider.Model{},
|
||||
Models: []catwalk.Model{},
|
||||
},
|
||||
}),
|
||||
}
|
||||
@@ -552,7 +552,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
|
||||
env := env.NewFromMap(map[string]string{})
|
||||
resolver := NewEnvironmentVariableResolver(env)
|
||||
err := cfg.configureProviders(env, resolver, []provider.Provider{})
|
||||
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, cfg.Providers.Len(), 0)
|
||||
@@ -567,7 +567,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Type: "unsupported",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -577,7 +577,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
|
||||
env := env.NewFromMap(map[string]string{})
|
||||
resolver := NewEnvironmentVariableResolver(env)
|
||||
err := cfg.configureProviders(env, resolver, []provider.Provider{})
|
||||
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, cfg.Providers.Len(), 0)
|
||||
@@ -591,8 +591,8 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
"custom": {
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Type: provider.TypeOpenAI,
|
||||
Models: []provider.Model{{
|
||||
Type: catwalk.TypeOpenAI,
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -602,7 +602,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
|
||||
env := env.NewFromMap(map[string]string{})
|
||||
resolver := NewEnvironmentVariableResolver(env)
|
||||
err := cfg.configureProviders(env, resolver, []provider.Provider{})
|
||||
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, cfg.Providers.Len(), 1)
|
||||
@@ -619,9 +619,9 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
"custom": {
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Type: provider.TypeOpenAI,
|
||||
Type: catwalk.TypeOpenAI,
|
||||
Disable: true,
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -631,7 +631,7 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
|
||||
env := env.NewFromMap(map[string]string{})
|
||||
resolver := NewEnvironmentVariableResolver(env)
|
||||
err := cfg.configureProviders(env, resolver, []provider.Provider{})
|
||||
err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, cfg.Providers.Len(), 0)
|
||||
@@ -642,12 +642,12 @@ func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
|
||||
|
||||
func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
|
||||
t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderVertexAI,
|
||||
ID: catwalk.InferenceProviderVertexAI,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "gemini-pro",
|
||||
}},
|
||||
},
|
||||
@@ -675,12 +675,12 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: provider.InferenceProviderBedrock,
|
||||
ID: catwalk.InferenceProviderBedrock,
|
||||
APIKey: "",
|
||||
APIEndpoint: "",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
}},
|
||||
},
|
||||
@@ -706,12 +706,12 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$MISSING_API_KEY",
|
||||
APIEndpoint: "https://api.openai.com/v1",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -737,12 +737,12 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$OPENAI_API_KEY",
|
||||
APIEndpoint: "$MISSING_ENDPOINT",
|
||||
Models: []provider.Model{{
|
||||
Models: []catwalk.Model{{
|
||||
ID: "test-model",
|
||||
}},
|
||||
},
|
||||
@@ -772,13 +772,13 @@ func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
|
||||
|
||||
func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "abc",
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -808,13 +808,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
assert.Equal(t, int64(500), small.MaxTokens)
|
||||
})
|
||||
t.Run("should error if no providers configured", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$MISSING_KEY",
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -838,13 +838,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("should error if model is missing", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "abc",
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "not-large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -868,13 +868,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("should configure the default models with a custom provider", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$MISSING", // will not be included in the config
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "not-large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -892,7 +892,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
"custom": {
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "model",
|
||||
DefaultMaxTokens: 600,
|
||||
@@ -917,13 +917,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("should fail if no model configured", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "$MISSING", // will not be included in the config
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "not-large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -941,7 +941,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
"custom": {
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Models: []provider.Model{},
|
||||
Models: []catwalk.Model{},
|
||||
},
|
||||
}),
|
||||
}
|
||||
@@ -954,13 +954,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
})
|
||||
t.Run("should use the default provider first", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "set",
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -978,7 +978,7 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
"custom": {
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.custom.com/v1",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -1005,13 +1005,13 @@ func TestConfig_defaultModelSelection(t *testing.T) {
|
||||
|
||||
func TestConfig_configureSelectedModels(t *testing.T) {
|
||||
t.Run("should override defaults", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "abc",
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "larger-model",
|
||||
DefaultMaxTokens: 2000,
|
||||
@@ -1053,13 +1053,13 @@ func TestConfig_configureSelectedModels(t *testing.T) {
|
||||
assert.Equal(t, int64(500), small.MaxTokens)
|
||||
})
|
||||
t.Run("should be possible to use multiple providers", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "abc",
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -1075,7 +1075,7 @@ func TestConfig_configureSelectedModels(t *testing.T) {
|
||||
APIKey: "abc",
|
||||
DefaultLargeModelID: "a-large-model",
|
||||
DefaultSmallModelID: "a-small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "a-large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
@@ -1116,13 +1116,13 @@ func TestConfig_configureSelectedModels(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("should override the max tokens only", func(t *testing.T) {
|
||||
knownProviders := []provider.Provider{
|
||||
knownProviders := []catwalk.Provider{
|
||||
{
|
||||
ID: "openai",
|
||||
APIKey: "abc",
|
||||
DefaultLargeModelID: "large-model",
|
||||
DefaultSmallModelID: "small-model",
|
||||
Models: []provider.Model{
|
||||
Models: []catwalk.Model{
|
||||
{
|
||||
ID: "large-model",
|
||||
DefaultMaxTokens: 1000,
|
||||
|
||||
@@ -10,17 +10,16 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/client"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
)
|
||||
|
||||
type ProviderClient interface {
|
||||
GetProviders() ([]provider.Provider, error)
|
||||
GetProviders() ([]catwalk.Provider, error)
|
||||
}
|
||||
|
||||
var (
|
||||
providerOnce sync.Once
|
||||
providerList []provider.Provider
|
||||
providerList []catwalk.Provider
|
||||
)
|
||||
|
||||
// file to cache provider data
|
||||
@@ -44,7 +43,7 @@ func providerCacheFileData() string {
|
||||
return filepath.Join(os.Getenv("HOME"), ".local", "share", appName, "providers.json")
|
||||
}
|
||||
|
||||
func saveProvidersInCache(path string, providers []provider.Provider) error {
|
||||
func saveProvidersInCache(path string, providers []catwalk.Provider) error {
|
||||
slog.Info("Caching provider data")
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return fmt.Errorf("failed to create directory for provider cache: %w", err)
|
||||
@@ -61,26 +60,26 @@ func saveProvidersInCache(path string, providers []provider.Provider) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadProvidersFromCache(path string) ([]provider.Provider, error) {
|
||||
func loadProvidersFromCache(path string) ([]catwalk.Provider, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read provider cache file: %w", err)
|
||||
}
|
||||
|
||||
var providers []provider.Provider
|
||||
var providers []catwalk.Provider
|
||||
if err := json.Unmarshal(data, &providers); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal provider data from cache: %w", err)
|
||||
}
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
func Providers() ([]provider.Provider, error) {
|
||||
client := client.New()
|
||||
func Providers() ([]catwalk.Provider, error) {
|
||||
client := catwalk.NewWithURL(catwalkURL)
|
||||
path := providerCacheFileData()
|
||||
return loadProvidersOnce(client, path)
|
||||
}
|
||||
|
||||
func loadProvidersOnce(client ProviderClient, path string) ([]provider.Provider, error) {
|
||||
func loadProvidersOnce(client ProviderClient, path string) ([]catwalk.Provider, error) {
|
||||
var err error
|
||||
providerOnce.Do(func() {
|
||||
providerList, err = loadProviders(client, path)
|
||||
@@ -91,7 +90,7 @@ func loadProvidersOnce(client ProviderClient, path string) ([]provider.Provider,
|
||||
return providerList, nil
|
||||
}
|
||||
|
||||
func loadProviders(client ProviderClient, path string) (providerList []provider.Provider, err error) {
|
||||
func loadProviders(client ProviderClient, path string) (providerList []catwalk.Provider, err error) {
|
||||
// if cache is not stale, load from it
|
||||
stale, exists := isCacheStale(path)
|
||||
if !stale {
|
||||
|
||||
@@ -5,14 +5,14 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type emptyProviderClient struct{}
|
||||
|
||||
func (m *emptyProviderClient) GetProviders() ([]provider.Provider, error) {
|
||||
return []provider.Provider{}, nil
|
||||
func (m *emptyProviderClient) GetProviders() ([]catwalk.Provider, error) {
|
||||
return []catwalk.Provider{}, nil
|
||||
}
|
||||
|
||||
func TestProvider_loadProvidersEmptyResult(t *testing.T) {
|
||||
@@ -33,7 +33,7 @@ func TestProvider_loadProvidersEmptyCache(t *testing.T) {
|
||||
tmpPath := t.TempDir() + "/providers.json"
|
||||
|
||||
// Create an empty cache file
|
||||
emptyProviders := []provider.Provider{}
|
||||
emptyProviders := []catwalk.Provider{}
|
||||
data, err := json.Marshal(emptyProviders)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(tmpPath, data, 0o644))
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -14,11 +14,11 @@ type mockProviderClient struct {
|
||||
shouldFail bool
|
||||
}
|
||||
|
||||
func (m *mockProviderClient) GetProviders() ([]provider.Provider, error) {
|
||||
func (m *mockProviderClient) GetProviders() ([]catwalk.Provider, error) {
|
||||
if m.shouldFail {
|
||||
return nil, errors.New("failed to load providers")
|
||||
}
|
||||
return []provider.Provider{
|
||||
return []catwalk.Provider{
|
||||
{
|
||||
Name: "Mock",
|
||||
},
|
||||
@@ -43,7 +43,7 @@ func TestProvider_loadProvidersWithIssues(t *testing.T) {
|
||||
client := &mockProviderClient{shouldFail: true}
|
||||
tmpPath := t.TempDir() + "/providers.json"
|
||||
// store providers to a temporary file
|
||||
oldProviders := []provider.Provider{
|
||||
oldProviders := []catwalk.Provider{
|
||||
{
|
||||
Name: "OldProvider",
|
||||
},
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
// Package client provides a client for interacting with the fur service.
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
)
|
||||
|
||||
const defaultURL = "https://fur.charm.sh"
|
||||
|
||||
// Client represents a client for the fur service.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// New creates a new client instance
|
||||
// Uses FUR_URL environment variable or falls back to localhost:8080.
|
||||
func New() *Client {
|
||||
baseURL := os.Getenv("FUR_URL")
|
||||
if baseURL == "" {
|
||||
baseURL = defaultURL
|
||||
}
|
||||
|
||||
return &Client{
|
||||
baseURL: baseURL,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
// NewWithURL creates a new client with a specific URL.
|
||||
func NewWithURL(url string) *Client {
|
||||
return &Client{
|
||||
baseURL: url,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
// GetProviders retrieves all available providers from the service.
|
||||
func (c *Client) GetProviders() ([]provider.Provider, error) {
|
||||
url := fmt.Sprintf("%s/providers", c.baseURL)
|
||||
|
||||
resp, err := c.httpClient.Get(url) //nolint:noctx
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close() //nolint:errcheck
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var providers []provider.Provider
|
||||
if err := json.NewDecoder(resp.Body).Decode(&providers); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return providers, nil
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
// Package provider provides types and constants for AI providers.
|
||||
package provider
|
||||
|
||||
// Type represents the type of AI provider.
|
||||
type Type string
|
||||
|
||||
// All the supported AI provider types.
|
||||
const (
|
||||
TypeOpenAI Type = "openai"
|
||||
TypeAnthropic Type = "anthropic"
|
||||
TypeGemini Type = "gemini"
|
||||
TypeAzure Type = "azure"
|
||||
TypeBedrock Type = "bedrock"
|
||||
TypeVertexAI Type = "vertexai"
|
||||
TypeXAI Type = "xai"
|
||||
)
|
||||
|
||||
// InferenceProvider represents the inference provider identifier.
|
||||
type InferenceProvider string
|
||||
|
||||
// All the inference providers supported by the system.
|
||||
const (
|
||||
InferenceProviderOpenAI InferenceProvider = "openai"
|
||||
InferenceProviderAnthropic InferenceProvider = "anthropic"
|
||||
InferenceProviderGemini InferenceProvider = "gemini"
|
||||
InferenceProviderAzure InferenceProvider = "azure"
|
||||
InferenceProviderBedrock InferenceProvider = "bedrock"
|
||||
InferenceProviderVertexAI InferenceProvider = "vertexai"
|
||||
InferenceProviderXAI InferenceProvider = "xai"
|
||||
InferenceProviderGROQ InferenceProvider = "groq"
|
||||
InferenceProviderOpenRouter InferenceProvider = "openrouter"
|
||||
)
|
||||
|
||||
// Provider represents an AI provider configuration.
|
||||
type Provider struct {
|
||||
Name string `json:"name"`
|
||||
ID InferenceProvider `json:"id"`
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
APIEndpoint string `json:"api_endpoint,omitempty"`
|
||||
Type Type `json:"type,omitempty"`
|
||||
DefaultLargeModelID string `json:"default_large_model_id,omitempty"`
|
||||
DefaultSmallModelID string `json:"default_small_model_id,omitempty"`
|
||||
Models []Model `json:"models,omitempty"`
|
||||
}
|
||||
|
||||
// Model represents an AI model configuration.
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
CostPer1MIn float64 `json:"cost_per_1m_in"`
|
||||
CostPer1MOut float64 `json:"cost_per_1m_out"`
|
||||
CostPer1MInCached float64 `json:"cost_per_1m_in_cached"`
|
||||
CostPer1MOutCached float64 `json:"cost_per_1m_out_cached"`
|
||||
ContextWindow int64 `json:"context_window"`
|
||||
DefaultMaxTokens int64 `json:"default_max_tokens"`
|
||||
CanReason bool `json:"can_reason"`
|
||||
HasReasoningEffort bool `json:"has_reasoning_efforts"`
|
||||
DefaultReasoningEffort string `json:"default_reasoning_effort,omitempty"`
|
||||
SupportsImages bool `json:"supports_attachments"`
|
||||
}
|
||||
|
||||
// KnownProviders returns all the known inference providers.
|
||||
func KnownProviders() []InferenceProvider {
|
||||
return []InferenceProvider{
|
||||
InferenceProviderOpenAI,
|
||||
InferenceProviderAnthropic,
|
||||
InferenceProviderGemini,
|
||||
InferenceProviderAzure,
|
||||
InferenceProviderBedrock,
|
||||
InferenceProviderVertexAI,
|
||||
InferenceProviderXAI,
|
||||
InferenceProviderGROQ,
|
||||
InferenceProviderOpenRouter,
|
||||
}
|
||||
}
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/csync"
|
||||
fur "github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/history"
|
||||
"github.com/charmbracelet/crush/internal/llm/prompt"
|
||||
"github.com/charmbracelet/crush/internal/llm/provider"
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/charmbracelet/crush/internal/permission"
|
||||
"github.com/charmbracelet/crush/internal/pubsub"
|
||||
"github.com/charmbracelet/crush/internal/session"
|
||||
"github.com/charmbracelet/crush/internal/shell"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
@@ -52,7 +53,7 @@ type AgentEvent struct {
|
||||
|
||||
type Service interface {
|
||||
pubsub.Suscriber[AgentEvent]
|
||||
Model() fur.Model
|
||||
Model() catwalk.Model
|
||||
Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
|
||||
Cancel(sessionID string)
|
||||
CancelAll()
|
||||
@@ -226,7 +227,7 @@ func NewAgent(
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *agent) Model() fur.Model {
|
||||
func (a *agent) Model() catwalk.Model {
|
||||
return *config.Get().GetModelByType(a.agentCfg.Model)
|
||||
}
|
||||
|
||||
@@ -234,7 +235,7 @@ func (a *agent) Cancel(sessionID string) {
|
||||
// Cancel regular requests
|
||||
if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
|
||||
if cancel, ok := cancelFunc.(context.CancelFunc); ok {
|
||||
slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
|
||||
slog.Info("Request cancellation initiated", "session_id", sessionID)
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
@@ -242,7 +243,7 @@ func (a *agent) Cancel(sessionID string) {
|
||||
// Also check for summarize requests
|
||||
if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
|
||||
if cancel, ok := cancelFunc.(context.CancelFunc); ok {
|
||||
slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
|
||||
slog.Info("Summarize cancellation initiated", "session_id", sessionID)
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
@@ -372,7 +373,7 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
|
||||
})
|
||||
titleErr := a.generateTitle(context.Background(), sessionID, content)
|
||||
if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
|
||||
slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
|
||||
slog.Error("failed to generate title", "error", titleErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -645,7 +646,7 @@ func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
|
||||
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
|
||||
sess, err := a.sessions.Get(ctx, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get session: %w", err)
|
||||
@@ -770,6 +771,8 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
|
||||
a.Publish(pubsub.CreatedEvent, event)
|
||||
return
|
||||
}
|
||||
shell := shell.GetPersistentShell(config.Get().WorkingDir())
|
||||
summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
|
||||
event = AgentEvent{
|
||||
Type: AgentEventTypeSummarize,
|
||||
Progress: "Creating new session...",
|
||||
|
||||
@@ -9,17 +9,17 @@ import (
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
)
|
||||
|
||||
func CoderPrompt(p string, contextFiles ...string) string {
|
||||
var basePrompt string
|
||||
switch p {
|
||||
case string(provider.InferenceProviderOpenAI):
|
||||
case string(catwalk.InferenceProviderOpenAI):
|
||||
basePrompt = baseOpenAICoderPrompt
|
||||
case string(provider.InferenceProviderGemini), string(provider.InferenceProviderVertexAI):
|
||||
case string(catwalk.InferenceProviderGemini), string(catwalk.InferenceProviderVertexAI):
|
||||
basePrompt = baseGeminiCoderPrompt
|
||||
default:
|
||||
basePrompt = baseAnthropicCoderPrompt
|
||||
|
||||
@@ -15,8 +15,8 @@ import (
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/bedrock"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
)
|
||||
@@ -71,7 +71,7 @@ func (a *anthropicClient) convertMessages(messages []message.Message) (anthropic
|
||||
var contentBlocks []anthropic.ContentBlockParamUnion
|
||||
contentBlocks = append(contentBlocks, content)
|
||||
for _, binaryContent := range msg.BinaryContent() {
|
||||
base64Image := binaryContent.String(provider.InferenceProviderAnthropic)
|
||||
base64Image := binaryContent.String(catwalk.InferenceProviderAnthropic)
|
||||
imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
|
||||
contentBlocks = append(contentBlocks, imageBlock)
|
||||
}
|
||||
@@ -248,7 +248,7 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message,
|
||||
return nil, retryErr
|
||||
}
|
||||
if retry {
|
||||
slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
@@ -401,7 +401,7 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message
|
||||
return
|
||||
}
|
||||
if retry {
|
||||
slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// context cancelled
|
||||
@@ -529,6 +529,6 @@ func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *anthropicClient) Model() provider.Model {
|
||||
func (a *anthropicClient) Model() catwalk.Model {
|
||||
return a.providerOptions.model(a.providerOptions.modelType)
|
||||
}
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
)
|
||||
@@ -32,7 +32,7 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
|
||||
}
|
||||
}
|
||||
|
||||
opts.model = func(modelType config.SelectedModelType) provider.Model {
|
||||
opts.model = func(modelType config.SelectedModelType) catwalk.Model {
|
||||
model := config.Get().GetModelByType(modelType)
|
||||
|
||||
// Prefix the model name with region
|
||||
@@ -88,6 +88,6 @@ func (b *bedrockClient) stream(ctx context.Context, messages []message.Message,
|
||||
return b.childProvider.stream(ctx, messages, tools)
|
||||
}
|
||||
|
||||
func (b *bedrockClient) Model() provider.Model {
|
||||
func (b *bedrockClient) Model() catwalk.Model {
|
||||
return b.providerOptions.model(b.providerOptions.modelType)
|
||||
}
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/google/uuid"
|
||||
@@ -210,7 +210,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
|
||||
return nil, retryErr
|
||||
}
|
||||
if retry {
|
||||
slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
@@ -323,7 +323,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
return
|
||||
}
|
||||
if retry {
|
||||
slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if ctx.Err() != nil {
|
||||
@@ -463,7 +463,7 @@ func (g *geminiClient) usage(resp *genai.GenerateContentResponse) TokenUsage {
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiClient) Model() provider.Model {
|
||||
func (g *geminiClient) Model() catwalk.Model {
|
||||
return g.providerOptions.model(g.providerOptions.modelType)
|
||||
}
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/openai/openai-go"
|
||||
@@ -66,7 +66,7 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
|
||||
textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
|
||||
content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
|
||||
for _, binaryContent := range msg.BinaryContent() {
|
||||
imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
|
||||
imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)}
|
||||
imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
|
||||
|
||||
content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
|
||||
@@ -222,7 +222,7 @@ func (o *openaiClient) send(ctx context.Context, messages []message.Message, too
|
||||
return nil, retryErr
|
||||
}
|
||||
if retry {
|
||||
slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
@@ -395,7 +395,7 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
|
||||
return
|
||||
}
|
||||
if retry {
|
||||
slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
|
||||
slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// context cancelled
|
||||
@@ -486,6 +486,6 @@ func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
|
||||
}
|
||||
}
|
||||
|
||||
func (o *openaiClient) Model() provider.Model {
|
||||
func (o *openaiClient) Model() catwalk.Model {
|
||||
return o.providerOptions.model(o.providerOptions.modelType)
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/openai/openai-go"
|
||||
@@ -55,10 +55,10 @@ func TestOpenAIClientStreamChoices(t *testing.T) {
|
||||
modelType: config.SelectedModelTypeLarge,
|
||||
apiKey: "test-key",
|
||||
systemMessage: "test",
|
||||
model: func(config.SelectedModelType) provider.Model {
|
||||
return provider.Model{
|
||||
ID: "test-model",
|
||||
Model: "test-model",
|
||||
model: func(config.SelectedModelType) catwalk.Model {
|
||||
return catwalk.Model{
|
||||
ID: "test-model",
|
||||
Name: "test-model",
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
)
|
||||
@@ -57,7 +57,7 @@ type Provider interface {
|
||||
|
||||
StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
|
||||
|
||||
Model() provider.Model
|
||||
Model() catwalk.Model
|
||||
}
|
||||
|
||||
type providerClientOptions struct {
|
||||
@@ -65,7 +65,7 @@ type providerClientOptions struct {
|
||||
config config.ProviderConfig
|
||||
apiKey string
|
||||
modelType config.SelectedModelType
|
||||
model func(config.SelectedModelType) provider.Model
|
||||
model func(config.SelectedModelType) catwalk.Model
|
||||
disableCache bool
|
||||
systemMessage string
|
||||
maxTokens int64
|
||||
@@ -80,7 +80,7 @@ type ProviderClient interface {
|
||||
send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
|
||||
stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
|
||||
|
||||
Model() provider.Model
|
||||
Model() catwalk.Model
|
||||
}
|
||||
|
||||
type baseProvider[C ProviderClient] struct {
|
||||
@@ -109,7 +109,7 @@ func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message
|
||||
return p.client.stream(ctx, messages, tools)
|
||||
}
|
||||
|
||||
func (p *baseProvider[C]) Model() provider.Model {
|
||||
func (p *baseProvider[C]) Model() catwalk.Model {
|
||||
return p.client.Model()
|
||||
}
|
||||
|
||||
@@ -149,7 +149,7 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
|
||||
apiKey: resolvedAPIKey,
|
||||
extraHeaders: cfg.ExtraHeaders,
|
||||
extraBody: cfg.ExtraBody,
|
||||
model: func(tp config.SelectedModelType) provider.Model {
|
||||
model: func(tp config.SelectedModelType) catwalk.Model {
|
||||
return *config.Get().GetModelByType(tp)
|
||||
},
|
||||
}
|
||||
@@ -157,37 +157,37 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
|
||||
o(&clientOptions)
|
||||
}
|
||||
switch cfg.Type {
|
||||
case provider.TypeAnthropic:
|
||||
case catwalk.TypeAnthropic:
|
||||
return &baseProvider[AnthropicClient]{
|
||||
options: clientOptions,
|
||||
client: newAnthropicClient(clientOptions, false),
|
||||
}, nil
|
||||
case provider.TypeOpenAI:
|
||||
case catwalk.TypeOpenAI:
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
client: newOpenAIClient(clientOptions),
|
||||
}, nil
|
||||
case provider.TypeGemini:
|
||||
case catwalk.TypeGemini:
|
||||
return &baseProvider[GeminiClient]{
|
||||
options: clientOptions,
|
||||
client: newGeminiClient(clientOptions),
|
||||
}, nil
|
||||
case provider.TypeBedrock:
|
||||
case catwalk.TypeBedrock:
|
||||
return &baseProvider[BedrockClient]{
|
||||
options: clientOptions,
|
||||
client: newBedrockClient(clientOptions),
|
||||
}, nil
|
||||
case provider.TypeAzure:
|
||||
case catwalk.TypeAzure:
|
||||
return &baseProvider[AzureClient]{
|
||||
options: clientOptions,
|
||||
client: newAzureClient(clientOptions),
|
||||
}, nil
|
||||
case provider.TypeVertexAI:
|
||||
case catwalk.TypeVertexAI:
|
||||
return &baseProvider[VertexAIClient]{
|
||||
options: clientOptions,
|
||||
client: newVertexAIClient(clientOptions),
|
||||
}, nil
|
||||
case provider.TypeXAI:
|
||||
case catwalk.TypeXAI:
|
||||
clientOptions.baseURL = "https://api.x.ai/v1"
|
||||
return &baseProvider[OpenAIClient]{
|
||||
options: clientOptions,
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -23,8 +22,10 @@ type BashPermissionsParams struct {
|
||||
}
|
||||
|
||||
type BashResponseMetadata struct {
|
||||
StartTime int64 `json:"start_time"`
|
||||
EndTime int64 `json:"end_time"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
EndTime int64 `json:"end_time"`
|
||||
Output string `json:"output"`
|
||||
WorkingDirectory string `json:"working_directory"`
|
||||
}
|
||||
type bashTool struct {
|
||||
permissions permission.Service
|
||||
@@ -146,6 +147,7 @@ Before executing the command, please follow these steps:
|
||||
5. Return Result:
|
||||
- Provide the processed output of the command.
|
||||
- If any errors occurred during execution, include those in the output.
|
||||
- The result will also have metadata like the cwd (current working directory) at the end, included with <cwd></cwd> tags.
|
||||
|
||||
Usage notes:
|
||||
- The command argument is required.
|
||||
@@ -389,9 +391,12 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
|
||||
defer cancel()
|
||||
}
|
||||
stdout, stderr, err := shell.
|
||||
GetPersistentShell(b.workingDir).
|
||||
Exec(ctx, params.Command)
|
||||
|
||||
persistentShell := shell.GetPersistentShell(b.workingDir)
|
||||
stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
|
||||
|
||||
// Get the current working directory after command execution
|
||||
currentWorkingDir := persistentShell.GetWorkingDir()
|
||||
interrupted := shell.IsInterrupt(err)
|
||||
exitCode := shell.ExitCode(err)
|
||||
if exitCode == 0 && !interrupted && err != nil {
|
||||
@@ -401,15 +406,6 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
stdout = truncateOutput(stdout)
|
||||
stderr = truncateOutput(stderr)
|
||||
|
||||
slog.Info("Bash command executed",
|
||||
"command", params.Command,
|
||||
"stdout", stdout,
|
||||
"stderr", stderr,
|
||||
"exit_code", exitCode,
|
||||
"interrupted", interrupted,
|
||||
"err", err,
|
||||
)
|
||||
|
||||
errorMessage := stderr
|
||||
if errorMessage == "" && err != nil {
|
||||
errorMessage = err.Error()
|
||||
@@ -438,9 +434,12 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
|
||||
}
|
||||
|
||||
metadata := BashResponseMetadata{
|
||||
StartTime: startTime.UnixMilli(),
|
||||
EndTime: time.Now().UnixMilli(),
|
||||
StartTime: startTime.UnixMilli(),
|
||||
EndTime: time.Now().UnixMilli(),
|
||||
Output: stdout,
|
||||
WorkingDirectory: currentWorkingDir,
|
||||
}
|
||||
stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", currentWorkingDir)
|
||||
if stdout == "" {
|
||||
return WithResponseMetadata(NewTextResponse(BashNoOutput), metadata), nil
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func globFiles(pattern, searchPath string, limit int) ([]string, bool, error) {
|
||||
if err == nil {
|
||||
return matches, len(matches) >= limit && limit > 0, nil
|
||||
}
|
||||
slog.Warn(fmt.Sprintf("Ripgrep execution failed: %v. Falling back to doublestar.", err))
|
||||
slog.Warn("Ripgrep execution failed, falling back to doublestar", "error", err)
|
||||
}
|
||||
|
||||
return fsext.GlobWithDoubleStar(pattern, searchPath, limit)
|
||||
|
||||
@@ -90,7 +90,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
|
||||
slog.Debug("BaseURI", "baseURI", u)
|
||||
}
|
||||
default:
|
||||
slog.Debug("GlobPattern", "unknown type", fmt.Sprintf("%T", v))
|
||||
slog.Debug("GlobPattern unknown type", "type", fmt.Sprintf("%T", v))
|
||||
}
|
||||
|
||||
// Log WatchKind
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
)
|
||||
|
||||
type MessageRole string
|
||||
@@ -74,9 +74,9 @@ type BinaryContent struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (bc BinaryContent) String(p provider.InferenceProvider) string {
|
||||
func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
|
||||
base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
|
||||
if p == provider.InferenceProviderOpenAI {
|
||||
if p == catwalk.InferenceProviderOpenAI {
|
||||
return "data:" + bc.MIMEType + ";base64," + base64Encoded
|
||||
}
|
||||
return base64Encoded
|
||||
|
||||
@@ -50,6 +50,7 @@ type permissionService struct {
|
||||
autoApproveSessions []string
|
||||
autoApproveSessionsMu sync.RWMutex
|
||||
skip bool
|
||||
allowedTools []string
|
||||
}
|
||||
|
||||
func (s *permissionService) GrantPersistent(permission PermissionRequest) {
|
||||
@@ -82,6 +83,12 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if the tool/action combination is in the allowlist
|
||||
commandKey := opts.ToolName + ":" + opts.Action
|
||||
if slices.Contains(s.allowedTools, commandKey) || slices.Contains(s.allowedTools, opts.ToolName) {
|
||||
return true
|
||||
}
|
||||
|
||||
s.autoApproveSessionsMu.RLock()
|
||||
autoApprove := slices.Contains(s.autoApproveSessions, opts.SessionID)
|
||||
s.autoApproveSessionsMu.RUnlock()
|
||||
@@ -130,11 +137,12 @@ func (s *permissionService) AutoApproveSession(sessionID string) {
|
||||
s.autoApproveSessionsMu.Unlock()
|
||||
}
|
||||
|
||||
func NewPermissionService(workingDir string, skip bool) Service {
|
||||
func NewPermissionService(workingDir string, skip bool, allowedTools []string) Service {
|
||||
return &permissionService{
|
||||
Broker: pubsub.NewBroker[PermissionRequest](),
|
||||
workingDir: workingDir,
|
||||
sessionPermissions: make([]PermissionRequest, 0),
|
||||
skip: skip,
|
||||
allowedTools: allowedTools,
|
||||
}
|
||||
}
|
||||
|
||||
92
internal/permission/permission_test.go
Normal file
92
internal/permission/permission_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package permission
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPermissionService_AllowedCommands(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedTools []string
|
||||
toolName string
|
||||
action string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "tool in allowlist",
|
||||
allowedTools: []string{"bash", "view"},
|
||||
toolName: "bash",
|
||||
action: "execute",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "tool:action in allowlist",
|
||||
allowedTools: []string{"bash:execute", "edit:create"},
|
||||
toolName: "bash",
|
||||
action: "execute",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "tool not in allowlist",
|
||||
allowedTools: []string{"view", "ls"},
|
||||
toolName: "bash",
|
||||
action: "execute",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "tool:action not in allowlist",
|
||||
allowedTools: []string{"bash:read", "edit:create"},
|
||||
toolName: "bash",
|
||||
action: "execute",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty allowlist",
|
||||
allowedTools: []string{},
|
||||
toolName: "bash",
|
||||
action: "execute",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
service := NewPermissionService("/tmp", false, tt.allowedTools)
|
||||
|
||||
// Create a channel to capture the permission request
|
||||
// Since we're testing the allowlist logic, we need to simulate the request
|
||||
ps := service.(*permissionService)
|
||||
|
||||
// Test the allowlist logic directly
|
||||
commandKey := tt.toolName + ":" + tt.action
|
||||
allowed := false
|
||||
for _, cmd := range ps.allowedTools {
|
||||
if cmd == commandKey || cmd == tt.toolName {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allowed != tt.expected {
|
||||
t.Errorf("expected %v, got %v for tool %s action %s with allowlist %v",
|
||||
tt.expected, allowed, tt.toolName, tt.action, tt.allowedTools)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionService_SkipMode(t *testing.T) {
|
||||
service := NewPermissionService("/tmp", true, []string{})
|
||||
|
||||
result := service.Request(CreatePermissionRequest{
|
||||
SessionID: "test-session",
|
||||
ToolName: "bash",
|
||||
Action: "execute",
|
||||
Description: "test command",
|
||||
Path: "/tmp",
|
||||
})
|
||||
|
||||
if !result {
|
||||
t.Error("expected permission to be granted in skip mode")
|
||||
}
|
||||
}
|
||||
59
internal/shell/coreutils.go
Normal file
59
internal/shell/coreutils.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package shell
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/u-root/u-root/pkg/core"
|
||||
"github.com/u-root/u-root/pkg/core/cat"
|
||||
"github.com/u-root/u-root/pkg/core/chmod"
|
||||
"github.com/u-root/u-root/pkg/core/cp"
|
||||
"github.com/u-root/u-root/pkg/core/find"
|
||||
"github.com/u-root/u-root/pkg/core/ls"
|
||||
"github.com/u-root/u-root/pkg/core/mkdir"
|
||||
"github.com/u-root/u-root/pkg/core/mv"
|
||||
"github.com/u-root/u-root/pkg/core/rm"
|
||||
"github.com/u-root/u-root/pkg/core/touch"
|
||||
"github.com/u-root/u-root/pkg/core/xargs"
|
||||
"mvdan.cc/sh/v3/interp"
|
||||
)
|
||||
|
||||
var coreUtils = map[string]func() core.Command{
|
||||
"cat": func() core.Command { return cat.New() },
|
||||
"chmod": func() core.Command { return chmod.New() },
|
||||
"cp": func() core.Command { return cp.New() },
|
||||
"find": func() core.Command { return find.New() },
|
||||
"ls": func() core.Command { return ls.New() },
|
||||
"mkdir": func() core.Command { return mkdir.New() },
|
||||
"mv": func() core.Command { return mv.New() },
|
||||
"rm": func() core.Command { return rm.New() },
|
||||
"touch": func() core.Command { return touch.New() },
|
||||
"xargs": func() core.Command { return xargs.New() },
|
||||
}
|
||||
|
||||
func (s *Shell) coreUtilsHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
|
||||
return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
|
||||
return func(ctx context.Context, args []string) error {
|
||||
if len(args) == 0 {
|
||||
return next(ctx, args)
|
||||
}
|
||||
|
||||
program, programArgs := args[0], args[1:]
|
||||
|
||||
newCoreUtil, ok := coreUtils[program]
|
||||
if !ok {
|
||||
return next(ctx, args)
|
||||
}
|
||||
|
||||
c := interp.HandlerCtx(ctx)
|
||||
|
||||
cmd := newCoreUtil()
|
||||
cmd.SetIO(c.Stdin, c.Stdout, c.Stderr)
|
||||
cmd.SetWorkingDir(c.Dir)
|
||||
cmd.SetLookupEnv(func(key string) (string, bool) {
|
||||
v := c.Env.Get(key)
|
||||
return v.Str, v.Set
|
||||
})
|
||||
return cmd.RunContext(ctx, programArgs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -221,7 +221,7 @@ func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string,
|
||||
interp.Interactive(false),
|
||||
interp.Env(expand.ListEnviron(s.env...)),
|
||||
interp.Dir(s.cwd),
|
||||
interp.ExecHandlers(s.blockHandler()),
|
||||
interp.ExecHandlers(s.blockHandler(), s.coreUtilsHandler()),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("could not run command: %w", err)
|
||||
|
||||
@@ -289,7 +289,7 @@ func (a Anim) View() string {
|
||||
var b strings.Builder
|
||||
for i := range a.width {
|
||||
switch {
|
||||
case !a.initialized && time.Since(a.startTime) < a.birthOffsets[i]:
|
||||
case !a.initialized && i < len(a.birthOffsets) && time.Since(a.startTime) < a.birthOffsets[i]:
|
||||
// Birth offset not reached: render initial character.
|
||||
b.WriteString(a.initialFrames[a.step][i])
|
||||
case i < a.cyclingCharWidth:
|
||||
|
||||
@@ -187,9 +187,11 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
value = value[:m.completionsStartIndex]
|
||||
value += item.Path
|
||||
m.textarea.SetValue(value)
|
||||
m.isCompletionsOpen = false
|
||||
m.currentQuery = ""
|
||||
m.completionsStartIndex = 0
|
||||
if !msg.Insert {
|
||||
m.isCompletionsOpen = false
|
||||
m.currentQuery = ""
|
||||
m.completionsStartIndex = 0
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
case openEditorMsg:
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
|
||||
"github.com/charmbracelet/bubbles/v2/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/lipgloss/v2"
|
||||
"github.com/charmbracelet/x/ansi"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/message"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/anim"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/core"
|
||||
@@ -369,11 +369,11 @@ func (m *assistantSectionModel) View() string {
|
||||
model := config.Get().GetModel(m.message.Provider, m.message.Model)
|
||||
if model == nil {
|
||||
// This means the model is not configured anymore
|
||||
model = &provider.Model{
|
||||
Model: "Unknown Model",
|
||||
model = &catwalk.Model{
|
||||
Name: "Unknown Model",
|
||||
}
|
||||
}
|
||||
modelFormatted := t.S().Muted.Render(model.Model)
|
||||
modelFormatted := t.S().Muted.Render(model.Name)
|
||||
assistant := fmt.Sprintf("%s %s %s", icon, modelFormatted, infoMsg)
|
||||
return t.S().Base.PaddingLeft(2).Render(
|
||||
core.Section(assistant, m.width-2),
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/ansiext"
|
||||
"github.com/charmbracelet/crush/internal/fsext"
|
||||
"github.com/charmbracelet/crush/internal/llm/agent"
|
||||
"github.com/charmbracelet/crush/internal/llm/tools"
|
||||
@@ -212,10 +213,19 @@ func (br bashRenderer) Render(v *toolCallCmp) string {
|
||||
args := newParamBuilder().addMain(cmd).build()
|
||||
|
||||
return br.renderWithParams(v, "Bash", args, func() string {
|
||||
if v.result.Content == tools.BashNoOutput {
|
||||
var meta tools.BashResponseMetadata
|
||||
if err := br.unmarshalParams(v.result.Metadata, &meta); err != nil {
|
||||
return renderPlainContent(v, v.result.Content)
|
||||
}
|
||||
// for backwards compatibility with older tool calls.
|
||||
if meta.Output == "" && v.result.Content != tools.BashNoOutput {
|
||||
meta.Output = v.result.Content
|
||||
}
|
||||
|
||||
if meta.Output == "" {
|
||||
return ""
|
||||
}
|
||||
return renderPlainContent(v, v.result.Content)
|
||||
return renderPlainContent(v, meta.Output)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -693,7 +703,7 @@ func renderPlainContent(v *toolCallCmp, content string) string {
|
||||
if i >= responseContextHeight {
|
||||
break
|
||||
}
|
||||
ln = escapeContent(ln)
|
||||
ln = ansiext.Escape(ln)
|
||||
ln = " " + ln // left padding
|
||||
if len(ln) > width {
|
||||
ln = v.fit(ln, width)
|
||||
@@ -731,7 +741,7 @@ func renderCodeContent(v *toolCallCmp, path, content string, offset int) string
|
||||
|
||||
lines := strings.Split(truncated, "\n")
|
||||
for i, ln := range lines {
|
||||
lines[i] = escapeContent(ln)
|
||||
lines[i] = ansiext.Escape(ln)
|
||||
}
|
||||
|
||||
highlighted, _ := highlight.SyntaxHighlight(strings.Join(lines, "\n"), path, t.BgBase)
|
||||
@@ -807,20 +817,3 @@ func prettifyToolName(name string) string {
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
// escapeContent replaces control characters with their Unicode Control Picture
|
||||
// representations to ensure they are displayed correctly in the UI.
|
||||
func escapeContent(content string) string {
|
||||
var sb strings.Builder
|
||||
for _, r := range content {
|
||||
switch {
|
||||
case r >= 0 && r <= 0x1f: // Control characters 0x00-0x1F
|
||||
sb.WriteRune('\u2400' + r)
|
||||
case r == ansi.DEL:
|
||||
sb.WriteRune('\u2421')
|
||||
default:
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
@@ -9,10 +9,10 @@ import (
|
||||
"sync"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/diff"
|
||||
"github.com/charmbracelet/crush/internal/fsext"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/history"
|
||||
"github.com/charmbracelet/crush/internal/lsp"
|
||||
"github.com/charmbracelet/crush/internal/lsp/protocol"
|
||||
@@ -897,7 +897,7 @@ func (s *sidebarCmp) currentModelBlock() string {
|
||||
t := styles.CurrentTheme()
|
||||
|
||||
modelIcon := t.S().Base.Foreground(t.FgSubtle).Render(styles.ModelIcon)
|
||||
modelName := t.S().Text.Render(model.Model)
|
||||
modelName := t.S().Text.Render(model.Name)
|
||||
modelInfo := fmt.Sprintf("%s %s", modelIcon, modelName)
|
||||
parts := []string{
|
||||
modelInfo,
|
||||
@@ -905,14 +905,14 @@ func (s *sidebarCmp) currentModelBlock() string {
|
||||
if model.CanReason {
|
||||
reasoningInfoStyle := t.S().Subtle.PaddingLeft(2)
|
||||
switch modelProvider.Type {
|
||||
case provider.TypeOpenAI:
|
||||
case catwalk.TypeOpenAI:
|
||||
reasoningEffort := model.DefaultReasoningEffort
|
||||
if selectedModel.ReasoningEffort != "" {
|
||||
reasoningEffort = selectedModel.ReasoningEffort
|
||||
}
|
||||
formatter := cases.Title(language.English, cases.NoLower)
|
||||
parts = append(parts, reasoningInfoStyle.Render(formatter.String(fmt.Sprintf("Reasoning %s", reasoningEffort))))
|
||||
case provider.TypeAnthropic:
|
||||
case catwalk.TypeAnthropic:
|
||||
formatter := cases.Title(language.English, cases.NoLower)
|
||||
if selectedModel.Think {
|
||||
parts = append(parts, reasoningInfoStyle.Render(formatter.String("Thinking on")))
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"github.com/charmbracelet/bubbles/v2/key"
|
||||
"github.com/charmbracelet/bubbles/v2/spinner"
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/prompt"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/chat"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/completions"
|
||||
@@ -109,7 +109,7 @@ func (s *splashCmp) SetOnboarding(onboarding bool) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
filteredProviders := []provider.Provider{}
|
||||
filteredProviders := []catwalk.Provider{}
|
||||
simpleProviders := []string{
|
||||
"anthropic",
|
||||
"openai",
|
||||
@@ -407,7 +407,7 @@ func (s *splashCmp) setPreferredModel(selectedItem models.ModelOption) tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *splashCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
|
||||
func (s *splashCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
|
||||
providers, err := config.Providers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -36,7 +36,8 @@ type CompletionsOpenedMsg struct{}
|
||||
type CloseCompletionsMsg struct{}
|
||||
|
||||
type SelectCompletionMsg struct {
|
||||
Value any // The value of the selected completion item
|
||||
Value any // The value of the selected completion item
|
||||
Insert bool
|
||||
}
|
||||
|
||||
type Completions interface {
|
||||
@@ -115,6 +116,30 @@ func (c *completionsCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
d, cmd := c.list.Update(msg)
|
||||
c.list = d.(list.ListModel)
|
||||
return c, cmd
|
||||
case key.Matches(msg, c.keyMap.UpInsert):
|
||||
selectedItemInx := c.list.SelectedIndex() - 1
|
||||
items := c.list.Items()
|
||||
if selectedItemInx == list.NoSelection || selectedItemInx < 0 {
|
||||
return c, nil // No item selected, do nothing
|
||||
}
|
||||
selectedItem := items[selectedItemInx].(CompletionItem).Value()
|
||||
c.list.SetSelected(selectedItemInx)
|
||||
return c, util.CmdHandler(SelectCompletionMsg{
|
||||
Value: selectedItem,
|
||||
Insert: true,
|
||||
})
|
||||
case key.Matches(msg, c.keyMap.DownInsert):
|
||||
selectedItemInx := c.list.SelectedIndex() + 1
|
||||
items := c.list.Items()
|
||||
if selectedItemInx == list.NoSelection || selectedItemInx >= len(items) {
|
||||
return c, nil // No item selected, do nothing
|
||||
}
|
||||
selectedItem := items[selectedItemInx].(CompletionItem).Value()
|
||||
c.list.SetSelected(selectedItemInx)
|
||||
return c, util.CmdHandler(SelectCompletionMsg{
|
||||
Value: selectedItem,
|
||||
Insert: true,
|
||||
})
|
||||
case key.Matches(msg, c.keyMap.Select):
|
||||
selectedItemInx := c.list.SelectedIndex()
|
||||
if selectedItemInx == list.NoSelection {
|
||||
|
||||
@@ -9,6 +9,8 @@ type KeyMap struct {
|
||||
Up,
|
||||
Select,
|
||||
Cancel key.Binding
|
||||
DownInsert,
|
||||
UpInsert key.Binding
|
||||
}
|
||||
|
||||
func DefaultKeyMap() KeyMap {
|
||||
@@ -29,6 +31,14 @@ func DefaultKeyMap() KeyMap {
|
||||
key.WithKeys("esc"),
|
||||
key.WithHelp("esc", "cancel"),
|
||||
),
|
||||
DownInsert: key.NewBinding(
|
||||
key.WithKeys("ctrl+n"),
|
||||
key.WithHelp("ctrl+n", "insert next"),
|
||||
),
|
||||
UpInsert: key.NewBinding(
|
||||
key.WithKeys("ctrl+p"),
|
||||
key.WithHelp("ctrl+p", "insert previous"),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package status
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/v2/help"
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/crush/internal/tui/styles"
|
||||
"github.com/charmbracelet/crush/internal/tui/util"
|
||||
"github.com/charmbracelet/lipgloss/v2"
|
||||
"github.com/charmbracelet/x/ansi"
|
||||
)
|
||||
|
||||
@@ -74,18 +72,15 @@ func (m *statusCmp) infoMsg() string {
|
||||
switch m.info.Type {
|
||||
case util.InfoTypeError:
|
||||
infoType = t.S().Base.Background(t.Red).Padding(0, 1).Render("ERROR")
|
||||
width := m.width - lipgloss.Width(infoType)
|
||||
message = t.S().Base.Background(t.Error).Foreground(t.White).Padding(0, 1).Width(width).Render(ansi.Truncate(m.info.Msg, width, "…"))
|
||||
message = t.S().Base.Background(t.Error).Width(m.width).Foreground(t.White).Padding(0, 1).Render(m.info.Msg)
|
||||
case util.InfoTypeWarn:
|
||||
infoType = t.S().Base.Foreground(t.BgOverlay).Background(t.Yellow).Padding(0, 1).Render("WARNING")
|
||||
width := m.width - lipgloss.Width(infoType)
|
||||
message = t.S().Base.Foreground(t.BgOverlay).Background(t.Warning).Padding(0, 1).Width(width).Render(ansi.Truncate(m.info.Msg, width, "…"))
|
||||
message = t.S().Base.Foreground(t.BgOverlay).Width(m.width).Background(t.Warning).Padding(0, 1).Render(m.info.Msg)
|
||||
default:
|
||||
infoType = t.S().Base.Foreground(t.BgOverlay).Background(t.Green).Padding(0, 1).Render("OKAY!")
|
||||
width := m.width - lipgloss.Width(infoType)
|
||||
message = t.S().Base.Background(t.Success).Foreground(t.White).Padding(0, 1).Width(width).Render(ansi.Truncate(m.info.Msg, width, "…"))
|
||||
message = t.S().Base.Background(t.Success).Width(m.width).Foreground(t.White).Padding(0, 1).Render(m.info.Msg)
|
||||
}
|
||||
return strings.Join([]string{infoType, message}, "")
|
||||
return ansi.Truncate(infoType+message, m.width, "…")
|
||||
}
|
||||
|
||||
func (m *statusCmp) ToggleFullHelp() {
|
||||
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"github.com/charmbracelet/bubbles/v2/help"
|
||||
"github.com/charmbracelet/bubbles/v2/key"
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/lipgloss/v2"
|
||||
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/llm/prompt"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/chat"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/completions"
|
||||
@@ -270,7 +270,7 @@ func (c *commandDialogCmp) defaultCommands() []Command {
|
||||
providerCfg := cfg.GetProviderForModel(agentCfg.Model)
|
||||
model := cfg.GetModelByType(agentCfg.Model)
|
||||
if providerCfg != nil && model != nil &&
|
||||
providerCfg.Type == provider.TypeAnthropic && model.CanReason {
|
||||
providerCfg.Type == catwalk.TypeAnthropic && model.CanReason {
|
||||
selectedModel := cfg.Models[agentCfg.Model]
|
||||
status := "Enable"
|
||||
if selectedModel.Think {
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"slices"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/completions"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/core/list"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/dialogs/commands"
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
type ModelListComponent struct {
|
||||
list list.ListModel
|
||||
modelType int
|
||||
providers []provider.Provider
|
||||
providers []catwalk.Provider
|
||||
}
|
||||
|
||||
func NewModelListComponent(keyMap list.KeyMap, inputStyle lipgloss.Style, inputPlaceholder string) *ModelListComponent {
|
||||
@@ -109,19 +109,19 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
|
||||
}
|
||||
|
||||
// Check if this provider is not in the known providers list
|
||||
if !slices.ContainsFunc(knownProviders, func(p provider.Provider) bool { return p.ID == provider.InferenceProvider(providerID) }) {
|
||||
if !slices.ContainsFunc(knownProviders, func(p catwalk.Provider) bool { return p.ID == catwalk.InferenceProvider(providerID) }) {
|
||||
// Convert config provider to provider.Provider format
|
||||
configProvider := provider.Provider{
|
||||
configProvider := catwalk.Provider{
|
||||
Name: providerConfig.Name,
|
||||
ID: provider.InferenceProvider(providerID),
|
||||
Models: make([]provider.Model, len(providerConfig.Models)),
|
||||
ID: catwalk.InferenceProvider(providerID),
|
||||
Models: make([]catwalk.Model, len(providerConfig.Models)),
|
||||
}
|
||||
|
||||
// Convert models
|
||||
for i, model := range providerConfig.Models {
|
||||
configProvider.Models[i] = provider.Model{
|
||||
configProvider.Models[i] = catwalk.Model{
|
||||
ID: model.ID,
|
||||
Model: model.Model,
|
||||
Name: model.Name,
|
||||
CostPer1MIn: model.CostPer1MIn,
|
||||
CostPer1MOut: model.CostPer1MOut,
|
||||
CostPer1MInCached: model.CostPer1MInCached,
|
||||
@@ -144,7 +144,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
|
||||
section.SetInfo(configured)
|
||||
modelItems = append(modelItems, section)
|
||||
for _, model := range configProvider.Models {
|
||||
modelItems = append(modelItems, completions.NewCompletionItem(model.Model, ModelOption{
|
||||
modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
|
||||
Provider: configProvider,
|
||||
Model: model,
|
||||
}))
|
||||
@@ -179,7 +179,7 @@ func (m *ModelListComponent) SetModelType(modelType int) tea.Cmd {
|
||||
}
|
||||
modelItems = append(modelItems, section)
|
||||
for _, model := range provider.Models {
|
||||
modelItems = append(modelItems, completions.NewCompletionItem(model.Model, ModelOption{
|
||||
modelItems = append(modelItems, completions.NewCompletionItem(model.Name, ModelOption{
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
}))
|
||||
@@ -201,6 +201,6 @@ func (m *ModelListComponent) SetInputPlaceholder(placeholder string) {
|
||||
m.list.SetFilterPlaceholder(placeholder)
|
||||
}
|
||||
|
||||
func (m *ModelListComponent) SetProviders(providers []provider.Provider) {
|
||||
func (m *ModelListComponent) SetProviders(providers []catwalk.Provider) {
|
||||
m.providers = providers
|
||||
}
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"github.com/charmbracelet/bubbles/v2/key"
|
||||
"github.com/charmbracelet/bubbles/v2/spinner"
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
"github.com/charmbracelet/catwalk/pkg/catwalk"
|
||||
"github.com/charmbracelet/crush/internal/config"
|
||||
"github.com/charmbracelet/crush/internal/fur/provider"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/completions"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/core"
|
||||
"github.com/charmbracelet/crush/internal/tui/components/core/list"
|
||||
@@ -48,8 +48,8 @@ type ModelDialog interface {
|
||||
}
|
||||
|
||||
type ModelOption struct {
|
||||
Provider provider.Provider
|
||||
Model provider.Model
|
||||
Provider catwalk.Provider
|
||||
Model catwalk.Model
|
||||
}
|
||||
|
||||
type modelDialogCmp struct {
|
||||
@@ -363,7 +363,7 @@ func (m *modelDialogCmp) isProviderConfigured(providerID string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *modelDialogCmp) getProvider(providerID provider.InferenceProvider) (*provider.Provider, error) {
|
||||
func (m *modelDialogCmp) getProvider(providerID catwalk.InferenceProvider) (*catwalk.Provider, error) {
|
||||
providers, err := config.Providers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"fmt"
|
||||
"image/color"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/alecthomas/chroma/v2"
|
||||
"github.com/charmbracelet/crush/internal/ansiext"
|
||||
"github.com/charmbracelet/lipgloss/v2"
|
||||
)
|
||||
|
||||
@@ -20,9 +22,12 @@ type chromaFormatter struct {
|
||||
// Format implements the chroma.Formatter interface.
|
||||
func (c chromaFormatter) Format(w io.Writer, style *chroma.Style, it chroma.Iterator) error {
|
||||
for token := it(); token != chroma.EOF; token = it() {
|
||||
value := strings.TrimRight(token.Value, "\n")
|
||||
value = ansiext.Escape(value)
|
||||
|
||||
entry := style.Get(token.Type)
|
||||
if entry.IsZero() {
|
||||
if _, err := fmt.Fprint(w, token.Value); err != nil {
|
||||
if _, err := fmt.Fprint(w, value); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
@@ -44,7 +49,7 @@ func (c chromaFormatter) Format(w io.Writer, style *chroma.Style, it chroma.Iter
|
||||
s = s.Foreground(lipgloss.Color(entry.Colour.String()))
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprint(w, s.Render(token.Value)); err != nil {
|
||||
if _, err := fmt.Fprint(w, s.Render(value)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,6 +193,7 @@ func (dv *DiffView) clearSyntaxCache() {
|
||||
|
||||
// String returns the string representation of the DiffView.
|
||||
func (dv *DiffView) String() string {
|
||||
dv.normalizeLineEndings()
|
||||
dv.replaceTabs()
|
||||
if err := dv.computeDiff(); err != nil {
|
||||
return err.Error()
|
||||
@@ -227,6 +228,12 @@ func (dv *DiffView) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeLineEndings ensures the file contents use Unix-style line endings.
|
||||
func (dv *DiffView) normalizeLineEndings() {
|
||||
dv.before.content = strings.ReplaceAll(dv.before.content, "\r\n", "\n")
|
||||
dv.after.content = strings.ReplaceAll(dv.after.content, "\r\n", "\n")
|
||||
}
|
||||
|
||||
// replaceTabs replaces tabs in the before and after file contents with spaces
|
||||
// according to the specified tab width.
|
||||
func (dv *DiffView) replaceTabs() {
|
||||
@@ -396,8 +403,7 @@ func (dv *DiffView) renderUnified() string {
|
||||
shouldWrite := func() bool { return printedLines >= 0 }
|
||||
|
||||
getContent := func(in string, ls LineStyle) (content string, leadingEllipsis bool) {
|
||||
content = strings.ReplaceAll(in, "\r\n", "\n")
|
||||
content = strings.TrimSuffix(content, "\n")
|
||||
content = strings.TrimSuffix(in, "\n")
|
||||
content = dv.hightlightCode(content, ls.Code.GetBackground())
|
||||
content = ansi.GraphemeWidth.Cut(content, dv.xOffset, len(content))
|
||||
content = ansi.Truncate(content, dv.codeWidth, "…")
|
||||
@@ -520,8 +526,7 @@ func (dv *DiffView) renderSplit() string {
|
||||
shouldWrite := func() bool { return printedLines >= 0 }
|
||||
|
||||
getContent := func(in string, ls LineStyle) (content string, leadingEllipsis bool) {
|
||||
content = strings.ReplaceAll(in, "\r\n", "\n")
|
||||
content = strings.TrimSuffix(content, "\n")
|
||||
content = strings.TrimSuffix(in, "\n")
|
||||
content = dv.hightlightCode(content, ls.Code.GetBackground())
|
||||
content = ansi.GraphemeWidth.Cut(content, dv.xOffset, len(content))
|
||||
content = ansi.Truncate(content, dv.codeWidth, "…")
|
||||
|
||||
@@ -36,6 +36,12 @@ var TestTabsBefore string
|
||||
//go:embed testdata/TestTabs.after
|
||||
var TestTabsAfter string
|
||||
|
||||
//go:embed testdata/TestLineBreakIssue.before
|
||||
var TestLineBreakIssueBefore string
|
||||
|
||||
//go:embed testdata/TestLineBreakIssue.after
|
||||
var TestLineBreakIssueAfter string
|
||||
|
||||
type (
|
||||
TestFunc func(dv *diffview.DiffView) *diffview.DiffView
|
||||
TestFuncs map[string]TestFunc
|
||||
@@ -177,6 +183,26 @@ func TestDiffViewTabs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiffViewLineBreakIssue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for layoutName, layoutFunc := range LayoutFuncs {
|
||||
t.Run(layoutName, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dv := diffview.New().
|
||||
Before("index.js", TestLineBreakIssueBefore).
|
||||
After("index.js", TestLineBreakIssueAfter).
|
||||
Style(diffview.DefaultLightStyle()).
|
||||
ChromaStyle(styles.Get("catppuccin-latte"))
|
||||
dv = layoutFunc(dv)
|
||||
|
||||
output := dv.String()
|
||||
golden.RequireEqual(t, []byte(output))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiffViewWidth(t *testing.T) {
|
||||
for layoutName, layoutFunc := range LayoutFuncs {
|
||||
t.Run(layoutName, func(t *testing.T) {
|
||||
|
||||
9
internal/tui/exp/diffview/testdata/TestDiffViewLineBreakIssue/Split.golden
generated
vendored
Normal file
9
internal/tui/exp/diffview/testdata/TestDiffViewLineBreakIssue/Split.golden
generated
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
[48;2;71;118;255m [m[38;2;77;76;87;48;2;71;118;255m …[m[48;2;71;118;255m [m[38;2;96;95;107;48;2;113;154;252m @@ -1,6 +1,8 @@ [m[48;2;113;154;252m [m[48;2;71;118;255m [m[38;2;77;76;87;48;2;71;118;255m …[m[48;2;71;118;255m [m[38;2;96;95;107;48;2;113;154;252m [m[48;2;113;154;252m [m
|
||||
[48;2;255;205;210m [m[38;2;255;56;139;48;2;255;205;210m 1[m[48;2;255;205;210m [m[38;2;255;56;139;48;2;255;235;238m- [m[38;2;32;31;38;48;2;255;235;238m[3;38;2;156;160;176;48;2;255;235;238m// this is[m[m[48;2;255;235;238m [m[48;2;200;230;201m [m[38;2;10;220;217;48;2;200;230;201m 1[m[48;2;200;230;201m [m[38;2;10;220;217;48;2;232;245;233m+ [m[38;2;32;31;38;48;2;232;245;233m[38;2;210;15;57;48;2;232;245;233m/**[m[m[48;2;232;245;233m [m
|
||||
[48;2;223;219;221m [m[48;2;223;219;221m [m[48;2;223;219;221m [m[48;2;223;219;221m [m[48;2;223;219;221m [m[48;2;200;230;201m [m[38;2;10;220;217;48;2;200;230;201m 2[m[48;2;200;230;201m [m[38;2;10;220;217;48;2;232;245;233m+ [m[38;2;32;31;38;48;2;232;245;233m[38;2;76;79;105;48;2;232;245;233m [m[1;38;2;4;165;229;48;2;232;245;233m*[m[38;2;76;79;105;48;2;232;245;233m [m[38;2;136;57;239;48;2;232;245;233mthis[m[38;2;76;79;105;48;2;232;245;233m [m[38;2;76;79;105;48;2;232;245;233mis[m[m[48;2;232;245;233m [m
|
||||
[48;2;255;205;210m [m[38;2;255;56;139;48;2;255;205;210m 2[m[48;2;255;205;210m [m[38;2;255;56;139;48;2;255;235;238m- [m[38;2;32;31;38;48;2;255;235;238m[3;38;2;156;160;176;48;2;255;235;238m// a regular[m[m[48;2;255;235;238m [m[48;2;200;230;201m [m[38;2;10;220;217;48;2;200;230;201m 3[m[48;2;200;230;201m [m[38;2;10;220;217;48;2;232;245;233m+ [m[38;2;32;31;38;48;2;232;245;233m[38;2;76;79;105;48;2;232;245;233m [m[1;38;2;4;165;229;48;2;232;245;233m*[m[38;2;76;79;105;48;2;232;245;233m [m[38;2;76;79;105;48;2;232;245;233ma[m[38;2;76;79;105;48;2;232;245;233m [m[38;2;76;79;105;48;2;232;245;233mblock[m[m[48;2;232;245;233m [m
|
||||
[48;2;255;205;210m [m[38;2;255;56;139;48;2;255;205;210m 3[m[48;2;255;205;210m [m[38;2;255;56;139;48;2;255;235;238m- [m[38;2;32;31;38;48;2;255;235;238m[3;38;2;156;160;176;48;2;255;235;238m// comment[m[m[48;2;255;235;238m [m[48;2;200;230;201m [m[38;2;10;220;217;48;2;200;230;201m 4[m[48;2;200;230;201m [m[38;2;10;220;217;48;2;232;245;233m+ [m[38;2;32;31;38;48;2;232;245;233m[38;2;76;79;105;48;2;232;245;233m [m[1;38;2;4;165;229;48;2;232;245;233m*[m[38;2;76;79;105;48;2;232;245;233m [m[38;2;76;79;105;48;2;232;245;233mcomment[m[m[48;2;232;245;233m [m
|
||||
[48;2;223;219;221m [m[48;2;223;219;221m [m[48;2;223;219;221m [m[48;2;223;219;221m [m[48;2;223;219;221m [m[48;2;200;230;201m [m[38;2;10;220;217;48;2;200;230;201m 5[m[48;2;200;230;201m [m[38;2;10;220;217;48;2;232;245;233m+ [m[38;2;32;31;38;48;2;232;245;233m[38;2;76;79;105;48;2;232;245;233m [m[1;38;2;4;165;229;48;2;232;245;233m*[m[38;2;210;15;57;48;2;232;245;233m/[m[m[48;2;232;245;233m [m
|
||||
[48;2;223;219;221m [m[38;2;58;57;67;48;2;223;219;221m 4[m[48;2;223;219;221m [m[38;2;32;31;38;48;2;241;239;239m [38;2;76;79;105;48;2;241;239;239m$[m[38;2;76;79;105;48;2;241;239;239m([m[38;2;210;15;57;48;2;241;239;239mfunction[m[38;2;76;79;105;48;2;241;239;239m()[m[38;2;76;79;105;48;2;241;239;239m [m[38;2;76;79;105;48;2;241;239;239m{[m[m[48;2;241;239;239m [m[48;2;223;219;221m [m[38;2;58;57;67;48;2;223;219;221m 6[m[48;2;223;219;221m [m[38;2;32;31;38;48;2;241;239;239m [38;2;76;79;105;48;2;241;239;239m$[m[38;2;76;79;105;48;2;241;239;239m([m[38;2;210;15;57;48;2;241;239;239mfunction[m[38;2;76;79;105;48;2;241;239;239m()[m[38;2;76;79;105;48;2;241;239;239m [m[38;2;76;79;105;48;2;241;239;239m{[m[m[48;2;241;239;239m [m
|
||||
[48;2;223;219;221m [m[38;2;58;57;67;48;2;223;219;221m 5[m[48;2;223;219;221m [m[38;2;32;31;38;48;2;241;239;239m [38;2;76;79;105;48;2;241;239;239m [m[38;2;76;79;105;48;2;241;239;239mconsole[m[38;2;76;79;105;48;2;241;239;239m.[m[38;2;76;79;105;48;2;241;239;239mlog[m[38;2;76;79;105;48;2;241;239;239m([m[38;2;64;160;43;48;2;241;239;239m"Hello, world!"[m[38;2;76;79;105;48;2;241;239;239m);[m[m[48;2;241;239;239m [m[48;2;223;219;221m [m[38;2;58;57;67;48;2;223;219;221m 7[m[48;2;223;219;221m [m[38;2;32;31;38;48;2;241;239;239m [38;2;76;79;105;48;2;241;239;239m [m[38;2;76;79;105;48;2;241;239;239mconsole[m[38;2;76;79;105;48;2;241;239;239m.[m[38;2;76;79;105;48;2;241;239;239mlog[m[38;2;76;79;105;48;2;241;239;239m([m[38;2;64;160;43;48;2;241;239;239m"Hello, world!"[m[38;2;76;79;105;48;2;241;239;239m);[m[m[48;2;241;239;239m [m
|
||||
[48;2;223;219;221m [m[38;2;58;57;67;48;2;223;219;221m 6[m[48;2;223;219;221m [m[38;2;32;31;38;48;2;241;239;239m [38;2;76;79;105;48;2;241;239;239m});[m[m[48;2;241;239;239m [m[48;2;223;219;221m [m[38;2;58;57;67;48;2;223;219;221m 8[m[48;2;223;219;221m [m[38;2;32;31;38;48;2;241;239;239m [38;2;76;79;105;48;2;241;239;239m});[m[m[48;2;241;239;239m [m
|
||||
12
internal/tui/exp/diffview/testdata/TestDiffViewLineBreakIssue/Unified.golden
generated
vendored
Normal file
12
internal/tui/exp/diffview/testdata/TestDiffViewLineBreakIssue/Unified.golden
generated
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
[48;2;71;118;255m [m[38;2;77;76;87;48;2;71;118;255m …[m[48;2;71;118;255m [m[48;2;71;118;255m [m[38;2;77;76;87;48;2;71;118;255m …[m[48;2;71;118;255m [m[38;2;96;95;107;48;2;113;154;252m @@ -1,6 +1,8 @@ [m[48;2;113;154;252m [m
|
||||
[48;2;255;205;210m [m[38;2;255;56;139;48;2;255;205;210m 1[m[48;2;255;205;210m [m[48;2;255;205;210m [m[38;2;255;56;139;48;2;255;205;210m [m[48;2;255;205;210m [m[38;2;255;56;139;48;2;255;235;238m- [m[38;2;32;31;38;48;2;255;235;238m[3;38;2;156;160;176;48;2;255;235;238m// this is[m[m[48;2;255;235;238m [m
|
||||
[48;2;200;230;201m [m[38;2;10;220;217;48;2;200;230;201m [m[48;2;200;230;201m [m[48;2;200;230;201m [m[38;2;10;220;217;48;2;200;230;201m 1[m[48;2;200;230;201m [m[38;2;10;220;217;48;2;232;245;233m+ [m[38;2;32;31;38;48;2;232;245;233m[38;2;210;15;57;48;2;232;245;233m/**[m[m[48;2;232;245;233m [m
|
||||
[48;2;200;230;201m [m[38;2;10;220;217;48;2;200;230;201m [m[48;2;200;230;201m [m[48;2;200;230;201m [m[38;2;10;220;217;48;2;200;230;201m 2[m[48;2;200;230;201m [m[38;2;10;220;217;48;2;232;245;233m+ [m[38;2;32;31;38;48;2;232;245;233m[38;2;76;79;105;48;2;232;245;233m [m[1;38;2;4;165;229;48;2;232;245;233m*[m[38;2;76;79;105;48;2;232;245;233m [m[38;2;136;57;239;48;2;232;245;233mthis[m[38;2;76;79;105;48;2;232;245;233m [m[38;2;76;79;105;48;2;232;245;233mis[m[m[48;2;232;245;233m [m
|
||||
[48;2;255;205;210m [m[38;2;255;56;139;48;2;255;205;210m 2[m[48;2;255;205;210m [m[48;2;255;205;210m [m[38;2;255;56;139;48;2;255;205;210m [m[48;2;255;205;210m [m[38;2;255;56;139;48;2;255;235;238m- [m[38;2;32;31;38;48;2;255;235;238m[3;38;2;156;160;176;48;2;255;235;238m// a regular[m[m[48;2;255;235;238m [m
|
||||
[48;2;200;230;201m [m[38;2;10;220;217;48;2;200;230;201m [m[48;2;200;230;201m [m[48;2;200;230;201m [m[38;2;10;220;217;48;2;200;230;201m 3[m[48;2;200;230;201m [m[38;2;10;220;217;48;2;232;245;233m+ [m[38;2;32;31;38;48;2;232;245;233m[38;2;76;79;105;48;2;232;245;233m [m[1;38;2;4;165;229;48;2;232;245;233m*[m[38;2;76;79;105;48;2;232;245;233m [m[38;2;76;79;105;48;2;232;245;233ma[m[38;2;76;79;105;48;2;232;245;233m [m[38;2;76;79;105;48;2;232;245;233mblock[m[m[48;2;232;245;233m [m
|
||||
[48;2;255;205;210m [m[38;2;255;56;139;48;2;255;205;210m 3[m[48;2;255;205;210m [m[48;2;255;205;210m [m[38;2;255;56;139;48;2;255;205;210m [m[48;2;255;205;210m [m[38;2;255;56;139;48;2;255;235;238m- [m[38;2;32;31;38;48;2;255;235;238m[3;38;2;156;160;176;48;2;255;235;238m// comment[m[m[48;2;255;235;238m [m
|
||||
[48;2;200;230;201m [m[38;2;10;220;217;48;2;200;230;201m [m[48;2;200;230;201m [m[48;2;200;230;201m [m[38;2;10;220;217;48;2;200;230;201m 4[m[48;2;200;230;201m [m[38;2;10;220;217;48;2;232;245;233m+ [m[38;2;32;31;38;48;2;232;245;233m[38;2;76;79;105;48;2;232;245;233m [m[1;38;2;4;165;229;48;2;232;245;233m*[m[38;2;76;79;105;48;2;232;245;233m [m[38;2;76;79;105;48;2;232;245;233mcomment[m[m[48;2;232;245;233m [m
|
||||
[48;2;200;230;201m [m[38;2;10;220;217;48;2;200;230;201m [m[48;2;200;230;201m [m[48;2;200;230;201m [m[38;2;10;220;217;48;2;200;230;201m 5[m[48;2;200;230;201m [m[38;2;10;220;217;48;2;232;245;233m+ [m[38;2;32;31;38;48;2;232;245;233m[38;2;76;79;105;48;2;232;245;233m [m[1;38;2;4;165;229;48;2;232;245;233m*[m[38;2;210;15;57;48;2;232;245;233m/[m[m[48;2;232;245;233m [m
|
||||
[48;2;223;219;221m [m[38;2;58;57;67;48;2;223;219;221m 4[m[48;2;223;219;221m [m[48;2;223;219;221m [m[38;2;58;57;67;48;2;223;219;221m 6[m[48;2;223;219;221m [m[38;2;32;31;38;48;2;241;239;239m [38;2;76;79;105;48;2;241;239;239m$[m[38;2;76;79;105;48;2;241;239;239m([m[38;2;210;15;57;48;2;241;239;239mfunction[m[38;2;76;79;105;48;2;241;239;239m()[m[38;2;76;79;105;48;2;241;239;239m [m[38;2;76;79;105;48;2;241;239;239m{[m[m[48;2;241;239;239m [m
|
||||
[48;2;223;219;221m [m[38;2;58;57;67;48;2;223;219;221m 5[m[48;2;223;219;221m [m[48;2;223;219;221m [m[38;2;58;57;67;48;2;223;219;221m 7[m[48;2;223;219;221m [m[38;2;32;31;38;48;2;241;239;239m [38;2;76;79;105;48;2;241;239;239m [m[38;2;76;79;105;48;2;241;239;239mconsole[m[38;2;76;79;105;48;2;241;239;239m.[m[38;2;76;79;105;48;2;241;239;239mlog[m[38;2;76;79;105;48;2;241;239;239m([m[38;2;64;160;43;48;2;241;239;239m"Hello, world!"[m[38;2;76;79;105;48;2;241;239;239m);[m[m[48;2;241;239;239m [m
|
||||
[48;2;223;219;221m [m[38;2;58;57;67;48;2;223;219;221m 6[m[48;2;223;219;221m [m[48;2;223;219;221m [m[38;2;58;57;67;48;2;223;219;221m 8[m[48;2;223;219;221m [m[38;2;32;31;38;48;2;241;239;239m [38;2;76;79;105;48;2;241;239;239m});[m[m[48;2;241;239;239m [m
|
||||
8
internal/tui/exp/diffview/testdata/TestLineBreakIssue.after
vendored
Normal file
8
internal/tui/exp/diffview/testdata/TestLineBreakIssue.after
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
/**
|
||||
* this is
|
||||
* a block
|
||||
* comment
|
||||
*/
|
||||
$(function() {
|
||||
console.log("Hello, world!");
|
||||
});
|
||||
6
internal/tui/exp/diffview/testdata/TestLineBreakIssue.before
vendored
Normal file
6
internal/tui/exp/diffview/testdata/TestLineBreakIssue.before
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
// this is
|
||||
// a regular
|
||||
// comment
|
||||
$(function() {
|
||||
console.log("Hello, world!");
|
||||
});
|
||||
@@ -279,7 +279,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if model.SupportsImages {
|
||||
return p, util.CmdHandler(OpenFilePickerMsg{})
|
||||
} else {
|
||||
return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Model)
|
||||
return p, util.ReportWarn("File attachments are not supported by the current model: " + model.Name)
|
||||
}
|
||||
case key.Matches(msg, p.keyMap.Tab):
|
||||
if p.session.ID == "" {
|
||||
|
||||
@@ -3,6 +3,7 @@ package tui
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/v2/key"
|
||||
tea "github.com/charmbracelet/bubbletea/v2"
|
||||
@@ -112,6 +113,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
return a, tea.Batch(cmds...)
|
||||
case tea.WindowSizeMsg:
|
||||
a.wWidth, a.wHeight = msg.Width, msg.Height
|
||||
a.completions.Update(msg)
|
||||
return a, a.handleWindowResize(msg.Width, msg.Height)
|
||||
|
||||
@@ -290,7 +292,6 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// handleWindowResize processes window resize events and updates all components.
|
||||
func (a *appModel) handleWindowResize(width, height int) tea.Cmd {
|
||||
var cmds []tea.Cmd
|
||||
a.wWidth, a.wHeight = width, height
|
||||
if a.showingFullHelp {
|
||||
height -= 5
|
||||
} else {
|
||||
@@ -319,26 +320,20 @@ func (a *appModel) handleWindowResize(width, height int) tea.Cmd {
|
||||
|
||||
// handleKeyPressMsg processes keyboard input and routes to appropriate handlers.
|
||||
func (a *appModel) handleKeyPressMsg(msg tea.KeyPressMsg) tea.Cmd {
|
||||
if a.completions.Open() {
|
||||
// completions
|
||||
keyMap := a.completions.KeyMap()
|
||||
switch {
|
||||
case key.Matches(msg, keyMap.Up), key.Matches(msg, keyMap.Down),
|
||||
key.Matches(msg, keyMap.Select), key.Matches(msg, keyMap.Cancel),
|
||||
key.Matches(msg, keyMap.UpInsert), key.Matches(msg, keyMap.DownInsert):
|
||||
u, cmd := a.completions.Update(msg)
|
||||
a.completions = u.(completions.Completions)
|
||||
return cmd
|
||||
}
|
||||
}
|
||||
switch {
|
||||
// completions
|
||||
case a.completions.Open() && key.Matches(msg, a.completions.KeyMap().Up):
|
||||
u, cmd := a.completions.Update(msg)
|
||||
a.completions = u.(completions.Completions)
|
||||
return cmd
|
||||
|
||||
case a.completions.Open() && key.Matches(msg, a.completions.KeyMap().Down):
|
||||
u, cmd := a.completions.Update(msg)
|
||||
a.completions = u.(completions.Completions)
|
||||
return cmd
|
||||
case a.completions.Open() && key.Matches(msg, a.completions.KeyMap().Select):
|
||||
u, cmd := a.completions.Update(msg)
|
||||
a.completions = u.(completions.Completions)
|
||||
return cmd
|
||||
case a.completions.Open() && key.Matches(msg, a.completions.KeyMap().Cancel):
|
||||
u, cmd := a.completions.Update(msg)
|
||||
a.completions = u.(completions.Completions)
|
||||
return cmd
|
||||
// help
|
||||
// help
|
||||
case key.Matches(msg, a.keyMap.Help):
|
||||
a.status.ToggleFullHelp()
|
||||
a.showingFullHelp = !a.showingFullHelp
|
||||
@@ -429,6 +424,27 @@ func (a *appModel) moveToPage(pageID page.PageID) tea.Cmd {
|
||||
|
||||
// View renders the complete application interface including pages, dialogs, and overlays.
|
||||
func (a *appModel) View() tea.View {
|
||||
var view tea.View
|
||||
t := styles.CurrentTheme()
|
||||
view.BackgroundColor = t.BgBase
|
||||
if a.wWidth < 25 || a.wHeight < 15 {
|
||||
view.Layer = lipgloss.NewCanvas(
|
||||
lipgloss.NewLayer(
|
||||
t.S().Base.Width(a.wWidth).Height(a.wHeight).
|
||||
Align(lipgloss.Center, lipgloss.Center).
|
||||
Render(
|
||||
t.S().Base.
|
||||
Padding(1, 4).
|
||||
Foreground(t.White).
|
||||
BorderStyle(lipgloss.RoundedBorder()).
|
||||
BorderForeground(t.Primary).
|
||||
Render("Window too small!"),
|
||||
),
|
||||
),
|
||||
)
|
||||
return view
|
||||
}
|
||||
|
||||
page := a.pages[a.currentPage]
|
||||
if withHelp, ok := page.(core.KeyMapHelp); ok {
|
||||
a.status.SetKeyMap(withHelp.Help())
|
||||
@@ -453,6 +469,11 @@ func (a *appModel) View() tea.View {
|
||||
var cursor *tea.Cursor
|
||||
if v, ok := page.(util.Cursor); ok {
|
||||
cursor = v.Cursor()
|
||||
// Hide the cursor if it's positioned outside the textarea
|
||||
statusHeight := a.height - strings.Count(pageView, "\n") + 1
|
||||
if cursor != nil && cursor.Y+statusHeight+chat.EditorHeight-2 <= a.height { // 2 for the top and bottom app padding
|
||||
cursor = nil
|
||||
}
|
||||
}
|
||||
activeView := a.dialog.ActiveModel()
|
||||
if activeView != nil {
|
||||
@@ -475,10 +496,7 @@ func (a *appModel) View() tea.View {
|
||||
layers...,
|
||||
)
|
||||
|
||||
var view tea.View
|
||||
t := styles.CurrentTheme()
|
||||
view.Layer = canvas
|
||||
view.BackgroundColor = t.BgBase
|
||||
view.Cursor = cursor
|
||||
return view
|
||||
}
|
||||
|
||||
22
main.go
22
main.go
@@ -1,12 +1,9 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
_ "net/http/pprof" // profiling
|
||||
|
||||
@@ -14,14 +11,9 @@ import (
|
||||
|
||||
"github.com/charmbracelet/crush/internal/cmd"
|
||||
"github.com/charmbracelet/crush/internal/log"
|
||||
"github.com/charmbracelet/lipgloss/v2"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if runtime.GOOS == "windows" {
|
||||
showWindowsWarning()
|
||||
}
|
||||
|
||||
defer log.RecoverPanic("main", func() {
|
||||
slog.Error("Application terminated due to unhandled panic")
|
||||
})
|
||||
@@ -30,22 +22,10 @@ func main() {
|
||||
go func() {
|
||||
slog.Info("Serving pprof at localhost:6060")
|
||||
if httpErr := http.ListenAndServe("localhost:6060", nil); httpErr != nil {
|
||||
slog.Error(fmt.Sprintf("Failed to pprof listen: %v", httpErr))
|
||||
slog.Error("Failed to pprof listen", "error", httpErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
cmd.Execute()
|
||||
}
|
||||
|
||||
func showWindowsWarning() {
|
||||
content := strings.Join([]string{
|
||||
lipgloss.NewStyle().Bold(true).Render("WARNING:") + " Crush is experimental on Windows!",
|
||||
"While we work on it, we recommend WSL2 for a better experience.",
|
||||
lipgloss.NewStyle().Italic(true).Render("Press Enter to continue..."),
|
||||
}, "\n")
|
||||
fmt.Print(content)
|
||||
|
||||
var input string
|
||||
fmt.Scanln(&input)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user