Skip to content

Commit 8fd8ce7

Browse files
feat: Add limit parameter to operations
Adds a limit parameter to Extract, Map, Filter, and Reduce operations to control the number of processed items. Co-authored-by: ss.shankar505 <[email protected]>
1 parent 81d1104 commit 8fd8ce7

File tree

11 files changed

+211
-86
lines changed

11 files changed

+211
-86
lines changed

docetl/operations/extract.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class schema(BaseOperation.schema):
2626
timeout: int | None = None
2727
skip_on_error: bool = False
2828
litellm_completion_kwargs: dict[str, Any] = Field(default_factory=dict)
29+
limit: int | None = Field(None, gt=0)
2930

3031
@field_validator("prompt")
3132
def validate_prompt(cls, v):
@@ -392,6 +393,10 @@ def execute(self, input_data: list[dict]) -> tuple[list[dict], float]:
392393
Returns:
393394
tuple[list[dict], float]: A tuple containing the processed data and the total cost of the operation.
394395
"""
396+
limit_value = self.config.get("limit")
397+
if limit_value is not None:
398+
input_data = input_data[:limit_value]
399+
395400
if not input_data:
396401
return [], 0.0
397402

docetl/operations/filter.py

Lines changed: 30 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,30 @@ def validate_filter_output_schema(self):
3333

3434
return self
3535

36+
def __init__(self, *args, **kwargs):
37+
super().__init__(*args, **kwargs)
38+
self._filter_key = next(
39+
iter(
40+
[
41+
k
42+
for k in self.config["output"]["schema"].keys()
43+
if k != "_short_explanation"
44+
]
45+
)
46+
)
47+
self._filter_is_build = False
48+
49+
def _limit_applies_to_inputs(self) -> bool:
50+
return False
51+
52+
def _handle_result(self, result: dict[str, Any]) -> tuple[dict | None, bool]:
53+
keep_record = bool(result.get(self._filter_key))
54+
result.pop(self._filter_key, None)
55+
56+
if self._filter_is_build or keep_record:
57+
return result, keep_record
58+
return None, False
59+
3660
def execute(
3761
self, input_data: list[dict], is_build: bool = False
3862
) -> tuple[list[dict], float]:
@@ -46,55 +70,10 @@ def execute(
4670
Returns:
4771
tuple[list[dict], float]: A tuple containing the filtered list of dictionaries
4872
and the total cost of the operation.
49-
50-
This method performs the following steps:
51-
1. Processes each input item using an LLM model
52-
2. Validates the output
53-
3. Filters the results based on the specified filter key
54-
4. Calculates the total cost of the operation
55-
56-
The method uses multi-threading to process items in parallel, improving performance
57-
for large datasets.
58-
59-
Usage:
60-
```python
61-
from docetl.operations import FilterOperation
62-
63-
config = {
64-
"prompt": "Determine if the following item is important: {{input}}",
65-
"output": {
66-
"schema": {"is_important": "bool"}
67-
},
68-
"model": "gpt-3.5-turbo"
69-
}
70-
filter_op = FilterOperation(config)
71-
input_data = [
72-
{"id": 1, "text": "Critical update"},
73-
{"id": 2, "text": "Regular maintenance"}
74-
]
75-
results, cost = filter_op.execute(input_data)
76-
print(f"Filtered results: {results}")
77-
print(f"Total cost: {cost}")
78-
```
7973
"""
80-
filter_key = next(
81-
iter(
82-
[
83-
k
84-
for k in self.config["output"]["schema"].keys()
85-
if k != "_short_explanation"
86-
]
87-
)
88-
)
89-
90-
results, total_cost = super().execute(input_data)
91-
92-
# Drop records with filter_key values that are False
93-
if not is_build:
94-
results = [result for result in results if result[filter_key]]
95-
96-
# Drop the filter_key from the results
97-
for result in results:
98-
result.pop(filter_key, None)
99-
100-
return results, total_cost
74+
previous_state = self._filter_is_build
75+
self._filter_is_build = is_build
76+
try:
77+
return super().execute(input_data)
78+
finally:
79+
self._filter_is_build = previous_state

docetl/operations/map.py

