Skip to content

Commit 7128024

Browse files
committed
feat: Added RemoteDatasetProxy that executes Ray Data operations remotely
Signed-off-by: ntkathole <[email protected]>
1 parent a24e06e commit 7128024

File tree

2 files changed

+159
-21
lines changed

2 files changed

+159
-21
lines changed

sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
_build_required_columns,
4444
apply_field_mapping,
4545
ensure_timestamp_compatibility,
46+
is_ray_data,
4647
normalize_timestamp_columns,
4748
)
4849
from feast.infra.registry.base_registry import BaseRegistry
@@ -61,7 +62,7 @@
6162

6263

6364
def _get_data_schema_info(
64-
data: Union[pd.DataFrame, Dataset],
65+
data: Union[pd.DataFrame, Dataset, Any],
6566
) -> Tuple[Dict[str, Any], List[str]]:
6667
"""
6768
Extract schema information from DataFrame or Dataset.
@@ -70,7 +71,7 @@ def _get_data_schema_info(
7071
Returns:
7172
Tuple of (dtypes_dict, column_names)
7273
"""
73-
if isinstance(data, Dataset):
74+
if is_ray_data(data):
7475
schema = data.schema()
7576
dtypes = {}
7677
for i, col in enumerate(schema.names):
@@ -84,16 +85,17 @@ def _get_data_schema_info(
8485
dtypes[col] = pd.api.types.pandas_dtype("object")
8586
columns = schema.names
8687
else:
88+
assert isinstance(data, pd.DataFrame)
8789
dtypes = data.dtypes.to_dict()
8890
columns = list(data.columns)
8991
return dtypes, columns
9092

9193

9294
def _apply_to_data(
93-
data: Union[pd.DataFrame, Dataset],
95+
data: Union[pd.DataFrame, Dataset, Any],
9496
process_func: Callable[[pd.DataFrame], pd.DataFrame],
9597
inplace: bool = False,
96-
) -> Union[pd.DataFrame, Dataset]:
98+
) -> Union[pd.DataFrame, Dataset, Any]:
9799
"""
98100
Apply a processing function to DataFrame or Dataset.
99101
Args:
@@ -103,9 +105,10 @@ def _apply_to_data(
103105
Returns:
104106
Processed DataFrame or Dataset
105107
"""
106-
if isinstance(data, Dataset):
108+
if is_ray_data(data):
107109
return data.map_batches(process_func, batch_format="pandas")
108110
else:
111+
assert isinstance(data, pd.DataFrame)
109112
if not inplace:
110113
data = data.copy()
111114
return process_func(data)
@@ -158,7 +161,7 @@ def _safe_infer_event_timestamp_column(
158161

159162

160163
def _safe_get_entity_timestamp_bounds(
161-
data: Union[pd.DataFrame, Dataset], timestamp_column: str
164+
data: Union[pd.DataFrame, Dataset, Any], timestamp_column: str
162165
) -> Tuple[Optional[datetime], Optional[datetime]]:
163166
"""
164167
Safely get entity timestamp bounds.
@@ -170,7 +173,7 @@ def _safe_get_entity_timestamp_bounds(
170173
Tuple of (min_timestamp, max_timestamp) or (None, None) if failed
171174
"""
172175
try:
173-
if isinstance(data, Dataset):
176+
if is_ray_data(data):
174177
min_ts = data.min(timestamp_column)
175178
max_ts = data.max(timestamp_column)
176179
else:
@@ -192,7 +195,7 @@ def _safe_get_entity_timestamp_bounds(
192195
f"Timestamp bounds extraction failed: {e}, falling back to manual calculation"
193196
)
194197
try:
195-
if isinstance(data, Dataset):
198+
if is_ray_data(data):
196199

197200
def extract_bounds(batch: pd.DataFrame) -> pd.DataFrame:
198201
if timestamp_column in batch.columns and not batch.empty:
@@ -212,6 +215,7 @@ def extract_bounds(batch: pd.DataFrame) -> pd.DataFrame:
212215
if pd.notna(min_ts) and pd.notna(max_ts):
213216
return min_ts.to_pydatetime(), max_ts.to_pydatetime()
214217
else:
218+
assert isinstance(data, pd.DataFrame)
215219
if timestamp_column in data.columns:
216220
timestamps = pd.to_datetime(data[timestamp_column], utc=True)
217221
return (

sdk/python/feast/infra/ray_shared_utils.py

Lines changed: 147 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,145 @@
1-
from typing import Dict, List, Optional, Union
1+
from typing import Any, Dict, List, Optional, Union
22

33
import numpy as np
44
import pandas as pd
5+
import pyarrow as pa
6+
import ray
57
from ray.data import Dataset
68

79

10+
class RemoteDatasetProxy:
11+
"""Proxy class that executes Ray Data operations remotely on cluster workers."""
12+
13+
def __init__(self, dataset_ref: Any):
14+
"""Initialize with a reference to the remote dataset."""
15+
self._dataset_ref = dataset_ref
16+
17+
def map_batches(self, func, **kwargs) -> "RemoteDatasetProxy":
18+
"""Execute map_batches remotely on cluster workers."""
19+
20+
@ray.remote
21+
def _remote_map_batches(dataset, function, batch_kwargs):
22+
return dataset.map_batches(function, **batch_kwargs)
23+
24+
new_ref = _remote_map_batches.remote(self._dataset_ref, func, kwargs)
25+
return RemoteDatasetProxy(new_ref)
26+
27+
def filter(self, fn) -> "RemoteDatasetProxy":
28+
"""Execute filter remotely on cluster workers."""
29+
30+
@ray.remote
31+
def _remote_filter(dataset, filter_fn):
32+
return dataset.filter(filter_fn)
33+
34+
new_ref = _remote_filter.remote(self._dataset_ref, fn)
35+
return RemoteDatasetProxy(new_ref)
36+
37+
def to_pandas(self) -> pd.DataFrame:
38+
"""Execute to_pandas remotely and transfer result to client."""
39+
40+
@ray.remote
41+
def _remote_to_pandas(dataset):
42+
return dataset.to_pandas()
43+
44+
result_ref = _remote_to_pandas.remote(self._dataset_ref)
45+
return ray.get(result_ref)
46+
47+
def to_arrow(self) -> pa.Table:
48+
"""Execute to_arrow remotely and transfer result to client."""
49+
50+
@ray.remote
51+
def _remote_to_arrow(dataset):
52+
return dataset.to_arrow()
53+
54+
result_ref = _remote_to_arrow.remote(self._dataset_ref)
55+
return ray.get(result_ref)
56+
57+
def schema(self) -> Any:
58+
"""Get dataset schema."""
59+
60+
@ray.remote
61+
def _remote_schema(dataset):
62+
return dataset.schema()
63+
64+
schema_ref = _remote_schema.remote(self._dataset_ref)
65+
return ray.get(schema_ref)
66+
67+
def sort(self, key, descending=False) -> "RemoteDatasetProxy":
68+
"""Execute sort remotely on cluster workers."""
69+
70+
@ray.remote
71+
def _remote_sort(dataset, sort_key, desc):
72+
return dataset.sort(sort_key, descending=desc)
73+
74+
new_ref = _remote_sort.remote(self._dataset_ref, key, descending)
75+
return RemoteDatasetProxy(new_ref)
76+
77+
def limit(self, count) -> "RemoteDatasetProxy":
78+
"""Execute limit remotely on cluster workers."""
79+
80+
@ray.remote
81+
def _remote_limit(dataset, limit_count):
82+
return dataset.limit(limit_count)
83+
84+
new_ref = _remote_limit.remote(self._dataset_ref, count)
85+
return RemoteDatasetProxy(new_ref)
86+
87+
def union(self, other) -> "RemoteDatasetProxy":
88+
"""Execute union remotely on cluster workers."""
89+
90+
@ray.remote
91+
def _remote_union(dataset1, dataset2):
92+
return dataset1.union(dataset2)
93+
94+
new_ref = _remote_union.remote(self._dataset_ref, other._dataset_ref)
95+
return RemoteDatasetProxy(new_ref)
96+
97+
def materialize(self) -> "RemoteDatasetProxy":
98+
"""Execute materialize remotely on cluster workers."""
99+
100+
@ray.remote
101+
def _remote_materialize(dataset):
102+
return dataset.materialize()
103+
104+
new_ref = _remote_materialize.remote(self._dataset_ref)
105+
return RemoteDatasetProxy(new_ref)
106+
107+
def count(self) -> int:
108+
"""Execute count remotely and return result."""
109+
110+
@ray.remote
111+
def _remote_count(dataset):
112+
return dataset.count()
113+
114+
result_ref = _remote_count.remote(self._dataset_ref)
115+
return ray.get(result_ref)
116+
117+
def take(self, n=20) -> list:
118+
"""Execute take remotely and return result."""
119+
120+
@ray.remote
121+
def _remote_take(dataset, num):
122+
return dataset.take(num)
123+
124+
result_ref = _remote_take.remote(self._dataset_ref, n)
125+
return ray.get(result_ref)
126+
127+
def __getattr__(self, name):
128+
"""Catch any method calls that we haven't explicitly implemented."""
129+
raise AttributeError(f"RemoteDatasetProxy has no attribute '{name}'")
130+
131+
132+
def is_ray_data(data: Any) -> bool:
133+
"""Check if data is a Ray Dataset or RemoteDatasetProxy."""
134+
return isinstance(data, (Dataset, RemoteDatasetProxy))
135+
136+
8137
def normalize_timestamp_columns(
9-
data: Union[pd.DataFrame, Dataset],
138+
data: Union[pd.DataFrame, Dataset, Any],
10139
columns: Union[str, List[str]],
11140
inplace: bool = False,
12141
exclude_columns: Optional[List[str]] = None,
13-
) -> Union[pd.DataFrame, Dataset]:
142+
) -> Union[pd.DataFrame, Dataset, Any]:
14143
column_list = [columns] if isinstance(columns, str) else columns
15144
exclude_columns = exclude_columns or []
16145

@@ -21,7 +150,7 @@ def apply_normalization(series: pd.Series) -> pd.Series:
21150
.astype("datetime64[ns, UTC]")
22151
)
23152

