Skip to content

Commit b18c4fe

Browse files
committed
use groupby overrides
1 parent 8e48bf3 commit b18c4fe

File tree

16 files changed

+1923
-1874
lines changed

16 files changed

+1923
-1874
lines changed

docs/source/modin/groupby.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
GroupBy
33
=============================
44

5-
.. currentmodule:: snowflake.snowpark.modin.plugin.extensions.groupby_overrides
5+
.. currentmodule:: modin.pandas.groupby
66
.. rubric:: :doc:`All supported groupby APIs <supported/groupby_supported>`
77

88
.. rubric:: Indexing, iteration

src/snowflake/snowpark/modin/plugin/__init__.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
import snowflake.snowpark.modin.plugin.extensions.dataframe_overrides # isort: skip # noqa: E402,F401
5858
import snowflake.snowpark.modin.plugin.extensions.series_extensions # isort: skip # noqa: E402,F401
5959
import snowflake.snowpark.modin.plugin.extensions.series_overrides # isort: skip # noqa: E402,F401
60+
import snowflake.snowpark.modin.plugin.extensions.dataframe_groupby_overrides # isort: skip # noqa: E402,F401
61+
import snowflake.snowpark.modin.plugin.extensions.series_groupby_overrides # isort: skip # noqa: E402,F401
6062

6163
# === INITIALIZE DOCSTRINGS ===
6264
# These imports also all need to occur after modin + pandas dependencies are validated.
@@ -72,8 +74,9 @@
7274
import modin.pandas.series_utils # type: ignore[import] # isort: skip # noqa: E402
7375

7476
# Hybrid Mode Imports
75-
from modin.core.storage_formats.pandas.query_compiler_caster import (
77+
from modin.core.storage_formats.pandas.query_compiler_caster import ( # isort: skip # noqa: E402
7678
_GENERAL_EXTENSIONS,
79+
_NON_EXTENDABLE_ATTRIBUTES,
7780
register_function_for_post_op_switch,
7881
register_function_for_pre_op_switch,
7982
)
@@ -182,18 +185,25 @@
182185
"cumsum",
183186
]
184187

185-
post_op_switch_points = [
186-
{"class_name": None, "method": "read_snowflake"},
187-
{"class_name": "Series", "method": "value_counts"},
188-
{"class_name": "DataFrame", "method": "value_counts"},
189-
# Series.agg can return a Series if a list of aggregations is provided
190-
{"class_name": "Series", "method": "agg"},
191-
{"class_name": "Series", "method": "aggregate"},
192-
] + [{"class_name": "DataFrame", "method": agg_method} for agg_method in aggregations] + [
193-
{"class_name": "DataFrameGroupBy", "method": agg_method} for agg_method in aggregations
194-
] + [
195-
{"class_name": "SeriesGroupBy", "method": agg_method} for agg_method in aggregations
196-
]
188+
post_op_switch_points = (
189+
[
190+
{"class_name": None, "method": "read_snowflake"},
191+
{"class_name": "Series", "method": "value_counts"},
192+
{"class_name": "DataFrame", "method": "value_counts"},
193+
# Series.agg can return a Series if a list of aggregations is provided
194+
{"class_name": "Series", "method": "agg"},
195+
{"class_name": "Series", "method": "aggregate"},
196+
]
197+
+ [{"class_name": "DataFrame", "method": agg_method} for agg_method in aggregations]
198+
+ [
199+
{"class_name": "DataFrameGroupBy", "method": agg_method}
200+
for agg_method in aggregations
201+
]
202+
+ [
203+
{"class_name": "SeriesGroupBy", "method": agg_method}
204+
for agg_method in aggregations
205+
]
206+
)
197207

198208
pre_op_points = []
199209
for point in pre_op_switch_points:
@@ -214,15 +224,14 @@
214224
)
215225

216226

