Skip to content
14 changes: 14 additions & 0 deletions python/ray/data/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,20 @@ py_test(
],
)

py_test(
name = "test_predicate_pushdown",
size = "small",
srcs = ["tests/test_predicate_pushdown.py"],
tags = [
"exclusive",
"team:data",
],
deps = [
":conftest",
"//:ray_lib",
],
)

py_test(
name = "test_path_util",
size = "small",
Expand Down
32 changes: 32 additions & 0 deletions python/ray/data/_internal/datasource/csv_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ray.data.block import Block
from ray.data.datasource.file_based_datasource import FileBasedDatasource
from ray.data.expressions import Expr

if TYPE_CHECKING:
import pyarrow
Expand Down Expand Up @@ -36,6 +37,29 @@ def __init__(
)
self.parse_options = arrow_csv_args.pop("parse_options", csv.ParseOptions())
self.arrow_csv_args = arrow_csv_args
self._predicate_expr: Optional[Expr] = None

def supports_predicate_pushdown(self) -> bool:
return True

def get_current_predicate(self) -> Optional[Expr]:
return self._predicate_expr

def apply_predicate(
self,
predicate_expr: Expr,
) -> "CSVDatasource":
import copy

clone = copy.copy(self)

# Combine with existing predicate using AND
clone._predicate_expr = (
predicate_expr
if clone._predicate_expr is None
else clone._predicate_expr & predicate_expr
)
return clone

def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
import pyarrow as pa
Expand All @@ -47,6 +71,12 @@ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
self.parse_options.invalid_row_handler
)

filter_expr = (
self._predicate_expr.to_pyarrow()
if self._predicate_expr is not None
else None
)

