-
Notifications
You must be signed in to change notification settings - Fork 18.7k
feat(langchain): add stuff and map reduce chains #32333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
"""Internal document utilities.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
from langchain_core.documents import Document | ||
|
||
|
||
def format_document_xml(doc: Document) -> str: | ||
"""Format a document as XML-like structure for LLM consumption. | ||
Args: | ||
doc: Document to format | ||
Returns: | ||
Document wrapped in XML tags: | ||
<document> | ||
<id>...</id> | ||
<content>...</content> | ||
<metadata>...</metadata> | ||
</document> | ||
Note: | ||
Does not generate valid XML or escape special characters. | ||
Intended for semi-structured LLM input only. | ||
""" | ||
id_str = f"<id>{doc.id}</id>" if doc.id is not None else "<id></id>" | ||
metadata_str = "" | ||
if doc.metadata: | ||
metadata_items = [f"{k}: {v!s}" for k, v in doc.metadata.items()] | ||
metadata_str = f"<metadata>{', '.join(metadata_items)}</metadata>" | ||
return ( | ||
f"<document>{id_str}" | ||
f"<content>{doc.page_content}</content>" | ||
f"{metadata_str}" | ||
f"</document>" | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
"""Lazy import utilities.""" | ||
|
||
from importlib import import_module | ||
from typing import Union | ||
|
||
|
||
def import_attr( | ||
attr_name: str, | ||
module_name: Union[str, None], | ||
package: Union[str, None], | ||
) -> object: | ||
"""Import an attribute from a module located in a package. | ||
This utility function is used in custom __getattr__ methods within __init__.py | ||
files to dynamically import attributes. | ||
Args: | ||
attr_name: The name of the attribute to import. | ||
module_name: The name of the module to import from. If None, the attribute | ||
is imported from the package itself. | ||
package: The name of the package where the module is located. | ||
""" | ||
if module_name == "__module__" or module_name is None: | ||
try: | ||
result = import_module(f".{attr_name}", package=package) | ||
except ModuleNotFoundError: | ||
msg = f"module '{package!r}' has no attribute {attr_name!r}" | ||
raise AttributeError(msg) from None | ||
else: | ||
try: | ||
module = import_module(f".{module_name}", package=package) | ||
except ModuleNotFoundError as err: | ||
msg = f"module '{package!r}.{module_name!r}' not found ({err})" | ||
raise ImportError(msg) from None | ||
result = getattr(module, attr_name) | ||
return result |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
"""Internal prompt resolution utilities. | ||
This module provides utilities for resolving different types of prompt specifications | ||
into standardized message formats for language models. It supports both synchronous | ||
and asynchronous prompt resolution with automatic detection of callable types. | ||
The module is designed to handle common prompt patterns across LangChain components, | ||
particularly for summarization chains and other document processing workflows. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import inspect | ||
from typing import TYPE_CHECKING, Callable, Union | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Awaitable | ||
|
||
from langchain_core.messages import MessageLikeRepresentation | ||
from langgraph.runtime import Runtime | ||
|
||
from langchain._internal._typing import ContextT, StateT | ||
|
||
|
||
def resolve_prompt( | ||
prompt: Union[ | ||
str, | ||
None, | ||
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]], | ||
], | ||
state: StateT, | ||
runtime: Runtime[ContextT], | ||
default_user_content: str, | ||
default_system_content: str, | ||
) -> list[MessageLikeRepresentation]: | ||
"""Resolve a prompt specification into a list of messages. | ||
Handles prompt resolution across different strategies. Supports callable functions, | ||
string system messages, and None for default behavior. | ||
Args: | ||
prompt: The prompt specification to resolve. Can be: | ||
- Callable: Function taking (state, runtime) returning message list. | ||
- str: A system message string. | ||
- None: Use the provided default system message. | ||
state: Current state, passed to callable prompts. | ||
runtime: LangGraph runtime instance, passed to callable prompts. | ||
default_user_content: User content to include (e.g., document text). | ||
default_system_content: Default system message when prompt is None. | ||
Returns: | ||
List of message dictionaries for language models, typically containing | ||
a system message and user message with content. | ||
Raises: | ||
TypeError: If prompt type is not str, None, or callable. | ||
Example: | ||
```python | ||
def custom_prompt(state, runtime): | ||
return [{"role": "system", "content": "Custom"}] | ||
messages = resolve_prompt(custom_prompt, state, runtime, "content", "default") | ||
messages = resolve_prompt("Custom system", state, runtime, "content", "default") | ||
messages = resolve_prompt(None, state, runtime, "content", "Default") | ||
``` | ||
Note: | ||
Callable prompts have full control over message structure and content | ||
parameter is ignored. String/None prompts create standard system + user | ||
structure. | ||
""" | ||
if callable(prompt): | ||
return prompt(state, runtime) | ||
if isinstance(prompt, str): | ||
system_msg = prompt | ||
elif prompt is None: | ||
system_msg = default_system_content | ||
else: | ||
msg = f"Invalid prompt type: {type(prompt)}. Expected str, None, or callable." | ||
raise TypeError(msg) | ||
|
||
return [ | ||
{"role": "system", "content": system_msg}, | ||
{"role": "user", "content": default_user_content}, | ||
] | ||
|
||
|
||
async def aresolve_prompt( | ||
prompt: Union[ | ||
str, | ||
None, | ||
Callable[[StateT, Runtime[ContextT]], list[MessageLikeRepresentation]], | ||
Callable[ | ||
[StateT, Runtime[ContextT]], Awaitable[list[MessageLikeRepresentation]] | ||
], | ||
], | ||
state: StateT, | ||
runtime: Runtime[ContextT], | ||
default_user_content: str, | ||
default_system_content: str, | ||
) -> list[MessageLikeRepresentation]: | ||
"""Async version of resolve_prompt supporting both sync and async callables. | ||
Handles prompt resolution across different strategies. Supports sync/async callable | ||
functions, string system messages, and None for default behavior. | ||
Args: | ||
prompt: The prompt specification to resolve. Can be: | ||
- Callable (sync): Function taking (state, runtime) returning message list. | ||
- Callable (async): Async function taking (state, runtime) returning | ||
awaitable message list. | ||
- str: A system message string. | ||
- None: Use the provided default system message. | ||
state: Current state, passed to callable prompts. | ||
runtime: LangGraph runtime instance, passed to callable prompts. | ||
default_user_content: User content to include (e.g., document text). | ||
default_system_content: Default system message when prompt is None. | ||
Returns: | ||
List of message dictionaries for language models, typically containing | ||
a system message and user message with content. | ||
Raises: | ||
TypeError: If prompt type is not str, None, or callable. | ||
Example: | ||
```python | ||
async def async_prompt(state, runtime): | ||
return [{"role": "system", "content": "Async"}] | ||
def sync_prompt(state, runtime): | ||
return [{"role": "system", "content": "Sync"}] | ||
messages = await aresolve_prompt( | ||
async_prompt, state, runtime, "content", "default" | ||
) | ||
messages = await aresolve_prompt( | ||
sync_prompt, state, runtime, "content", "default" | ||
) | ||
messages = await aresolve_prompt("Custom", state, runtime, "content", "default") | ||
``` | ||
Note: | ||
Callable prompts have full control over message structure and content | ||
parameter is ignored. Automatically detects and handles async | ||
callables. | ||
""" | ||
if callable(prompt): | ||
result = prompt(state, runtime) | ||
# Check if the result is awaitable (async function) | ||
if inspect.isawaitable(result): | ||
return await result | ||
return result | ||
if isinstance(prompt, str): | ||
system_msg = prompt | ||
elif prompt is None: | ||
system_msg = default_system_content | ||
else: | ||
msg = f"Invalid prompt type: {type(prompt)}. Expected str, None, or callable." | ||
raise TypeError(msg) | ||
|
||
return [ | ||
{"role": "system", "content": system_msg}, | ||
{"role": "user", "content": default_user_content}, | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""Private typing utilities for langchain.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, Union | ||
|
||
from langgraph.graph._node import StateNode | ||
from pydantic import BaseModel | ||
from typing_extensions import TypeAlias | ||
|
||
if TYPE_CHECKING: | ||
from dataclasses import Field | ||
|
||
|
||
class TypedDictLikeV1(Protocol): | ||
"""Protocol to represent types that behave like TypedDicts. | ||
Version 1: using `ClassVar` for keys. | ||
""" | ||
|
||
__required_keys__: ClassVar[frozenset[str]] | ||
__optional_keys__: ClassVar[frozenset[str]] | ||
|
||
|
||
class TypedDictLikeV2(Protocol): | ||
"""Protocol to represent types that behave like TypedDicts. | ||
Version 2: not using `ClassVar` for keys. | ||
""" | ||
|
||
__required_keys__: frozenset[str] | ||
__optional_keys__: frozenset[str] | ||
|
||
|
||
class DataclassLike(Protocol): | ||
"""Protocol to represent types that behave like dataclasses. | ||
Inspired by the private _DataclassT from dataclasses that uses a similar | ||
protocol as a bound. | ||
""" | ||
|
||
__dataclass_fields__: ClassVar[dict[str, Field[Any]]] | ||
|
||
|
||
StateLike: TypeAlias = Union[TypedDictLikeV1, TypedDictLikeV2, DataclassLike, BaseModel] | ||
"""Type alias for state-like types. | ||
It can either be a `TypedDict`, `dataclass`, or Pydantic `BaseModel`. | ||
Note: we cannot use either `TypedDict` or `dataclass` directly due to limitations in | ||
type checking. | ||
""" | ||
|
||
StateT = TypeVar("StateT", bound=StateLike) | ||
"""Type variable used to represent the state in a graph.""" | ||
|
||
ContextT = TypeVar("ContextT", bound=Union[StateLike, None]) | ||
"""Type variable for context types.""" | ||
|
||
|
||
__all__ = [ | ||
"ContextT", | ||
"StateLike", | ||
"StateNode", | ||
"StateT", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Re-exporting internal utilities from LangGraph for internal use in LangChain. | ||
# A different wrapper needs to be created for this purpose in LangChain. | ||
from langgraph._internal._runnable import RunnableCallable | ||
|
||
__all__ = [ | ||
"RunnableCallable", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from langchain.chains.documents import ( | ||
create_map_reduce_chain, | ||
create_stuff_documents_chain, | ||
) | ||
|
||
__all__ = [ | ||
"create_map_reduce_chain", | ||
"create_stuff_documents_chain", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
"""Document extraction chains. | ||
This module provides different strategies for extracting information from collections | ||
of documents using LangGraph and modern language models. | ||
Available Strategies: | ||
- Stuff: Processes all documents together in a single context window | ||
- Map-Reduce: Processes documents in parallel (map), then combines results (reduce) | ||
""" | ||
|
||
from langchain.chains.documents.map_reduce import create_map_reduce_chain | ||
from langchain.chains.documents.stuff import create_stuff_documents_chain | ||
|
||
__all__ = [ | ||
"create_map_reduce_chain", | ||
"create_stuff_documents_chain", | ||
] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can make these public in langgraph if it'd help you out 🫡, easy to fix later though bc internal