-
Notifications
You must be signed in to change notification settings - Fork 2.5k
fix: prompt-builder - jinja2 template set vars still shows required #9932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
136b7e0
f58094e
492073e
c78de27
b102f58
de715ff
1f24f61
3db96ab
6a5705e
d3f4b80
54f22c4
cb2ca2d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -12,7 +12,7 @@ | |||||
| from haystack import component, default_from_dict, default_to_dict, logging | ||||||
| from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent | ||||||
| from haystack.lazy_imports import LazyImport | ||||||
| from haystack.utils import Jinja2TimeExtension | ||||||
| from haystack.utils import Jinja2TimeExtension, extract_declared_variables | ||||||
| from haystack.utils.jinja2_chat_extension import ChatMessageExtension, templatize_part | ||||||
|
|
||||||
| logger = logging.getLogger(__name__) | ||||||
|
|
@@ -171,21 +171,30 @@ def __init__( | |||||
|
|
||||||
| extracted_variables = [] | ||||||
| if template and not variables: | ||||||
|
|
||||||
| def _extract_from_text( | ||||||
| text: Optional[str], role: Optional[str] = None, is_filter_allowed: bool = False | ||||||
| ) -> list: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| if text is None: | ||||||
| raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=role or "unknown", message=text)) | ||||||
| if is_filter_allowed and "templatize_part" in text: | ||||||
| raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE) | ||||||
|
|
||||||
| ast = self._env.parse(text) | ||||||
| template_variables = meta.find_undeclared_variables(ast) | ||||||
| assigned_variables = extract_declared_variables(text, env=self._env) | ||||||
| return list(template_variables - assigned_variables) | ||||||
|
|
||||||
| if isinstance(template, list): | ||||||
| for message in template: | ||||||
| if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM): | ||||||
| # infer variables from template | ||||||
| if message.text is None: | ||||||
| raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=message.role.value, message=message)) | ||||||
| if message.text and "templatize_part" in message.text: | ||||||
| raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE) | ||||||
| ast = self._env.parse(message.text) | ||||||
| template_variables = meta.find_undeclared_variables(ast) | ||||||
| extracted_variables += list(template_variables) | ||||||
| extracted_variables += _extract_from_text( | ||||||
| message.text, role=message.role.value, is_filter_allowed=True | ||||||
| ) | ||||||
| elif isinstance(template, str): | ||||||
| ast = self._env.parse(template) | ||||||
| extracted_variables = list(meta.find_undeclared_variables(ast)) | ||||||
| extracted_variables = _extract_from_text(template, is_filter_allowed=False) | ||||||
|
|
||||||
| extracted_variables = extracted_variables or [] | ||||||
| self.variables = variables or extracted_variables | ||||||
| self.required_variables = required_variables or [] | ||||||
|
|
||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,39 @@ | ||||||||||||||||||||||||||||||||
| # SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]> | ||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| from typing import Optional | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| from jinja2 import Environment, nodes | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def extract_declared_variables(template_str: str, env: Optional[Environment] = None) -> set: | ||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| Extract declared variables from a Jinja2 template string. | ||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||
| template_str (str): The Jinja2 template string to analyze. | ||||||||||||||||||||||||||||||||
| env (Environment, optional): The Jinja2 Environment. Defaults to None. | ||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||
| A list of variable names used in the template. | ||||||||||||||||||||||||||||||||
|
Comment on lines
+12
to
+19
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make sure to use the docstring formatting consistent with the rest of the library so
Suggested change
Also update the returns docstring to reflect we are returning a set not a list |
||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| env = env or Environment() | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||
| ast = env.parse(template_str) | ||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||
| raise RuntimeError(f"Failed to parse Jinja2 template: {e}") | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Collect all variables assigned inside the template via {% set %} | ||||||||||||||||||||||||||||||||
| assigned_variables = set() | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| for node in ast.find_all(nodes.Assign): | ||||||||||||||||||||||||||||||||
| if isinstance(node.target, nodes.Name): | ||||||||||||||||||||||||||||||||
| assigned_variables.add(node.target.name) | ||||||||||||||||||||||||||||||||
| elif isinstance(node.target, (nodes.List, nodes.Tuple)): | ||||||||||||||||||||||||||||||||
| for name_node in node.target.items: | ||||||||||||||||||||||||||||||||
| if isinstance(name_node, nodes.Name): | ||||||||||||||||||||||||||||||||
| assigned_variables.add(name_node.name) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| return assigned_variables | ||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| fixes: | ||
| - | | ||
| Fixed an issue where Jinja2 variable assignments using the `set` directive | ||
| were not being parsed correctly in certain contexts. This fix ensures that | ||
| variables assigned with `{% set var = value %}` are now properly recognized | ||
| and can be used as expected within templates inside `PromptBuilder` and | ||
| `ChatPromptBuilder`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make this a class function instead of inline. Please also keep it as private so fine to leave the name as
_extract_from_text