Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions docetl/operations/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .base import BaseOperation
from .clustering_utils import get_embeddings_for_clustering
from .utils import RichLoopBar, strict_render
from docetl.utils import has_jinja_syntax, prompt_user_for_non_jinja_confirmation


class ClusterOperation(BaseOperation):
Expand All @@ -19,6 +20,19 @@ def __init__(
self.max_batch_size: int = self.config.get(
"max_batch_size", kwargs.get("max_batch_size", float("inf"))
)
# Check for non-Jinja prompts and prompt user for confirmation
if "summary_prompt" in self.config and not has_jinja_syntax(
self.config["summary_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["summary_prompt"], self.config["name"], "summary_prompt"
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your summary_prompt."
)
# Mark that we need to append document statement (cluster uses inputs)
self.config["_append_document_to_prompt"] = True
self.config["_is_reduce_operation"] = True

def syntax_check(self) -> None:
"""
Expand Down Expand Up @@ -48,11 +62,16 @@ def syntax_check(self) -> None:
if not isinstance(self.config["summary_prompt"], str):
raise TypeError("'prompt' must be a string")

# Check if the prompt is a valid Jinja2 template
try:
Template(self.config["summary_prompt"])
except Exception as e:
raise ValueError(f"Invalid Jinja2 template in 'prompt': {str(e)}")
# Check if the prompt has Jinja syntax
if not has_jinja_syntax(self.config["summary_prompt"]):
# This will be handled during initialization with user confirmation
pass
else:
# Check if the prompt is a valid Jinja2 template
try:
Template(self.config["summary_prompt"])
except Exception as e:
raise ValueError(f"Invalid Jinja2 template in 'prompt': {str(e)}")

# Check optional parameters
if "max_batch_size" in self.config:
Expand Down
41 changes: 40 additions & 1 deletion docetl/operations/equijoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from docetl.operations.base import BaseOperation
from docetl.operations.utils import strict_render
from docetl.operations.utils.progress import RichLoopBar
from docetl.utils import completion_cost
from docetl.utils import (
completion_cost,
has_jinja_syntax,
prompt_user_for_non_jinja_confirmation,
)

# Global variables to store shared data
_right_data = None
Expand Down Expand Up @@ -89,6 +93,41 @@ def validate_limits(cls, v):
)
return v

@field_validator("comparison_prompt")
def validate_comparison_prompt(cls, v):
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
return v
# If it has Jinja syntax, validate it's a valid template
from jinja2 import Template

try:
Template(v)
except Exception as e:
raise ValueError(
f"Invalid Jinja2 template in 'comparison_prompt': {str(e)}"
)
return v

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check for non-Jinja prompts and prompt user for confirmation
if "comparison_prompt" in self.config and not has_jinja_syntax(
self.config["comparison_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["comparison_prompt"],
self.config["name"],
"comparison_prompt",
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your comparison_prompt."
)
# Mark that we need to append document statement
# Note: equijoin uses left and right, so we'll handle it in strict_render
self.config["_append_document_to_comparison_prompt"] = True

def compare_pair(
self,
comparison_prompt: str,
Expand Down
15 changes: 15 additions & 0 deletions docetl/operations/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from docetl.operations.base import BaseOperation
from docetl.operations.utils import RichLoopBar, strict_render
from docetl.utils import has_jinja_syntax, prompt_user_for_non_jinja_confirmation


class ExtractOperation(BaseOperation):
Expand All @@ -28,6 +29,10 @@ class schema(BaseOperation.schema):

@field_validator("prompt")
def validate_prompt(cls, v):
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
return v
try:
Template(v)
except Exception as e:
Expand All @@ -47,6 +52,16 @@ def __init__(
self.extraction_key_suffix = f"_extracted_{self.config['name']}"
else:
self.extraction_key_suffix = self.config["extraction_key_suffix"]
# Check for non-Jinja prompts and prompt user for confirmation
if "prompt" in self.config and not has_jinja_syntax(self.config["prompt"]):
if not prompt_user_for_non_jinja_confirmation(
self.config["prompt"], self.config["name"], "prompt"
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your prompt."
)
# Mark that we need to append document statement
self.config["_append_document_to_prompt"] = True

def _reformat_text_with_line_numbers(self, text: str, line_width: int = 80) -> str:
"""
Expand Down
18 changes: 18 additions & 0 deletions docetl/operations/link_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,29 @@

from docetl.operations.base import BaseOperation
from docetl.operations.utils import RichLoopBar, strict_render
from docetl.utils import has_jinja_syntax, prompt_user_for_non_jinja_confirmation

from .clustering_utils import get_embeddings_for_clustering


class LinkResolveOperation(BaseOperation):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check for non-Jinja prompts and prompt user for confirmation
if "comparison_prompt" in self.config and not has_jinja_syntax(
self.config["comparison_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["comparison_prompt"],
self.config["name"],
"comparison_prompt",
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your comparison_prompt."
)
# Mark that we need to append document statement
# Note: link_resolve uses link_value, id_value, and item, so strict_render will handle it
self.config["_append_document_to_comparison_prompt"] = True
def execute(self, input_data: list[dict]) -> tuple[list[dict], float]:
"""
Executes the resolve links operation on the provided dataset.
Expand Down
33 changes: 33 additions & 0 deletions docetl/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import requests
from jinja2 import Template

from docetl.utils import has_jinja_syntax, prompt_user_for_non_jinja_confirmation
from litellm.utils import ModelResponse
from pydantic import Field, field_validator, model_validator
from tqdm import tqdm
Expand Down Expand Up @@ -49,6 +51,11 @@ class schema(BaseOperation.schema):
@field_validator("batch_prompt")
def validate_batch_prompt(cls, v):
if v is not None:
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
# We'll mark it for later processing
return v
try:
template = Template(v)
# Test render with a minimal inputs list to validate template
Expand All @@ -62,6 +69,11 @@ def validate_batch_prompt(cls, v):
@field_validator("prompt")
def validate_prompt(cls, v):
if v is not None:
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
# We'll mark it for later processing
return v
try:
Template(v)
except Exception as e:
Expand Down Expand Up @@ -118,6 +130,27 @@ def __init__(
"max_batch_size", kwargs.get("max_batch_size", None)
)
self.clustering_method = "random"
# Check for non-Jinja prompts and prompt user for confirmation
if "prompt" in self.config and not has_jinja_syntax(self.config["prompt"]):
if not prompt_user_for_non_jinja_confirmation(
self.config["prompt"], self.config["name"], "prompt"
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your prompt."
)
# Mark that we need to append document statement
self.config["_append_document_to_prompt"] = True
if "batch_prompt" in self.config and not has_jinja_syntax(
self.config["batch_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["batch_prompt"], self.config["name"], "batch_prompt"
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your batch_prompt."
)
# Mark that we need to append document statement
self.config["_append_document_to_batch_prompt"] = True

def _generate_calibration_context(self, input_data: list[dict]) -> str:
"""
Expand Down
46 changes: 46 additions & 0 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pydantic import Field, field_validator, model_validator

from docetl.operations.base import BaseOperation
from docetl.utils import has_jinja_syntax, prompt_user_for_non_jinja_confirmation
from docetl.operations.clustering_utils import (
cluster_documents,
get_embeddings_for_clustering,
Expand Down Expand Up @@ -67,6 +68,10 @@ class schema(BaseOperation.schema):
@field_validator("prompt")
def validate_prompt(cls, v):
if v is not None:
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
return v
try:
template = Template(v)
template_vars = template.environment.parse(v).find_all(
Expand All @@ -84,6 +89,10 @@ def validate_prompt(cls, v):
@field_validator("fold_prompt")
def validate_fold_prompt(cls, v):
if v is not None:
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
return v
try:
fold_template = Template(v)
fold_template_vars = fold_template.environment.parse(v).find_all(
Expand All @@ -104,6 +113,10 @@ def validate_fold_prompt(cls, v):
@field_validator("merge_prompt")
def validate_merge_prompt(cls, v):
if v is not None:
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
return v
try:
merge_template = Template(v)
merge_template_vars = merge_template.environment.parse(v).find_all(
Expand Down Expand Up @@ -181,6 +194,39 @@ def __init__(self, *args, **kwargs):
)
self.intermediates = {}
self.lineage_keys = self.config.get("output", {}).get("lineage", [])
# Check for non-Jinja prompts and prompt user for confirmation
if "prompt" in self.config and not has_jinja_syntax(self.config["prompt"]):
if not prompt_user_for_non_jinja_confirmation(
self.config["prompt"], self.config["name"], "prompt"
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your prompt."
)
# Mark that we need to append document statement (for reduce, use inputs)
self.config["_append_document_to_prompt"] = True
self.config["_is_reduce_operation"] = True
if "fold_prompt" in self.config and not has_jinja_syntax(
self.config["fold_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["fold_prompt"], self.config["name"], "fold_prompt"
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your fold_prompt."
)
self.config["_append_document_to_fold_prompt"] = True
self.config["_is_reduce_operation"] = True
if "merge_prompt" in self.config and not has_jinja_syntax(
self.config["merge_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["merge_prompt"], self.config["name"], "merge_prompt"
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your merge_prompt."
)
self.config["_append_document_to_merge_prompt"] = True
self.config["_is_reduce_operation"] = True

def execute(self, input_data: list[dict]) -> tuple[list[dict], float]:
"""
Expand Down
47 changes: 46 additions & 1 deletion docetl/operations/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

from docetl.operations.base import BaseOperation
from docetl.operations.utils import RichLoopBar, rich_as_completed, strict_render
from docetl.utils import completion_cost, extract_jinja_variables
from docetl.utils import (
completion_cost,
extract_jinja_variables,
has_jinja_syntax,
prompt_user_for_non_jinja_confirmation,
)


def find_cluster(item, cluster_map):
Expand Down Expand Up @@ -48,6 +53,10 @@ class schema(BaseOperation.schema):
@field_validator("comparison_prompt")
def validate_comparison_prompt(cls, v):
if v is not None:
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
return v
try:
comparison_template = Template(v)
comparison_vars = comparison_template.environment.parse(v).find_all(
Expand All @@ -70,6 +79,10 @@ def validate_comparison_prompt(cls, v):
@field_validator("resolution_prompt")
def validate_resolution_prompt(cls, v):
if v is not None:
# Check if it has Jinja syntax
if not has_jinja_syntax(v):
# This will be handled during initialization with user confirmation
return v
try:
reduction_template = Template(v)
reduction_vars = reduction_template.environment.parse(v).find_all(
Expand Down Expand Up @@ -123,6 +136,38 @@ def validate_output_schema(self, info: ValidationInfo):

return self

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check for non-Jinja prompts and prompt user for confirmation
if "comparison_prompt" in self.config and not has_jinja_syntax(
self.config["comparison_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["comparison_prompt"],
self.config["name"],
"comparison_prompt",
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your comparison_prompt."
)
# Mark that we need to append document statement
# Note: comparison_prompt uses input1 and input2, so we'll handle it specially in strict_render
self.config["_append_document_to_comparison_prompt"] = True
if "resolution_prompt" in self.config and not has_jinja_syntax(
self.config["resolution_prompt"]
):
if not prompt_user_for_non_jinja_confirmation(
self.config["resolution_prompt"],
self.config["name"],
"resolution_prompt",
):
raise ValueError(
f"Operation '{self.config['name']}' cancelled by user. Please add Jinja2 template syntax to your resolution_prompt."
)
# Mark that we need to append document statement (resolution uses inputs)
self.config["_append_document_to_resolution_prompt"] = True
self.config["_is_reduce_operation"] = True

def compare_pair(
self,
comparison_prompt: str,
Expand Down
Loading