Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c7e931d
Adopt `PreTrainedModelWrapper` for Hugging Face models
jon-tow Jan 18, 2023
195bf01
Adopt `PreTrainedModelWrapper` for Hugging Face models
jon-tow Jan 18, 2023
dfe2b51
Merge branch 'update-pre-commit' of https://github.com/jon-tow/trlx i…
jon-tow Jan 28, 2023
55252e1
Update documentation
jon-tow Jan 28, 2023
79870bb
Merge branch 'update-save-pretrained' of https://github.com/jon-tow/t…
jon-tow Jan 28, 2023
41432ff
Fix up broken merge
jon-tow Jan 28, 2023
e7338f0
Run pre-commit
jon-tow Jan 28, 2023
056efcb
Revert dtype change to `ILQLHead`
jon-tow Jan 28, 2023
f7f5189
Fix isort
jon-tow Jan 28, 2023
3af2b73
Format again...
jon-tow Jan 28, 2023
a9913aa
Revert newline deletion
jon-tow Jan 28, 2023
162b213
Revert unrelated changes and update docs
jon-tow Jan 28, 2023
d19f538
Update `README.md` saving example
jon-tow Jan 28, 2023
2e7fa31
Revert unrelated changes
jon-tow Jan 28, 2023
f24d793
Fix `dtype` access and hydra `return_dict`
jon-tow Jan 28, 2023
840bfb2
Force ref models into eval mode
jon-tow Jan 28, 2023
871b24e
Add unit tests for `AutoModel...`s
jon-tow Jan 29, 2023
e120c13
Commit work on fixing `T5Branch`
jon-tow Jan 30, 2023
c0d4792
Merge branch 'main' of https://github.com/CarperAI/trlx into update-s…
jon-tow Feb 10, 2023
c8f7127
Merge branch 'main' of https://github.com/CarperAI/trlx into update-s…
jon-tow Feb 11, 2023
8811899
refactor(models): move models out of trainer dir
jon-tow Feb 12, 2023
7298764
refactor(sft): remove `save_pretrained` override
jon-tow Feb 12, 2023
2cf6966
Run pre-commit
jon-tow Feb 12, 2023
3d7c99b
Ignore line length for links
jon-tow Feb 12, 2023
3809322
Merge branch 'main' of https://github.com/CarperAI/trlx into update-s…
jon-tow Feb 13, 2023
7761b3d
Revert naming to `base_model`
jon-tow Feb 13, 2023
68393b0
Rename hydra models for clarity
jon-tow Feb 13, 2023
db9bb93
Add `from_config` support
jon-tow Feb 14, 2023
7a6b160
cleanup docstrings
jon-tow Feb 14, 2023
91eb155
Revert T5 branch changes
jon-tow Feb 16, 2023
32800b2
Merge branch 'main' of https://github.com/CarperAI/trlx into update-s…
jon-tow Feb 21, 2023
e4aff47
Remove variadic params
jon-tow Feb 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ trainer.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'),
trainer.save_pretrained('/path/to/output/folder/')
```

🩹 Warning: Only the `AcceleratePPOTrainer` can write HuggingFace transformers to disk with `save_pretrained` at the moment, as ILQL trainers require inference behavior currently unsupported by available `transformers` architectures.

#### Use 🤗 Accelerate to launch distributed training

```bash
Expand All @@ -74,13 +72,13 @@ accelerate launch examples/simulacra.py

#### Use NeMo-Megatron to launch distributed training

Follow the setup instructions in the [NeMo README](./trlx/trainer/nemo).
Follow the setup instructions in the [NeMo README](./trlx/models/).

```bash
python examples/nemo_ilql_sentiments.py
```

For more usage see the [NeMo README](./trlx/trainer/nemo)
For more usage see the [NeMo README](./trlx/models)

#### Use Ray Tune to launch hyperparameter sweep

Expand Down
15 changes: 0 additions & 15 deletions docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,7 @@ Note that new trainers must be registered with ``trlx.trainer.register_trainer``
.. autoclass:: trlx.trainer.accelerate_ppo_trainer.AcceleratePPOTrainer
:members:

.. autoclass:: trlx.trainer.nn.ppo_models.CausalLMWithValueHead
:members:

.. autoclass:: trlx.trainer.nn.ppo_models.GPTModelBranch
:members:

.. autoclass:: trlx.trainer.nn.ppo_models.OPTModelBranch
:members:

.. autoclass:: trlx.trainer.nn.ppo_models.CausalLMHydraWithValueHead
:members:

**ILQL**

.. autoclass:: trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer
:members:

.. autoclass:: trlx.trainer.nn.ilql_models.CausalLMWithValueHeads
:members:
340 changes: 340 additions & 0 deletions tests/test_models.py

Large diffs are not rendered by default.

84 changes: 0 additions & 84 deletions tests/test_ppo.py

This file was deleted.

28 changes: 25 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import unittest

