Skip to content

Commit 9db82e9

Browse files
committed
Add mkdocs
1 parent 1583f8a commit 9db82e9

40 files changed

+2149
-156
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ Example of a reduce operation with value sampling:
480480
enabled: true
481481
method: cluster
482482
sample_size: 50
483-
embedding_model: text-embedding-ada-002
483+
embedding_model: text-embedding-3-small
484484
embedding_keys:
485485
- name
486486
- price
@@ -609,7 +609,7 @@ Example:
609609
blocking_keys:
610610
- record
611611
blocking_threshold: 0.8
612-
embedding_model: text-embedding-ada-002
612+
embedding_model: text-embedding-3-small
613613
resolution_model: gpt-4o-mini
614614
comparison_model: gpt-4o-mini
615615
```

docetl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "0.1.0"

docetl/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def _optimize_step(
709709
)
710710

711711
if (
712-
not op_object.get("optimize", True)
712+
not op_object.get("optimize", False) # Default don't optimize
713713
or op_object.get("type") not in SUPPORTED_OPS
714714
):
715715
# If optimize is False or operation type is not supported, just use the operation without optimization

docetl/cli.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,15 @@ def clear_cache():
7373
cc()
7474

7575

76+
@app.command()
77+
def version():
78+
"""
79+
Display the current version of DocETL.
80+
"""
81+
import docetl
82+
83+
typer.echo(f"DocETL version: {docetl.__version__}")
84+
85+
7686
if __name__ == "__main__":
7787
app()

docetl/operations/resolve.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from typing import Any, Dict, List, Tuple
77
import random
88

9+
import numpy as np
10+
911
import jinja2
1012
from jinja2 import Template
1113
from docetl.utils import completion_cost
@@ -291,17 +293,16 @@ def meets_blocking_conditions(pair):
291293
else float("inf")
292294
)
293295
if remaining_comparisons > 0 and blocking_threshold is not None:
296+
# Compute cosine similarity for all pairs at once
297+
all_embeddings = np.array([embeddings[i] for i in range(len(input_data))])
298+
similarity_matrix = cosine_similarity(all_embeddings)
299+
294300
cosine_pairs = []
295301
for i, j in all_pairs:
296302
if (i, j) not in blocked_pairs and find_cluster(i) != find_cluster(j):
297-
try:
298-
similarity = cosine_similarity(
299-
[embeddings[i]], [embeddings[j]]
300-
)[0][0]
301-
if similarity >= blocking_threshold:
302-
cosine_pairs.append((i, j, similarity))
303-
except Exception as e:
304-
self.console.log(f"Error comparing pair {i} and {j}: {e}")
303+
similarity = similarity_matrix[i, j]
304+
if similarity >= blocking_threshold:
305+
cosine_pairs.append((i, j, similarity))
305306

306307
if remaining_comparisons != float("inf"):
307308
cosine_pairs.sort(key=lambda x: x[2], reverse=True)

docetl/operations/utils.py

Lines changed: 113 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import threading
77
from concurrent.futures import as_completed
88
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
9+
from openai import OpenAI
910

1011
from dotenv import load_dotenv
1112
from frozendict import frozendict
@@ -17,6 +18,7 @@
1718
from diskcache import Cache
1819
import tiktoken
1920
from rich import print as rprint
21+
from pydantic import BaseModel, create_model
2022

2123
from docetl.utils import count_tokens
2224

@@ -28,6 +30,8 @@
2830
LLM_CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "llm_cache")
2931
cache = Cache(LLM_CACHE_DIR)
3032

33+
client = OpenAI()
34+
3135

3236
def freezeargs(func):
3337
"""
@@ -150,6 +154,49 @@ def clear_cache(console: Console = Console()):
150154
console.log(f"[bold red]Error clearing cache: {str(e)}[/bold red]")
151155

152156

157+
def create_dynamic_model(schema: Dict[str, Any], model_name: str = "DynamicModel"):
158+
fields = {}
159+
160+
def process_schema(s: Dict[str, Any], prefix: str = "") -> None:
161+
for key, value in s.items():
162+
field_name = f"{prefix}__{key}" if prefix else key
163+
if isinstance(value, dict):
164+
process_schema(value, field_name)
165+
else:
166+
fields[field_name] = parse_type(value, field_name)
167+
168+
def parse_type(type_str: str, field_name: str) -> tuple:
169+
type_str = type_str.strip().lower()
170+
if type_str in ["str", "text", "string", "varchar"]:
171+
return (str, ...)
172+
elif type_str in ["int", "integer"]:
173+
return (int, ...)
174+
elif type_str in ["float", "decimal", "number"]:
175+
return (float, ...)
176+
elif type_str in ["bool", "boolean"]:
177+
return (bool, ...)
178+
elif type_str.startswith("list["):
179+
inner_type = type_str[5:-1].strip()
180+
item_type = parse_type(inner_type, f"{field_name}_item")[0]
181+
return (List[item_type], ...)
182+
elif type_str == "list":
183+
return (List[Any], ...)
184+
elif type_str.startswith("{") and type_str.endswith("}"):
185+
subfields = {}
186+
for item in type_str[1:-1].split(","):
187+
sub_key, sub_type = item.strip().split(":")
188+
subfields[sub_key.strip()] = parse_type(
189+
sub_type.strip(), f"{field_name}_{sub_key}"
190+
)
191+
SubModel = create_model(f"{model_name}_{field_name}", **subfields)
192+
return (SubModel, ...)
193+
else:
194+
return (Any, ...)
195+
196+
process_schema(schema)
197+
return create_model(model_name, **fields)
198+
199+
153200
def convert_val(value: Any) -> Dict[str, Any]:
154201
"""
155202
Convert a string representation of a type to a dictionary representation.
@@ -419,47 +466,62 @@ def call_llm_with_cache(
419466
parameters["required"] = list(props.keys())
420467
parameters["additionalProperties"] = False
421468

422-
tools = [
423-
{
424-
"type": "function",
425-
"function": {
426-
"name": "write_output",
427-
"description": "Write processing output to a database",
428-
"strict": True,
429-
"parameters": parameters,
430-
"additionalProperties": False,
431-
},
432-
}
433-
]
434-
tool_choice = {"type": "function", "function": {"name": "write_output"}}
469+
response_format = {
470+
"type": "json_schema",
471+
"json_schema": {
472+
"name": "write_output",
473+
"description": "Write task output to a database",
474+
"strict": True,
475+
"schema": parameters,
476+
# "additionalProperties": False,
477+
},
478+
}
479+
480+
tools = []
481+
# tool_choice = {"type": "function", "function": {"name": "write_output"}}
482+
tool_choice = "auto"
435483

436484
else:
437485
tools = json.loads(tools)
438486
tool_choice = (
439487
"required" if any(tool.get("required", False) for tool in tools) else "auto"
440488
)
441489
tools = [{"type": "function", "function": tool["function"]} for tool in tools]
490+
response_format = None
442491

443-
system_prompt = f"You are a helpful assistant, intelligently processing data. This is a {op_type} operation."
492+
system_prompt = f"You are a helpful assistant, intelligently processing data. This is a {op_type} operation. You will perform the task on the user-provided data and write the output to a database."
444493
if scratchpad:
445494
system_prompt += f"\n\nYou are incrementally processing data across multiple batches. Your task is to {op_type} the data. Consider what intermediate state you need to maintain between batches to accomplish this task effectively.\n\nYour current scratchpad contains: {scratchpad}\n\nAs you process each batch, update your scratchpad with information crucial for processing subsequent batches. This may include partial results, counters, or any other relevant data that doesn't fit into {output_schema.keys()}. For example, if you're counting occurrences, track items that have appeared once.\n\nKeep your scratchpad concise (~500 chars) and use a format you can easily parse in future batches. You may use bullet points, key-value pairs, or any other clear structure."
446495
messages = json.loads(messages)
447496

448497
# Truncate messages if they exceed the model's context length
449498
messages = truncate_messages(messages, model)
450499

451-
response = completion(
452-
model=model,
453-
messages=[
454-
{
455-
"role": "system",
456-
"content": system_prompt,
457-
},
458-
]
459-
+ messages,
460-
tools=tools,
461-
tool_choice=tool_choice,
462-
)
500+
if response_format is None:
501+
response = completion(
502+
model=model,
503+
messages=[
504+
{
505+
"role": "system",
506+
"content": system_prompt,
507+
},
508+
]
509+
+ messages,
510+
tools=tools,
511+
tool_choice=tool_choice,
512+
)
513+
else:
514+
response = completion(
515+
model=model,
516+
messages=[
517+
{
518+
"role": "system",
519+
"content": system_prompt,
520+
},
521+
]
522+
+ messages,
523+
response_format=response_format,
524+
)
463525

464526
return response
465527

@@ -612,22 +674,20 @@ def call_llm_with_gleaning(
612674
messages.append({"role": "user", "content": improvement_prompt})
613675

614676
# Call LLM for improvement
677+
# TODO: support gleaning and tools
615678
response = completion(
616679
model=model,
617680
messages=truncate_messages(messages, model),
618-
tools=[
619-
{
620-
"type": "function",
621-
"function": {
622-
"name": "write_output",
623-
"description": "Write processing output to a database",
624-
"strict": True,
625-
"parameters": parameters,
626-
"additionalProperties": False,
627-
},
628-
}
629-
],
630-
tool_choice={"type": "function", "function": {"name": "write_output"}},
681+
response_format={
682+
"type": "json_schema",
683+
"json_schema": {
684+
"name": "write_output",
685+
"description": "Write processing output to a database",
686+
"strict": True,
687+
"schema": parameters,
688+
# "additionalProperties": False,
689+
},
690+
},
631691
)
632692

633693
# Update messages with the new response
@@ -682,16 +742,20 @@ def parse_llm_response(
682742
results.append(function_args)
683743
return results
684744
else:
685-
# Default behavior for write_output function
686-
tool_calls = response.choices[0].message.tool_calls
687-
outputs = []
688-
for tool_call in tool_calls:
689-
if tool_call.function.name == "write_output":
690-
try:
691-
outputs.append(json.loads(tool_call.function.arguments))
692-
except json.JSONDecodeError:
693-
return [{}]
694-
return outputs
745+
if "tool_calls" in response.choices[0].message:
746+
# Default behavior for write_output function
747+
tool_calls = response.choices[0].message.tool_calls
748+
outputs = []
749+
for tool_call in tool_calls:
750+
if tool_call.function.name == "write_output":
751+
try:
752+
outputs.append(json.loads(tool_call.function.arguments))
753+
except json.JSONDecodeError:
754+
return [{}]
755+
return outputs
756+
757+
else:
758+
return [json.loads(response.choices[0].message.content)]
695759

696760
# message = response.choices[0].message
697761
# return [json.loads(message.content)]

docs/advanced/custom-operators.md

Whitespace-only changes.

docs/advanced/extending-agents.md

Whitespace-only changes.

docs/advanced/performance-tuning.md

Whitespace-only changes.

docs/api-reference/docetl.md

Whitespace-only changes.

0 commit comments

Comments
 (0)