Lines changed: 74 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class schema(BaseOperation.schema):
4444
litellm_completion_kwargs: dict[str, Any] = {}
4545
pdf_url_key: str | None = None
4646
flush_partial_result: bool = False
47+
limit: int | None = Field(None, gt=0)
4748
# Calibration parameters
4849
calibrate: bool = False
4950
num_calibration_docs: int = Field(10, gt=0)
@@ -152,6 +153,12 @@ def __init__(
152153
# Mark that we need to append document statement
153154
self.config["_append_document_to_batch_prompt"] = True
154155

156+
def _limit_applies_to_inputs(self) -> bool:
157+
return True
158+
159+
def _handle_result(self, result: dict[str, Any]) -> tuple[dict | None, bool]:
160+
return result, True
161+
155162
def _generate_calibration_context(self, input_data: list[dict]) -> str:
156163
"""
157164
Generate calibration context by running the operation on a sample of documents
@@ -272,17 +279,27 @@ def execute(self, input_data: list[dict]) -> tuple[list[dict], float]:
272279
273280
The method uses parallel processing to improve performance.
274281
"""
282+
limit_value = self.config.get("limit")
283+
275284
# Check if there's no prompt and only drop_keys
276285
if "prompt" not in self.config and "drop_keys" in self.config:
286+
data_to_process = input_data
287+
if limit_value is not None and self._limit_applies_to_inputs():
288+
data_to_process = input_data[:limit_value]
277289
# If only drop_keys is specified, simply drop the keys and return
278290
dropped_results = []
279-
for item in input_data:
291+
for item in data_to_process:
280292
new_item = {
281293
k: v for k, v in item.items() if k not in self.config["drop_keys"]
282294
}
283295
dropped_results.append(new_item)
296+
if limit_value is not None and len(dropped_results) >= limit_value:
297+
break
284298
return dropped_results, 0.0 # Return the modified data with no cost
285299

300+
if limit_value is not None and self._limit_applies_to_inputs():
301+
input_data = input_data[:limit_value]
302+
286303
# Generate calibration context if enabled
287304
calibration_context = ""
288305
if self.config.get("calibrate", False) and "prompt" in self.config:
@@ -512,40 +529,62 @@ def _process_map_batch(items: list[dict]) -> tuple[list[dict], float]:
512529

513530
return all_results, total_cost
514531

515-
with ThreadPoolExecutor(max_workers=self.max_batch_size) as executor:
516-
batch_size = self.max_batch_size if self.max_batch_size is not None else 1
517-
futures = []
518-
for i in range(0, len(input_data), batch_size):
519-
batch = input_data[i : i + batch_size]
520-
futures.append(executor.submit(_process_map_batch, batch))
521-
results = []
522-
total_cost = 0
523-
pbar = RichLoopBar(
524-
range(len(futures)),
525-
desc=f"Processing {self.config['name']} (map) on all documents",
526-
console=self.console,
527-
)
528-
for batch_index in pbar:
529-
result_list, item_cost = futures[batch_index].result()
530-
if result_list:
531-
if "drop_keys" in self.config:
532-
result_list = [
533-
{
534-
k: v
535-
for k, v in result.items()
536-
if k not in self.config["drop_keys"]
537-
}
538-
for result in result_list
539-
]
540-
results.extend(result_list)
541-
# --- BEGIN: Flush partial checkpoint ---
542-
if self.config.get("flush_partial_results", False):
543-
op_name = self.config["name"]
544-
self.runner._flush_partial_results(
545-
op_name, batch_index, result_list
546-
)
547-
# --- END: Flush partial checkpoint ---
548-
total_cost += item_cost
532+
limit_counter = 0
533+
534+
batch_size = self.max_batch_size if self.max_batch_size is not None else 1
535+
total_batches = (
536+
(len(input_data) + batch_size - 1) // batch_size if input_data else 0
537+
)
538+
results: list[dict] = []
539+
total_cost = 0.0
540+
limit_reached = False
541+
542+
pbar = RichLoopBar(
543+
range(total_batches),
544+
desc=f"Processing {self.config['name']} (map) on all documents",
545+
console=self.console,
546+
)
547+
548+
for batch_index in pbar:
549+
if limit_value is not None and limit_counter >= limit_value:
550+
break
551+
552+
batch_start = batch_index * batch_size
553+
batch = input_data[batch_start : batch_start + batch_size]
554+
if not batch:
555+
break
556+
557+
result_list, item_cost = _process_map_batch(batch)
558+
total_cost += item_cost
559+
560+
if result_list:
561+
if "drop_keys" in self.config:
562+
result_list = [
563+
{
564+
k: v
565+
for k, v in result.items()
566+
if k not in self.config["drop_keys"]
567+
}
568+
for result in result_list
569+
]
570+
571+
if self.config.get("flush_partial_results", False):
572+
op_name = self.config["name"]
573+
self.runner._flush_partial_results(op_name, batch_index, result_list)
574+
575+
for result in result_list:
576+
processed_result, counts_towards_limit = self._handle_result(result)
577+
if processed_result is not None:
578+
results.append(processed_result)
579+
580+
if limit_value is not None and counts_towards_limit:
581+
limit_counter += 1
582+
if limit_counter >= limit_value:
583+
limit_reached = True
584+
break
585+
586+
if limit_reached:
587+
break
549588

550589
if self.status:
551590
self.status.start()

docetl/operations/reduce.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class schema(BaseOperation.schema):
6464
timeout: int | None = None
6565
litellm_completion_kwargs: dict[str, Any] = Field(default_factory=dict)
6666
enable_observability: bool = False
67+
limit: int | None = Field(None, gt=0)
6768

6869
@field_validator("prompt")
6970
def validate_prompt(cls, v):
@@ -282,6 +283,10 @@ def get_group_key(item):
282283
# Convert the grouped data to a list of tuples
283284
grouped_data = list(grouped_data.items())
284285

286+
limit_value = self.config.get("limit")
287+
if limit_value is not None:
288+
grouped_data = grouped_data[:limit_value]
289+
285290
def process_group(
286291
key: tuple, group_elems: list[dict]
287292
) -> tuple[dict | None, float]:
@@ -388,6 +393,9 @@ def process_group(
388393
if output is not None:
389394
results.append(output)
390395

396+
if limit_value is not None and len(results) > limit_value:
397+
results = results[:limit_value]
398+
391399
if self.config.get("persist_intermediates", False):
392400
for result in results:
393401
key = tuple(result[k] for k in self.config["reduce_key"])

docs/operators/extract.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ This strategy asks the LLM to generate regex patterns matching the desired conte
140140
| `timeout` | Timeout for LLM calls in seconds | 120 |
141141
| `skip_on_error` | Continue processing if errors occur | false |
142142
| `litellm_completion_kwargs` | Additional parameters for LiteLLM calls | {} |
143+
| `limit` | Maximum number of documents to extract from before stopping | Processes all data |
144+
145+
When `limit` is set, Extract only reformats and submits the first _N_ documents. This is handy when the upstream dataset is large and you want to cap cost while previewing results.
143146

144147
## Best Practices
145148

docs/operators/filter.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ This example demonstrates how the Filter operation distinguishes between high-im
8383

8484
See [map optional parameters](./map.md#optional-parameters) for additional configuration options, including `batch_prompt` and `max_batch_size`.
8585

86+
### Limiting filtered outputs
87+
88+
`limit` behaves slightly differently for filter operations than for map operations. Because filter drops documents whose predicate evaluates to `false`, the limit counts only the documents that would be retained (i.e., the ones whose boolean output is `true`). DocETL will continue evaluating additional inputs until it has collected `limit` passing documents and then stop scheduling further LLM calls. This ensures you can request “the first N matches” without paying to score the entire dataset.
89+
8690
!!! info "Validation"
8791

8892
For more details on validation techniques and implementation, see [operators](../concepts/operators.md#validation).

docs/operators/map.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ This example demonstrates how the Map operation can transform long, unstructured
140140
| `optimize` | Flag to enable operation optimization | `True` |
141141
| `recursively_optimize` | Flag to enable recursive optimization of operators synthesized as part of rewrite rules | `false` |
142142
| `sample` | Number of samples to use for the operation | Processes all data |
143+
| `limit` | Maximum number of outputs to produce before stopping | Processes all data |
143144
| `tools` | List of tool definitions for LLM use | None |
144145
| `validate` | List of Python expressions to validate the output | None |
145146
| `flush_partial_results` | Write results of individual batches of map operation to disk for faster inspection | False |
@@ -158,6 +159,10 @@ This example demonstrates how the Map operation can transform long, unstructured
158159

159160
Note: If `drop_keys` is specified, `prompt` and `output` become optional parameters.
160161

162+
### Limiting execution
163+
164+
Set `limit` when you only need the first _N_ map results or want to cap LLM spend. The operation slices the processed dataset to the first `limit` entries and also stops scheduling new prompts once that many outputs have been produced, even if a prompt returns multiple records. Filter operations inherit this behavior but redefine the count so the limit only applies to records whose filter predicate evaluates to `true` (see [Filter](./filter.md#optional-parameters)).
165+
161166

162167
!!! info "Validation and Gleaning"
163168

docs/operators/reduce.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ This Reduce operation processes customer feedback grouped by department:
5252
| Parameter | Description | Default |
5353
| ------------------------- | ------------------------------------------------------------------------------------------------------ | --------------------------- |
5454
| `sample` | Number of samples to use for the operation | None |
55+
| `limit` | Maximum number of groups to process before stopping | All groups |
5556
| `synthesize_resolve` | If false, won't synthesize a resolve operation between map and reduce | true |
5657
| `model` | The language model to use | Falls back to default_model |
5758
| `input` | Specifies the schema or keys to subselect from each item | All keys from input items |
@@ -67,6 +68,10 @@ This Reduce operation processes customer feedback grouped by department:
6768
| `litellm_completion_kwargs` | Additional parameters to pass to LiteLLM completion calls. | {} |
6869
| `bypass_cache` | If true, bypass the cache for this operation. | False |
6970

71+
### Limiting group processing
72+
73+
Set `limit` to short-circuit the reduce phase after the first _N_ groups have been aggregated. This is useful for previewing results or capping LLM usage when you only need the earliest groups (according to the original input order). Groups beyond the limit are never scheduled, so you avoid extra fold/merge calls. If a grouped reduce returns more than one record per group, the final output list is truncated to `limit`.
74+
7075
## Advanced Features
7176

7277
### Incremental Folding

0 commit comments

Comments
 (0)