Skip to content

Commit 66f6666

Browse files
committed
feat: add trino sqlalchemy dialect
Signed-off-by: Đặng Minh Dũng <[email protected]>
1 parent 1c7149f commit 66f6666

File tree

6 files changed

+648
-3
lines changed

6 files changed

+648
-3
lines changed

trino/dbapi.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,44 @@
2929
import trino.exceptions
3030
import trino.client
3131
import trino.logging
32-
from trino.transaction import Transaction, IsolationLevel, NO_TRANSACTION
33-
32+
from trino.transaction import (
33+
Transaction,
34+
IsolationLevel,
35+
NO_TRANSACTION
36+
)
37+
from trino.exceptions import (
38+
Warning,
39+
Error,
40+
InterfaceError,
41+
DatabaseError,
42+
DataError,
43+
OperationalError,
44+
IntegrityError,
45+
InternalError,
46+
ProgrammingError,
47+
NotSupportedError,
48+
)
3449

35-
__all__ = ["connect", "Connection", "Cursor"]
50+
__all__ = [
51+
# https://www.python.org/dev/peps/pep-0249/#globals
52+
"apilevel",
53+
"threadsafety",
54+
"paramstyle",
55+
"connect",
56+
"Connection",
57+
"Cursor",
58+
# https://www.python.org/dev/peps/pep-0249/#exceptions
59+
"Warning",
60+
"Error",
61+
"InterfaceError",
62+
"DatabaseError",
63+
"DataError",
64+
"OperationalError",
65+
"IntegrityError",
66+
"InternalError",
67+
"ProgrammingError",
68+
"NotSupportedError",
69+
]
3670

3771

3872
apilevel = "2.0"

trino/sqlalchemy/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
from sqlalchemy.dialects import registry
13+
14+
registry.register("trino", "trino.sqlalchemy.dialect.TrinoDialect", "TrinoDialect")

trino/sqlalchemy/compiler.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
from sqlalchemy.sql import compiler
13+
14+
# https://trino.io/docs/current/language/reserved.html
15+
RESERVED_WORDS = {
16+
"alter",
17+
"and",
18+
"as",
19+
"between",
20+
"by",
21+
"case",
22+
"cast",
23+
"constraint",
24+
"create",
25+
"cross",
26+
"cube",
27+
"current_date",
28+
"current_path",
29+
"current_role",
30+
"current_time",
31+
"current_timestamp",
32+
"current_user",
33+
"deallocate",
34+
"delete",
35+
"describe",
36+
"distinct",
37+
"drop",
38+
"else",
39+
"end",
40+
"escape",
41+
"except",
42+
"execute",
43+
"exists",
44+
"extract",
45+
"false",
46+
"for",
47+
"from",
48+
"full",
49+
"group",
50+
"grouping",
51+
"having",
52+
"in",
53+
"inner",
54+
"insert",
55+
"intersect",
56+
"into",
57+
"is",
58+
"join",
59+
"left",
60+
"like",
61+
"localtime",
62+
"localtimestamp",
63+
"natural",
64+
"normalize",
65+
"not",
66+
"null",
67+
"on",
68+
"or",
69+
"order",
70+
"outer",
71+
"prepare",
72+
"recursive",
73+
"right",
74+
"rollup",
75+
"select",
76+
"table",
77+
"then",
78+
"true",
79+
"uescape",
80+
"union",
81+
"unnest",
82+
"using",
83+
"values",
84+
"when",
85+
"where",
86+
"with",
87+
}
88+
89+
90+
class TrinoSQLCompiler(compiler.SQLCompiler):
91+
pass
92+
93+
94+
class TrinoDDLCompiler(compiler.DDLCompiler):
95+
pass
96+
97+
98+
class TrinoTypeCompiler(compiler.GenericTypeCompiler):
99+
pass
100+
101+
102+
class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
103+
reserved_words = RESERVED_WORDS