24-
if isinstance(data, Dataset):
153+
if is_ray_data(data):
25154

26155
def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame:
27156
for column in column_list:
@@ -35,6 +164,7 @@ def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame:
35164

36165
return data.map_batches(normalize_batch, batch_format="pandas")
37166
else:
167+
assert isinstance(data, pd.DataFrame)
38168
if not inplace:
39169
data = data.copy()
40170
for column in column_list:
@@ -44,13 +174,13 @@ def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame:
44174

45175

46176
def ensure_timestamp_compatibility(
47-
data: Union[pd.DataFrame, Dataset],
177+
data: Union[pd.DataFrame, Dataset, Any],
48178
timestamp_fields: List[str],
49179
inplace: bool = False,
50-
) -> Union[pd.DataFrame, Dataset]:
180+
) -> Union[pd.DataFrame, Dataset, Any]:
51181
from feast.utils import make_df_tzaware
52182

53-
if isinstance(data, Dataset):
183+
if is_ray_data(data):
54184

55185
def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame:
56186
batch = make_df_tzaware(batch)
@@ -65,6 +195,7 @@ def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame:
65195

66196
return data.map_batches(ensure_compatibility, batch_format="pandas")
67197
else:
198+
assert isinstance(data, pd.DataFrame)
68199
if not inplace:
69200
data = data.copy()
70201
from feast.utils import make_df_tzaware
@@ -77,22 +208,24 @@ def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame:
77208

