-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[RLlib] Cleanup/simplification: Remove (new api stack) Model and TorchModel base classes. #55251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[RLlib] Cleanup/simplification: Remove (new api stack) Model and TorchModel base classes. #55251
Conversation
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…ve_new_api_stack_model_and_torchmodel_base_classes Signed-off-by: sven1977 <[email protected]> # Conflicts: # rllib/core/distribution/distribution.py # rllib/core/distribution/torch/torch_distribution.py # rllib/core/models/base.py # rllib/core/models/tests/test_catalog.py # rllib/core/models/torch/base.py # rllib/core/rl_module/rl_module.py
Signed-off-by: sven1977 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @sven1977, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
I've focused on a significant cleanup and simplification within RLlib's new API stack. My primary goal was to remove the Model
and TorchModel
base classes, streamlining the model architecture. This change simplifies the inheritance hierarchy for various neural network components, allowing them to directly inherit from framework-specific modules (like PyTorch's nn.Module
) or more specialized abstract base classes. This refactoring aims to reduce complexity and improve maintainability of the model definitions.
Highlights
- Base Class Removal: I've completely removed the
Model
andTorchModel
base classes from the new API stack, simplifying the core model hierarchy. - Encoder Refactoring: The
Encoder
class now directly inherits fromabc.ABC
, and its primary forward pass method has been standardized toforward
. - PyTorch Model Simplification: All PyTorch-specific model components (encoders and heads) have been updated to inherit directly from
Encoder
ornn.Module
, removing their dependency on the now-removedTorchModel
. - RLModule Cleanup: Several deprecated input and output specification methods have been cleaned up from the
RLModule
class. - ModelCatalog Refinement: The
ModelCatalog
has been streamlined by removing verbose docstrings and marking it explicitly as part of theOldAPIStack
.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a nice simplification of the RLlib model hierarchy by removing the Model
and TorchModel
base classes. The changes are mostly cleanups and refactoring. I've found a critical issue where the Torch encoders are no longer nn.Module
s, which will prevent them from being trained. I also have a couple of medium-severity suggestions to improve code clarity and maintainability regarding type hints and a hardcoded framework string in a base class.
class TorchMLPEncoder(Encoder): | ||
def __init__(self, config: MLPEncoderConfig) -> None: | ||
TorchModel.__init__(self, config) | ||
Encoder.__init__(self, config) | ||
super().__init__(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This encoder needs to be a torch.nn.Module
to be correctly handled by PyTorch and have its parameters registered with the parent RLModule
. Please inherit from nn.Module
and call its constructor.
This issue applies to all other encoder classes in this file (TorchCNNEncoder
, TorchGRUEncoder
, TorchLSTMEncoder
).
class TorchMLPEncoder(Encoder): | |
def __init__(self, config: MLPEncoderConfig) -> None: | |
TorchModel.__init__(self, config) | |
Encoder.__init__(self, config) | |
super().__init__(config) | |
class TorchMLPEncoder(Encoder, nn.Module): | |
def __init__(self, config: MLPEncoderConfig) -> None: | |
nn.Module.__init__(self) | |
super().__init__(config) |
@@ -127,7 +127,7 @@ def build_encoder(self, framework: str) -> Encoder: | |||
) | |||
|
|||
@OverrideToImplementCustomLogic | |||
def build_pi_head(self, framework: str) -> Model: | |||
def build_pi_head(self, framework: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type hint was removed. It's good practice to keep it for type checking and code clarity. Since this method builds a torch model when framework='torch'
, you could add -> "nn.Module"
. You'll need to add from typing import TYPE_CHECKING
and if TYPE_CHECKING: from torch import nn
to the file.
This also applies to build_vf_head
.
def build_pi_head(self, framework: str): | |
def build_pi_head(self, framework: str) -> "nn.Module": |
@@ -262,8 +88,13 @@ def build(self, framework: str): | |||
|
|||
""" | |||
|
|||
framework = "torch" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoding framework = "torch"
in the Encoder
base class makes it framework-specific, which contradicts its purpose as a framework-agnostic base class as stated in the docstring. This could lead to confusion and issues if other frameworks are to be supported in the new API stack. It would be cleaner to define the framework
attribute in the framework-specific encoder implementations (e.g., in rllib/core/models/torch/encoder.py
), and have the base Encoder
not define it, or set it to None
.
Cleanup/simplification: Remove (new api stack) Model and TorchModel base classes.
Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.