Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions ci/make_geography_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@ def make_geography_db(
table = sa.Table(
table_name,
metadata,
*(
sa.Column(col_name, col_type)
for col_name, col_type in schema
),
*(sa.Column(col_name, col_type) for col_name, col_type in schema),
)
table_columns = table.c.keys()
post_parse = POST_PARSE_FUNCTIONS.get(table_name, toolz.identity)
Expand All @@ -82,16 +79,11 @@ def make_geography_db(
table.create(bind=bind)
bind.execute(
table.insert().values(),
[
post_parse(dict(zip(table_columns, row)))
for row in data[table_name]
],
[post_parse(dict(zip(table_columns, row))) for row in data[table_name]],
)


@click.command(
help="Create the geography SQLite database for the Ibis tutorial"
)
@click.command(help="Create the geography SQLite database for the Ibis tutorial")
@click.option(
"-d",
"--output-directory",
Expand Down
4 changes: 1 addition & 3 deletions docs/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
.aggregate(
num_investments=c.permalink.nunique(),
acq_ipos=(
c.status.isin(("ipo", "acquired"))
.ifelse(c.permalink, ibis.NA)
.nunique()
c.status.isin(("ipo", "acquired")).ifelse(c.permalink, ibis.NA).nunique()
),
)
.mutate(acq_rate=lambda t: t.acq_ipos / t.num_investments)
Expand Down
8 changes: 2 additions & 6 deletions docs/sqlalchemy_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,12 @@
).label("investor_name"),
sa.func.count(c.c.permalink.distinct()).label("num_investments"),
sa.func.count(
sa.case(
[(c.status.in_(("ipo", "acquired")), c.c.permalink)]
).distinct()
sa.case([(c.status.in_(("ipo", "acquired")), c.c.permalink)]).distinct()
).label("acq_ipos"),
]
)
.select_from(
c.join(
i, onclause=c.c.permalink == i.c.company_permalink, isouter=True
)
c.join(i, onclause=c.c.permalink == i.c.company_permalink, isouter=True)
)
.group_by(1)
.order_by(sa.desc(2))
Expand Down
5 changes: 1 addition & 4 deletions gen_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ def get_backends():
pyproject = tomli.loads(Path("pyproject.toml").read_text())
backends = pyproject["tool"]["poetry"]["plugins"]["ibis.backends"]
del backends["spark"]
return [
(backend, getattr(ibis, backend))
for backend in sorted(backends.keys())
]
return [(backend, getattr(ibis, backend)) for backend in sorted(backends.keys())]


def get_leaf_classes(op):
Expand Down
4 changes: 1 addition & 3 deletions ibis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def __getattr__(name: str) -> BaseBackend:
the `ibis.backends` entrypoints. If successful, the `ibis.sqlite`
attribute is "cached", so this function is only called the first time.
"""
entry_points = {
ep for ep in util.backend_entry_points() if ep.name == name
}
entry_points = {ep for ep in util.backend_entry_points() if ep.name == name}

if not entry_points:
msg = f"module 'ibis' has no attribute '{name}'. "
Expand Down
16 changes: 4 additions & 12 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,7 @@ def to_pyarrow(
# so construct at one column table (if applicable)
# then return the column _from_ the table
table = pa.Table.from_batches(
self.to_pyarrow_batches(
expr, params=params, limit=limit, **kwargs
)
self.to_pyarrow_batches(expr, params=params, limit=limit, **kwargs)
)
except ValueError:
# The pyarrow batches iterator is empty so pass in an empty
Expand Down Expand Up @@ -430,9 +428,7 @@ def database(self, name: str | None = None) -> Database:
Database
A database object for the specified database.
"""
return self.database_class(
name=name or self.current_database, client=self
)
return self.database_class(name=name or self.current_database, client=self)

@property
@abc.abstractmethod
Expand Down Expand Up @@ -561,9 +557,7 @@ def register_options(cls) -> None:
try:
setattr(options, backend_name, backend_options)
except ValueError as e:
raise exc.BackendConfigurationNotRegistered(
backend_name
) from e
raise exc.BackendConfigurationNotRegistered(backend_name) from e

def compile(
self,
Expand Down Expand Up @@ -592,9 +586,7 @@ def add_operation(self, operation: ops.Node) -> Callable:
... return 'NULL'
"""
if not hasattr(self, 'compiler'):
raise RuntimeError(
'Only SQL-based backends support `add_operation`'
)
raise RuntimeError('Only SQL-based backends support `add_operation`')

def decorator(translation_function: Callable) -> None:
self.compiler.translator_class.add_operation(
Expand Down
23 changes: 6 additions & 17 deletions ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ def sql(self, query: str) -> ir.Table:
return ops.SQLQueryResult(query, schema, self).to_expr()

def _get_schema_using_query(self, query):
raise NotImplementedError(
f"Backend {self.name} does not support .sql()"
)
raise NotImplementedError(f"Backend {self.name} does not support .sql()")

def raw_sql(self, query: str) -> Any:
"""Execute a query string.
Expand Down Expand Up @@ -149,9 +147,7 @@ def _cursor_batches(
limit: int | str | None = None,
chunk_size: int = 1_000_000,
) -> Iterable[list]:
query_ast = self.compiler.to_ast_ensure_limit(
expr, limit, params=params
)
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()

with self._safe_raw_sql(sql) as cursor:
Expand Down Expand Up @@ -207,9 +203,7 @@ def _batches():
)
yield pa.RecordBatch.from_struct_array(struct_array)

