1
- from typing import Dict , List , Optional , Union
1
+ from typing import Any , Dict , List , Optional , Union
2
2
3
3
import numpy as np
4
4
import pandas as pd
5
+ import pyarrow as pa
6
+ import ray
5
7
from ray .data import Dataset
6
8
7
9
10
+ class RemoteDatasetProxy :
11
+ """Proxy class that executes Ray Data operations remotely on cluster workers."""
12
+
13
+ def __init__ (self , dataset_ref : Any ):
14
+ """Initialize with a reference to the remote dataset."""
15
+ self ._dataset_ref = dataset_ref
16
+
17
+ def map_batches (self , func , ** kwargs ) -> "RemoteDatasetProxy" :
18
+ """Execute map_batches remotely on cluster workers."""
19
+
20
+ @ray .remote
21
+ def _remote_map_batches (dataset , function , batch_kwargs ):
22
+ return dataset .map_batches (function , ** batch_kwargs )
23
+
24
+ new_ref = _remote_map_batches .remote (self ._dataset_ref , func , kwargs )
25
+ return RemoteDatasetProxy (new_ref )
26
+
27
+ def filter (self , fn ) -> "RemoteDatasetProxy" :
28
+ """Execute filter remotely on cluster workers."""
29
+
30
+ @ray .remote
31
+ def _remote_filter (dataset , filter_fn ):
32
+ return dataset .filter (filter_fn )
33
+
34
+ new_ref = _remote_filter .remote (self ._dataset_ref , fn )
35
+ return RemoteDatasetProxy (new_ref )
36
+
37
+ def to_pandas (self ) -> pd .DataFrame :
38
+ """Execute to_pandas remotely and transfer result to client."""
39
+
40
+ @ray .remote
41
+ def _remote_to_pandas (dataset ):
42
+ return dataset .to_pandas ()
43
+
44
+ result_ref = _remote_to_pandas .remote (self ._dataset_ref )
45
+ return ray .get (result_ref )
46
+
47
+ def to_arrow (self ) -> pa .Table :
48
+ """Execute to_arrow remotely and transfer result to client."""
49
+
50
+ @ray .remote
51
+ def _remote_to_arrow (dataset ):
52
+ return dataset .to_arrow ()
53
+
54
+ result_ref = _remote_to_arrow .remote (self ._dataset_ref )
55
+ return ray .get (result_ref )
56
+
57
+ def schema (self ) -> Any :
58
+ """Get dataset schema."""
59
+
60
+ @ray .remote
61
+ def _remote_schema (dataset ):
62
+ return dataset .schema ()
63
+
64
+ schema_ref = _remote_schema .remote (self ._dataset_ref )
65
+ return ray .get (schema_ref )
66
+
67
+ def sort (self , key , descending = False ) -> "RemoteDatasetProxy" :
68
+ """Execute sort remotely on cluster workers."""
69
+
70
+ @ray .remote
71
+ def _remote_sort (dataset , sort_key , desc ):
72
+ return dataset .sort (sort_key , descending = desc )
73
+
74
+ new_ref = _remote_sort .remote (self ._dataset_ref , key , descending )
75
+ return RemoteDatasetProxy (new_ref )
76
+
77
+ def limit (self , count ) -> "RemoteDatasetProxy" :
78
+ """Execute limit remotely on cluster workers."""
79
+
80
+ @ray .remote
81
+ def _remote_limit (dataset , limit_count ):
82
+ return dataset .limit (limit_count )
83
+
84
+ new_ref = _remote_limit .remote (self ._dataset_ref , count )
85
+ return RemoteDatasetProxy (new_ref )
86
+
87
+ def union (self , other ) -> "RemoteDatasetProxy" :
88
+ """Execute union remotely on cluster workers."""
89
+
90
+ @ray .remote
91
+ def _remote_union (dataset1 , dataset2 ):
92
+ return dataset1 .union (dataset2 )
93
+
94
+ new_ref = _remote_union .remote (self ._dataset_ref , other ._dataset_ref )
95
+ return RemoteDatasetProxy (new_ref )
96
+
97
+ def materialize (self ) -> "RemoteDatasetProxy" :
98
+ """Execute materialize remotely on cluster workers."""
99
+
100
+ @ray .remote
101
+ def _remote_materialize (dataset ):
102
+ return dataset .materialize ()
103
+
104
+ new_ref = _remote_materialize .remote (self ._dataset_ref )
105
+ return RemoteDatasetProxy (new_ref )
106
+
107
+ def count (self ) -> int :
108
+ """Execute count remotely and return result."""
109
+
110
+ @ray .remote
111
+ def _remote_count (dataset ):
112
+ return dataset .count ()
113
+
114
+ result_ref = _remote_count .remote (self ._dataset_ref )
115
+ return ray .get (result_ref )
116
+
117
+ def take (self , n = 20 ) -> list :
118
+ """Execute take remotely and return result."""
119
+
120
+ @ray .remote
121
+ def _remote_take (dataset , num ):
122
+ return dataset .take (num )
123
+
124
+ result_ref = _remote_take .remote (self ._dataset_ref , n )
125
+ return ray .get (result_ref )
126
+
127
+ def __getattr__ (self , name ):
128
+ """Catch any method calls that we haven't explicitly implemented."""
129
+ raise AttributeError (f"RemoteDatasetProxy has no attribute '{ name } '" )
130
+
131
+
132
+ def is_ray_data (data : Any ) -> bool :
133
+ """Check if data is a Ray Dataset or RemoteDatasetProxy."""
134
+ return isinstance (data , (Dataset , RemoteDatasetProxy ))
135
+
136
+
8
137
def normalize_timestamp_columns (
9
- data : Union [pd .DataFrame , Dataset ],
138
+ data : Union [pd .DataFrame , Dataset , Any ],
10
139
columns : Union [str , List [str ]],
11
140
inplace : bool = False ,
12
141
exclude_columns : Optional [List [str ]] = None ,
13
- ) -> Union [pd .DataFrame , Dataset ]:
142
+ ) -> Union [pd .DataFrame , Dataset , Any ]:
14
143
column_list = [columns ] if isinstance (columns , str ) else columns
15
144
exclude_columns = exclude_columns or []
16
145
@@ -21,7 +150,7 @@ def apply_normalization(series: pd.Series) -> pd.Series:
21
150
.astype ("datetime64[ns, UTC]" )
22
151
)
23
152
24
- if isinstance (data , Dataset ):
153
+ if is_ray_data (data ):
25
154
26
155
def normalize_batch (batch : pd .DataFrame ) -> pd .DataFrame :
27
156
for column in column_list :
@@ -35,6 +164,7 @@ def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame:
35
164
36
165
return data .map_batches (normalize_batch , batch_format = "pandas" )
37
166
else :
167
+ assert isinstance (data , pd .DataFrame )
38
168
if not inplace :
39
169
data = data .copy ()
40
170
for column in column_list :
@@ -44,13 +174,13 @@ def normalize_batch(batch: pd.DataFrame) -> pd.DataFrame:
44
174
45
175
46
176
def ensure_timestamp_compatibility (
47
- data : Union [pd .DataFrame , Dataset ],
177
+ data : Union [pd .DataFrame , Dataset , Any ],
48
178
timestamp_fields : List [str ],
49
179
inplace : bool = False ,
50
- ) -> Union [pd .DataFrame , Dataset ]:
180
+ ) -> Union [pd .DataFrame , Dataset , Any ]:
51
181
from feast .utils import make_df_tzaware
52
182
53
- if isinstance (data , Dataset ):
183
+ if is_ray_data (data ):
54
184
55
185
def ensure_compatibility (batch : pd .DataFrame ) -> pd .DataFrame :
56
186
batch = make_df_tzaware (batch )
@@ -65,6 +195,7 @@ def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame:
65
195
66
196
return data .map_batches (ensure_compatibility , batch_format = "pandas" )
67
197
else :
198
+ assert isinstance (data , pd .DataFrame )
68
199
if not inplace :
69
200
data = data .copy ()
70
201
from feast .utils import make_df_tzaware
@@ -77,22 +208,24 @@ def ensure_compatibility(batch: pd.DataFrame) -> pd.DataFrame:
77
208
78
209
79
210
def apply_field_mapping (
80
- data : Union [pd .DataFrame , Dataset ], field_mapping : Dict [str , str ]
81
- ) -> Union [pd .DataFrame , Dataset ]:
211
+ data : Union [pd .DataFrame , Dataset , Any ],
212
+ field_mapping : Dict [str , str ],
213
+ ) -> Union [pd .DataFrame , Dataset , Any ]:
82
214
def rename_columns (df : pd .DataFrame ) -> pd .DataFrame :
83
215
return df .rename (columns = field_mapping )
84
216
85
- if isinstance (data , Dataset ):
217
+ if is_ray_data (data ):
86
218
return data .map_batches (rename_columns , batch_format = "pandas" )
87
219
else :
220
+ assert isinstance (data , pd .DataFrame )
88
221
return data .rename (columns = field_mapping )
89
222
90
223
91
224
def deduplicate_by_keys_and_timestamp (
92
- data : Union [pd .DataFrame , Dataset ],
225
+ data : Union [pd .DataFrame , Dataset , Any ],
93
226
join_keys : List [str ],
94
227
timestamp_columns : List [str ],
95
- ) -> Union [pd .DataFrame , Dataset ]:
228
+ ) -> Union [pd .DataFrame , Dataset , Any ]:
96
229
def deduplicate_batch (batch : pd .DataFrame ) -> pd .DataFrame :
97
230
if batch .empty :
98
231
return batch
@@ -110,9 +243,10 @@ def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame:
110
243
return deduped_batch
111
244
return batch
112
245
113
- if isinstance (data , Dataset ):
246
+ if is_ray_data (data ):
114
247
return data .map_batches (deduplicate_batch , batch_format = "pandas" )
115
248
else :
249
+ assert isinstance (data , pd .DataFrame )
116
250
return deduplicate_batch (data )
117
251
118
252
0 commit comments