-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[RLlib] Cleanup/simplification: Remove (new api stack) Model and TorchModel base classes. #55251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,9 @@ | ||
import abc | ||
from typing import List, Optional, Tuple, Union | ||
|
||
|
||
from ray.rllib.core.columns import Columns | ||
from ray.rllib.core.models.configs import ModelConfig | ||
from ray.rllib.core.models.specs.specs_base import Spec | ||
from ray.rllib.policy.rnn_sequencing import get_fold_unfold_fns | ||
from ray.rllib.utils.annotations import ExperimentalAPI, override | ||
from ray.rllib.utils.typing import TensorType | ||
from ray.util.annotations import DeveloperAPI | ||
|
||
# Top level keys that unify model i/o. | ||
|
@@ -17,178 +13,8 @@ | |
CRITIC: str = "critic" | ||
|
||
|
||
@ExperimentalAPI | ||
class Model(abc.ABC): | ||
"""Framework-agnostic base class for RLlib models. | ||
|
||
Models are low-level neural network components that offer input- and | ||
output-specification, a forward method, and a get_initial_state method. Models | ||
are composed in RLModules. | ||
|
||
Usage Example together with ModelConfig: | ||
|
||
.. testcode:: | ||
|
||
from ray.rllib.core.models.base import Model | ||
from ray.rllib.core.models.configs import ModelConfig | ||
from ray.rllib.core.models.configs import ModelConfig | ||
from dataclasses import dataclass | ||
|
||
class MyModel(Model): | ||
def __init__(self, config): | ||
super().__init__(config) | ||
self.my_param = config.my_param * 2 | ||
|
||
def _forward(self, input_dict): | ||
return input_dict["obs"] * self.my_param | ||
|
||
|
||
@dataclass | ||
class MyModelConfig(ModelConfig): | ||
my_param: int = 42 | ||
|
||
def build(self, framework: str): | ||
if framework == "bork": | ||
return MyModel(self) | ||
|
||
|
||
config = MyModelConfig(my_param=3) | ||
model = config.build(framework="bork") | ||
print(model._forward({"obs": 1})) | ||
|
||
.. testoutput:: | ||
|
||
6 | ||
|
||
""" | ||
|
||
def __init__(self, config: ModelConfig): | ||
self.config = config | ||
|
||
def __init_subclass__(cls, **kwargs): | ||
# Automatically add a __post_init__ method to all subclasses of Model. | ||
# This method is called after the __init__ method of the subclass. | ||
def init_decorator(previous_init): | ||
def new_init(self, *args, **kwargs): | ||
previous_init(self, *args, **kwargs) | ||
if type(self) is cls: | ||
self.__post_init__() | ||
|
||
return new_init | ||
|
||
cls.__init__ = init_decorator(cls.__init__) | ||
|
||
def __post_init__(self): | ||
"""Called automatically after the __init__ method of the subclasses. | ||
|
||
The module first calls the __init__ method of the subclass, With in the | ||
__init__ you should call the super().__init__ method. Then after the __init__ | ||
method of the subclass is called, the __post_init__ method is called. | ||
|
||
This is a good place to do any initialization that requires access to the | ||
subclass's attributes. | ||
""" | ||
self._input_specs = self.get_input_specs() | ||
self._output_specs = self.get_output_specs() | ||
|
||
def get_input_specs(self) -> Optional[Spec]: | ||
"""Returns the input specs of this model. | ||
|
||
Override `get_input_specs` to define your own input specs. | ||
This method should not be called often, e.g. every forward pass. | ||
Instead, it should be called once at instantiation to define Model.input_specs. | ||
|
||
Returns: | ||
Spec: The input specs. | ||
""" | ||
return None | ||
|
||
def get_output_specs(self) -> Optional[Spec]: | ||
"""Returns the output specs of this model. | ||
|
||
Override `get_output_specs` to define your own output specs. | ||
This method should not be called often, e.g. every forward pass. | ||
Instead, it should be called once at instantiation to define Model.output_specs. | ||
|
||
Returns: | ||
Spec: The output specs. | ||
""" | ||
return None | ||
|
||
@property | ||
def input_specs(self) -> Spec: | ||
"""Returns the input spec of this model.""" | ||
return self._input_specs | ||
|
||
@input_specs.setter | ||
def input_specs(self, spec: Spec) -> None: | ||
raise ValueError( | ||
"`input_specs` cannot be set directly. Override " | ||
"Model.get_input_specs() instead. Set Model._input_specs if " | ||
"you want to override this behavior." | ||
) | ||
|
||
@property | ||
def output_specs(self) -> Spec: | ||
"""Returns the output specs of this model.""" | ||
return self._output_specs | ||
|
||
@output_specs.setter | ||
def output_specs(self, spec: Spec) -> None: | ||
raise ValueError( | ||
"`output_specs` cannot be set directly. Override " | ||
"Model.get_output_specs() instead. Set Model._output_specs if " | ||
"you want to override this behavior." | ||
) | ||
|
||
def get_initial_state(self) -> Union[dict, List[TensorType]]: | ||
"""Returns the initial state of the Model. | ||
|
||
It can be left empty if this Model is not stateful. | ||
""" | ||
return dict() | ||
|
||
@abc.abstractmethod | ||
def _forward(self, input_dict: dict, **kwargs) -> dict: | ||
"""Returns the output of this model for the given input. | ||
|
||
This method is called by the forwarding method of the respective framework | ||
that is itself wrapped by RLlib in order to check model inputs and outputs. | ||
|
||
Args: | ||
input_dict: The input tensors. | ||
**kwargs: Forward compatibility kwargs. | ||
|
||
Returns: | ||
dict: The output tensors. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def get_num_parameters(self) -> Tuple[int, int]: | ||
"""Returns a tuple of (num trainable params, num non-trainable params).""" | ||
|
||
@abc.abstractmethod | ||
def _set_to_dummy_weights(self, value_sequence=(-0.02, -0.01, 0.01, 0.02)) -> None: | ||
"""Helper method to set all weights to deterministic dummy values. | ||
|
||
Calling this method on two `Models` that have the same architecture using | ||
the exact same `value_sequence` arg should make both models output the exact | ||
same values on arbitrary inputs. This will work, even if the two `Models` | ||
are of different DL frameworks. | ||
|
||
Args: | ||
value_sequence: Looping through the list of all parameters (weight matrices, | ||
bias tensors, etc..) of this model, in each iteration i, we set all | ||
values in this parameter to `value_sequence[i % len(value_sequence)]` | ||
(round robin). | ||
|
||
Example: | ||
TODO: | ||
""" | ||
|
||
|
||
@ExperimentalAPI | ||
class Encoder(Model, abc.ABC): | ||
@DeveloperAPI(stability="alpha") | ||
class Encoder(abc.ABC): | ||
"""The framework-agnostic base class for all RLlib encoders. | ||
|
||
Encoders are used to transform observations to a latent space. | ||
|
@@ -262,8 +88,13 @@ def build(self, framework: str): | |
|
||
""" | ||
|
||
framework = "torch" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hardcoding |
||
|
||
def __init__(self, config): | ||
self.config = config | ||
|
||
@abc.abstractmethod | ||
def _forward(self, input_dict: dict, **kwargs) -> dict: | ||
def forward(self, input_dict: dict, **kwargs) -> dict: | ||
"""Returns the latent of the encoder for the given inputs. | ||
|
||
This method is called by the forwarding method of the respective framework | ||
|
@@ -287,6 +118,9 @@ def _forward(self, input_dict: dict, **kwargs) -> dict: | |
The output tensors. Must contain at a minimum the key ENCODER_OUT. | ||
""" | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self.forward(*args, **kwargs) | ||
|
||
|
||
@ExperimentalAPI | ||
class ActorCriticEncoder(Encoder): | ||
|
@@ -298,8 +132,6 @@ class ActorCriticEncoder(Encoder): | |
assumption that they have the same input and output specs. | ||
""" | ||
|
||
framework = None | ||
|
||
def __init__(self, config: ModelConfig) -> None: | ||
super().__init__(config) | ||
|
||
|
@@ -315,8 +147,8 @@ def __init__(self, config: ModelConfig) -> None: | |
framework=self.framework | ||
) | ||
|
||
@override(Model) | ||
def _forward(self, inputs: dict, **kwargs) -> dict: | ||
@override(Encoder) | ||
def forward(self, inputs: dict, **kwargs) -> dict: | ||
if self.config.shared: | ||
encoder_outs = self.encoder(inputs, **kwargs) | ||
return { | ||
|
@@ -363,8 +195,6 @@ class StatefulActorCriticEncoder(Encoder): | |
`(STATE_OUT, ACTOR)` and `(STATE_OUT, CRITIC)`, respectively. | ||
""" | ||
|
||
framework = None | ||
|
||
def __init__(self, config: ModelConfig) -> None: | ||
super().__init__(config) | ||
|
||
|
@@ -378,7 +208,6 @@ def __init__(self, config: ModelConfig) -> None: | |
framework=self.framework | ||
) | ||
|
||
@override(Model) | ||
def get_initial_state(self): | ||
if self.config.shared: | ||
return self.encoder.get_initial_state() | ||
|
@@ -388,8 +217,8 @@ def get_initial_state(self): | |
CRITIC: self.critic_encoder.get_initial_state(), | ||
} | ||
|
||
@override(Model) | ||
def _forward(self, inputs: dict, **kwargs) -> dict: | ||
@override(Encoder) | ||
def forward(self, inputs: dict, **kwargs) -> dict: | ||
outputs = {} | ||
|
||
if self.config.shared: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type hint was removed. It's good practice to keep it for type checking and code clarity. Since this method builds a torch model when
framework='torch'
, you could add-> "nn.Module"
. You'll need to addfrom typing import TYPE_CHECKING
andif TYPE_CHECKING: from torch import nn
to the file.This also applies to
build_vf_head
.