return pa.RecordBatchReader.from_batches(
schema.to_pyarrow(), _batches()
)
return pa.RecordBatchReader.from_batches(schema.to_pyarrow(), _batches())

def execute(
self,
Expand Down Expand Up @@ -249,9 +243,7 @@ def execute(
# feature than all this magic.
# we don't want to pass `timecontext` to `raw_sql`
kwargs.pop('timecontext', None)
query_ast = self.compiler.to_ast_ensure_limit(
expr, limit, params=params
)
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()
self._log(sql)

Expand Down Expand Up @@ -349,9 +341,7 @@ def compile(
The output of compilation. The type of this value depends on the
backend.
"""
return self.compiler.to_ast_ensure_limit(
expr, limit, params=params
).compile()
return self.compiler.to_ast_ensure_limit(expr, limit, params=params).compile()

def explain(
self,
Expand Down Expand Up @@ -397,6 +387,5 @@ def has_operation(cls, operation: type[ops.Value]) -> bool:

def _create_temp_view(self, view, definition):
raise NotImplementedError(
f"The {self.name} backend does not implement temporary view "
"creation"
f"The {self.name} backend does not implement temporary view " "creation"
)
22 changes: 6 additions & 16 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
import ibis.expr.types as ir
import ibis.util as util
from ibis.backends.base.sql import BaseSQLBackend
from ibis.backends.base.sql.alchemy.database import (
AlchemyDatabase,
AlchemyTable,
)
from ibis.backends.base.sql.alchemy.database import AlchemyDatabase, AlchemyTable
from ibis.backends.base.sql.alchemy.datatypes import (
schema_from_table,
table_from_schema,
Expand Down Expand Up @@ -71,9 +68,7 @@ class BaseAlchemyBackend(BaseSQLBackend):
table_class = AlchemyTable
compiler = AlchemyCompiler

def _build_alchemy_url(
self, url, host, port, user, password, database, driver
):
def _build_alchemy_url(self, url, host, port, user, password, database, driver):
if url is not None:
return sa.engine.url.make_url(url)

Expand Down Expand Up @@ -189,8 +184,7 @@ def create_table(

if database is not None:
raise NotImplementedError(
'Creating tables from a different database is not yet '
'implemented'
'Creating tables from a different database is not yet ' 'implemented'
)

if expr is None and schema is None:
Expand Down Expand Up @@ -236,9 +230,7 @@ def _get_insert_method(self, expr):

return methodcaller("from_select", list(expr.columns), compiled)

def _columns_from_schema(
self, name: str, schema: sch.Schema
) -> list[sa.Column]:
def _columns_from_schema(self, name: str, schema: sch.Schema) -> list[sa.Column]:
return [
sa.Column(colname, to_sqla_type(dtype), nullable=dtype.nullable)
for colname, dtype in zip(schema.names, schema.types)
Expand Down Expand Up @@ -273,8 +265,7 @@ def drop_table(

if database is not None:
raise NotImplementedError(
'Dropping tables from a different database is not yet '
'implemented'
'Dropping tables from a different database is not yet ' 'implemented'
)

t = self._get_sqla_table(table_name, schema=database, autoload=False)
Expand Down Expand Up @@ -515,8 +506,7 @@ def _get_temp_view_definition(
definition: sa.sql.compiler.Compiled,
) -> str:
raise NotImplementedError(
f"The {self.name} backend does not implement temporary view "
"creation"
f"The {self.name} backend does not implement temporary view " "creation"
)

def _register_temp_view_cleanup(self, name: str, raw_name: str) -> None:
Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/base/sql/alchemy/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def __init__(self, source, sqla_table, name, schema):
name = sqla_table.name
if schema is None:
schema = sch.infer(sqla_table, schema=schema)
super().__init__(
name=name, schema=schema, sqla_table=sqla_table, source=source
)
super().__init__(name=name, schema=schema, sqla_table=sqla_table, source=source)

# TODO(kszucs): remove this
def __equals__(self, other: AlchemyTable) -> bool:
Expand Down
8 changes: 2 additions & 6 deletions ibis/backends/base/sql/alchemy/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ def __init__(
self,
pairs: Iterable[tuple[str, sa.types.TypeEngine]],
):
self.pairs = [
(name, sa.types.to_instance(type)) for name, type in pairs
]
self.pairs = [(name, sa.types.to_instance(type)) for name, type in pairs]

def get_col_spec(self, **_):
pairs = ", ".join(f"{k} {v}" for k, v in self.pairs)
Expand Down Expand Up @@ -176,9 +174,7 @@ def sa_boolean(_, satype, nullable=True):
@dt.dtype.register(MySQLDialect, sa.NUMERIC)
def sa_mysql_numeric(_, satype, nullable=True):
# https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html
return dt.Decimal(
satype.precision or 10, satype.scale or 0, nullable=nullable
)
return dt.Decimal(satype.precision or 10, satype.scale or 0, nullable=nullable)


@dt.dtype.register(MySQLDialect, mysql.TINYBLOB)
Expand Down
20 changes: 5 additions & 15 deletions ibis/backends/base/sql/alchemy/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,9 @@ def get_result(self):
elif jtype is ops.OuterJoin:
result = result.outerjoin(table, onclause, full=True)
elif jtype is ops.LeftSemiJoin:
result = result.select().where(
sa.exists(sa.select(1).where(onclause))
)
result = result.select().where(sa.exists(sa.select(1).where(onclause)))
elif jtype is ops.LeftAntiJoin:
result = result.select().where(
~sa.exists(sa.select(1).where(onclause))
)
result = result.select().where(~sa.exists(sa.select(1).where(onclause)))
else:
raise NotImplementedError(jtype)

Expand Down Expand Up @@ -129,9 +125,7 @@ def _format_table(self, op):
# hack
if isinstance(op, ops.SelfReference):
table = ctx.get_ref(ref_op)
self_ref = (
alias if hasattr(alias, "name") else table.alias(alias)
)
self_ref = alias if hasattr(alias, "name") else table.alias(alias)
ctx.set_ref(op, self_ref)
return self_ref
return alias
Expand Down Expand Up @@ -304,9 +298,7 @@ def _add_where(self, fragment):
if not len(self.where):
return fragment

args = [
self._translate(pred, permit_subquery=True) for pred in self.where
]
args = [self._translate(pred, permit_subquery=True) for pred in self.where]
clause = functools.reduce(sql.and_, args)
return fragment.where(clause)

Expand Down Expand Up @@ -359,9 +351,7 @@ def compile(self):

def call(distinct, *args):
return (
self.distinct_func(*args)
if distinct
else self.non_distinct_func(*args)
self.distinct_func(*args) if distinct else self.non_distinct_func(*args)
)

for table in self.tables:
Expand Down
20 changes: 5 additions & 15 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,7 @@ def _cast(t, op):
return sa_arg

# specialize going from an integer type to a timestamp
if isinstance(arg.output_dtype, dt.Integer) and isinstance(
sa_type, sa.DateTime
):
if isinstance(arg.output_dtype, dt.Integer) and isinstance(sa_type, sa.DateTime):
return t.integer_to_timestamp(sa_arg)

if isinstance(arg.output_dtype, dt.Binary) and isinstance(typ, dt.String):
Expand Down Expand Up @@ -285,9 +283,7 @@ def _translate_case(t, cases, results, default):

def _negate(t, op):
arg = t.translate(op.arg)
return (
sa.not_(arg) if isinstance(op.arg.output_dtype, dt.Boolean) else -arg
)
return sa.not_(arg) if isinstance(op.arg.output_dtype, dt.Boolean) else -arg


def unary(sa_func):
Expand Down Expand Up @@ -392,9 +388,7 @@ def _window(t, op):
partition_by=partition_by, order_by=order_by, **additional_params
)

if isinstance(
window_op, (ops.RowNumber, ops.DenseRank, ops.MinRank, ops.NTile)
):
if isinstance(window_op, (ops.RowNumber, ops.DenseRank, ops.MinRank, ops.NTile)):
return result - 1
else:
return result
Expand Down Expand Up @@ -427,9 +421,7 @@ def _sort_key(t, op):


def _string_join(t, op):
return sa.func.concat_ws(
t.translate(op.sep), *map(t.translate, op.arg.values)
)
return sa.func.concat_ws(t.translate(op.sep), *map(t.translate, op.arg.values))


def reduction(sa_func):
Expand Down Expand Up @@ -540,9 +532,7 @@ def _count_star(t, op):
ops.BitXor: reduction(sa.func.bit_xor),
ops.CountDistinct: reduction(lambda arg: sa.func.count(arg.distinct())),
ops.HLLCardinality: reduction(lambda arg: sa.func.count(arg.distinct())),
ops.ApproxCountDistinct: reduction(
lambda arg: sa.func.count(arg.distinct())
),
ops.ApproxCountDistinct: reduction(lambda arg: sa.func.count(arg.distinct())),
ops.GroupConcat: _group_concat,
ops.Between: fixed_arity(sa.between, 3),
ops.IsNull: _is_null,
Expand Down
Loading