Skip to content

Commit 2965fc8

Browse files
authored
Adding a ThinkingBlock among content blocks (#19919)
* wip: add thinking block * ci: lint * feat: add/modify tests * feat: add thinking block handling also as input * fix: fixes for Gemini message merging; ci: tests tests tests * chore: add support for ThinkingBlock also in OpenAI streaming; fix: DevX improvement in Anthropic by automatically setting temperature to 1 if thinking detected * fix: add correct support for thinking in OpenAIResponses streaming * fix: fix test and streaming * chore: vbump (minor)
1 parent 05fb789 commit 2965fc8

File tree

21 files changed

+540
-96
lines changed

21 files changed

+540
-96
lines changed

llama-index-core/llama_index/core/base/llms/types.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,24 @@ def validate_cited_content(cls, v: Any) -> Any:
425425
return v
426426

427427

428+
class ThinkingBlock(BaseModel):
429+
"""A representation of the content streamed from reasoning/thinking processes by LLMs"""
430+
431+
block_type: Literal["thinking"] = "thinking"
432+
content: Optional[str] = Field(
433+
description="Content of the reasoning/thinking process, if available",
434+
default=None,
435+
)
436+
num_tokens: Optional[int] = Field(
437+
description="Number of token used for reasoning/thinking, if available",
438+
default=None,
439+
)
440+
additional_information: Dict[str, Any] = Field(
441+
description="Additional information related to the thinking/reasoning process, if available",
442+
default_factory=dict,
443+
)
444+
445+
428446
ContentBlock = Annotated[
429447
Union[
430448
TextBlock,
@@ -435,6 +453,7 @@ def validate_cited_content(cls, v: Any) -> Any:
435453
CachePoint,
436454
CitableBlock,
437455
CitationBlock,
456+
ThinkingBlock,
438457
],
439458
Field(discriminator="block_type"),
440459
]

llama-index-core/llama_index/core/memory/memory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
CachePoint,
2929
CitableBlock,
3030
CitationBlock,
31+
ThinkingBlock,
3132
)
3233
from llama_index.core.bridge.pydantic import (
3334
BaseModel,
@@ -343,6 +344,7 @@ def _estimate_token_count(
343344
DocumentBlock,
344345
CitableBlock,
345346
CitationBlock,
347+
ThinkingBlock,
346348
]
347349
] = []
348350

llama-index-core/tests/base/llms/test_types.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
AudioBlock,
1818
CachePoint,
1919
CacheControl,
20+
ThinkingBlock,
2021
)
2122
from llama_index.core.bridge.pydantic import BaseModel
2223
from llama_index.core.bridge.pydantic import ValidationError
@@ -455,3 +456,20 @@ def test_video_block_store_as_base64(mp4_bytes: bytes, mp4_base64: bytes):
455456
assert VideoBlock(video=mp4_bytes).video == mp4_base64
456457
# Store already encoded data
457458
assert VideoBlock(video=mp4_base64).video == mp4_base64
459+
460+
461+
def test_thinking_block():
462+
block = ThinkingBlock()
463+
assert block.block_type == "thinking"
464+
assert block.additional_information == {}
465+
assert block.content is None
466+
assert block.num_tokens is None
467+
block = ThinkingBlock(
468+
content="hello world",
469+
num_tokens=100,
470+
additional_information={"total_thinking_tokens": 1000},
471+
)
472+
assert block.block_type == "thinking"
473+
assert block.additional_information == {"total_thinking_tokens": 1000}
474+
assert block.content == "hello world"
475+
assert block.num_tokens == 100

llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/base.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from llama_index.core.base.llms.types import TextBlock as LITextBlock
2727
from llama_index.core.base.llms.types import CitationBlock as LICitationBlock
28+
from llama_index.core.base.llms.types import ThinkingBlock as LIThinkingBlock
2829
from llama_index.core.bridge.pydantic import Field, PrivateAttr
2930
from llama_index.core.callbacks import CallbackManager
3031
from llama_index.core.constants import DEFAULT_TEMPERATURE
@@ -204,6 +205,9 @@ def __init__(
204205
) -> None:
205206
additional_kwargs = additional_kwargs or {}
206207
callback_manager = callback_manager or CallbackManager([])
208+
# set the temperature to 1 when thinking is enabled, as per: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
209+
if thinking_dict and thinking_dict.get("type") == "enabled":
210+
temperature = 1
207211

