Skip to content

Commit 1b46223

Browse files
authored
Fix issue on context id generation (#45)
* fix issue on context id generation * fix comment * optimize code
1 parent 3bb198d commit 1b46223

File tree

9 files changed

+352
-80
lines changed

9 files changed

+352
-80
lines changed

examples/go.mod

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ require (
99
github.com/lestrrat-go/jwx/v2 v2.1.4
1010
github.com/redis/go-redis/v9 v9.10.0
1111
golang.org/x/oauth2 v0.29.0
12-
trpc.group/trpc-go/trpc-a2a-go v0.0.0
12+
trpc.group/trpc-go/trpc-a2a-go v0.0.3
13+
trpc.group/trpc-go/trpc-a2a-go/taskmanager/redis v0.0.0-20250625115112-3bb198d0dc98
1314
)
1415

1516
require (
@@ -31,3 +32,5 @@ require (
3132
)
3233

3334
replace trpc.group/trpc-go/trpc-a2a-go => ../
35+
36+
replace trpc.group/trpc-go/trpc-a2a-go/taskmanager/redis => ../taskmanager/redis

examples/simple/client/main.go

Lines changed: 143 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@ package main
1010
import (
1111
"context"
1212
"flag"
13-
"log"
13+
"fmt"
1414
"time"
1515

16-
"github.com/google/uuid"
17-
1816
"trpc.group/trpc-go/trpc-a2a-go/client"
17+
"trpc.group/trpc-go/trpc-a2a-go/log"
1918
"trpc.group/trpc-go/trpc-a2a-go/protocol"
2019
)
2120

@@ -24,6 +23,7 @@ func main() {
2423
agentURL := flag.String("agent", "http://localhost:8080/", "Target A2A agent URL")
2524
timeout := flag.Duration("timeout", 30*time.Second, "Request timeout (e.g., 30s, 1m)")
2625
message := flag.String("message", "Hello, world!", "Message to send to the agent")
26+
streaming := flag.Bool("streaming", false, "Use streaming mode (message/stream)")
2727
flag.Parse()
2828

2929
// Create A2A client.
@@ -33,56 +33,105 @@ func main() {
3333
}
3434

3535
// Display connection information.
36-
log.Printf("Connecting to agent: %s (Timeout: %v)", *agentURL, *timeout)
37-
38-
// Create a new unique message ID and context ID.
39-
contextID := uuid.New().String()
40-
log.Printf("Context ID: %s", contextID)
36+
log.Infof("Connecting to agent: %s (Timeout: %v)", *agentURL, *timeout)
37+
log.Infof("Mode: %s", map[bool]string{true: "Streaming", false: "Standard"}[*streaming])
4138

4239
// Create the message to send using the new constructor.
43-
userMessage := protocol.NewMessageWithContext(
40+
userMessage := protocol.NewMessage(
4441
protocol.MessageRoleUser,
4542
[]protocol.Part{protocol.NewTextPart(*message)},
46-
nil, // taskID
47-
&contextID,
4843
)
4944

5045
// Create message parameters using the new SendMessageParams structure.
5146
params := protocol.SendMessageParams{
5247
Message: userMessage,
5348
Configuration: &protocol.SendMessageConfiguration{
54-
Blocking: boolPtr(true), // Wait for completion
49+
Blocking: boolPtr(false), // Non-blocking for streaming, blocking for standard
5550
},
5651
}
5752

58-
log.Printf("Sending message with content: %s", *message)
53+
log.Infof("Sending message with content: %s", *message)
5954

60-
// Send message to the agent using the new message API.
55+
// Create context for the request.
6156
ctx, cancel := context.WithTimeout(context.Background(), *timeout)
6257
defer cancel()
6358

59+
if *streaming {
60+
// Use streaming mode
61+
handleStreamingMode(ctx, a2aClient, params)
62+
} else {
63+
// Use standard mode
64+
handleStandardMode(ctx, a2aClient, params)
65+
}
66+
}
67+
68+
// handleStreamingMode handles streaming message sending and event processing
69+
func handleStreamingMode(ctx context.Context, a2aClient *client.A2AClient, params protocol.SendMessageParams) {
70+
log.Infof("Starting streaming request...")
71+
72+
// Send streaming message request
73+
eventChan, err := a2aClient.StreamMessage(ctx, params)
74+
if err != nil {
75+
log.Fatalf("Failed to start streaming: %v", err)
76+
}
77+
78+
log.Infof("Processing streaming events...")
79+
80+
eventCount := 0
81+
var finalResult string
82+
83+
// Process streaming events
84+
for {
85+
select {
86+
case event, ok := <-eventChan:
87+
if !ok {
88+
log.Infof("Stream completed. Total events received: %d", eventCount)
89+
if finalResult != "" {
90+
log.Infof("Final result: %s", finalResult)
91+
}
92+
return
93+
}
94+
95+
eventCount++
96+
log.Infof("Event %d received: %s", eventCount, getEventDescription(event))
97+
98+
// Extract final result from completed events
99+
if result := extractFinalResult(event); result != "" {
100+
finalResult = result
101+
log.Infof("Received msg: [Text: %s]", result)
102+
}
103+
104+
case <-ctx.Done():
105+
log.Infof("Request timed out after receiving %d events", eventCount)
106+
return
107+
}
108+
}
109+
}
110+
111+
// handleStandardMode handles standard (non-streaming) message sending
112+
func handleStandardMode(ctx context.Context, a2aClient *client.A2AClient, params protocol.SendMessageParams) {
64113
messageResult, err := a2aClient.SendMessage(ctx, params)
65114
if err != nil {
66115
log.Fatalf("Failed to send message: %v", err)
67116
}
68117

69118
// Display the result.
70-
log.Printf("Message sent successfully")
119+
log.Infof("Message sent successfully")
71120

72121
// Handle the result based on its type
73122
switch result := messageResult.Result.(type) {
74123
case *protocol.Message:
75-
log.Printf("Received message response:")
124+
log.Infof("Received message response:")
76125
printMessage(*result)
77126
case *protocol.Task:
78-
log.Printf("Received task response - ID: %s, State: %s", result.ID, result.Status.State)
127+
log.Infof("Received task response - ID: %s, State: %s", result.ID, result.Status.State)
79128

80129
// If task is not completed, wait and check again
81130
if result.Status.State != protocol.TaskStateCompleted &&
82131
result.Status.State != protocol.TaskStateFailed &&
83132
result.Status.State != protocol.TaskStateCanceled {
84133

85-
log.Printf("Task %s is %s, fetching final state...", result.ID, result.Status.State)
134+
log.Infof("Task %s is %s, fetching final state...", result.ID, result.Status.State)
86135

87136
// Get the task's final state.
88137
queryParams := protocol.TaskQueryParams{
@@ -97,65 +146,124 @@ func main() {
97146
log.Fatalf("Failed to get task status: %v", err)
98147
}
99148

100-
log.Printf("Task %s final state: %s", task.ID, task.Status.State)
149+
log.Infof("Task %s final state: %s", task.ID, task.Status.State)
101150
printTaskResult(task)
102151
} else {
103152
printTaskResult(result)
104153
}
105154
default:
106-
log.Printf("Received unknown result type: %T", result)
155+
log.Infof("Received unknown result type: %T", result)
156+
}
157+
}
158+
159+
// getEventDescription returns a human-readable description of the streaming event
160+
func getEventDescription(event protocol.StreamingMessageEvent) string {
161+
switch result := event.Result.(type) {
162+
case *protocol.Message:
163+
ctxID := "unknown"
164+
if result.ContextID != nil {
165+
ctxID = *result.ContextID
166+
}
167+
return fmt.Sprintf("Message from %s, ContextID: %v", result.Role, ctxID)
168+
case *protocol.Task:
169+
return fmt.Sprintf("Task %s - State: %s, ContextID: %v", result.ID, result.Status.State, result.ContextID)
170+
case *protocol.TaskStatusUpdateEvent:
171+
return fmt.Sprintf(
172+
"Status Update - Task: %s, State: %s, ContextID: %v",
173+
result.TaskID,
174+
result.Status.State,
175+
result.ContextID,
176+
)
177+
case *protocol.TaskArtifactUpdateEvent:
178+
artifactName := "Unnamed"
179+
if result.Artifact.Name != nil {
180+
artifactName = *result.Artifact.Name
181+
}
182+
return fmt.Sprintf("Artifact Update - %s, ContextID: %v", artifactName, result.ContextID)
183+
default:
184+
return fmt.Sprintf("Unknown event type: %T", result)
185+
}
186+
}
187+
188+
// extractFinalResult extracts the final text result from streaming events
189+
func extractFinalResult(event protocol.StreamingMessageEvent) string {
190+
switch result := event.Result.(type) {
191+
case *protocol.Message:
192+
// Extract text from message parts
193+
for _, part := range result.Parts {
194+
if textPart, ok := part.(*protocol.TextPart); ok {
195+
return textPart.Text
196+
}
197+
}
198+
case *protocol.Task:
199+
// Extract text from task status message
200+
if result.Status.Message != nil {
201+
for _, part := range result.Status.Message.Parts {
202+
if textPart, ok := part.(*protocol.TextPart); ok {
203+
return textPart.Text
204+
}
205+
}
206+
}
207+
case *protocol.TaskArtifactUpdateEvent:
208+
// Extract text from artifact parts
209+
for _, part := range result.Artifact.Parts {
210+
if textPart, ok := part.(*protocol.TextPart); ok {
211+
return textPart.Text
212+
}
213+
}
107214
}
215+
return ""
108216
}
109217

110218
// printMessage prints the contents of a message.
111219
func printMessage(message protocol.Message) {
112-
log.Printf("Message ID: %s", message.MessageID)
220+
log.Infof("Message ID: %s", message.MessageID)
113221
if message.ContextID != nil {
114-
log.Printf("Context ID: %s", *message.ContextID)
222+
log.Infof("Context ID: %s", *message.ContextID)
115223
}
116-
log.Printf("Role: %s", message.Role)
224+
log.Infof("Role: %s", message.Role)
117225

118-
log.Printf("Message parts:")
226+
log.Infof("Message parts:")
119227
for i, part := range message.Parts {
120228
switch p := part.(type) {
121229
case *protocol.TextPart:
122-
log.Printf(" Part %d (text): %s", i+1, p.Text)
230+
log.Infof(" Part %d (text): %s", i+1, p.Text)
123231
case *protocol.FilePart:
124-
log.Printf(" Part %d (file): [file content]", i+1)
232+
log.Infof(" Part %d (file): [file content]", i+1)
125233
case *protocol.DataPart:
126-
log.Printf(" Part %d (data): %+v", i+1, p.Data)
234+
log.Infof(" Part %d (data): %+v", i+1, p.Data)
127235
default:
128-
log.Printf(" Part %d (unknown): %+v", i+1, part)
236+
log.Infof(" Part %d (unknown): %+v", i+1, part)
129237
}
130238
}
131239
}
132240

133241
// printTaskResult prints the contents of a task result.
134242
func printTaskResult(task *protocol.Task) {
135243
if task.Status.Message != nil {
136-
log.Printf("Task result message:")
244+
log.Infof("Task result message:")
137245
printMessage(*task.Status.Message)
138246
}
139247

140248
// Print artifacts if any
141249
if len(task.Artifacts) > 0 {
142-
log.Printf("Task artifacts:")
250+
log.Infof("Task artifacts:")
143251
for i, artifact := range task.Artifacts {
144252
name := "Unnamed"
145253
if artifact.Name != nil {
146254
name = *artifact.Name
147255
}
148-
log.Printf(" Artifact %d: %s", i+1, name)
256+
log.Infof(" Artifact %d: %s", i+1, name)
149257
for j, part := range artifact.Parts {
150258
switch p := part.(type) {
151259
case *protocol.TextPart:
152-
log.Printf(" Part %d (text): %s", j+1, p.Text)
260+
log.Infof(" Part %d (text): %s", j+1, p.Text)
153261
case *protocol.FilePart:
154-
log.Printf(" Part %d (file): [file content]", j+1)
262+
log.Infof(" Part %d (file): [file content]", j+1)
155263
case *protocol.DataPart:
156-
log.Printf(" Part %d (data): %+v", j+1, p.Data)
264+
log.Infof(" Part %d (data): %+v", j+1, p.Data)
157265
default:
158-
log.Printf(" Part %d (unknown): %+v", j+1, part)
266+
log.Infof(" Part %d (unknown): %+v", j+1, part)
159267
}
160268
}
161269
}

0 commit comments

Comments
 (0)