Skip to content

Commit 5da182f

Browse files
authored
Merge pull request #489 from Scale3-Labs/obinna/S3EN-1111-fix-agno-non-streaming-bug
Obinna/s3 en 1111 fix agno non streaming bug
2 parents d035cbc + c06ad23 commit 5da182f

File tree

2 files changed

+143
-89
lines changed

2 files changed

+143
-89
lines changed

src/langtrace_python_sdk/instrumentation/agno/patch.py

Lines changed: 142 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,46 @@
1515
from langtrace_python_sdk.utils.llm import get_span_name, set_span_attributes
1616
from langtrace_python_sdk.utils.misc import serialize_args, serialize_kwargs
1717

18+
def _safe_serialize(obj):
19+
"""Safely serialize objects that might not be JSON serializable"""
20+
if hasattr(obj, 'to_dict'):
21+
return obj.to_dict()
22+
elif hasattr(obj, '__dict__'):
23+
return {k: _safe_serialize(v) for k, v in obj.__dict__.items() if not k.startswith('_')}
24+
elif isinstance(obj, dict):
25+
return {k: _safe_serialize(v) for k, v in obj.items()}
26+
elif isinstance(obj, (list, tuple)):
27+
return [_safe_serialize(i) for i in obj]
28+
return str(obj)
29+
30+
def _safe_json_dumps(obj):
31+
"""Safely dump an object to JSON, handling non-serializable types"""
32+
try:
33+
return json.dumps(obj)
34+
except (TypeError, ValueError):
35+
return json.dumps(_safe_serialize(obj))
36+
1837
def _extract_metrics(metrics: Dict[str, Any]) -> Dict[str, Any]:
1938
"""Helper function to extract and format metrics"""
20-
formatted_metrics = {}
39+
if not metrics:
40+
return {}
41+
42+
if hasattr(metrics, 'to_dict'):
43+
metrics = metrics.to_dict()
44+
elif hasattr(metrics, '__dict__'):
45+
metrics = {k: v for k, v in metrics.__dict__.items() if not k.startswith('_')}
2146

22-
# Extract basic metrics
47+
formatted_metrics = {}
48+
2349
for key in ['time', 'time_to_first_token', 'input_tokens', 'output_tokens',
24-
'prompt_tokens', 'completion_tokens', 'total_tokens']:
50+
'prompt_tokens', 'completion_tokens', 'total_tokens',
51+
'prompt_tokens_details', 'completion_tokens_details', 'tool_call_times']:
2552
if key in metrics:
2653
formatted_metrics[key] = metrics[key]
27-
28-
# Extract nested metric details if present
29-
if 'prompt_tokens_details' in metrics:
30-
formatted_metrics['prompt_tokens_details'] = metrics['prompt_tokens_details']
31-
if 'completion_tokens_details' in metrics:
32-
formatted_metrics['completion_tokens_details'] = metrics['completion_tokens_details']
33-
if 'tool_call_times' in metrics:
34-
formatted_metrics['tool_call_times'] = metrics['tool_call_times']
35-
54+
3655
return formatted_metrics
3756

57+
3858
def patch_memory(operation_name, version, tracer: Tracer):
3959
def traced_method(wrapped, instance, args, kwargs):
4060
service_provider = SERVICE_PROVIDERS["AGNO"]
@@ -110,86 +130,120 @@ def traced_method(wrapped, instance, args, kwargs):
110130
try:
111131
set_span_attributes(span, attributes)
112132
AgnoSpanAttributes(span=span, instance=instance)
113-
result_generator = wrapped(*args, **kwargs)
114-
115-
accumulated_content = ""
116-
current_tool_call = None
117-
response_metadata = None
118-
seen_tool_calls = set()
119-
120-
try:
121-
for response in result_generator:
122-
if not hasattr(response, 'to_dict'):
123-
yield response
124-
continue
125-
126-
if not response_metadata:
127-
response_metadata = {
128-
"run_id": response.run_id,
129-
"agent_id": response.agent_id,
130-
"session_id": response.session_id,
131-
"model": response.model,
132-
"content_type": response.content_type,
133-
}
134-
for key, value in response_metadata.items():
135-
if value is not None:
136-
set_span_attribute(span, f"agno.agent.{key}", str(value))
137-
138-
if response.content:
139-
accumulated_content += response.content
140-
set_span_attribute(span, "agno.agent.response", accumulated_content)
141-
142-
if response.messages:
143-
for msg in response.messages:
144-
if msg.tool_calls:
145-
for tool_call in msg.tool_calls:
146-
tool_id = tool_call.get('id')
147-
if tool_id and tool_id not in seen_tool_calls:
148-
seen_tool_calls.add(tool_id)
149-
tool_info = {
150-
'id': tool_id,
151-
'name': tool_call.get('function', {}).get('name'),
152-
'arguments': tool_call.get('function', {}).get('arguments'),
153-
'start_time': msg.created_at,
154-
}
155-
current_tool_call = tool_info
156-
set_span_attribute(span, f"agno.agent.tool_call.{tool_id}", json.dumps(tool_info))
157-
158-
if msg.metrics:
159-
metrics = _extract_metrics(msg.metrics)
160-
role_prefix = f"agno.agent.metrics.{msg.role}"
161-
for key, value in metrics.items():
162-
set_span_attribute(span, f"{role_prefix}.{key}", str(value))
163-
164-
if response.tools:
165-
for tool in response.tools:
166-
tool_id = tool.get('tool_call_id')
167-
if tool_id and current_tool_call and current_tool_call['id'] == tool_id:
168-
tool_result = {
169-
**current_tool_call,
170-
'result': tool.get('content'),
171-
'error': tool.get('tool_call_error'),
172-
'end_time': tool.get('created_at'),
173-
'metrics': tool.get('metrics'),
174-
}
175-
set_span_attribute(span, f"agno.agent.tool_call.{tool_id}", json.dumps(tool_result))
176-
current_tool_call = None
177-
178-
yield response
179-
180-
except Exception as err:
181-
span.record_exception(err)
182-
span.set_status(Status(StatusCode.ERROR, str(err)))
183-
raise
184-
finally:
185-
span.set_status(Status(StatusCode.OK))
186-
if len(seen_tool_calls) > 0:
187-
span.set_attribute("agno.agent.total_tool_calls", len(seen_tool_calls))
133+
is_streaming = kwargs.get('stream', False)
134+
result = wrapped(*args, **kwargs)
135+
136+
if not is_streaming and not operation_name.startswith('Agent._'):
137+
if hasattr(result, 'to_dict'):
138+
_process_response(span, result)
139+
return result
140+
141+
# Handle streaming (generator) case
142+
return _process_generator(span, result)
143+
188144
except Exception as err:
189145
span.record_exception(err)
190146
span.set_status(Status(StatusCode.ERROR, str(err)))
191147
raise
192148