208212
super().__init__(
209213
temperature=temperature,
@@ -340,11 +344,8 @@ def _completion_response_from_chat_response(
340344

341345
def _get_blocks_and_tool_calls_and_thinking(
342346
self, response: Any
343-
) -> Tuple[
344-
List[ContentBlock], List[Dict[str, Any]], Dict[str, Any], List[Dict[str, Any]]
345-
]:
347+
) -> Tuple[List[ContentBlock], List[Dict[str, Any]], List[Dict[str, Any]]]:
346348
tool_calls = []
347-
thinking = None
348349
blocks: List[ContentBlock] = []
349350
citations: List[TextCitation] = []
350351
tracked_citations: Set[str] = set()
@@ -375,11 +376,18 @@ def _get_blocks_and_tool_calls_and_thinking(
375376
citations.extend(content_block.citations)
376377
# this assumes a single thinking block, which as of 2025-03-06, is always true
377378
elif isinstance(content_block, ThinkingBlock):
378-
thinking = content_block.model_dump()
379+
blocks.append(
380+
LIThinkingBlock(
381+
content=content_block.thinking,
382+
additional_information=content_block.model_dump(
383+
exclude={"thinking"}
384+
),
385+
)
386+
)
379387
elif isinstance(content_block, ToolUseBlock):
380388
tool_calls.append(content_block.model_dump())
381389

382-
return blocks, tool_calls, thinking, [x.model_dump() for x in citations]
390+
return blocks, tool_calls, [x.model_dump() for x in citations]
383391

384392
@llm_chat_callback()
385393
def chat(
@@ -397,8 +405,8 @@ def chat(
397405
**all_kwargs,
398406
)
399407

400-
blocks, tool_calls, thinking, citations = (
401-
self._get_blocks_and_tool_calls_and_thinking(response)
408+
blocks, tool_calls, citations = self._get_blocks_and_tool_calls_and_thinking(
409+
response
402410
)
403411

404412
return AnthropicChatResponse(
@@ -407,7 +415,6 @@ def chat(
407415
blocks=blocks,
408416
additional_kwargs={
409417
"tool_calls": tool_calls,
410-
"thinking": thinking,
411418
},
412419
),
413420
citations=citations,
@@ -570,8 +577,8 @@ async def achat(
570577
**all_kwargs,
571578
)
572579

573-
blocks, tool_calls, thinking, citations = (
574-
self._get_blocks_and_tool_calls_and_thinking(response)
580+
blocks, tool_calls, citations = self._get_blocks_and_tool_calls_and_thinking(
581+
response
575582
)
576583

577584
return AnthropicChatResponse(
@@ -580,7 +587,6 @@ async def achat(
580587
blocks=blocks,
581588
additional_kwargs={
582589
"tool_calls": tool_calls,
583-
"thinking": thinking,
584590
},
585591
),
586592
citations=citations,

llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
CachePoint,
1515
CitableBlock,
1616
CitationBlock,
17+
ThinkingBlock,
1718
ContentBlock,
1819
)
1920

@@ -26,7 +27,7 @@
2627
CacheControlEphemeralParam,
2728
Base64PDFSourceParam,
2829
)
29-
from anthropic.types import ContentBlock as AnthropicContentBlock
30+
from anthropic.types import ContentBlockParam as AnthropicContentBlock
3031
from anthropic.types.beta import (
3132
BetaSearchResultBlockParam,
3233
BetaTextBlockParam,
@@ -201,9 +202,6 @@ def blocks_to_anthropic_blocks(
201202
if kwargs.get("cache_control"):
202203
global_cache_control = CacheControlEphemeralParam(**kwargs["cache_control"])
203204

204-
if kwargs.get("thinking"):
205-
anthropic_blocks.append(ThinkingBlockParam(**kwargs["thinking"]))
206-
207205
for block in blocks:
208206
if isinstance(block, TextBlock):
209207
if block.text:
@@ -251,6 +249,17 @@ def blocks_to_anthropic_blocks(
251249
if global_cache_control:
252250
anthropic_blocks[-1]["cache_control"] = global_cache_control
253251

252+
elif isinstance(block, ThinkingBlock):
253+
if block.content:
254+
signature = block.additional_information.get("signature", "")
255+
anthropic_blocks.append(
256+
ThinkingBlockParam(
257+
signature=signature, thinking=block.content, type="thinking"
258+
)
259+
)
260+
if global_cache_control:
261+
anthropic_blocks[-1]["cache_control"] = global_cache_control
262+
254263
elif isinstance(block, CachePoint):
255264
if len(anthropic_blocks) > 0:
256265
anthropic_blocks[-1]["cache_control"] = CacheControlEphemeralParam(

llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ dev = [
2727

2828
[project]
2929
name = "llama-index-llms-anthropic"
30-
version = "0.8.6"
30+
version = "0.9.0"
3131
description = "llama-index llms anthropic integration"
3232
authors = [{name = "Your Name", email = "[email protected]"}]
3333
requires-python = ">=3.9,<4.0"

llama-index-integrations/llms/llama-index-llms-anthropic/tests/test_llms_anthropic.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
CachePoint,
1818
CacheControl,
1919
)
20+
from llama_index.core.base.llms.types import ThinkingBlock
2021
from llama_index.core.tools import FunctionTool
2122
from llama_index.llms.anthropic import Anthropic
2223
from llama_index.llms.anthropic.base import AnthropicChatResponse
@@ -384,6 +385,59 @@ def test_cache_point_to_cache_control() -> None:
384385
assert ant_messages[0]["content"][-1]["cache_control"]["ttl"] == "5m"
385386

386387

388+
def test_thinking_input():
389+
messages = [
390+
ChatMessage(
391+
role="assistant",
392+
blocks=[
393+
ThinkingBlock(content="Hello"),
394+
TextBlock(text="World"),
395+
],
396+
),
397+
]
398+
ant_messages, _ = messages_to_anthropic_messages(messages)
399+
assert ant_messages[0]["role"] == "assistant"
400+
assert ant_messages[0]["content"][0]["type"] == "thinking"
401+
assert ant_messages[0]["content"][0]["thinking"] == "Hello"
402+
assert ant_messages[0]["content"][1]["type"] == "text"
403+
assert ant_messages[0]["content"][1]["text"] == "World"
404+
405+
406+
@pytest.mark.skipif(
407+
os.getenv("ANTHROPIC_API_KEY") is None,
408+
reason="Anthropic API key not available to test Anthropic document uploading ",
409+
)
410+
def test_thinking():
411+
llm = Anthropic(
412+
model="claude-sonnet-4-0",
413+
# max_tokens must be greater than budget_tokens
414+
max_tokens=64000,
415+
# temperature must be 1.0 for thinking to work
416+
temperature=1.0,
417+
thinking_dict={"type": "enabled", "budget_tokens": 1600},
418+
)
419+
res = llm.chat(
420+
messages=[
421+
ChatMessage(
422+
content="Please solve the following equation for x: x^2+12x+7=0. Please think before providing a response."
423+
)
424+
]
425+
)
426+
assert any(isinstance(block, ThinkingBlock) for block in res.message.blocks)
427+
assert (
428+
len(
429+
"".join(
430+
[
431+
block.content or ""
432+
for block in res.message.blocks
433+
if isinstance(block, ThinkingBlock)
434+
]
435+
)
436+
)
437+
> 0
438+
)
439+
440+
387441
@pytest.mark.skipif(
388442
os.getenv("ANTHROPIC_API_KEY") is None,
389443
reason="Anthropic API key not available to test Anthropic document uploading ",

llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
CompletionResponseGen,
3737
LLMMetadata,
3838
MessageRole,
39+
ThinkingBlock,
3940
)
4041
from llama_index.core.bridge.pydantic import BaseModel, Field, PrivateAttr
4142
from llama_index.core.callbacks import CallbackManager
@@ -393,7 +394,7 @@ def gen() -> ChatResponseGen:
393394
)
394395
llama_resp.delta = content_delta
395396
llama_resp.message.content = content
396-
llama_resp.message.additional_kwargs["thoughts"] = thoughts
397+
llama_resp.message.blocks.append(ThinkingBlock(content=thoughts))
397398
llama_resp.message.additional_kwargs["tool_calls"] = existing_tool_calls
398399
yield llama_resp
399400

@@ -453,7 +454,9 @@ async def gen() -> ChatResponseAsyncGen:
453454
)
454455
llama_resp.delta = content_delta
455456
llama_resp.message.content = content
456-
llama_resp.message.additional_kwargs["thoughts"] = thoughts
457+
llama_resp.message.blocks.append(
458+
ThinkingBlock(content=thoughts)
459+
)
457460
llama_resp.message.additional_kwargs["tool_calls"] = (
458461
existing_tool_calls
459462
)

0 commit comments

Comments
 (0)