Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions python/ray/train/v2/api/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union

from ray.air.config import (
CheckpointConfig,
FailureConfig as FailureConfigV1,
RunConfig as RunConfigV1,
ScalingConfig as ScalingConfigV1,
)
from ray.runtime_env import RuntimeEnv
Expand All @@ -13,10 +15,15 @@
TRAINER_RESOURCES_DEPRECATION_MESSAGE,
)
from ray.train.v2._internal.util import date_str
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.train import UserCallback

import pyarrow.fs

logger = logging.getLogger(__name__)


@dataclass
class ScalingConfig(ScalingConfigV1):
Expand Down Expand Up @@ -97,7 +104,8 @@ def __post_init__(self):


@dataclass
class RunConfig(RunConfigV1):
@PublicAPI(stability="stable")
class RunConfig:
"""Runtime configuration for training runs.

Args:
Expand All @@ -119,6 +127,11 @@ class RunConfig(RunConfigV1):
for all Ray Train worker actors.
"""

name: Optional[str] = None
storage_path: Optional[str] = None
storage_filesystem: Optional[pyarrow.fs.FileSystem] = None
failure_config: Optional[FailureConfig] = None
checkpoint_config: Optional[CheckpointConfig] = None
callbacks: Optional[List["UserCallback"]] = None
worker_runtime_env: Optional[Union[dict, RuntimeEnv]] = None

Expand All @@ -129,7 +142,19 @@ class RunConfig(RunConfigV1):
log_to_file: str = _DEPRECATED

def __post_init__(self):
super().__post_init__()
from ray.train.constants import DEFAULT_STORAGE_PATH

if self.storage_path is None:
self.storage_path = DEFAULT_STORAGE_PATH

if not self.failure_config:
self.failure_config = FailureConfig()

if not self.checkpoint_config:
self.checkpoint_config = CheckpointConfig()

if isinstance(self.storage_path, Path):
self.storage_path = self.storage_path.as_posix()

# TODO(justinvyu): Add link to migration guide.
run_config_deprecation_message = (
Expand Down