import accelerate
import pytest
import torch
Expand Down Expand Up @@ -68,9 +70,9 @@ def test_hf_attr_getters(model_name: str):
arch = transformers.AutoModelForCausalLM.from_config(config)

arch_getters = [
modeling_utils.hf_get_causal_base_model,
modeling_utils.hf_get_causal_final_norm,
modeling_utils.hf_get_causal_hidden_layers,
modeling_utils.hf_get_decoder,
modeling_utils.hf_get_decoder_final_norm,
modeling_utils.hf_get_decoder_blocks,
modeling_utils.hf_get_lm_head,
]
for get in arch_getters:
Expand Down Expand Up @@ -125,3 +127,23 @@ def test_parse_delta_kwargs(model_name):
)
for kwarg_mod in delta_kwargs["modified_modules"]:
assert kwarg_mod.endswith("a") or kwarg_mod.endswith("b"), "Parsed modified module should contain ['a', 'b']"


class TestStatistics(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.m = modeling_utils.RunningMoments()
cls.a1 = torch.arange(100, dtype=float)
cls.a2 = torch.ones(100, dtype=float)
cls.a3 = torch.exp(torch.arange(10, dtype=float))
cls.a4 = torch.tensor([-10, -1, 0, 1, 10], dtype=float)

def test_running_moments(self):
assert torch.isclose(self.m.update(self.a1)[1], self.a1.std(unbiased=True), atol=1e-6)
assert torch.isclose(self.m.update(self.a2)[1], self.a2.std(unbiased=True), atol=1e-6)
assert torch.isclose(self.m.update(self.a3)[1], self.a3.std(unbiased=True), atol=1e-6)
assert torch.isclose(self.m.update(self.a4)[1], self.a4.std(unbiased=True), atol=1e-6)

a = torch.hstack((self.a1, self.a2, self.a3, self.a4))
assert torch.isclose(self.m.mean, a.mean(), atol=1e-6)
assert torch.isclose(self.m.std, a.std(unbiased=True), atol=1e-6)
File renamed without changes.
File renamed without changes.
223 changes: 223 additions & 0 deletions trlx/models/modeling_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright 2022 CarperAI & The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# NOTE: This file contains a modified version of the `PreTrainedModelWrapper` class from
# HuggingFace's `trl` library. The original source code can be found here:
# https://github.com/lvwerra/trl/blob/78c13226bf8ea1ccd9b1c091f03a938098521f6c/trl/models/modeling_base.py

import inspect
import json
import os
from typing import Any, Dict, List, Optional, Union

import torch
import torch.nn as nn
import transformers
from huggingface_hub import hf_hub_download


class PreTrainedModelWrapper(nn.Module, transformers.utils.PushToHubMixin):
"""A wrapper around `transformers.PreTrainedModel`

Reference: @younesbelkada's `PreTrainedModelWrapper`
https://github.com/lvwerra/trl/blob/4f5c16fafde42d9aca971952bcdcc1f5a0a68cf0/trl/models/modeling_base.py#L2

Attributes:
_auto_model_parent_class (transformers.AutoModel): The `transformers.AutoModel`
type to base the wrapping behavior off of, e.g. `transformers.AutoModelForCausalLM`.
_supported_modules (List[str]): A list of attribute names for modules of
the underlying architecture model. This is used, for example, to save
and load any additional modules by manipulating the state dict.
_supported_args (List[str]): A list of arguments specific to the underlying
architecture to separate from arguments that are supported by the
parent `AutoModel` class. Any arguments that are not supported by the
underlying model will be passed to the parent `AutoModel` class.
"""

_auto_model_parent_class: transformers.AutoModel = None
_supported_modules: List[str] = None
# TODO (jon-tow): Supported args should come from a `PretrainedConfig` of the
# specific underlying type similar to how config instances can be used to instantiate
# `transformers.PreTrainedModel`s.
_supported_args: List[str] = None

def __init__(self, base_model: Optional[transformers.PreTrainedModel] = None, **kwargs):
super().__init__()
self.base_model = base_model
# cache `forward` args for general use (avoids incompatible args across architectures)
self.forward_kwargs = inspect.getfullargspec(self.base_model.forward).args

@classmethod
def _split_kwargs(cls, kwargs: Dict[str, Any]):
"""Separates the kwargs from the supported arguments within `supported_args`
and those that are not
"""
supported_kwargs = {}
unsupported_kwargs = {}
for key, value in kwargs.items():
if key in cls._supported_args:
supported_kwargs[key] = value
else:
unsupported_kwargs[key] = value
return supported_kwargs, unsupported_kwargs

@classmethod
def from_config(cls, config: transformers.PretrainedConfig, **kwargs):
"""Instantiate the pretrained pytorch model from a configuration.

Args:
config (transformers.PretrainedConfig): The configuration to use to
instantiate the base model.

NOTE: Loading a model from its configuration file does **not** load the
model weights. It only affects the model's configuration. Use
`~transformers.AutoModel.from_pretrained` to load the model weights.
"""
if kwargs is not None:
wrapped_model_kwargs, from_config_kwargs = cls._split_kwargs(kwargs)
else:
from_config_kwargs = {}
wrapped_model_kwargs = {}
base_model = cls._auto_model_parent_class.from_config(config, **from_config_kwargs)
model = cls(base_model, **wrapped_model_kwargs)
return model

@classmethod
def from_pretrained( # noqa: max-complexity
cls,
pretrained_model_name_or_path: Union[str, transformers.PreTrainedModel],
*model_args,
**kwargs,
):
"""Instantiate a pretrained pytorch model from a pretrained model configuration.
This method is a wrapper around `transformers.PreTrainedModel.from_pretrained`.
Please refer to the documentation of `transformers.PreTrainedModel.from_pretrained`
for more information.

Args:
pretrained_model_name_or_path (str or `transformers.PreTrainedModel`):
The identifier of the pretrained model to load or the pretrained model itself.
*model_args (sequence of positional arguments, *optional*):
All remaining positional arguments will be passed to the `_auto_model_parent_class`.
**kwargs (dict, *optional*):
Dictionary of keyword arguments to pass to both the underlying `_auto_model_parent_class`
call (e.g. `transformers.AutoModelForCausalLM.from_pretrained`) and the specific
instance of the wrapped model.

NOTE: You must pass in arguments specific to the wrapped model as keyword arguments.
"""
if kwargs is not None:
wrapped_model_kwargs, from_pretrained_kwargs = cls._split_kwargs(kwargs)
else:
from_pretrained_kwargs = {}
wrapped_model_kwargs = {}

if isinstance(pretrained_model_name_or_path, str):
# Load the base model using the `transformers` AutoClass (e.g. AutoModelForCausalLM)
base_model = cls._auto_model_parent_class.from_pretrained(
pretrained_model_name_or_path, *model_args, **from_pretrained_kwargs
)
elif isinstance(pretrained_model_name_or_path, transformers.PreTrainedModel):
base_model = pretrained_model_name_or_path
else:
raise ValueError(
f"Invalid type for `base_model_name_or_path`: {type(pretrained_model_name_or_path)}"
"Expected `str` or `transformers.PreTrainedModel`."
)

model = cls(base_model, **wrapped_model_kwargs)

if isinstance(pretrained_model_name_or_path, str):
filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json")
is_sharded = False

if not os.path.exists(filename):
try:
filename = hf_hub_download(pretrained_model_name_or_path, "pytorch_model.bin")
# Sharded
except Exception:
if os.path.exists(sharded_index_filename):
index_file_name = sharded_index_filename
else:
index_file_name = hf_hub_download(
pretrained_model_name_or_path,
"pytorch_model.bin.index.json",
)
with open(index_file_name, "r") as f:
index = json.load(f)
# Collect files containing weights from supported modules
files_to_download = set()
for k, v in index["weight_map"].items():
if any([module in k for module in cls._supported_modules]):
files_to_download.add(v)
is_sharded = True

if is_sharded:
# Merge each shard into a state dict
# TODO: Optimize this to avoid wasting RAM
state_dict = {}
for shard_file in files_to_download:
filename = os.path.join(pretrained_model_name_or_path, shard_file)
# Download if shard file doesn't exist locally
if not os.path.exists(filename):
filename = hf_hub_download(pretrained_model_name_or_path, shard_file)
state_dict.update(torch.load(filename, map_location="cpu"))
else:
state_dict = torch.load(filename, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path.state_dict()

model.post_init(state_dict=state_dict)
return model

def save_pretrained(self, *args, **kwargs):
"""Save the pretrained model to a directory. This method is a wrapper
around `transformers.PreTrainedModel.save_pretrained`. Please refer to
the documentation of `transformers.PreTrainedModel.save_pretrained` for
more information.

Args:
*args (`list`, *optional*):
Positional arguments passed along to the underlying model's
`save_pretrained` method.
**kwargs (`dict`, *optional*):
Keyword arguments passed along to the underlying model's
`save_pretrained` method.
"""
state_dict = kwargs.pop("state_dict", None)
if state_dict is None:
state_dict = self.state_dict()
kwargs["state_dict"] = state_dict

return self.base_model.save_pretrained(*args, **kwargs)

def state_dict(self, *args, **kwargs):
"""Return the state_dict of the pretrained model."""
raise NotImplementedError

def post_init(self, *args, **kwargs):
"""Post initialization method. This method is called after the model is
instantiated and loaded from a checkpoint. It can be used to perform
additional operations such as loading the state_dict.
"""
raise NotImplementedError

def get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]:
"""Filter out arguments not supported by the specific instance of
`base_model.transformer.forward`
"""
# FIXME: This is a hack to get around the fact that the `transformers`
# architectures we use don't have a consistent API for `forward` parameters.
return {k: v for k, v in kwargs.items() if k in self.forward_kwargs}
Loading