trino/sqlalchemy/datatype.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
import re
13+
from typing import Dict, Iterator, Type, Union
14+
15+
from sqlalchemy import util
16+
from sqlalchemy.sql import sqltypes
17+
from sqlalchemy.sql.type_api import TypeEngine
18+
19+
# https://trino.io/docs/current/language/types.html
20+
_type_map = {
21+
# === Boolean ===
22+
'boolean': sqltypes.BOOLEAN,
23+
24+
# === Integer ===
25+
'tinyint': sqltypes.SMALLINT,
26+
'smallint': sqltypes.SMALLINT,
27+
'integer': sqltypes.INTEGER,
28+
'bigint': sqltypes.BIGINT,
29+
30+
# === Floating-point ===
31+
'real': sqltypes.FLOAT,
32+
'double': sqltypes.FLOAT,
33+
34+
# === Fixed-precision ===
35+
'decimal': sqltypes.DECIMAL,
36+
37+
# === String ===
38+
'varchar': sqltypes.VARCHAR,
39+
'char': sqltypes.CHAR,
40+
'varbinary': sqltypes.VARBINARY,
41+
'json': sqltypes.JSON,
42+
43+
# === Date and time ===
44+
'date': sqltypes.DATE,
45+
'time': sqltypes.TIME,
46+
'timestamp': sqltypes.TIMESTAMP,
47+
48+
# 'interval year to month':
49+
# 'interval day to second':
50+
#
51+
# === Structural ===
52+
# 'array': ARRAY,
53+
# 'map': MAP
54+
# 'row': ROW
55+
#
56+
# === Mixed ===
57+
# 'ipaddress': IPADDRESS
58+
# 'uuid': UUID,
59+
# 'hyperloglog': HYPERLOGLOG,
60+
# 'p4hyperloglog': P4HYPERLOGLOG,
61+
# 'qdigest': QDIGEST,
62+
# 'tdigest': TDIGEST,
63+
}
64+
65+
SQLType = Union[TypeEngine, Type[TypeEngine]]
66+
67+
68+
class MAP(TypeEngine):
69+
__visit_name__ = "MAP"
70+
71+
def __init__(self, key_type: SQLType, value_type: SQLType):
72+
if isinstance(key_type, type):
73+
key_type = key_type()
74+
self.key_type: TypeEngine = key_type
75+
76+
if isinstance(value_type, type):
77+
value_type = value_type()
78+
self.value_type: TypeEngine = value_type
79+
80+
@property
81+
def python_type(self):
82+
return dict
83+
84+
85+
class ROW(TypeEngine):
86+
__visit_name__ = "ROW"
87+
88+
def __init__(self, attr_types: Dict[str, SQLType]):
89+
for name, attr_type in attr_types.items():
90+
if isinstance(attr_type, type):
91+
attr_type = attr_type()
92+
attr_types[name] = attr_type
93+
self.attr_types: Dict[str, TypeEngine] = attr_types
94+
95+
@property
96+
def python_type(self):
97+
return dict
98+
99+
100+
def split(string: str, delimiter: str = ',',
101+
quote: str = '"', escaped_quote: str = r'\"',
102+
open_bracket: str = '(', close_bracket: str = ')') -> Iterator[str]:
103+
"""
104+
A split function that is aware of quotes and brackets/parentheses.
105+
106+
:param string: string to split
107+
:param delimiter: string defining where to split, usually a comma or space
108+
:param quote: string, either a single or a double quote
109+
:param escaped_quote: string representing an escaped quote
110+
:param open_bracket: string, either [, {, < or (
111+
:param close_bracket: string, either ], }, > or )
112+
"""
113+
parens = 0
114+
quotes = False
115+
i = 0
116+
for j, character in enumerate(string):
117+
complete = parens == 0 and not quotes
118+
if complete and character == delimiter:
119+
yield string[i:j]
120+
i = j + len(delimiter)
121+
elif character == open_bracket:
122+
parens += 1
123+
elif character == close_bracket:
124+
parens -= 1
125+
elif character == quote:
126+
if quotes and string[j - len(escaped_quote) + 1: j + 1] != escaped_quote:
127+
quotes = False
128+
elif not quotes:
129+
quotes = True
130+
yield string[i:]
131+
132+
133+
def parse_sqltype(type_str: str) -> TypeEngine:
134+
type_str = type_str.strip().lower()
135+
match = re.match(r'^(?P<type>\w+)\s*(?:\((?P<options>.*)\))?', type_str)
136+
if not match:
137+
util.warn(f"Could not parse type name '{type_str}'")
138+
return sqltypes.NULLTYPE
139+
type_name = match.group("type")
140+
type_opts = match.group("options")
141+
142+
if type_name == "array":
143+
item_type = parse_sqltype(type_opts)
144+
if isinstance(item_type, sqltypes.ARRAY):
145+
dimensions = (item_type.dimensions or 1) + 1
146+
return sqltypes.ARRAY(item_type.item_type, dimensions=dimensions)
147+
return sqltypes.ARRAY(item_type)
148+
elif type_name == "map":
149+
key_type_str, value_type_str = split(type_opts)
150+
key_type = parse_sqltype(key_type_str)
151+
value_type = parse_sqltype(value_type_str)
152+
return MAP(key_type, value_type)
153+
elif type_name == "row":
154+
attr_types: Dict[str, SQLType] = {}
155+
for attr_str in split(type_opts):
156+
name, attr_type_str = split(attr_str.strip(), delimiter=' ')
157+
attr_type = parse_sqltype(attr_type_str)
158+
attr_types[name] = attr_type
159+
return ROW(attr_types)
160+
161+
if type_name not in _type_map:
162+
util.warn(f"Did not recognize type '{type_name}'")
163+
return sqltypes.NULLTYPE
164+
type_class = _type_map[type_name]
165+
type_args = [int(o.strip()) for o in type_opts.split(',')] if type_opts else []
166+
if type_name in ('time', 'timestamp'):
167+
type_kwargs = dict(timezone=type_str.endswith("with time zone"))
168+
return type_class(**type_kwargs) # TODO: handle time/timestamp(p) precision
169+
return type_class(*type_args)

0 commit comments

Comments
 (0)