149+
# Helper function to process a generator
150+
def _process_generator(span, result_generator):
151+
accumulated_content = ""
152+
current_tool_call = None
153+
response_metadata = None
154+
seen_tool_calls = set()
155+
156+
try:
157+
for response in result_generator:
158+
if not hasattr(response, 'to_dict'):
159+
yield response
160+
continue
161+
162+
_process_response(span, response,
163+
accumulated_content=accumulated_content,
164+
current_tool_call=current_tool_call,
165+
response_metadata=response_metadata,
166+
seen_tool_calls=seen_tool_calls)
167+
168+
if response.content:
169+
accumulated_content += response.content
170+
171+
yield response
172+
173+
except Exception as err:
174+
span.record_exception(err)
175+
span.set_status(Status(StatusCode.ERROR, str(err)))
176+
raise
177+
finally:
178+
span.set_status(Status(StatusCode.OK))
179+
if len(seen_tool_calls) > 0:
180+
span.set_attribute("agno.agent.total_tool_calls", len(seen_tool_calls))
181+
182+
def _process_response(span, response, accumulated_content="", current_tool_call=None,
183+
response_metadata=None, seen_tool_calls=set()):
184+
if not response_metadata:
185+
response_metadata = {
186+
"run_id": response.run_id,
187+
"agent_id": response.agent_id,
188+
"session_id": response.session_id,
189+
"model": response.model,
190+
"content_type": response.content_type,
191+
}
192+
for key, value in response_metadata.items():
193+
if value is not None:
194+
set_span_attribute(span, f"agno.agent.{key}", str(value))
195+
196+
if response.content:
197+
if accumulated_content:
198+
accumulated_content += response.content
199+
else:
200+
accumulated_content = response.content
201+
set_span_attribute(span, "agno.agent.response", accumulated_content)
202+
203+
if response.messages:
204+
for msg in response.messages:
205+
if msg.tool_calls:
206+
for tool_call in msg.tool_calls:
207+
tool_id = tool_call.get('id')
208+
if tool_id and tool_id not in seen_tool_calls:
209+
seen_tool_calls.add(tool_id)
210+
tool_info = {
211+
'id': tool_id,
212+
'name': tool_call.get('function', {}).get('name'),
213+
'arguments': tool_call.get('function', {}).get('arguments'),
214+
'start_time': msg.created_at,
215+
}
216+
current_tool_call = tool_info
217+
set_span_attribute(span, f"agno.agent.tool_call.{tool_id}", _safe_json_dumps(tool_info))
218+
219+
if msg.metrics:
220+
metrics = _extract_metrics(msg.metrics)
221+
role_prefix = f"agno.agent.metrics.{msg.role}"
222+
for key, value in metrics.items():
223+
set_span_attribute(span, f"{role_prefix}.{key}", str(value))
224+
225+
if response.tools:
226+
for tool in response.tools:
227+
tool_id = tool.get('tool_call_id')
228+
if tool_id and current_tool_call and current_tool_call['id'] == tool_id:
229+
tool_result = {
230+
**current_tool_call,
231+
'result': tool.get('content'),
232+
'error': tool.get('tool_call_error'),
233+
'end_time': tool.get('created_at'),
234+
'metrics': tool.get('metrics'),
235+
}
236+
set_span_attribute(span, f"agno.agent.tool_call.{tool_id}", _safe_json_dumps(tool_result))
237+
current_tool_call = None
238+
239+
if response.metrics:
240+
metrics = _extract_metrics(response.metrics)
241+
for key, value in metrics.items():
242+
set_span_attribute(span, f"agno.agent.metrics.{key}", str(value))
243+
244+
if len(seen_tool_calls) > 0:
245+
span.set_attribute("agno.agent.total_tool_calls", len(seen_tool_calls))
246+
193247
return traced_method
194248

195249
class AgnoSpanAttributes:
@@ -238,7 +292,7 @@ def run(self):
238292

239293
if hasattr(self.instance.model, 'metrics') and self.instance.model.metrics:
240294
metrics = _extract_metrics(self.instance.model.metrics)
241-
set_span_attribute(self.span, "agno.agent.model.metrics", json.dumps(metrics))
295+
set_span_attribute(self.span, "agno.agent.model.metrics", _safe_json_dumps(metrics))
242296

243297
if self.instance.tools:
244298
tool_list = []
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.8.1"
1+
__version__ = "3.8.2"

0 commit comments

Comments
 (0)