Skip to content

Commit 77e4af2

Browse files
authored
Serialize data to yaml (#1449)
1 parent eef3689 commit 77e4af2

File tree

2 files changed

+90
-42
lines changed

2 files changed

+90
-42
lines changed

lumen/ai/schemas.py

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from dataclasses import dataclass, field
44
from typing import Any
55

6+
import yaml
7+
68
from ..sources import Source
79
from .config import SOURCE_TABLE_SEPARATOR
810
from .utils import (
@@ -40,48 +42,52 @@ class VectorMetaset:
4042

4143
def _generate_context(self, include_columns: bool = False, truncate: bool = False) -> str:
4244
"""
43-
Generate formatted text representation of the context.
45+
Generate YAML formatted representation of the context.
4446
4547
Args:
4648
include_columns: Whether to include column details in the context
4749
truncate: Whether to truncate strings and columns for brevity
4850
"""
49-
context = ""
51+
tables_data = {}
52+
5053
for table_slug, vector_metadata in self.vector_metadata_map.items():
5154
base_sql = truncate_string(vector_metadata.base_sql, max_length=200) if truncate else vector_metadata.base_sql
52-
context += f"{table_slug!r} (access this table with: {base_sql})\n"
55+
56+
table_data = {'read_with': base_sql}
5357

5458
if vector_metadata.description:
5559
desc = truncate_string(vector_metadata.description, max_length=100) if truncate else vector_metadata.description
56-
context += f"Info: {desc}\n"
60+
table_data['info'] = desc
5761

5862
# Only include columns if explicitly requested
5963
if include_columns and vector_metadata.columns:
6064
max_length = 20
6165
cols_to_show = vector_metadata.columns
6266

6367
show_ellipsis = False
64-
original_indices = []
6568
if truncate:
6669
cols_to_show, original_indices, show_ellipsis = truncate_iterable(cols_to_show, max_length)
6770
else:
6871
cols_to_show = list(cols_to_show)
69-
original_indices = list(range(len(cols_to_show)))
7072

71-
for i, (col, orig_idx) in enumerate(zip(cols_to_show, original_indices, strict=False)):
73+
columns_data = {}
74+
for i, col in enumerate(cols_to_show):
7275
if show_ellipsis and i == len(cols_to_show) // 2:
73-
context += "...\n"
74-
75-
if i == 0:
76-
context += "Cols:\n"
76+
columns_data['...'] = '...'
7777

7878
col_name = truncate_string(col.name) if truncate else col.name
79-
context += f"{orig_idx}. {col_name!r}"
8079
if col.description:
8180
col_desc = truncate_string(col.description, max_length=100) if truncate else col.description
82-
context += f": {col_desc}"
83-
context += "\n"
84-
return context
81+
columns_data[col_name] = col_desc
82+
else:
83+
columns_data[col_name] = None
84+
85+
if columns_data:
86+
table_data['columns'] = columns_data
87+
88+
tables_data[table_slug] = table_data
89+
90+
return yaml.dump(tables_data, default_flow_style=False, allow_unicode=True, sort_keys=False)
8591

8692
@property
8793
def table_context(self) -> str:
@@ -120,63 +126,76 @@ class SQLMetaset:
120126

121127
def _generate_context(self, include_columns: bool = False, truncate: bool = False) -> str:
122128
"""
123-
Generate formatted context with both vector and SQL data.
129+
Generate YAML formatted context with both vector and SQL data.
124130
125131
Args:
126132
include_columns: Whether to include column details in the context
127133
truncate: Whether to truncate strings and columns for brevity
128134
129135
Returns:
130-
Formatted context string
136+
YAML formatted context string
131137
"""
132-
context = ""
138+
tables_data = {}
133139

134140
for table_slug in self.sql_metadata_map.keys():
135141
vector_metadata = self.vector_metaset.vector_metadata_map.get(table_slug)
136142
if not vector_metadata:
137143
continue
138144

139145
base_sql = truncate_string(vector_metadata.base_sql, max_length=200) if truncate else vector_metadata.base_sql
140-
context += f"\n{table_slug!r} (access this table with: {base_sql})\n"
146+
147+
table_data = {'read_with': base_sql}
141148

142149
if vector_metadata.description:
143150
desc = truncate_string(vector_metadata.description, max_length=100) if truncate else vector_metadata.description
144-
context += f"Info: {desc}\n"
151+
table_data['info'] = desc
145152

146153
sql_data: SQLMetadata = self.sql_metadata_map.get(table_slug)
147154
if sql_data:
148155
# Get the count from schema
149156
if sql_data.schema.get("__len__"):
150-
context += f"Row count: {len(sql_data.schema)}\n"
157+
table_data['row_count'] = len(sql_data.schema)
151158

152159
# Only include columns if explicitly requested
153160
if include_columns and vector_metadata.columns:
154161
cols_to_show = vector_metadata.columns
155-
context += "Columns:"
162+
columns_data = {}
163+
156164
for col in cols_to_show:
157165
schema_data = None
158166
if sql_data and col.name in sql_data.schema:
159167
schema_data = sql_data.schema[col.name]
160168
if truncate and schema_data == "<null>":
161169
continue
162170

163-
# Get column name
164-
context += f"\n- {col.name}"
171+
col_info = {}
165172

166173
# Get column description with optional truncation
167174
if col.description:
168175
col_desc = truncate_string(col.description, max_length=100) if truncate else col.description
169-
context += f": {col_desc}"
170-
else:
171-
context += ": "
176+
col_info['description'] = col_desc
172177

173178
# Add schema info for the column if available
174-
if schema_data:
175-
if truncate and schema_data.get('type') == 'enum':
176-
schema_data = truncate_string(str(schema_data), max_length=50)
177-
context += f" `{schema_data}`"
178-
context += "\n"
179-
return context.replace("'type': 'str', ", "") # Remove type info for lower token
179+
if schema_data and schema_data != "<null>":
180+
if isinstance(schema_data, dict):
181+
# Remove 'type': 'str' for token efficiency
182+
schema_copy = {k: v for k, v in schema_data.items() if not (k == 'type' and v == 'str')}
183+
if truncate and schema_copy.get('type') == 'enum':
184+
schema_str = str(schema_copy)
185+
if len(schema_str) > 50:
186+
schema_copy = truncate_string(schema_str, max_length=50)
187+
col_info.update(schema_copy)
188+
else:
189+
col_info['value'] = schema_data
190+
191+
columns_data[col.name] = col_info if col_info else None
192+
193+
if columns_data:
194+
table_data['columns'] = columns_data
195+
196+
tables_data[table_slug] = table_data
197+
198+
return yaml.dump(tables_data, default_flow_style=False, allow_unicode=True, sort_keys=False)
180199

181200
@property
182201
def table_context(self) -> str:

lumen/ai/utils.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import pandas as pd
2424
import param
25+
import yaml
2526

2627
from jinja2 import (
2728
ChoiceLoader, DictLoader, Environment, FileSystemLoader, StrictUndefined,
@@ -520,9 +521,9 @@ def describe_data_sync(df):
520521
# Use the first column as index to save tokens
521522
if len(df.columns) > 1:
522523
df = df.set_index(df.columns[0])
524+
return yaml.dump(df.to_dict('index'), default_flow_style=False, allow_unicode=True, sort_keys=False)
523525
else:
524-
df = df.to_dict("records")
525-
return df
526+
return yaml.dump(df.to_dict("records"), default_flow_style=False, allow_unicode=True, sort_keys=False)
526527

527528
is_sampled = False
528529
if shape[0] > 5000:
@@ -598,7 +599,7 @@ def describe_data_sync(df):
598599
head_sample = df.head(2).to_dict('records')
599600
tail_sample = df.tail(2).to_dict('records')
600601

601-
return {
602+
result = {
602603
"summary": {
603604
"n_cells": size,
604605
"shape": shape,
@@ -610,9 +611,42 @@ def describe_data_sync(df):
610611
"tail": tail_sample[0] if tail_sample else {},
611612
}
612613

614+
return yaml.dump(result, default_flow_style=False, allow_unicode=True, sort_keys=False)
615+
613616
return await asyncio.to_thread(describe_data_sync, df)
614617

615618

619+
def format_data_as_yaml(data: pd.DataFrame | dict, title: str = "Data Overview") -> str:
620+
"""
621+
Format a DataFrame or data dictionary as YAML.
622+
623+
Parameters
624+
----------
625+
data : pd.DataFrame | dict
626+
The data to format. Can be a DataFrame or dictionary.
627+
title : str
628+
The title to use in the YAML output.
629+
630+
Returns
631+
-------
632+
str
633+
YAML formatted string representation of the data.
634+
"""
635+
if isinstance(data, pd.DataFrame):
636+
# Convert DataFrame to dictionary, using index as keys
637+
if data.index.name:
638+
data_dict = data.to_dict('index')
639+
else:
640+
data_dict = data.to_dict('records')
641+
else:
642+
data_dict = data
643+
644+
# Wrap in title
645+
output = {title: data_dict}
646+
647+
return yaml.dump(output, default_flow_style=False, allow_unicode=True, sort_keys=False)
648+
649+
616650
def clean_sql(sql_expr: str, dialect: str | None = None) -> str:
617651
"""
618652
Cleans up a SQL expression generated by an LLM by removing
@@ -838,20 +872,15 @@ def truncate_iterable(iterable, max_length=150) -> tuple[list, list, bool]:
838872
iterable_list = list(iterable)
839873
if len(iterable_list) > max_length:
840874
half = max_length // 2
841-
first_half_indices = list(range(half))
842-
second_half_indices = list(range(len(iterable_list) - half, len(iterable_list)))
843-
844875
first_half_items = iterable_list[:half]
845876
second_half_items = iterable_list[-half:]
846877

847878
cols_to_show = first_half_items + second_half_items
848-
original_indices = first_half_indices + second_half_indices
849879
show_ellipsis = True
850880
else:
851881
cols_to_show = iterable_list
852-
original_indices = list(range(len(iterable_list)))
853882
show_ellipsis = False
854-
return cols_to_show, original_indices, show_ellipsis
883+
return cols_to_show, show_ellipsis
855884

856885

857886
async def with_timeout(coro, timeout_seconds=10, default_value=None, error_message=None):

0 commit comments

Comments
 (0)