Skip to content

[RLlib] Cleanup/simplification: Remove (new api stack) Model and TorchModel base classes. #55251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Aug 5, 2025

Cleanup/simplification: Remove (new api stack) Model and TorchModel base classes.

Why are these changes needed?

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

sven1977 added 5 commits July 30, 2025 10:11
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
Signed-off-by: sven1977 <[email protected]>
…ve_new_api_stack_model_and_torchmodel_base_classes

Signed-off-by: sven1977 <[email protected]>

# Conflicts:
#	rllib/core/distribution/distribution.py
#	rllib/core/distribution/torch/torch_distribution.py
#	rllib/core/models/base.py
#	rllib/core/models/tests/test_catalog.py
#	rllib/core/models/torch/base.py
#	rllib/core/rl_module/rl_module.py
Signed-off-by: sven1977 <[email protected]>
@sven1977 sven1977 requested a review from a team as a code owner August 5, 2025 14:45
@sven1977 sven1977 enabled auto-merge (squash) August 5, 2025 14:45
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Aug 5, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @sven1977, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've focused on a significant cleanup and simplification within RLlib's new API stack. My primary goal was to remove the Model and TorchModel base classes, streamlining the model architecture. This change simplifies the inheritance hierarchy for various neural network components, allowing them to directly inherit from framework-specific modules (like PyTorch's nn.Module) or more specialized abstract base classes. This refactoring aims to reduce complexity and improve maintainability of the model definitions.

Highlights

  • Base Class Removal: I've completely removed the Model and TorchModel base classes from the new API stack, simplifying the core model hierarchy.
  • Encoder Refactoring: The Encoder class now directly inherits from abc.ABC, and its primary forward pass method has been standardized to forward.
  • PyTorch Model Simplification: All PyTorch-specific model components (encoders and heads) have been updated to inherit directly from Encoder or nn.Module, removing their dependency on the now-removed TorchModel.
  • RLModule Cleanup: Several deprecated input and output specification methods have been cleaned up from the RLModule class.
  • ModelCatalog Refinement: The ModelCatalog has been streamlined by removing verbose docstrings and marking it explicitly as part of the OldAPIStack.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a nice simplification of the RLlib model hierarchy by removing the Model and TorchModel base classes. The changes are mostly cleanups and refactoring. I've found a critical issue where the Torch encoders are no longer nn.Modules, which will prevent them from being trained. I also have a couple of medium-severity suggestions to improve code clarity and maintainability regarding type hints and a hardcoded framework string in a base class.

Comment on lines +28 to +30
class TorchMLPEncoder(Encoder):
def __init__(self, config: MLPEncoderConfig) -> None:
TorchModel.__init__(self, config)
Encoder.__init__(self, config)
super().__init__(config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This encoder needs to be a torch.nn.Module to be correctly handled by PyTorch and have its parameters registered with the parent RLModule. Please inherit from nn.Module and call its constructor.

This issue applies to all other encoder classes in this file (TorchCNNEncoder, TorchGRUEncoder, TorchLSTMEncoder).

Suggested change
class TorchMLPEncoder(Encoder):
def __init__(self, config: MLPEncoderConfig) -> None:
TorchModel.__init__(self, config)
Encoder.__init__(self, config)
super().__init__(config)
class TorchMLPEncoder(Encoder, nn.Module):
def __init__(self, config: MLPEncoderConfig) -> None:
nn.Module.__init__(self)
super().__init__(config)

@@ -127,7 +127,7 @@ def build_encoder(self, framework: str) -> Encoder:
)

@OverrideToImplementCustomLogic
def build_pi_head(self, framework: str) -> Model:
def build_pi_head(self, framework: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint was removed. It's good practice to keep it for type checking and code clarity. Since this method builds a torch model when framework='torch', you could add -> "nn.Module". You'll need to add from typing import TYPE_CHECKING and if TYPE_CHECKING: from torch import nn to the file.

This also applies to build_vf_head.

Suggested change
def build_pi_head(self, framework: str):
def build_pi_head(self, framework: str) -> "nn.Module":

@@ -262,8 +88,13 @@ def build(self, framework: str):

"""

framework = "torch"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding framework = "torch" in the Encoder base class makes it framework-specific, which contradicts its purpose as a framework-agnostic base class as stated in the docstring. This could lead to confusion and issues if other frameworks are to be supported in the new API stack. It would be cleaner to define the framework attribute in the framework-specific encoder implementations (e.g., in rllib/core/models/torch/encoder.py), and have the base Encoder not define it, or set it to None.

@sven1977 sven1977 disabled auto-merge August 5, 2025 14:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant