|
6 | 6 | import threading
|
7 | 7 | from concurrent.futures import as_completed
|
8 | 8 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
| 9 | +from openai import OpenAI |
9 | 10 |
|
10 | 11 | from dotenv import load_dotenv
|
11 | 12 | from frozendict import frozendict
|
|
17 | 18 | from diskcache import Cache
|
18 | 19 | import tiktoken
|
19 | 20 | from rich import print as rprint
|
| 21 | +from pydantic import BaseModel, create_model |
20 | 22 |
|
21 | 23 | from docetl.utils import count_tokens
|
22 | 24 |
|
|
28 | 30 | LLM_CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "llm_cache")
|
29 | 31 | cache = Cache(LLM_CACHE_DIR)
|
30 | 32 |
|
| 33 | +client = OpenAI() |
| 34 | + |
31 | 35 |
|
32 | 36 | def freezeargs(func):
|
33 | 37 | """
|
@@ -150,6 +154,49 @@ def clear_cache(console: Console = Console()):
|
150 | 154 | console.log(f"[bold red]Error clearing cache: {str(e)}[/bold red]")
|
151 | 155 |
|
152 | 156 |
|
| 157 | +def create_dynamic_model(schema: Dict[str, Any], model_name: str = "DynamicModel"): |
| 158 | + fields = {} |
| 159 | + |
| 160 | + def process_schema(s: Dict[str, Any], prefix: str = "") -> None: |
| 161 | + for key, value in s.items(): |
| 162 | + field_name = f"{prefix}__{key}" if prefix else key |
| 163 | + if isinstance(value, dict): |
| 164 | + process_schema(value, field_name) |
| 165 | + else: |
| 166 | + fields[field_name] = parse_type(value, field_name) |
| 167 | + |
| 168 | + def parse_type(type_str: str, field_name: str) -> tuple: |
| 169 | + type_str = type_str.strip().lower() |
| 170 | + if type_str in ["str", "text", "string", "varchar"]: |
| 171 | + return (str, ...) |
| 172 | + elif type_str in ["int", "integer"]: |
| 173 | + return (int, ...) |
| 174 | + elif type_str in ["float", "decimal", "number"]: |
| 175 | + return (float, ...) |
| 176 | + elif type_str in ["bool", "boolean"]: |
| 177 | + return (bool, ...) |
| 178 | + elif type_str.startswith("list["): |
| 179 | + inner_type = type_str[5:-1].strip() |
| 180 | + item_type = parse_type(inner_type, f"{field_name}_item")[0] |
| 181 | + return (List[item_type], ...) |
| 182 | + elif type_str == "list": |
| 183 | + return (List[Any], ...) |
| 184 | + elif type_str.startswith("{") and type_str.endswith("}"): |
| 185 | + subfields = {} |
| 186 | + for item in type_str[1:-1].split(","): |
| 187 | + sub_key, sub_type = item.strip().split(":") |
| 188 | + subfields[sub_key.strip()] = parse_type( |
| 189 | + sub_type.strip(), f"{field_name}_{sub_key}" |
| 190 | + ) |
| 191 | + SubModel = create_model(f"{model_name}_{field_name}", **subfields) |
| 192 | + return (SubModel, ...) |
| 193 | + else: |
| 194 | + return (Any, ...) |
| 195 | + |
| 196 | + process_schema(schema) |
| 197 | + return create_model(model_name, **fields) |
| 198 | + |
| 199 | + |
153 | 200 | def convert_val(value: Any) -> Dict[str, Any]:
|
154 | 201 | """
|
155 | 202 | Convert a string representation of a type to a dictionary representation.
|
@@ -419,47 +466,62 @@ def call_llm_with_cache(
|
419 | 466 | parameters["required"] = list(props.keys())
|
420 | 467 | parameters["additionalProperties"] = False
|
421 | 468 |
|
422 |
| - tools = [ |
423 |
| - { |
424 |
| - "type": "function", |
425 |
| - "function": { |
426 |
| - "name": "write_output", |
427 |
| - "description": "Write processing output to a database", |
428 |
| - "strict": True, |
429 |
| - "parameters": parameters, |
430 |
| - "additionalProperties": False, |
431 |
| - }, |
432 |
| - } |
433 |
| - ] |
434 |
| - tool_choice = {"type": "function", "function": {"name": "write_output"}} |
| 469 | + response_format = { |
| 470 | + "type": "json_schema", |
| 471 | + "json_schema": { |
| 472 | + "name": "write_output", |
| 473 | + "description": "Write task output to a database", |
| 474 | + "strict": True, |
| 475 | + "schema": parameters, |
| 476 | + # "additionalProperties": False, |
| 477 | + }, |
| 478 | + } |
| 479 | + |
| 480 | + tools = [] |
| 481 | + # tool_choice = {"type": "function", "function": {"name": "write_output"}} |
| 482 | + tool_choice = "auto" |
435 | 483 |
|
436 | 484 | else:
|
437 | 485 | tools = json.loads(tools)
|
438 | 486 | tool_choice = (
|
439 | 487 | "required" if any(tool.get("required", False) for tool in tools) else "auto"
|
440 | 488 | )
|
441 | 489 | tools = [{"type": "function", "function": tool["function"]} for tool in tools]
|
| 490 | + response_format = None |
442 | 491 |
|
443 |
| - system_prompt = f"You are a helpful assistant, intelligently processing data. This is a {op_type} operation." |
| 492 | + system_prompt = f"You are a helpful assistant, intelligently processing data. This is a {op_type} operation. You will perform the task on the user-provided data and write the output to a database." |
444 | 493 | if scratchpad:
|
445 | 494 | system_prompt += f"\n\nYou are incrementally processing data across multiple batches. Your task is to {op_type} the data. Consider what intermediate state you need to maintain between batches to accomplish this task effectively.\n\nYour current scratchpad contains: {scratchpad}\n\nAs you process each batch, update your scratchpad with information crucial for processing subsequent batches. This may include partial results, counters, or any other relevant data that doesn't fit into {output_schema.keys()}. For example, if you're counting occurrences, track items that have appeared once.\n\nKeep your scratchpad concise (~500 chars) and use a format you can easily parse in future batches. You may use bullet points, key-value pairs, or any other clear structure."
|
446 | 495 | messages = json.loads(messages)
|
447 | 496 |
|
448 | 497 | # Truncate messages if they exceed the model's context length
|
449 | 498 | messages = truncate_messages(messages, model)
|
450 | 499 |
|
451 |
| - response = completion( |
452 |
| - model=model, |
453 |
| - messages=[ |
454 |
| - { |
455 |
| - "role": "system", |
456 |
| - "content": system_prompt, |
457 |
| - }, |
458 |
| - ] |
459 |
| - + messages, |
460 |
| - tools=tools, |
461 |
| - tool_choice=tool_choice, |
462 |
| - ) |
| 500 | + if response_format is None: |
| 501 | + response = completion( |
| 502 | + model=model, |
| 503 | + messages=[ |
| 504 | + { |
| 505 | + "role": "system", |
| 506 | + "content": system_prompt, |
| 507 | + }, |
| 508 | + ] |
| 509 | + + messages, |
| 510 | + tools=tools, |
| 511 | + tool_choice=tool_choice, |
| 512 | + ) |
| 513 | + else: |
| 514 | + response = completion( |
| 515 | + model=model, |
| 516 | + messages=[ |
| 517 | + { |
| 518 | + "role": "system", |
| 519 | + "content": system_prompt, |
| 520 | + }, |
| 521 | + ] |
| 522 | + + messages, |
| 523 | + response_format=response_format, |
| 524 | + ) |
463 | 525 |
|
464 | 526 | return response
|
465 | 527 |
|
@@ -612,22 +674,20 @@ def call_llm_with_gleaning(
|
612 | 674 | messages.append({"role": "user", "content": improvement_prompt})
|
613 | 675 |
|
614 | 676 | # Call LLM for improvement
|
| 677 | + # TODO: support gleaning and tools |
615 | 678 | response = completion(
|
616 | 679 | model=model,
|
617 | 680 | messages=truncate_messages(messages, model),
|
618 |
| - tools=[ |
619 |
| - { |
620 |
| - "type": "function", |
621 |
| - "function": { |
622 |
| - "name": "write_output", |
623 |
| - "description": "Write processing output to a database", |
624 |
| - "strict": True, |
625 |
| - "parameters": parameters, |
626 |
| - "additionalProperties": False, |
627 |
| - }, |
628 |
| - } |
629 |
| - ], |
630 |
| - tool_choice={"type": "function", "function": {"name": "write_output"}}, |
| 681 | + response_format={ |
| 682 | + "type": "json_schema", |
| 683 | + "json_schema": { |
| 684 | + "name": "write_output", |
| 685 | + "description": "Write processing output to a database", |
| 686 | + "strict": True, |
| 687 | + "schema": parameters, |
| 688 | + # "additionalProperties": False, |
| 689 | + }, |
| 690 | + }, |
631 | 691 | )
|
632 | 692 |
|
633 | 693 | # Update messages with the new response
|
@@ -682,16 +742,20 @@ def parse_llm_response(
|
682 | 742 | results.append(function_args)
|
683 | 743 | return results
|
684 | 744 | else:
|
685 |
| - # Default behavior for write_output function |
686 |
| - tool_calls = response.choices[0].message.tool_calls |
687 |
| - outputs = [] |
688 |
| - for tool_call in tool_calls: |
689 |
| - if tool_call.function.name == "write_output": |
690 |
| - try: |
691 |
| - outputs.append(json.loads(tool_call.function.arguments)) |
692 |
| - except json.JSONDecodeError: |
693 |
| - return [{}] |
694 |
| - return outputs |
| 745 | + if "tool_calls" in response.choices[0].message: |
| 746 | + # Default behavior for write_output function |
| 747 | + tool_calls = response.choices[0].message.tool_calls |
| 748 | + outputs = [] |
| 749 | + for tool_call in tool_calls: |
| 750 | + if tool_call.function.name == "write_output": |
| 751 | + try: |
| 752 | + outputs.append(json.loads(tool_call.function.arguments)) |
| 753 | + except json.JSONDecodeError: |
| 754 | + return [{}] |
| 755 | + return outputs |
| 756 | + |
| 757 | + else: |
| 758 | + return [json.loads(response.choices[0].message.content)] |
695 | 759 |
|
696 | 760 | # message = response.choices[0].message
|
697 | 761 | # return [json.loads(message.content)]
|
|
0 commit comments