Skip to content

Commit 1075526

Browse files
committed
set WriteBackupFileName and InputFileName more consistently in the tool calls by passing the toolusedata struct
1 parent ed75e09 commit 1075526

File tree

12 files changed

+77
-65
lines changed

12 files changed

+77
-65
lines changed

pkg/aiusechat/tools.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ func GetAdderToolDefinition() uctypes.ToolDefinition {
271271
"required": []string{"values"},
272272
"additionalProperties": false,
273273
},
274-
ToolAnyCallback: func(input any) (any, error) {
274+
ToolAnyCallback: func(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
275275
inputMap, ok := input.(map[string]any)
276276
if !ok {
277277
return nil, fmt.Errorf("invalid input format")

pkg/aiusechat/tools_builder.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func GetBuilderWriteAppFileToolDefinition(appId string) uctypes.ToolDefinition {
5858
ToolInputDesc: func(input any) string {
5959
return fmt.Sprintf("writing app.go for %s", appId)
6060
},
61-
ToolAnyCallback: func(input any) (any, error) {
61+
ToolAnyCallback: func(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
6262
params, err := parseBuilderWriteAppFileInput(input)
6363
if err != nil {
6464
return nil, err
@@ -149,7 +149,7 @@ func GetBuilderEditAppFileToolDefinition(appId string) uctypes.ToolDefinition {
149149
}
150150
return fmt.Sprintf("editing app.go for %s (%d edits)", appId, len(params.Edits))
151151
},
152-
ToolAnyCallback: func(input any) (any, error) {
152+
ToolAnyCallback: func(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
153153
params, err := parseBuilderEditAppFileInput(input)
154154
if err != nil {
155155
return nil, err
@@ -188,7 +188,7 @@ func GetBuilderListFilesToolDefinition(appId string) uctypes.ToolDefinition {
188188
ToolInputDesc: func(input any) string {
189189
return fmt.Sprintf("listing files for %s", appId)
190190
},
191-
ToolAnyCallback: func(input any) (any, error) {
191+
ToolAnyCallback: func(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
192192
result, err := waveappstore.ListAllAppFiles(appId)
193193
if err != nil {
194194
return nil, err

pkg/aiusechat/tools_readdir.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func parseReadDirInput(input any) (*readDirParams, error) {
5050
return result, nil
5151
}
5252

53-
func readDirCallback(input any) (any, error) {
53+
func readDirCallback(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
5454
params, err := parseReadDirInput(input)
5555
if err != nil {
5656
return nil, err

pkg/aiusechat/tools_readdir_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"path/filepath"
1010
"strings"
1111
"testing"
12+
13+
"github.com/wavetermdev/waveterm/pkg/aiusechat/uctypes"
1214
)
1315

1416
func TestReadDirCallback(t *testing.T) {
@@ -39,7 +41,7 @@ func TestReadDirCallback(t *testing.T) {
3941
"path": tmpDir,
4042
}
4143

42-
result, err := readDirCallback(input)
44+
result, err := readDirCallback(input, &uctypes.UIMessageDataToolUse{})
4345
if err != nil {
4446
t.Fatalf("readDirCallback failed: %v", err)
4547
}
@@ -100,7 +102,7 @@ func TestReadDirOnFile(t *testing.T) {
100102
"path": tmpFile.Name(),
101103
}
102104

103-
_, err = readDirCallback(input)
105+
_, err = readDirCallback(input, &uctypes.UIMessageDataToolUse{})
104106
if err == nil {
105107
t.Fatalf("Expected error when reading a file with read_dir, got nil")
106108
}
@@ -134,7 +136,7 @@ func TestReadDirMaxEntries(t *testing.T) {
134136
"max_entries": maxEntries,
135137
}
136138

137-
result, err := readDirCallback(input)
139+
result, err := readDirCallback(input, &uctypes.UIMessageDataToolUse{})
138140
if err != nil {
139141
t.Fatalf("readDirCallback failed: %v", err)
140142
}
@@ -200,7 +202,7 @@ func TestReadDirSortBeforeTruncate(t *testing.T) {
200202
"max_entries": maxEntries,
201203
}
202204

203-
result, err := readDirCallback(input)
205+
result, err := readDirCallback(input, &uctypes.UIMessageDataToolUse{})
204206
if err != nil {
205207
t.Fatalf("readDirCallback failed: %v", err)
206208
}

pkg/aiusechat/tools_readfile.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ func isBlockedFile(expandedPath string) (bool, string) {
198198
}
199199

200200

201-
func readTextFileCallback(input any) (any, error) {
201+
func readTextFileCallback(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
202202
const ReadLimit = 1024 * 1024 * 1024
203203

204204
params, err := parseReadTextFileInput(input)

pkg/aiusechat/tools_term.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func GetTermGetScrollbackToolDefinition(tabId string) uctypes.ToolDefinition {
190190
lineEnd := parsed.LineStart + parsed.Count
191191
return fmt.Sprintf("reading terminal output from %s (lines %d-%d)", parsed.WidgetId, parsed.LineStart, lineEnd)
192192
},
193-
ToolAnyCallback: func(input any) (any, error) {
193+
ToolAnyCallback: func(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
194194
parsed, err := parseTermGetScrollbackInput(input)
195195
if err != nil {
196196
return nil, err
@@ -266,7 +266,7 @@ func GetTermCommandOutputToolDefinition(tabId string) uctypes.ToolDefinition {
266266
}
267267
return fmt.Sprintf("reading last command output from %s", parsed.WidgetId)
268268
},
269-
ToolAnyCallback: func(input any) (any, error) {
269+
ToolAnyCallback: func(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
270270
parsed, err := parseTermCommandOutputInput(input)
271271
if err != nil {
272272
return nil, err

pkg/aiusechat/tools_tsunami.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ func handleTsunamiBlockDesc(block *waveobj.Block) string {
3131
return "tsunami widget - unknown description"
3232
}
3333

34-
func makeTsunamiGetCallback(status *blockcontroller.BlockControllerRuntimeStatus, apiPath string) func(any) (any, error) {
35-
return func(input any) (any, error) {
34+
func makeTsunamiGetCallback(status *blockcontroller.BlockControllerRuntimeStatus, apiPath string) func(any, *uctypes.UIMessageDataToolUse) (any, error) {
35+
return func(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
3636
if status.TsunamiPort == 0 {
3737
return nil, fmt.Errorf("tsunami port not available")
3838
}
@@ -66,8 +66,8 @@ func makeTsunamiGetCallback(status *blockcontroller.BlockControllerRuntimeStatus
6666
}
6767
}
6868

69-
func makeTsunamiPostCallback(status *blockcontroller.BlockControllerRuntimeStatus, apiPath string) func(any) (any, error) {
70-
return func(input any) (any, error) {
69+
func makeTsunamiPostCallback(status *blockcontroller.BlockControllerRuntimeStatus, apiPath string) func(any, *uctypes.UIMessageDataToolUse) (any, error) {
70+
return func(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
7171
if status.TsunamiPort == 0 {
7272
return nil, fmt.Errorf("tsunami port not available")
7373
}

pkg/aiusechat/tools_web.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func GetWebNavigateToolDefinition(tabId string) uctypes.ToolDefinition {
7777
}
7878
return fmt.Sprintf("navigating web widget %s to %q", parsed.WidgetId, parsed.Url)
7979
},
80-
ToolAnyCallback: func(input any) (any, error) {
80+
ToolAnyCallback: func(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
8181
parsed, err := parseWebNavigateInput(input)
8282
if err != nil {
8383
return nil, err

pkg/aiusechat/tools_writefile.go

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ func parseWriteTextFileInput(input any) (*writeTextFileParams, error) {
101101
return result, nil
102102
}
103103

104-
func verifyWriteTextFileInput(input any) error {
104+
func verifyWriteTextFileInput(input any, toolUseData *uctypes.UIMessageDataToolUse) error {
105105
params, err := parseWriteTextFileInput(input)
106106
if err != nil {
107107
return err
@@ -118,10 +118,15 @@ func verifyWriteTextFileInput(input any) error {
118118
}
119119

120120
_, err = validateTextFile(expandedPath, "write to", false)
121-
return err
121+
if err != nil {
122+
return err
123+
}
124+
125+
toolUseData.InputFileName = params.Filename
126+
return nil
122127
}
123128

124-
func writeTextFileCallback(input any) (any, error) {
129+
func writeTextFileCallback(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
125130
params, err := parseWriteTextFileInput(input)
126131
if err != nil {
127132
return nil, err
@@ -149,10 +154,11 @@ func writeTextFileCallback(input any) (any, error) {
149154
}
150155

151156
if fileInfo != nil {
152-
err = filebackup.MakeFileBackup(expandedPath)
157+
backupPath, err := filebackup.MakeFileBackup(expandedPath)
153158
if err != nil {
154159
return nil, fmt.Errorf("failed to create backup: %w", err)
155160
}
161+
toolUseData.WriteBackupFileName = backupPath
156162
}
157163

158164
err = os.WriteFile(expandedPath, contentsBytes, 0644)
@@ -230,7 +236,7 @@ func parseEditTextFileInput(input any) (*editTextFileParams, error) {
230236
return result, nil
231237
}
232238

233-
func verifyEditTextFileInput(input any) error {
239+
func verifyEditTextFileInput(input any, toolUseData *uctypes.UIMessageDataToolUse) error {
234240
params, err := parseEditTextFileInput(input)
235241
if err != nil {
236242
return err
@@ -242,7 +248,12 @@ func verifyEditTextFileInput(input any) error {
242248
}
243249

244250
_, err = validateTextFile(expandedPath, "edit", true)
245-
return err
251+
if err != nil {
252+
return err
253+
}
254+
255+
toolUseData.InputFileName = params.Filename
256+
return nil
246257
}
247258

248259
// EditTextFileDryRun applies edits to a file and returns the original and modified content
@@ -281,7 +292,7 @@ func EditTextFileDryRun(input any, fileOverride string) ([]byte, []byte, error)
281292
return originalContent, modifiedContent, nil
282293
}
283294

284-
func editTextFileCallback(input any) (any, error) {
295+
func editTextFileCallback(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
285296
params, err := parseEditTextFileInput(input)
286297
if err != nil {
287298
return nil, err
@@ -297,10 +308,11 @@ func editTextFileCallback(input any) (any, error) {
297308
return nil, err
298309
}
299310

300-
err = filebackup.MakeFileBackup(expandedPath)
311+
backupPath, err := filebackup.MakeFileBackup(expandedPath)
301312
if err != nil {
302313
return nil, fmt.Errorf("failed to create backup: %w", err)
303314
}
315+
toolUseData.WriteBackupFileName = backupPath
304316

305317
err = fileutil.ReplaceInFile(expandedPath, params.Edits)
306318
if err != nil {
@@ -399,7 +411,7 @@ func parseDeleteTextFileInput(input any) (*deleteTextFileParams, error) {
399411
return result, nil
400412
}
401413

402-
func verifyDeleteTextFileInput(input any) error {
414+
func verifyDeleteTextFileInput(input any, toolUseData *uctypes.UIMessageDataToolUse) error {
403415
params, err := parseDeleteTextFileInput(input)
404416
if err != nil {
405417
return err
@@ -411,10 +423,15 @@ func verifyDeleteTextFileInput(input any) error {
411423
}
412424

413425
_, err = validateTextFile(expandedPath, "delete", true)
414-
return err
426+
if err != nil {
427+
return err
428+
}
429+
430+
toolUseData.InputFileName = params.Filename
431+
return nil
415432
}
416433

417-
func deleteTextFileCallback(input any) (any, error) {
434+
func deleteTextFileCallback(input any, toolUseData *uctypes.UIMessageDataToolUse) (any, error) {
418435
params, err := parseDeleteTextFileInput(input)
419436
if err != nil {
420437
return nil, err
@@ -430,10 +447,11 @@ func deleteTextFileCallback(input any) (any, error) {
430447
return nil, err
431448
}
432449

433-
err = filebackup.MakeFileBackup(expandedPath)
450+
backupPath, err := filebackup.MakeFileBackup(expandedPath)
434451
if err != nil {
435452
return nil, fmt.Errorf("failed to create backup: %w", err)
436453
}
454+
toolUseData.WriteBackupFileName = backupPath
437455

438456
err = os.Remove(expandedPath)
439457
if err != nil {

pkg/aiusechat/uctypes/usechat-types.go

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,19 @@ type UIMessageDataUserFile struct {
7676

7777
// ToolDefinition represents a tool that can be used by the AI model
7878
type ToolDefinition struct {
79-
Name string `json:"name"`
80-
DisplayName string `json:"displayname,omitempty"` // internal field (cannot marshal to API, must be stripped)
81-
Description string `json:"description"`
82-
ShortDescription string `json:"shortdescription,omitempty"` // internal field (cannot marshal to API, must be stripped)
83-
ToolLogName string `json:"-"` // short name for telemetry (e.g., "term:getscrollback")
84-
InputSchema map[string]any `json:"input_schema"`
85-
Strict bool `json:"strict,omitempty"`
86-
ToolTextCallback func(any) (string, error) `json:"-"`
87-
ToolAnyCallback func(any) (any, error) `json:"-"`
88-
ToolInputDesc func(any) string `json:"-"`
89-
ToolApproval func(any) string `json:"-"`
90-
ToolVerifyInput func(any) error `json:"-"`
79+
Name string `json:"name"`
80+
DisplayName string `json:"displayname,omitempty"` // internal field (cannot marshal to API, must be stripped)
81+
Description string `json:"description"`
82+
ShortDescription string `json:"shortdescription,omitempty"` // internal field (cannot marshal to API, must be stripped)
83+
ToolLogName string `json:"-"` // short name for telemetry (e.g., "term:getscrollback")
84+
InputSchema map[string]any `json:"input_schema"`
85+
Strict bool `json:"strict,omitempty"`
86+
87+
ToolTextCallback func(any) (string, error) `json:"-"`
88+
ToolAnyCallback func(any, *UIMessageDataToolUse) (any, error) `json:"-"`
89+
ToolInputDesc func(any) string `json:"-"`
90+
ToolApproval func(any) string `json:"-"`
91+
ToolVerifyInput func(any, *UIMessageDataToolUse) error `json:"-"`
9192
}
9293

9394
func (td *ToolDefinition) Clean() *ToolDefinition {

0 commit comments

Comments
 (0)