Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 77 additions & 18 deletions modin/core/storage_formats/pandas/query_compiler_caster.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,16 @@ def _normalize_class_name(class_of_wrapped_fn: Optional[str]) -> str:

_AUTO_SWITCH_CLASS = defaultdict[BackendAndClassName, set[str]]

# For pre-op switch methods, we store method_name -> is_arg_based mapping
# where is_arg_based=True means switch only if parameters are unsupported
_AUTO_SWITCH_PRE_OP_CLASS = defaultdict[BackendAndClassName, dict[str, bool]]

_CLASS_AND_BACKEND_TO_POST_OP_SWITCH_METHODS: _AUTO_SWITCH_CLASS = _AUTO_SWITCH_CLASS(
set
)

_CLASS_AND_BACKEND_TO_PRE_OP_SWITCH_METHODS: _AUTO_SWITCH_CLASS = _AUTO_SWITCH_CLASS(
set
_CLASS_AND_BACKEND_TO_PRE_OP_SWITCH_METHODS: _AUTO_SWITCH_PRE_OP_CLASS = _AUTO_SWITCH_PRE_OP_CLASS(
dict
)


Expand Down Expand Up @@ -621,21 +625,45 @@ def _maybe_switch_backend_pre_op(
to the new query compiler type.
"""
input_backend = input_qc.get_backend()
if (
function_name
in _CLASS_AND_BACKEND_TO_PRE_OP_SWITCH_METHODS[
BackendAndClassName(
backend=input_qc.get_backend(), class_name=class_of_wrapped_fn
backend_class_key = BackendAndClassName(
backend=input_qc.get_backend(), class_name=class_of_wrapped_fn
)

# Check if this function is registered for pre-op switch
registered_methods = _CLASS_AND_BACKEND_TO_PRE_OP_SWITCH_METHODS[backend_class_key]

if function_name in registered_methods:
is_arg_based = registered_methods[function_name]

if is_arg_based:
# Arg-based switch: only switch if parameters are unsupported
stay_cost = input_qc.stay_cost(
api_cls_name=class_of_wrapped_fn,
operation=function_name,
arguments=arguments,
)

# Only trigger switch if parameters are unsupported (COST_IMPOSSIBLE)
if stay_cost is not None and stay_cost >= QCCoercionCost.COST_IMPOSSIBLE:
result_backend = _get_backend_for_auto_switch(
input_qc=input_qc,
class_of_wrapped_fn=class_of_wrapped_fn,
function_name=function_name,
arguments=arguments,
)
else:
# Parameters are supported, no need to switch
result_backend = input_backend
else:
# Non-arg-based switch: always consider switching
result_backend = _get_backend_for_auto_switch(
input_qc=input_qc,
class_of_wrapped_fn=class_of_wrapped_fn,
function_name=function_name,
arguments=arguments,
)
]
):
result_backend = _get_backend_for_auto_switch(
input_qc=input_qc,
class_of_wrapped_fn=class_of_wrapped_fn,
function_name=function_name,
arguments=arguments,
)
else:
# No registration found, stay on current backend
result_backend = input_backend

def cast_to_qc(arg: Any) -> Any:
Expand Down Expand Up @@ -773,12 +801,18 @@ def _get_backend_for_auto_switch(

min_move_stay_delta = None
best_backend = starting_backend

all_backends_impossible = True

stay_cost = input_qc.stay_cost(
api_cls_name=class_of_wrapped_fn,
operation=function_name,
arguments=arguments,
)

# Check if the current backend can handle the workload
if stay_cost is not None and stay_cost < QCCoercionCost.COST_IMPOSSIBLE:
all_backends_impossible = False

data_max_shape = input_qc._max_shape()
emit_metric(
f"hybrid.auto.api.{class_of_wrapped_fn}.{function_name}.group.{metrics_group}",
Expand Down Expand Up @@ -835,6 +869,12 @@ def _get_backend_for_auto_switch(
# We can execute this workload if we need to, consider
# move_to_cost/transfer time in our decision
move_stay_delta = (move_to_cost + other_execute_cost) - stay_cost

# Check if this backend can handle the workload (both execution and transfer must be possible)
if (other_execute_cost < QCCoercionCost.COST_IMPOSSIBLE and
move_to_cost < QCCoercionCost.COST_IMPOSSIBLE):
all_backends_impossible = False

if move_stay_delta < 0 and (
min_move_stay_delta is None or move_stay_delta < min_move_stay_delta
):
Expand All @@ -861,6 +901,20 @@ def _get_backend_for_auto_switch(
+ f"{move_stay_delta}"
)

# Check if all backends are impossible and raise exception
if all_backends_impossible:
emit_metric(f"hybrid.auto.decision.impossible.group.{metrics_group}", 1)
get_logger().error(
f"All backends impossible for {class_of_wrapped_fn}.{function_name}: "
f"starting_backend={starting_backend}, stay_cost={stay_cost}"
)
ErrorMessage.not_implemented(
f"No available backend can handle the workload for operation "
f"{class_of_wrapped_fn}.{function_name}. All backends returned COST_IMPOSSIBLE. "
f"Current backend: {starting_backend}, stay_cost: {stay_cost}. "
f"This operation cannot be executed due to memory or capability constraints across all backends."
)

if best_backend == starting_backend:
emit_metric(f"hybrid.auto.decision.{best_backend}.group.{metrics_group}", 0)
get_logger().info(
Expand Down Expand Up @@ -1228,7 +1282,7 @@ def register_function_for_post_op_switch(


def register_function_for_pre_op_switch(
class_name: Optional[str], backend: str, method: str
class_name: Optional[str], backend: str, method: str, *, arg_based: bool = False
) -> None:
"""
Register a function for pre-operation backend switch.
Expand All @@ -1242,7 +1296,12 @@ def register_function_for_pre_op_switch(
Only consider switching when the starting backend is this one.
method : str
The name of the method to register.
arg_based : bool, default: False
If True, the switch will only be triggered if unsupported parameters are detected
for the operation, avoiding unnecessary backend switching when parameters
are supported. If False, the switch will always be considered (existing behavior).
"""
_CLASS_AND_BACKEND_TO_PRE_OP_SWITCH_METHODS[
BackendAndClassName(backend=backend, class_name=class_name)
].add(method)
][method] = arg_based