try:
reader = csv.open_csv(
f,
Expand All @@ -61,6 +91,8 @@ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
table = pa.Table.from_batches([batch], schema=schema)
if schema is None:
schema = table.schema
if filter_expr is not None:
table = table.filter(filter_expr)
yield table
except StopIteration:
return
Expand Down
58 changes: 56 additions & 2 deletions python/ray/data/_internal/datasource/parquet_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from ray.data.datasource.path_util import (
_resolve_paths_and_filesystem,
)
from ray.data.expressions import Expr
from ray.util.debug import log_once

if TYPE_CHECKING:
Expand Down Expand Up @@ -284,7 +285,7 @@ def __init__(
self._file_metadata_shuffler = None
self._include_paths = include_paths
self._partitioning = partitioning

self._predicate_expr: Optional[Expr] = None
if shuffle == "files":
self._file_metadata_shuffler = np.random.default_rng()
elif isinstance(shuffle, FileShuffleConfig):
Expand Down Expand Up @@ -352,6 +353,12 @@ def get_read_tasks(
)

read_tasks = []
filter_expr = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much work is actually push our expr all the way into the reader itself?

If not a lot let's do the right thing right away (otherwise do it in stacked PR)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm already pushing this into the reader. apply_predicate should change the _predicate_expr which then calls to_pyarrow() to convert to a pyarrow.dataset.expression which then gets sent to fragment.to_batches()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant, threading of our expressions instead of PA ones

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow. How will pyarrow accept Ray Data's Expressions? At some point we have to convert before calling to_batches() right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline. As part of the next PR, I'll refactor the remaining 2 functions that are not managed by Pyarrow to only pass in Ray Data's Expr

self._predicate_expr.to_pyarrow()
if self._predicate_expr is not None
else None
)

for fragments, paths in zip(
np.array_split(pq_fragments, parallelism),
np.array_split(pq_paths, parallelism),
Expand Down Expand Up @@ -401,6 +408,7 @@ def get_read_tasks(
f,
include_paths,
partitioning,
filter_expr,
),
meta,
schema=target_schema,
Expand All @@ -424,6 +432,9 @@ def supports_distributed_reads(self) -> bool:
def supports_projection_pushdown(self) -> bool:
return True

def supports_predicate_pushdown(self) -> bool:
return True

def get_current_projection(self) -> Optional[List[str]]:
# NOTE: In case there's no projection both file and partition columns
# will be none
Expand All @@ -432,6 +443,9 @@ def get_current_projection(self) -> Optional[List[str]]:

return (self._data_columns or []) + (self._partition_columns or [])

def get_current_predicate(self) -> Optional[Expr]:
return self._predicate_expr

def apply_projection(
self,
columns: Optional[List[str]],
Expand All @@ -446,6 +460,37 @@ def apply_projection(

return clone

# TODO: This should be moved to the Datasource class
def apply_predicate(
self,
predicate_expr: Expr,
) -> "ParquetDatasource":
from ray.data._internal.planner.plan_expression.expression_visitors import (
_ColumnSubstitutionVisitor,
)
from ray.data.expressions import col

clone = copy.copy(self)
# Handle column renaming for Ray Data expressions
if self._data_columns_rename_map:
# Create mapping from new column names to old column names
# It's new to old mapping because we need to visit the predicate expression (which has all the new cols)
# and map them to the old columns so that the filtering can be pushed into the read tasks.
column_mapping = {
new_col: col(old_col)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not immediately clear to me why it's new_col -> old_col? Shouldn't it be the other way around?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's new to old cause we need to visit the predicate expression (which has all the new cols) and map them to the old columns so that the filtering can be pushed into the read.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Capture this in a comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also abstract this rebinding as it will also need to happen in in the Rule itself.

We'd also think about how we can consolidate rebinding in the Rule to avoid having it in 2 places.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for old_col, new_col in self._data_columns_rename_map.items()
}
visitor = _ColumnSubstitutionVisitor(column_mapping)
predicate_expr = visitor.visit(predicate_expr)

# Combine with existing predicate using AND
clone._predicate_expr = (
predicate_expr
if clone._predicate_expr is None
else clone._predicate_expr & predicate_expr
)
return clone
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Parquet Predicate Pushdown Inconsistency

ParquetDatasource implements predicate pushdown and its supports_predicate_pushdown method returns True, but it doesn't override get_current_predicate(). As a result, get_current_predicate() always returns None even when a predicate has been applied, leading to an inconsistency. The CSVDatasource correctly implements this method.

Additional Locations (1)

Fix in Cursor Fix in Web


def _estimate_in_mem_size(self, fragments: List[_ParquetFragment]) -> int:
in_mem_size = sum([f.file_size for f in fragments]) * self._encoding_ratio

Expand All @@ -463,6 +508,7 @@ def read_fragments(
fragments: List[_ParquetFragment],
include_paths: bool,
partitioning: Partitioning,
filter_expr: Optional["pyarrow.dataset.Expression"] = None,
) -> Iterator["pyarrow.Table"]:
# This import is necessary to load the tensor extension type.
from ray.data.extensions.tensor_extension import ArrowTensorType # noqa
Expand All @@ -484,6 +530,7 @@ def read_fragments(
partition_columns=partition_columns,
partitioning=partitioning,
include_path=include_paths,
filter_expr=filter_expr,
batch_size=default_read_batch_size_rows,
to_batches_kwargs=to_batches_kwargs,
),
Expand Down Expand Up @@ -522,7 +569,14 @@ def _read_batches_from(
# NOTE: Passed in kwargs overrides always take precedence
# TODO deprecate to_batches_kwargs
use_threads = to_batches_kwargs.pop("use_threads", use_threads)
filter_expr = to_batches_kwargs.pop("filter", filter_expr)
# TODO: We should deprecate filter through the read_parquet API and only allow through dataset.filter()
filter_from_kwargs = to_batches_kwargs.pop("filter", None)
if filter_from_kwargs is not None:
filter_expr = (
filter_from_kwargs
if filter_expr is None
else filter_expr & filter_from_kwargs
)
# NOTE: Arrow's ``to_batches`` expects ``batch_size`` as an int
if batch_size is not None:
to_batches_kwargs.setdefault("batch_size", batch_size)
Expand Down
7 changes: 6 additions & 1 deletion python/ray/data/_internal/logical/interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from .logical_operator import LogicalOperator, LogicalOperatorSupportsProjectionPushdown
from .logical_operator import (
LogicalOperator,
LogicalOperatorSupportsPredicatePushdown,
LogicalOperatorSupportsProjectionPushdown,
)
from .logical_plan import LogicalPlan
from .operator import Operator
from .optimizer import Optimizer, Rule
Expand All @@ -16,4 +20,5 @@
"Rule",
"SourceOperator",
"LogicalOperatorSupportsProjectionPushdown",
"LogicalOperatorSupportsPredicatePushdown",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .operator import Operator
from ray.data.block import BlockMetadata
from ray.data.expressions import Expr

if TYPE_CHECKING:
from ray.data.block import Schema
Expand Down Expand Up @@ -104,3 +105,19 @@ def apply_projection(
column_rename_map: Optional[Dict[str, str]],
) -> LogicalOperator:
return self


class LogicalOperatorSupportsPredicatePushdown(LogicalOperator):
"""Mixin for reading operators supporting predicate pushdown"""

def supports_predicate_pushdown(self) -> bool:
return False

def get_current_predicate(self) -> Optional[Expr]:
return None

def apply_predicate(
self,
predicate_expr: Expr,
) -> LogicalOperator:
return self
9 changes: 9 additions & 0 deletions python/ray/data/_internal/logical/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,15 @@ def __init__(
def can_modify_num_rows(self) -> bool:
return True

def is_expression_based(self) -> bool:
return self._predicate_expr is not None

def _get_operator_name(self, op_name: str, fn: UserDefinedFunction):
if self.is_expression_based():
# TODO: Use a truncated expression prefix here instead of <expression>.
return f"{op_name}(<expression>)"
return super()._get_operator_name(op_name, fn)


class Project(AbstractMap):
"""Logical operator for all Projection Operations."""
Expand Down
52 changes: 50 additions & 2 deletions python/ray/data/_internal/logical/operators/n_ary_operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Optional

from ray.data._internal.logical.interfaces import LogicalOperator
from ray.data._internal.logical.interfaces import (
LogicalOperator,
LogicalOperatorSupportsPredicatePushdown,
)
from ray.data.expressions import Expr


class NAry(LogicalOperator):
Expand Down Expand Up @@ -37,14 +41,15 @@ def estimated_num_outputs(self):
return total_num_outputs


class Union(NAry):
class Union(NAry, LogicalOperatorSupportsPredicatePushdown):
"""Logical operator for union."""

def __init__(
self,
*input_ops: LogicalOperator,
):
super().__init__(*input_ops)
self._predicate_expr: Optional[Expr] = None

def estimated_num_outputs(self):
total_num_outputs = 0
Expand All @@ -54,3 +59,46 @@ def estimated_num_outputs(self):
return None
total_num_outputs += num_outputs
return total_num_outputs

def supports_predicate_pushdown(self) -> bool:
"""Union supports predicate pushdown by applying predicates to all branches."""
return True

def get_current_predicate(self) -> Optional[Expr]:
"""Returns the current predicate expression applied to this Union."""
return self._predicate_expr

def apply_predicate(self, predicate_expr: Expr) -> "Union":
"""Apply a predicate by pushing it down to all input branches.

This creates a new Union with the predicate applied to each input operator
that supports predicate pushdown.
"""
import copy

from ray.data._internal.logical.operators.map_operator import Filter

clone = copy.copy(self)

# Combine with existing predicate using AND
clone._predicate_expr = (
predicate_expr
if clone._predicate_expr is None
else clone._predicate_expr & predicate_expr
)

# Apply predicate to each branch
new_inputs = []
for input_op in self._input_dependencies:
# If the branch supports predicate pushdown, use it
if (
isinstance(input_op, LogicalOperatorSupportsPredicatePushdown)
and input_op.supports_predicate_pushdown()
):
new_inputs.append(input_op.apply_predicate(predicate_expr))
else:
# Otherwise, wrap with a Filter operator
new_inputs.append(Filter(input_op, predicate_expr=predicate_expr))

clone._input_dependencies = new_inputs
return clone
24 changes: 23 additions & 1 deletion python/ray/data/_internal/logical/operators/read_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict, List, Optional, Union

from ray.data._internal.logical.interfaces import (
LogicalOperatorSupportsPredicatePushdown,
LogicalOperatorSupportsProjectionPushdown,
SourceOperator,
)
Expand All @@ -14,9 +15,15 @@
)
from ray.data.context import DataContext
from ray.data.datasource.datasource import Datasource, Reader
from ray.data.expressions import Expr


class Read(AbstractMap, SourceOperator, LogicalOperatorSupportsProjectionPushdown):
class Read(
AbstractMap,
SourceOperator,
LogicalOperatorSupportsProjectionPushdown,
LogicalOperatorSupportsPredicatePushdown,
):
"""Logical operator for read."""

# TODO: make this a frozen dataclass. https://github.com/ray-project/ray/issues/55747
Expand Down Expand Up @@ -173,6 +180,21 @@ def apply_projection(

return clone

def supports_predicate_pushdown(self) -> bool:
return self._datasource.supports_predicate_pushdown()

def get_current_predicate(self) -> Optional[Expr]:
return self._datasource.get_current_predicate()

def apply_predicate(self, predicate_expr: Expr) -> "Read":
clone = copy.copy(self)

predicated_datasource = self._datasource.apply_predicate(predicate_expr)
clone._datasource = predicated_datasource
clone._datasource_or_legacy_reader = predicated_datasource

return clone

def can_modify_num_rows(self) -> bool:
# NOTE: Returns true, since most of the readers expands its input
# and produce many rows for every single row of the input
Expand Down
Loading