Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 52 additions & 33 deletions lumen/ai/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass, field
from typing import Any

import yaml

from ..sources import Source
from .config import SOURCE_TABLE_SEPARATOR
from .utils import (
Expand Down Expand Up @@ -40,48 +42,52 @@ class VectorMetaset:

def _generate_context(self, include_columns: bool = False, truncate: bool = False) -> str:
"""
Generate formatted text representation of the context.
Generate YAML formatted representation of the context.

Args:
include_columns: Whether to include column details in the context
truncate: Whether to truncate strings and columns for brevity
"""
context = ""
tables_data = {}

for table_slug, vector_metadata in self.vector_metadata_map.items():
base_sql = truncate_string(vector_metadata.base_sql, max_length=200) if truncate else vector_metadata.base_sql
context += f"{table_slug!r} (access this table with: {base_sql})\n"

table_data = {'read_with': base_sql}

if vector_metadata.description:
desc = truncate_string(vector_metadata.description, max_length=100) if truncate else vector_metadata.description
context += f"Info: {desc}\n"
table_data['info'] = desc

# Only include columns if explicitly requested
if include_columns and vector_metadata.columns:
max_length = 20
cols_to_show = vector_metadata.columns

show_ellipsis = False
original_indices = []
if truncate:
cols_to_show, original_indices, show_ellipsis = truncate_iterable(cols_to_show, max_length)
else:
cols_to_show = list(cols_to_show)
original_indices = list(range(len(cols_to_show)))

for i, (col, orig_idx) in enumerate(zip(cols_to_show, original_indices, strict=False)):
columns_data = {}
for i, col in enumerate(cols_to_show):
if show_ellipsis and i == len(cols_to_show) // 2:
context += "...\n"

if i == 0:
context += "Cols:\n"
columns_data['...'] = '...'

col_name = truncate_string(col.name) if truncate else col.name
context += f"{orig_idx}. {col_name!r}"
if col.description:
col_desc = truncate_string(col.description, max_length=100) if truncate else col.description
context += f": {col_desc}"
context += "\n"
return context
columns_data[col_name] = col_desc
else:
columns_data[col_name] = None

if columns_data:
table_data['columns'] = columns_data

tables_data[table_slug] = table_data

return yaml.dump(tables_data, default_flow_style=False, allow_unicode=True, sort_keys=False)

@property
def table_context(self) -> str:
Expand Down Expand Up @@ -120,63 +126,76 @@ class SQLMetaset:

def _generate_context(self, include_columns: bool = False, truncate: bool = False) -> str:
"""
Generate formatted context with both vector and SQL data.
Generate YAML formatted context with both vector and SQL data.

Args:
include_columns: Whether to include column details in the context
truncate: Whether to truncate strings and columns for brevity

Returns:
Formatted context string
YAML formatted context string
"""
context = ""
tables_data = {}

for table_slug in self.sql_metadata_map.keys():
vector_metadata = self.vector_metaset.vector_metadata_map.get(table_slug)
if not vector_metadata:
continue

base_sql = truncate_string(vector_metadata.base_sql, max_length=200) if truncate else vector_metadata.base_sql
context += f"\n{table_slug!r} (access this table with: {base_sql})\n"

table_data = {'read_with': base_sql}

if vector_metadata.description:
desc = truncate_string(vector_metadata.description, max_length=100) if truncate else vector_metadata.description
context += f"Info: {desc}\n"
table_data['info'] = desc

sql_data: SQLMetadata = self.sql_metadata_map.get(table_slug)
if sql_data:
# Get the count from schema
if sql_data.schema.get("__len__"):
context += f"Row count: {len(sql_data.schema)}\n"
table_data['row_count'] = len(sql_data.schema)

# Only include columns if explicitly requested
if include_columns and vector_metadata.columns:
cols_to_show = vector_metadata.columns
context += "Columns:"
columns_data = {}

for col in cols_to_show:
schema_data = None
if sql_data and col.name in sql_data.schema:
schema_data = sql_data.schema[col.name]
if truncate and schema_data == "<null>":
continue

# Get column name
context += f"\n- {col.name}"
col_info = {}

# Get column description with optional truncation
if col.description:
col_desc = truncate_string(col.description, max_length=100) if truncate else col.description
context += f": {col_desc}"
else:
context += ": "
col_info['description'] = col_desc

# Add schema info for the column if available
if schema_data:
if truncate and schema_data.get('type') == 'enum':
schema_data = truncate_string(str(schema_data), max_length=50)
context += f" `{schema_data}`"
context += "\n"
return context.replace("'type': 'str', ", "") # Remove type info for lower token
if schema_data and schema_data != "<null>":
if isinstance(schema_data, dict):
# Remove 'type': 'str' for token efficiency
schema_copy = {k: v for k, v in schema_data.items() if not (k == 'type' and v == 'str')}
if truncate and schema_copy.get('type') == 'enum':
schema_str = str(schema_copy)
if len(schema_str) > 50:
schema_copy = truncate_string(schema_str, max_length=50)
col_info.update(schema_copy)
else:
col_info['value'] = schema_data

columns_data[col.name] = col_info if col_info else None

if columns_data:
table_data['columns'] = columns_data

tables_data[table_slug] = table_data

return yaml.dump(tables_data, default_flow_style=False, allow_unicode=True, sort_keys=False)

@property
def table_context(self) -> str:
Expand Down
47 changes: 38 additions & 9 deletions lumen/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pandas as pd
import param
import yaml

from jinja2 import (
ChoiceLoader, DictLoader, Environment, FileSystemLoader, StrictUndefined,
Expand Down Expand Up @@ -520,9 +521,9 @@ def describe_data_sync(df):
# Use the first column as index to save tokens
if len(df.columns) > 1:
df = df.set_index(df.columns[0])
return yaml.dump(df.to_dict('index'), default_flow_style=False, allow_unicode=True, sort_keys=False)
else:
df = df.to_dict("records")
return df
return yaml.dump(df.to_dict("records"), default_flow_style=False, allow_unicode=True, sort_keys=False)

is_sampled = False
if shape[0] > 5000:
Expand Down Expand Up @@ -598,7 +599,7 @@ def describe_data_sync(df):
head_sample = df.head(2).to_dict('records')
tail_sample = df.tail(2).to_dict('records')

return {
result = {
"summary": {
"n_cells": size,
"shape": shape,
Expand All @@ -610,9 +611,42 @@ def describe_data_sync(df):
"tail": tail_sample[0] if tail_sample else {},
}

return yaml.dump(result, default_flow_style=False, allow_unicode=True, sort_keys=False)

return await asyncio.to_thread(describe_data_sync, df)


def format_data_as_yaml(data: pd.DataFrame | dict, title: str = "Data Overview") -> str:
"""
Format a DataFrame or data dictionary as YAML.

Parameters
----------
data : pd.DataFrame | dict
The data to format. Can be a DataFrame or dictionary.
title : str
The title to use in the YAML output.

Returns
-------
str
YAML formatted string representation of the data.
"""
if isinstance(data, pd.DataFrame):
# Convert DataFrame to dictionary, using index as keys
if data.index.name:
data_dict = data.to_dict('index')
else:
data_dict = data.to_dict('records')
else:
data_dict = data

# Wrap in title
output = {title: data_dict}

return yaml.dump(output, default_flow_style=False, allow_unicode=True, sort_keys=False)


def clean_sql(sql_expr: str, dialect: str | None = None) -> str:
"""
Cleans up a SQL expression generated by an LLM by removing
Expand Down Expand Up @@ -838,20 +872,15 @@ def truncate_iterable(iterable, max_length=150) -> tuple[list, list, bool]:
iterable_list = list(iterable)
if len(iterable_list) > max_length:
half = max_length // 2
first_half_indices = list(range(half))
second_half_indices = list(range(len(iterable_list) - half, len(iterable_list)))

first_half_items = iterable_list[:half]
second_half_items = iterable_list[-half:]

cols_to_show = first_half_items + second_half_items
original_indices = first_half_indices + second_half_indices
show_ellipsis = True
else:
cols_to_show = iterable_list
original_indices = list(range(len(iterable_list)))
show_ellipsis = False
return cols_to_show, original_indices, show_ellipsis
return cols_to_show, show_ellipsis


async def with_timeout(coro, timeout_seconds=10, default_value=None, error_message=None):
Expand Down
Loading