Skip to content

Commit 8f49dd0

Browse files
fix streaming thinking/tool calling with anthropic (#20077)
1 parent 4e5839c commit 8f49dd0

File tree

2 files changed

+82
-45
lines changed

2 files changed

+82
-45
lines changed

llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py

Lines changed: 81 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@
5252
ContentBlockStopEvent,
5353
CitationsSearchResultLocation,
5454
InputJSONDelta,
55+
RawContentBlockDeltaEvent,
56+
RawContentBlockStartEvent,
57+
RawContentBlockStopEvent,
5558
TextBlock,
5659
TextDelta,
5760
ThinkingBlock,
@@ -443,7 +446,8 @@ def stream_chat(
443446
)
444447

445448
def gen() -> Generator[AnthropicChatResponse, None, None]:
446-
content = [LITextBlock(text="")]
449+
content = []
450+
cur_block = None
447451
content_delta = ""
448452
thinking = None
449453
cur_tool_calls: List[ToolUseBlock] = []
@@ -453,13 +457,13 @@ def gen() -> Generator[AnthropicChatResponse, None, None]:
453457
tracked_citations: Set[str] = set()
454458
role = MessageRole.ASSISTANT
455459
for r in response:
456-
if isinstance(r, ContentBlockDeltaEvent):
460+
if isinstance(r, (ContentBlockDeltaEvent, RawContentBlockDeltaEvent)):
457461
if isinstance(r.delta, TextDelta):
458462
content_delta = r.delta.text or ""
459-
if not isinstance(content[-1], LITextBlock):
460-
content.append(LITextBlock(text=content_delta))
463+
if not isinstance(cur_block, LITextBlock):
464+
cur_block = LITextBlock(text=content_delta)
461465
else:
462-
content[-1].text += content_delta
466+
cur_block.text += content_delta
463467

464468
elif isinstance(r.delta, CitationsDelta) and isinstance(
465469
r.delta.citation, CitationsSearchResultLocation
@@ -480,23 +484,23 @@ def gen() -> Generator[AnthropicChatResponse, None, None]:
480484
)
481485
)
482486
elif isinstance(r.delta, SignatureDelta):
483-
if thinking is None:
484-
thinking = ThinkingBlock(
485-
signature=r.delta.signature,
486-
thinking="",
487-
type="thinking",
487+
if not isinstance(cur_block, LIThinkingBlock):
488+
cur_block = LIThinkingBlock(
489+
content="",
490+
additional_information={"signature": r.delta.signature},
488491
)
489492
else:
490-
thinking.signature += r.delta.signature
493+
cur_block.additional_information["signature"] += (
494+
r.delta.signature
495+
)
491496
elif isinstance(r.delta, ThinkingDelta):
492-
if thinking is None:
493-
thinking = ThinkingBlock(
494-
signature="",
495-
thinking=r.delta.thinking,
496-
type="thinking",
497+
if cur_block is None:
498+
cur_block = LIThinkingBlock(
499+
content=r.delta.thinking or "",
500+
additional_information={"signature": ""},
497501
)
498502
else:
499-
thinking.thinking += r.delta.thinking
503+
cur_block.content += r.delta.thinking
500504
elif isinstance(r.delta, CitationsDelta):
501505
# TODO: handle citation deltas
502506
cur_citations.append(r.delta.citation.model_dump())
@@ -523,29 +527,44 @@ def gen() -> Generator[AnthropicChatResponse, None, None]:
523527
tool_calls_to_send = [*cur_tool_calls, cur_tool_call]
524528
else:
525529
tool_calls_to_send = cur_tool_calls
530+
526531
yield AnthropicChatResponse(
527532
message=ChatMessage(
528533
role=role,
529534
blocks=content,
530535
additional_kwargs={
531-
"tool_calls": [
532-
t.model_dump() for t in tool_calls_to_send
533-
],
534-
"thinking": thinking.model_dump() if thinking else None,
536+
"tool_calls": [t.dict() for t in tool_calls_to_send]
535537
},
536538
),
537539
citations=cur_citations,
538540
delta=content_delta,
539541
raw=dict(r),
540542
)
541-
elif isinstance(r, ContentBlockStartEvent):
543+
elif isinstance(r, (ContentBlockStartEvent, RawContentBlockStartEvent)):
542544
if isinstance(r.content_block, ToolUseBlock):
543545
cur_tool_call = r.content_block
544546
cur_tool_json = ""
545-
elif isinstance(r, ContentBlockStopEvent):
547+
elif isinstance(r, (ContentBlockStopEvent, RawContentBlockStopEvent)):
546548
if isinstance(cur_tool_call, ToolUseBlock):
547549
cur_tool_calls.append(cur_tool_call)
548550

551+
if cur_block is not None:
552+
content.append(cur_block)
553+
cur_block = None
554+
555+
yield AnthropicChatResponse(
556+
message=ChatMessage(
557+
role=role,
558+
blocks=content,
559+
additional_kwargs={
560+
"tool_calls": [t.dict() for t in tool_calls_to_send]
561+
},
562+
),
563+
citations=cur_citations,
564+
delta=content_delta,
565+
raw=dict(r),
566+
)
567+
549568
return gen()
550569

551570
@llm_completion_callback()
@@ -615,7 +634,8 @@ async def astream_chat(
615634
)
616635

617636
async def gen() -> ChatResponseAsyncGen:
618-
content = [LITextBlock(text="")]
637+
content = []
638+
cur_block = None
619639
content_delta = ""
620640
thinking = None
621641
cur_tool_calls: List[ToolUseBlock] = []
@@ -625,13 +645,13 @@ async def gen() -> ChatResponseAsyncGen:
625645
tracked_citations: Set[str] = set()
626646
role = MessageRole.ASSISTANT
627647
async for r in response:
628-
if isinstance(r, ContentBlockDeltaEvent):
648+
if isinstance(r, (ContentBlockDeltaEvent, RawContentBlockDeltaEvent)):
629649
if isinstance(r.delta, TextDelta):
630650
content_delta = r.delta.text or ""
631-
if not isinstance(content[-1], LITextBlock):
632-
content.append(LITextBlock(text=content_delta))
651+
if not isinstance(cur_block, LITextBlock):
652+
cur_block = LITextBlock(text=content_delta)
633653
else:
634-
content[-1].text += content_delta
654+
cur_block.text += content_delta
635655

636656
elif isinstance(r.delta, CitationsDelta) and isinstance(
637657
r.delta.citation, CitationsSearchResultLocation
@@ -652,23 +672,23 @@ async def gen() -> ChatResponseAsyncGen:
652672
)
653673
)
654674
elif isinstance(r.delta, SignatureDelta):
655-
if thinking is None:
656-
thinking = ThinkingBlock(
657-
signature=r.delta.signature,
658-
thinking="",
659-
type="thinking",
675+
if not isinstance(cur_block, LIThinkingBlock):
676+
cur_block = LIThinkingBlock(
677+
content="",
678+
additional_information={"signature": r.delta.signature},
660679
)
661680
else:
662-
thinking.signature += r.delta.signature
681+
cur_block.additional_information["signature"] += (
682+
r.delta.signature
683+
)
663684
elif isinstance(r.delta, ThinkingDelta):
664-
if thinking is None:
665-
thinking = ThinkingBlock(
666-
signature="",
667-
thinking=r.delta.thinking,
668-
type="thinking",
685+
if cur_block is None:
686+
cur_block = LIThinkingBlock(
687+
content=r.delta.thinking or "",
688+
additional_information={"signature": ""},
669689
)
670690
else:
671-
thinking.thinking += r.delta.thinking
691+
cur_block.content += r.delta.thinking
672692
elif isinstance(r.delta, CitationsDelta):
673693
# TODO: handle citation deltas
674694
cur_citations.append(r.delta.citation.model_dump())
@@ -695,27 +715,44 @@ async def gen() -> ChatResponseAsyncGen:
695715
tool_calls_to_send = [*cur_tool_calls, cur_tool_call]
696716
else:
697717
tool_calls_to_send = cur_tool_calls
718+
698719
yield AnthropicChatResponse(
699720
message=ChatMessage(
700721
role=role,
701722
blocks=content,
702723
additional_kwargs={
703-
"tool_calls": [t.dict() for t in tool_calls_to_send],
704-
"thinking": thinking.model_dump() if thinking else None,
724+
"tool_calls": [t.dict() for t in tool_calls_to_send]
705725
},
706726
),
707727
citations=cur_citations,
708728
delta=content_delta,
709729
raw=dict(r),
710730
)
711-
elif isinstance(r, ContentBlockStartEvent):
731+
elif isinstance(r, (ContentBlockStartEvent, RawContentBlockStartEvent)):
712732
if isinstance(r.content_block, ToolUseBlock):
713733
cur_tool_call = r.content_block
714734
cur_tool_json = ""
715-
elif isinstance(r, ContentBlockStopEvent):
735+
elif isinstance(r, (ContentBlockStopEvent, RawContentBlockStopEvent)):
716736
if isinstance(cur_tool_call, ToolUseBlock):
717737
cur_tool_calls.append(cur_tool_call)
718738

739+
if cur_block is not None:
740+
content.append(cur_block)
741+
cur_block = None
742+
743+
yield AnthropicChatResponse(
744+
message=ChatMessage(
745+
role=role,
746+
blocks=content,
747+
additional_kwargs={
748+
"tool_calls": [t.dict() for t in tool_calls_to_send]
749+
},
750+
),
751+
citations=cur_citations,
752+
delta=content_delta,
753+
raw=dict(r),
754+
)
755+
719756
return gen()
720757

721758
@llm_completion_callback()

llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ dev = [
2727

2828
[project]
2929
name = "llama-index-llms-anthropic"
30-
version = "0.9.3"
30+
version = "0.9.4"
3131
description = "llama-index llms anthropic integration"
3232
authors = [{name = "Your Name", email = "[email protected]"}]
3333
requires-python = ">=3.9,<4.0"

0 commit comments

Comments
 (0)