Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 57 additions & 31 deletions python/ray/data/grouped_data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union

from ray.data._internal.compute import ComputeStrategy
from ray.data._internal.logical.interfaces import LogicalPlan
from ray.data._internal.logical.operators.all_to_all_operator import Aggregate
from ray.data.aggregate import AggregateFn, Count, Max, Mean, Min, Std, Sum
from ray.data.block import (
Block,
BlockAccessor,
CallableClass,
DataBatch,
Expand Down Expand Up @@ -97,6 +98,7 @@ def map_groups(
self,
fn: UserDefinedFunction[DataBatch, DataBatch],
*,
zero_copy_batch: bool = False,
compute: Union[str, ComputeStrategy] = None,
batch_format: Optional[str] = "default",
fn_args: Optional[Iterable[Any]] = None,
Expand Down Expand Up @@ -157,6 +159,8 @@ def map_groups(
that can be instantiated to create such a callable. It takes as
input a batch of all records from a single group, and returns a
batch of zero or more records, similar to map_batches().
zero_copy_batch: If True, each group of rows (batch) will be provided w/o
making an additional copy.
compute: This argument is deprecated. Use ``concurrency`` argument.
batch_format: Specify ``"default"`` to use the default block format
(NumPy), ``"pandas"`` to select ``pandas.DataFrame``, "pyarrow" to
Expand Down Expand Up @@ -240,47 +244,40 @@ def map_groups(

# The batch is the entire block, because we have batch_size=None for
# map_batches() below.
def _apply_udf_to_groups(udf, batch, *args, **kwargs):
block = BlockAccessor.batch_to_block(batch)
block_accessor = BlockAccessor.for_block(block)

if self._key is None:
keys = []
elif isinstance(self._key, str):
keys = [self._key]
elif isinstance(self._key, List):
keys = self._key
else:
raise ValueError(
f"Group-by keys are expected to either be a single column (str) "
f"or a list of columns (got '{self._key}')"
)

boundaries = block_accessor._get_group_boundaries_sorted(keys)

for start, end in zip(boundaries[:-1], boundaries[1:]):
group_block = block_accessor.slice(start, end, copy=False)
group_block_accessor = BlockAccessor.for_block(group_block)
# Convert block of each group to batch format here, because the
# block format here can be different from batch format
# (e.g. block is Arrow format, and batch is NumPy format).
group_batch = group_block_accessor.to_batch_format(batch_format)
applied = udf(group_batch, *args, **kwargs)
yield applied
if self._key is None:
keys = []
elif isinstance(self._key, str):
keys = [self._key]
elif isinstance(self._key, List):
keys = self._key
else:
raise ValueError(
f"Group-by keys are expected to either be a single column (str) "
f"or a list of columns (got '{self._key}')"
)

# NOTE: It's crucial to make sure that UDF isn't capturing `GroupedData`
# object in its closure to ensure its serializability
#
# See https://github.com/ray-project/ray/issues/54280 for more details
if isinstance(fn, CallableClass):

class wrapped_fn:
def __init__(self, *args, **kwargs):
self.fn = fn(*args, **kwargs)

def __call__(self, batch, *args, **kwargs):
yield from _apply_udf_to_groups(self.fn, batch, *args, **kwargs)
yield from _apply_udf_to_groups(
self.fn, batch, keys, batch_format, *args, **kwargs
)

else:

def wrapped_fn(batch, *args, **kwargs):
yield from _apply_udf_to_groups(fn, batch, *args, **kwargs)
yield from _apply_udf_to_groups(
fn, batch, keys, batch_format, *args, **kwargs
)

# Change the name of the wrapped function so that users see the name of their
# function rather than `wrapped_fn` in the progress bar.
Expand All @@ -295,8 +292,11 @@ def wrapped_fn(batch, *args, **kwargs):
wrapped_fn,
batch_size=None,
compute=compute,
batch_format=batch_format,
zero_copy_batch=False,
# NOTE: We specify `batch_format` as none to avoid converting
# back-n-forth between batch and block formats (instead we convert
# once per group inside the method applying the UDF itself)
batch_format=None,
zero_copy_batch=zero_copy_batch,
fn_args=fn_args,
fn_kwargs=fn_kwargs,
fn_constructor_args=fn_constructor_args,
Expand Down Expand Up @@ -540,5 +540,31 @@ def std(
return self._aggregate_on(Std, on, ignore_nulls=ignore_nulls, ddof=ddof)


def _apply_udf_to_groups(
udf: Callable[[DataBatch, ...], DataBatch],
block: Block,
keys: List[str],
batch_format: Optional[str],
*args: Any,
**kwargs: Any,
) -> Iterator[DataBatch]:
"""Apply UDF to groups of rows having the same set of values of the specified
columns (keys).

NOTE: This function is defined at module level to avoid capturing closures and make it serializable."""
block_accessor = BlockAccessor.for_block(block)

boundaries = block_accessor._get_group_boundaries_sorted(keys)

for start, end in zip(boundaries[:-1], boundaries[1:]):
group_block = block_accessor.slice(start, end, copy=False)
group_block_accessor = BlockAccessor.for_block(group_block)

# Convert corresponding block of each group to batch format here,
# because the block format here can be different from batch format
# (e.g. block is Arrow format, and batch is NumPy format).
yield udf(group_block_accessor.to_batch_format(batch_format), *args, **kwargs)


# Backwards compatibility alias.
GroupedDataset = GroupedData