78209

79210
def apply_field_mapping(
80-
data: Union[pd.DataFrame, Dataset], field_mapping: Dict[str, str]
81-
) -> Union[pd.DataFrame, Dataset]:
211+
data: Union[pd.DataFrame, Dataset, Any],
212+
field_mapping: Dict[str, str],
213+
) -> Union[pd.DataFrame, Dataset, Any]:
82214
def rename_columns(df: pd.DataFrame) -> pd.DataFrame:
83215
return df.rename(columns=field_mapping)
84216

85-
if isinstance(data, Dataset):
217+
if is_ray_data(data):
86218
return data.map_batches(rename_columns, batch_format="pandas")
87219
else:
220+
assert isinstance(data, pd.DataFrame)
88221
return data.rename(columns=field_mapping)
89222

90223

91224
def deduplicate_by_keys_and_timestamp(
92-
data: Union[pd.DataFrame, Dataset],
225+
data: Union[pd.DataFrame, Dataset, Any],
93226
join_keys: List[str],
94227
timestamp_columns: List[str],
95-
) -> Union[pd.DataFrame, Dataset]:
228+
) -> Union[pd.DataFrame, Dataset, Any]:
96229
def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame:
97230
if batch.empty:
98231
return batch
@@ -110,9 +243,10 @@ def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame:
110243
return deduped_batch
111244
return batch
112245

113-
if isinstance(data, Dataset):
246+
if is_ray_data(data):
114247
return data.map_batches(deduplicate_batch, batch_format="pandas")
115248
else:
249+
assert isinstance(data, pd.DataFrame)
116250
return deduplicate_batch(data)
117251

118252

0 commit comments

Comments
 (0)