217-
218227
# Remove print statements for the customer validation release
219-
#print("#################### HYBRID MODE #################")
220-
#print(f"######## Registered Pre-Operation Methods ########\n{', '.join(pre_op_points)}")
221-
#print("##################################################")
222-
#print(
228+
# print("#################### HYBRID MODE #################")
229+
# print(f"######## Registered Pre-Operation Methods ########\n{', '.join(pre_op_points)}")
230+
# print("##################################################")
231+
# print(
223232
# f"######## Registered_Post-Operation_Methods #######\n{', '.join(post_op_points)}"
224-
#)
225-
#print("##################################################\n")
233+
# )
234+
# print("##################################################\n")
226235

227236
Backend.set_active_backends(["Snowflake", "Pandas"])
228237

@@ -255,10 +264,6 @@
255264
register_base_accessor,
256265
)
257266
from modin.pandas.accessor import ModinAPI # isort: skip # noqa: E402,F401
258-
from modin.core.storage_formats.pandas.query_compiler_caster import ( # isort: skip # noqa: E402,F401
259-
_NON_EXTENDABLE_ATTRIBUTES,
260-
_GENERAL_EXTENSIONS,
261-
)
262267

263268
from snowflake.snowpark.modin.plugin._internal.telemetry import ( # isort: skip # noqa: E402,F401
264269
TELEMETRY_PRIVATE_METHODS,

src/snowflake/snowpark/modin/plugin/_internal/telemetry.py

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,6 @@ class TelemetryMeta(type):
568568
def __new__(
569569
cls, name: str, bases: tuple, attrs: dict[str, Any]
570570
) -> Union[
571-
"snowflake.snowpark.modin.plugin.extensions.groupby_overrides.DataFrameGroupBy",
572571
"snowflake.snowpark.modin.plugin.extensions.resample_overrides.Resampler",
573572
"snowflake.snowpark.modin.plugin.extensions.window_overrides.Window",
574573
"snowflake.snowpark.modin.plugin.extensions.window_overrides.Rolling",
@@ -581,7 +580,6 @@ def __new__(
581580
with ``snowpark_pandas_telemetry_api_usage`` telemetry decorator.
582581
Method arguments returned by _get_kwargs_telemetry are collected otherwise set telemetry_args=list().
583582
TelemetryMeta is only set as the metaclass of:
584-
snowflake.snowpark.modin.plugin.extensions.groupby_overrides.DataFrameGroupBy,
585583
snowflake.snowpark.modin.plugin.extensions.resample_overrides.Resampler,
586584
snowflake.snowpark.modin.plugin.extensions.window_overrides.Window,
587585
snowflake.snowpark.modin.plugin.extensions.window_overrides.Rolling, and their subclasses.
@@ -593,7 +591,6 @@ def __new__(
593591
attrs (Dict[str, Any]): The attributes of the class.
594592
595593
Returns:
596-
Union[snowflake.snowpark.modin.plugin.extensions.groupby_overrides.DataFrameGroupBy,
597594
snowflake.snowpark.modin.plugin.extensions.resample_overrides.Resampler,
598595
snowflake.snowpark.modin.plugin.extensions.window_overrides.Window,
599596
snowflake.snowpark.modin.plugin.extensions.window_overrides.Rolling]:
@@ -620,13 +617,13 @@ def snowpark_pandas_api_watcher(api_name: str, _time: Union[int, float]) -> None
620617
if len(tokens) >= 2 and tokens[0] == "pandas-api":
621618
modin_api_call_history.append(tokens[1])
622619

623-
hybrid_switch_log = native_pd.DataFrame(
624-
{}
625-
)
620+
621+
hybrid_switch_log = native_pd.DataFrame({})
622+
626623

627624
@cached(cache={})
628625
def get_user_source_location(mode, group) -> str:
629-
626+
630627
import inspect
631628

632629
stack = inspect.stack()
@@ -644,33 +641,42 @@ def get_user_source_location(mode, group) -> str:
644641
and frame_before_snowpandas.code_context is not None
645642
):
646643
location = frame_before_snowpandas.code_context[0].replace("\n", "")
647-
return {'mode': mode, 'group': group, 'location': location }
644+
return {"mode": mode, "group": group, "location": location}
645+
648646

649647
def get_hybrid_switch_log():
650648
global hybrid_switch_log
651649
return hybrid_switch_log.copy()
652650

651+
653652
def add_to_hybrid_switch_log(metrics: dict):
654653
global hybrid_switch_log
655654
try:
656-
mode = metrics['mode']
657-
source = get_user_source_location(mode, metrics['group'])['location']
655+
mode = metrics["mode"]
656+
source = get_user_source_location(mode, metrics["group"])["location"]
658657
if len(source) > 40:
659658
source = source[0:17] + "..." + source[-20:-1] + source[-1]
660-
hybrid_switch_log = native_pd.concat([hybrid_switch_log,
661-
native_pd.DataFrame({'source': [source],
662-
'mode': [metrics['mode']],
663-
'group': [metrics['group']],
664-
'metric': [metrics['metric']],
665-
'submetric': [metrics['submetric'] or None],
666-
'value': [metrics['value']],
667-
'from': [metrics['from'] if 'from' in metrics else None],
668-
'to': [metrics['to'] if 'to' in metrics else None],
669-
})])
659+
hybrid_switch_log = native_pd.concat(
660+
[
661+
hybrid_switch_log,
662+
native_pd.DataFrame(
663+
{
664+
"source": [source],
665+
"mode": [metrics["mode"]],
666+
"group": [metrics["group"]],
667+
"metric": [metrics["metric"]],
668+
"submetric": [metrics["submetric"] or None],
669+
"value": [metrics["value"]],
670+
"from": [metrics["from"] if "from" in metrics else None],
671+
"to": [metrics["to"] if "to" in metrics else None],
672+
}
673+
),
674+
]
675+
)
670676
except Exception as e:
671677
print(f"Exception: {type(e).__name__} - {e}")
672-
673-
678+
679+
674680
def hybrid_metrics_watcher(metric_name: str, value: Union[int, float]) -> None:
675681
if metric_name.startswith("modin.hybrid.auto"):
676682
tokens = metric_name.split(".")
@@ -688,35 +694,41 @@ def hybrid_metrics_watcher(metric_name: str, value: Union[int, float]) -> None:
688694
if len(tokens) == 10:
689695
submetric = tokens[8]
690696
group = tokens[9]
691-
add_to_hybrid_switch_log({'mode': 'single',
692-
'from': from_engine,
693-
'to': to_engine,
694-
'metric': metric,
695-
'submetric': submetric,
696-
'group': group,
697-
'value': value})
697+
add_to_hybrid_switch_log(
698+
{
699+
"mode": "single",
700+
"from": from_engine,
701+
"to": to_engine,
702+
"metric": metric,
703+
"submetric": submetric,
704+
"group": group,
705+
"value": value,
706+
}
707+
)
698708
if metric_name.startswith("modin.hybrid.cast"):
699709
tokens = metric_name.split(".")
700710
to_engine = None
701711
metric = None
702712
submetric = None
703713
group = None
704-
if len(tokens) == 7 and tokens[3] == 'to' and tokens[5] == 'cost':
714+
if len(tokens) == 7 and tokens[3] == "to" and tokens[5] == "cost":
705715
to_engine = tokens[4]
706716
group = tokens[6]
707-
metric = 'cost'
708-
if len(tokens) == 6 and tokens[3] == 'decision':
717+
metric = "cost"
718+
if len(tokens) == 6 and tokens[3] == "decision":
709719
submetric = tokens[4]
710720
group = tokens[5]
711-
metric = 'decision'
712-
add_to_hybrid_switch_log({'mode': 'merge',
713-
'to': to_engine,
714-
'metric': metric,
715-
'submetric': submetric,
716-
'group': group,
717-
'value': value})
718-
719-
721+
metric = "decision"
722+
add_to_hybrid_switch_log(
723+
{
724+
"mode": "merge",
725+
"to": to_engine,
726+
"metric": metric,
727+
"submetric": submetric,
728+
"group": group,
729+
"value": value,
730+
}
731+
)
720732

721733

722734
def connect_modin_telemetry() -> None:

src/snowflake/snowpark/modin/plugin/_internal/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,13 +1187,13 @@ def create_ordered_dataframe_from_pandas(
11871187
]
11881188
),
11891189
)
1190-
ordered_df = cache_result(
1191-
OrderedDataFrame(
1192-
DataFrameReference(snowpark_df, snowflake_quoted_identifiers),
1193-
projected_column_snowflake_quoted_identifiers=snowflake_quoted_identifiers,
1194-
ordering_columns=ordering_columns,
1195-
row_position_snowflake_quoted_identifier=row_position_snowflake_quoted_identifier,
1196-
)
1190+
# TODO hybrid wraps this in cache_result, but this messes with query counts everywhere
1191+
# temporarily remove this for the sake of testing
1192+
ordered_df = OrderedDataFrame(
1193+
DataFrameReference(snowpark_df, snowflake_quoted_identifiers),
1194+
projected_column_snowflake_quoted_identifiers=snowflake_quoted_identifiers,
1195+
ordering_columns=ordering_columns,
1196+
row_position_snowflake_quoted_identifier=row_position_snowflake_quoted_identifier,
11971197
)
11981198
# Set the materialized row count
11991199
ordered_df.row_count = df.shape[0]

src/snowflake/snowpark/modin/plugin/extensions/base_extensions.py

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -43,61 +43,3 @@ def __array_function__(self, func: callable, types: tuple, args: tuple, kwargs:
4343
else:
4444
# per NEP18 we raise NotImplementedError so that numpy can intercept
4545
return NotImplemented # pragma: no cover
46-
47-
'''
48-
@register_base_override(name="__switcheroo__")
49-
def __switcheroo__(self, inplace=False, operation=""):
50-
if not is_autoswitch_enabled():
51-
return self
52-
from modin.core.storage_formats.pandas.native_query_compiler import (
53-
NativeQueryCompiler,
54-
)
55-
56-
cost_to_move = self._get_query_compiler().move_to_cost(
57-
NativeQueryCompiler, "", operation
58-
)
59-
60-
# figure out if this needs to be a standard API
61-
cost_to_stay = self._get_query_compiler().stay_cost(
62-
NativeQueryCompiler, "", operation
63-
)
64-
65-
# prototype explain
66-
import modin.pandas as pd
67-
68-
row_estimate = SnowflakeQueryCompiler._get_rows(self._get_query_compiler())
69-
import inspect
70-
71-
stack = inspect.stack()
72-
frame_before_snowpandas = None
73-
location = "<unknown>"
74-
for _i, f in enumerate(reversed(stack)):
75-
if f.filename is None:
76-
continue
77-
if "snowpark" in f.filename or "modin" in f.filename:
78-
break
79-
else:
80-
frame_before_snowpandas = f
81-
if (
82-
frame_before_snowpandas is not None
83-
and frame_before_snowpandas.code_context is not None
84-
):
85-
location = frame_before_snowpandas.code_context[0].replace("\n", "")
86-
pd.add_switcheroo_log(
87-
location,
88-
operation,
89-
"Snowflake",
90-
row_estimate,
91-
cost_to_stay,
92-
cost_to_move,
93-
"Pandas" if cost_to_move < cost_to_stay else "Snowflake",
94-
)
95-
96-
if cost_to_move < cost_to_stay:
97-
the_new_me_maybe = self.move_to("Pandas", inplace=inplace)
98-
if inplace:
99-
return self
100-
else:
101-
return the_new_me_maybe
102-
return self
103-
'''

src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2096,10 +2096,6 @@ def describe(
20962096
"""
20972097
Generate descriptive statistics.
20982098
"""
2099-
# TODO Remove Switcheroo
2100-
#self = self.__switcheroo__(inplace=True, operation="describe")
2101-
#if self.get_backend() != "Snowflake":
2102-
# return self.describe(percentiles, include, exclude)
21032099
# TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
21042100
percentiles = _refine_percentiles(percentiles)
21052101
data = self

0 commit comments

Comments
 (0)