Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,30 +181,47 @@ def extract_sql(self, llm_response: str) -> str:
str: The extracted SQL query.
"""

# If the llm_response contains a CTE (with clause), extract the last sql between WITH and ;
sqls = re.findall(r"\bWITH\b .*?;", llm_response, re.DOTALL)
import re
"""
Extracts the SQL query from the LLM response, handling various formats including:
- WITH clause
- SELECT statement
- CREATE TABLE AS SELECT
- Markdown code blocks
"""

# Match CREATE TABLE ... AS SELECT
sqls = re.findall(r"\bCREATE\s+TABLE\b.*?\bAS\b.*?;", llm_response, re.DOTALL | re.IGNORECASE)
if sqls:
sql = sqls[-1]
self.log(title="Extracted SQL", message=f"{sql}")
return sql

# If the llm_response is not markdown formatted, extract last sql by finding select and ; in the response
sqls = re.findall(r"SELECT.*?;", llm_response, re.DOTALL)
# Match WITH clause (CTEs)
sqls = re.findall(r"\bWITH\b .*?;", llm_response, re.DOTALL | re.IGNORECASE)
if sqls:
sql = sqls[-1]
self.log(title="Extracted SQL", message=f"{sql}")
return sql

# If the llm_response contains a markdown code block, with or without the sql tag, extract the last sql from it
sqls = re.findall(r"```sql\n(.*)```", llm_response, re.DOTALL)
# Match SELECT ... ;
sqls = re.findall(r"\bSELECT\b .*?;", llm_response, re.DOTALL | re.IGNORECASE)
if sqls:
sql = sqls[-1]
self.log(title="Extracted SQL", message=f"{sql}")
return sql

sqls = re.findall(r"```(.*)```", llm_response, re.DOTALL)
# Match ```sql ... ``` blocks
sqls = re.findall(r"```sql\s*\n(.*?)```", llm_response, re.DOTALL | re.IGNORECASE)
if sqls:
sql = sqls[-1]
sql = sqls[-1].strip()
self.log(title="Extracted SQL", message=f"{sql}")
return sql

# Match any ``` ... ``` code blocks
sqls = re.findall(r"```(.*?)```", llm_response, re.DOTALL | re.IGNORECASE)
if sqls:
sql = sqls[-1].strip()
self.log(title="Extracted SQL", message=f"{sql}")
return sql

Expand Down