|
| 1 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 2 | +# you may not use this file except in compliance with the License. |
| 3 | +# You may obtain a copy of the License at |
| 4 | +# |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# |
| 7 | +# Unless required by applicable law or agreed to in writing, software |
| 8 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 9 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 10 | +# See the License for the specific language governing permissions and |
| 11 | +# limitations under the License. |
| 12 | +import re |
| 13 | +from typing import Dict, Iterator, Type, Union |
| 14 | + |
| 15 | +from sqlalchemy import util |
| 16 | +from sqlalchemy.sql import sqltypes |
| 17 | +from sqlalchemy.sql.type_api import TypeEngine |
| 18 | + |
| 19 | +# https://trino.io/docs/current/language/types.html |
| 20 | +_type_map = { |
| 21 | + # === Boolean === |
| 22 | + 'boolean': sqltypes.BOOLEAN, |
| 23 | + |
| 24 | + # === Integer === |
| 25 | + 'tinyint': sqltypes.SMALLINT, |
| 26 | + 'smallint': sqltypes.SMALLINT, |
| 27 | + 'integer': sqltypes.INTEGER, |
| 28 | + 'bigint': sqltypes.BIGINT, |
| 29 | + |
| 30 | + # === Floating-point === |
| 31 | + 'real': sqltypes.FLOAT, |
| 32 | + 'double': sqltypes.FLOAT, |
| 33 | + |
| 34 | + # === Fixed-precision === |
| 35 | + 'decimal': sqltypes.DECIMAL, |
| 36 | + |
| 37 | + # === String === |
| 38 | + 'varchar': sqltypes.VARCHAR, |
| 39 | + 'char': sqltypes.CHAR, |
| 40 | + 'varbinary': sqltypes.VARBINARY, |
| 41 | + 'json': sqltypes.JSON, |
| 42 | + |
| 43 | + # === Date and time === |
| 44 | + 'date': sqltypes.DATE, |
| 45 | + 'time': sqltypes.TIME, |
| 46 | + 'timestamp': sqltypes.TIMESTAMP, |
| 47 | + |
| 48 | + # 'interval year to month': |
| 49 | + # 'interval day to second': |
| 50 | + # |
| 51 | + # === Structural === |
| 52 | + # 'array': ARRAY, |
| 53 | + # 'map': MAP |
| 54 | + # 'row': ROW |
| 55 | + # |
| 56 | + # === Mixed === |
| 57 | + # 'ipaddress': IPADDRESS |
| 58 | + # 'uuid': UUID, |
| 59 | + # 'hyperloglog': HYPERLOGLOG, |
| 60 | + # 'p4hyperloglog': P4HYPERLOGLOG, |
| 61 | + # 'qdigest': QDIGEST, |
| 62 | + # 'tdigest': TDIGEST, |
| 63 | +} |
| 64 | + |
| 65 | +SQLType = Union[TypeEngine, Type[TypeEngine]] |
| 66 | + |
| 67 | + |
| 68 | +class MAP(TypeEngine): |
| 69 | + __visit_name__ = "MAP" |
| 70 | + |
| 71 | + def __init__(self, key_type: SQLType, value_type: SQLType): |
| 72 | + if isinstance(key_type, type): |
| 73 | + key_type = key_type() |
| 74 | + self.key_type: TypeEngine = key_type |
| 75 | + |
| 76 | + if isinstance(value_type, type): |
| 77 | + value_type = value_type() |
| 78 | + self.value_type: TypeEngine = value_type |
| 79 | + |
| 80 | + @property |
| 81 | + def python_type(self): |
| 82 | + return dict |
| 83 | + |
| 84 | + |
| 85 | +class ROW(TypeEngine): |
| 86 | + __visit_name__ = "ROW" |
| 87 | + |
| 88 | + def __init__(self, attr_types: Dict[str, SQLType]): |
| 89 | + for name, attr_type in attr_types.items(): |
| 90 | + if isinstance(attr_type, type): |
| 91 | + attr_type = attr_type() |
| 92 | + attr_types[name] = attr_type |
| 93 | + self.attr_types: Dict[str, TypeEngine] = attr_types |
| 94 | + |
| 95 | + @property |
| 96 | + def python_type(self): |
| 97 | + return dict |
| 98 | + |
| 99 | + |
| 100 | +def split(string: str, delimiter: str = ',', |
| 101 | + quote: str = '"', escaped_quote: str = r'\"', |
| 102 | + open_bracket: str = '(', close_bracket: str = ')') -> Iterator[str]: |
| 103 | + """ |
| 104 | + A split function that is aware of quotes and brackets/parentheses. |
| 105 | +
|
| 106 | + :param string: string to split |
| 107 | + :param delimiter: string defining where to split, usually a comma or space |
| 108 | + :param quote: string, either a single or a double quote |
| 109 | + :param escaped_quote: string representing an escaped quote |
| 110 | + :param open_bracket: string, either [, {, < or ( |
| 111 | + :param close_bracket: string, either ], }, > or ) |
| 112 | + """ |
| 113 | + parens = 0 |
| 114 | + quotes = False |
| 115 | + i = 0 |
| 116 | + for j, character in enumerate(string): |
| 117 | + complete = parens == 0 and not quotes |
| 118 | + if complete and character == delimiter: |
| 119 | + yield string[i:j] |
| 120 | + i = j + len(delimiter) |
| 121 | + elif character == open_bracket: |
| 122 | + parens += 1 |
| 123 | + elif character == close_bracket: |
| 124 | + parens -= 1 |
| 125 | + elif character == quote: |
| 126 | + if quotes and string[j - len(escaped_quote) + 1: j + 1] != escaped_quote: |
| 127 | + quotes = False |
| 128 | + elif not quotes: |
| 129 | + quotes = True |
| 130 | + yield string[i:] |
| 131 | + |
| 132 | + |
| 133 | +def parse_sqltype(type_str: str) -> TypeEngine: |
| 134 | + type_str = type_str.strip().lower() |
| 135 | + match = re.match(r'^(?P<type>\w+)\s*(?:\((?P<options>.*)\))?', type_str) |
| 136 | + if not match: |
| 137 | + util.warn(f"Could not parse type name '{type_str}'") |
| 138 | + return sqltypes.NULLTYPE |
| 139 | + type_name = match.group("type") |
| 140 | + type_opts = match.group("options") |
| 141 | + |
| 142 | + if type_name == "array": |
| 143 | + item_type = parse_sqltype(type_opts) |
| 144 | + if isinstance(item_type, sqltypes.ARRAY): |
| 145 | + dimensions = (item_type.dimensions or 1) + 1 |
| 146 | + return sqltypes.ARRAY(item_type.item_type, dimensions=dimensions) |
| 147 | + return sqltypes.ARRAY(item_type) |
| 148 | + elif type_name == "map": |
| 149 | + key_type_str, value_type_str = split(type_opts) |
| 150 | + key_type = parse_sqltype(key_type_str) |
| 151 | + value_type = parse_sqltype(value_type_str) |
| 152 | + return MAP(key_type, value_type) |
| 153 | + elif type_name == "row": |
| 154 | + attr_types: Dict[str, SQLType] = {} |
| 155 | + for attr_str in split(type_opts): |
| 156 | + name, attr_type_str = split(attr_str.strip(), delimiter=' ') |
| 157 | + attr_type = parse_sqltype(attr_type_str) |
| 158 | + attr_types[name] = attr_type |
| 159 | + return ROW(attr_types) |
| 160 | + |
| 161 | + if type_name not in _type_map: |
| 162 | + util.warn(f"Did not recognize type '{type_name}'") |
| 163 | + return sqltypes.NULLTYPE |
| 164 | + type_class = _type_map[type_name] |
| 165 | + type_args = [int(o.strip()) for o in type_opts.split(',')] if type_opts else [] |
| 166 | + if type_name in ('time', 'timestamp'): |
| 167 | + type_kwargs = dict(timezone=type_str.endswith("with time zone")) |
| 168 | + return type_class(**type_kwargs) # TODO: handle time/timestamp(p) precision |
| 169 | + return type_class(*type_args) |
0 commit comments