Skip to content

[RLlib] Cleanup/simplification: Remove (new api stack) Model and TorchModel base classes. #55251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions rllib/algorithms/ppo/ppo_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
MLPHeadConfig,
FreeLogStdMLPHeadConfig,
)
from ray.rllib.core.models.base import Encoder, ActorCriticEncoder, Model
from ray.rllib.core.models.base import Encoder, ActorCriticEncoder
from ray.rllib.utils import override
from ray.rllib.utils.annotations import OverrideToImplementCustomLogic

Expand Down Expand Up @@ -127,7 +127,7 @@ def build_encoder(self, framework: str) -> Encoder:
)

@OverrideToImplementCustomLogic
def build_pi_head(self, framework: str) -> Model:
def build_pi_head(self, framework: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint was removed. It's good practice to keep it for type checking and code clarity. Since this method builds a torch model when framework='torch', you could add -> "nn.Module". You'll need to add from typing import TYPE_CHECKING and if TYPE_CHECKING: from torch import nn to the file.

This also applies to build_vf_head.

Suggested change
def build_pi_head(self, framework: str):
def build_pi_head(self, framework: str) -> "nn.Module":

"""Builds the policy head.

The default behavior is to build the head from the pi_head_config.
Expand Down Expand Up @@ -176,7 +176,7 @@ def build_pi_head(self, framework: str) -> Model:
return self.pi_head_config.build(framework=framework)

@OverrideToImplementCustomLogic
def build_vf_head(self, framework: str) -> Model:
def build_vf_head(self, framework: str):
"""Builds the value function head.

The default behavior is to build the head from the vf_head_config.
Expand Down
201 changes: 15 additions & 186 deletions rllib/core/models/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import abc
from typing import List, Optional, Tuple, Union


from ray.rllib.core.columns import Columns
from ray.rllib.core.models.configs import ModelConfig
from ray.rllib.core.models.specs.specs_base import Spec
from ray.rllib.policy.rnn_sequencing import get_fold_unfold_fns
from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.typing import TensorType
from ray.util.annotations import DeveloperAPI

# Top level keys that unify model i/o.
Expand All @@ -17,178 +13,8 @@
CRITIC: str = "critic"


@ExperimentalAPI
class Model(abc.ABC):
"""Framework-agnostic base class for RLlib models.

Models are low-level neural network components that offer input- and
output-specification, a forward method, and a get_initial_state method. Models
are composed in RLModules.

Usage Example together with ModelConfig:

.. testcode::

from ray.rllib.core.models.base import Model
from ray.rllib.core.models.configs import ModelConfig
from ray.rllib.core.models.configs import ModelConfig
from dataclasses import dataclass

class MyModel(Model):
def __init__(self, config):
super().__init__(config)
self.my_param = config.my_param * 2

def _forward(self, input_dict):
return input_dict["obs"] * self.my_param


@dataclass
class MyModelConfig(ModelConfig):
my_param: int = 42

def build(self, framework: str):
if framework == "bork":
return MyModel(self)


config = MyModelConfig(my_param=3)
model = config.build(framework="bork")
print(model._forward({"obs": 1}))

.. testoutput::

6

"""

def __init__(self, config: ModelConfig):
self.config = config

def __init_subclass__(cls, **kwargs):
# Automatically add a __post_init__ method to all subclasses of Model.
# This method is called after the __init__ method of the subclass.
def init_decorator(previous_init):
def new_init(self, *args, **kwargs):
previous_init(self, *args, **kwargs)
if type(self) is cls:
self.__post_init__()

return new_init

cls.__init__ = init_decorator(cls.__init__)

def __post_init__(self):
"""Called automatically after the __init__ method of the subclasses.

The module first calls the __init__ method of the subclass, With in the
__init__ you should call the super().__init__ method. Then after the __init__
method of the subclass is called, the __post_init__ method is called.

This is a good place to do any initialization that requires access to the
subclass's attributes.
"""
self._input_specs = self.get_input_specs()
self._output_specs = self.get_output_specs()

def get_input_specs(self) -> Optional[Spec]:
"""Returns the input specs of this model.

Override `get_input_specs` to define your own input specs.
This method should not be called often, e.g. every forward pass.
Instead, it should be called once at instantiation to define Model.input_specs.

Returns:
Spec: The input specs.
"""
return None

def get_output_specs(self) -> Optional[Spec]:
"""Returns the output specs of this model.

Override `get_output_specs` to define your own output specs.
This method should not be called often, e.g. every forward pass.
Instead, it should be called once at instantiation to define Model.output_specs.

Returns:
Spec: The output specs.
"""
return None

@property
def input_specs(self) -> Spec:
"""Returns the input spec of this model."""
return self._input_specs

@input_specs.setter
def input_specs(self, spec: Spec) -> None:
raise ValueError(
"`input_specs` cannot be set directly. Override "
"Model.get_input_specs() instead. Set Model._input_specs if "
"you want to override this behavior."
)

@property
def output_specs(self) -> Spec:
"""Returns the output specs of this model."""
return self._output_specs

@output_specs.setter
def output_specs(self, spec: Spec) -> None:
raise ValueError(
"`output_specs` cannot be set directly. Override "
"Model.get_output_specs() instead. Set Model._output_specs if "
"you want to override this behavior."
)

def get_initial_state(self) -> Union[dict, List[TensorType]]:
"""Returns the initial state of the Model.

It can be left empty if this Model is not stateful.
"""
return dict()

@abc.abstractmethod
def _forward(self, input_dict: dict, **kwargs) -> dict:
"""Returns the output of this model for the given input.

This method is called by the forwarding method of the respective framework
that is itself wrapped by RLlib in order to check model inputs and outputs.

Args:
input_dict: The input tensors.
**kwargs: Forward compatibility kwargs.

Returns:
dict: The output tensors.
"""

@abc.abstractmethod
def get_num_parameters(self) -> Tuple[int, int]:
"""Returns a tuple of (num trainable params, num non-trainable params)."""

@abc.abstractmethod
def _set_to_dummy_weights(self, value_sequence=(-0.02, -0.01, 0.01, 0.02)) -> None:
"""Helper method to set all weights to deterministic dummy values.

Calling this method on two `Models` that have the same architecture using
the exact same `value_sequence` arg should make both models output the exact
same values on arbitrary inputs. This will work, even if the two `Models`
are of different DL frameworks.

Args:
value_sequence: Looping through the list of all parameters (weight matrices,
bias tensors, etc..) of this model, in each iteration i, we set all
values in this parameter to `value_sequence[i % len(value_sequence)]`
(round robin).

Example:
TODO:
"""


@ExperimentalAPI
class Encoder(Model, abc.ABC):
@DeveloperAPI(stability="alpha")
class Encoder(abc.ABC):
"""The framework-agnostic base class for all RLlib encoders.

Encoders are used to transform observations to a latent space.
Expand Down Expand Up @@ -262,8 +88,13 @@ def build(self, framework: str):

"""

framework = "torch"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding framework = "torch" in the Encoder base class makes it framework-specific, which contradicts its purpose as a framework-agnostic base class as stated in the docstring. This could lead to confusion and issues if other frameworks are to be supported in the new API stack. It would be cleaner to define the framework attribute in the framework-specific encoder implementations (e.g., in rllib/core/models/torch/encoder.py), and have the base Encoder not define it, or set it to None.


def __init__(self, config):
self.config = config

@abc.abstractmethod
def _forward(self, input_dict: dict, **kwargs) -> dict:
def forward(self, input_dict: dict, **kwargs) -> dict:
"""Returns the latent of the encoder for the given inputs.

This method is called by the forwarding method of the respective framework
Expand All @@ -287,6 +118,9 @@ def _forward(self, input_dict: dict, **kwargs) -> dict:
The output tensors. Must contain at a minimum the key ENCODER_OUT.
"""

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)


