Skip to content

Commit d10b556

Browse files
authored
[Train] Remove the subclass relationship between RunConfig and RunConfigV1 (#54293)
In RunConfig, it subclasses RunConfigV1, whose `__post_init__` is using v1 configs to initialize. Now we remove the subclass relationship to avoid unexpected bugs. --------- Signed-off-by: xgui <[email protected]>
1 parent 263c7e1 commit d10b556

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

python/ray/train/v2/api/config.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
import logging
12
from dataclasses import dataclass
3+
from pathlib import Path
24
from typing import TYPE_CHECKING, List, Optional, Union
35

6+
import pyarrow.fs
7+
48
from ray.air.config import (
9+
CheckpointConfig,
510
FailureConfig as FailureConfigV1,
6-
RunConfig as RunConfigV1,
711
ScalingConfig as ScalingConfigV1,
812
)
913
from ray.runtime_env import RuntimeEnv
@@ -13,11 +17,15 @@
1317
TRAINER_RESOURCES_DEPRECATION_MESSAGE,
1418
)
1519
from ray.train.v2._internal.util import date_str
20+
from ray.util.annotations import PublicAPI
1621

1722
if TYPE_CHECKING:
1823
from ray.train import UserCallback
1924

2025

26+
logger = logging.getLogger(__name__)
27+
28+
2129
@dataclass
2230
class ScalingConfig(ScalingConfigV1):
2331
"""Configuration for scaling training.
@@ -97,7 +105,8 @@ def __post_init__(self):
97105

98106

99107
@dataclass
100-
class RunConfig(RunConfigV1):
108+
@PublicAPI(stability="stable")
109+
class RunConfig:
101110
"""Runtime configuration for training runs.
102111
103112
Args:
@@ -119,6 +128,11 @@ class RunConfig(RunConfigV1):
119128
for all Ray Train worker actors.
120129
"""
121130

131+
name: Optional[str] = None
132+
storage_path: Optional[str] = None
133+
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None
134+
failure_config: Optional[FailureConfig] = None
135+
checkpoint_config: Optional[CheckpointConfig] = None
122136
callbacks: Optional[List["UserCallback"]] = None
123137
worker_runtime_env: Optional[Union[dict, RuntimeEnv]] = None
124138

@@ -129,7 +143,19 @@ class RunConfig(RunConfigV1):
129143
log_to_file: str = _DEPRECATED
130144

131145
def __post_init__(self):
132-
super().__post_init__()
146+
from ray.train.constants import DEFAULT_STORAGE_PATH
147+
148+
if self.storage_path is None:
149+
self.storage_path = DEFAULT_STORAGE_PATH
150+
151+
if not self.failure_config:
152+
self.failure_config = FailureConfig()
153+
154+
if not self.checkpoint_config:
155+
self.checkpoint_config = CheckpointConfig()
156+
157+
if isinstance(self.storage_path, Path):
158+
self.storage_path = self.storage_path.as_posix()
133159

134160
# TODO(justinvyu): Add link to migration guide.
135161
run_config_deprecation_message = (

python/ray/train/v2/tests/test_v2_api.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,26 @@ def test_api_configs(operation, raise_error):
3030
pytest.fail(f"Default Operation raised an exception: {e}")
3131

3232

33+
def test_run_config_default_failure_config():
34+
"""Test that RunConfig creates a default FailureConfig from v2 API, not v1."""
35+
# Import the v2 FailureConfig and v1 FailureConfig for comparison
36+
from ray.train.v2.api.config import FailureConfig as FailureConfigV2
37+
38+
# Create a RunConfig without specifying failure_config
39+
run_config = RunConfig()
40+
41+
# Verify that the default failure_config is the v2 version
42+
assert run_config.failure_config is not None
43+
assert isinstance(run_config.failure_config, FailureConfigV2)
44+
assert type(run_config.failure_config) is FailureConfigV2
45+
46+
# Verify that explicitly passing None also creates v2 FailureConfig
47+
run_config_explicit_none = RunConfig(failure_config=None)
48+
assert run_config_explicit_none.failure_config is not None
49+
assert isinstance(run_config_explicit_none.failure_config, FailureConfigV2)
50+
assert type(run_config_explicit_none.failure_config) is FailureConfigV2
51+
52+
3353
def test_scaling_config_total_resources():
3454
"""Test the patched scaling config total resources calculation."""
3555
num_workers = 2

0 commit comments

Comments
 (0)