@ExperimentalAPI
class ActorCriticEncoder(Encoder):
Expand All @@ -298,8 +132,6 @@ class ActorCriticEncoder(Encoder):
assumption that they have the same input and output specs.
"""

framework = None

def __init__(self, config: ModelConfig) -> None:
super().__init__(config)

Expand All @@ -315,8 +147,8 @@ def __init__(self, config: ModelConfig) -> None:
framework=self.framework
)

@override(Model)
def _forward(self, inputs: dict, **kwargs) -> dict:
@override(Encoder)
def forward(self, inputs: dict, **kwargs) -> dict:
if self.config.shared:
encoder_outs = self.encoder(inputs, **kwargs)
return {
Expand Down Expand Up @@ -363,8 +195,6 @@ class StatefulActorCriticEncoder(Encoder):
`(STATE_OUT, ACTOR)` and `(STATE_OUT, CRITIC)`, respectively.
"""

framework = None

def __init__(self, config: ModelConfig) -> None:
super().__init__(config)

Expand All @@ -378,7 +208,6 @@ def __init__(self, config: ModelConfig) -> None:
framework=self.framework
)

@override(Model)
def get_initial_state(self):
if self.config.shared:
return self.encoder.get_initial_state()
Expand All @@ -388,8 +217,8 @@ def get_initial_state(self):
CRITIC: self.critic_encoder.get_initial_state(),
}

@override(Model)
def _forward(self, inputs: dict, **kwargs) -> dict:
@override(Encoder)
def forward(self, inputs: dict, **kwargs) -> dict:
outputs = {}

if self.config.shared:
Expand Down
6 changes: 0 additions & 6 deletions rllib/core/models/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,9 @@ class ModelConfig(abc.ABC):

Attributes:
input_dims: The input dimensions of the network
always_check_shapes: Whether to always check the inputs and outputs of the
model for the specifications. Input specifications are checked on failed
forward passes of the model regardless of this flag. If this flag is set
to `True`, inputs and outputs are checked on every call. This leads to
a slow-down and should only be used for debugging.
"""

input_dims: Union[List[int], Tuple[int]] = None
always_check_shapes: bool = False

@abc.abstractmethod
def build(self, framework: str):
Expand Down
Loading