Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,22 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
rev: v4.1.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-case-conflict
- id: check-json
- id: check-symlinks
- id: check-yaml
- id: destroyed-symlinks
- id: end-of-file-fixer
exclude: docs/CNAME
- id: fix-byte-order-marker
- id: fix-encoding-pragma
args: [--remove]
- id: mixed-line-ending
args: [--fix=lf]
- id: requirements-txt-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
Expand All @@ -17,3 +28,7 @@ repos:
hooks:
- id: isort
name: isort (python)
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
70 changes: 35 additions & 35 deletions docs/make.bat
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
@ECHO OFF

pushd %~dp0

REM Command file for Sphinx documentation

if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build

if "%1" == "" goto help

%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)

%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end

:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%

:end
popd
6 changes: 3 additions & 3 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
sphinx==4.0.0
sphinx_rtd_theme
accelerate==0.12.0
datasets==2.4.0
deepspeed==0.7.3
einops==0.4.1
numpy==1.23.2
sphinx==4.0.0
sphinx_rtd_theme
torchtyping
tqdm==4.64.0
transformers==4.21.2
wandb==0.13.2
torchtyping
13 changes: 3 additions & 10 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import os
import sys

import sphinx_rtd_theme

sys.path.insert(0, os.path.abspath('../..'))


Expand All @@ -28,16 +30,7 @@
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.

import sphinx_rtd_theme

extensions = [
'sphinx_rtd_theme',
'sphinx.ext.todo',
'sphinx.ext.viewcode',
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.autosectionlabel'
]
extensions = ['sphinx_rtd_theme', 'sphinx.ext.todo', 'sphinx.ext.viewcode', 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.autosectionlabel']

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
Expand Down
1 change: 1 addition & 0 deletions examples/experiments/grounded_program_synthesis/lang.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa
import copy
import json
import random
Expand Down
5 changes: 3 additions & 2 deletions examples/randomwalks/randomwalks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def randexclude(rng: np.random.RandomState, n: int, exclude: int) -> int:
return x


def generate_random_walks(
def generate_random_walks( # noqa: max-complexity
n_nodes=21, max_length=10, n_walks=1000, p_edge=0.1, seed=1002, gpt2_tokenizer=False
):
rng = np.random.RandomState(seed)
Expand All @@ -30,6 +30,8 @@ def generate_random_walks(

goal = 0
sample_walks = []
delimiter = "|" if gpt2_tokenizer else ""

for _ in range(n_walks):
node = randexclude(rng, n_nodes, goal)
walk = [node]
Expand All @@ -43,7 +45,6 @@ def generate_random_walks(
# code each node by a letter
# for bpe tokenizer join them over | for a guaranteed split
walk = [node_to_char[ix] for ix in walk]
delimiter = "|" if gpt2_tokenizer else ""

sample_walks.append(delimiter.join(walk))

Expand Down
25 changes: 25 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,28 @@ dev =
exclude =
docs*
tests*

[flake8]
max-complexity = 10
max-line-length = 127
# flake8 error codes: https://flake8.pycqa.org/en/latest/user/error-codes.html
# pycodestyle codes: https://pycodestyle.pycqa.org/en/latest/intro.html#error-codes
# E203 # whitespace before ‘,’, ‘;’, or ‘:’
# E741 # do not use variables named ‘l’, ‘O’, or ‘I’
# F401 # module imported but unused
# F821 # undefined name name
# W503 # line break before binary operator
# W605 # invalid escape sequence ‘x’
ignore =
E203
E741
F821
W503
W605
per-file-ignores = __init__.py:F401,loading.py:F401
exclude =
.git
__pycache__
docs/source/conf.py
build
dist
2 changes: 1 addition & 1 deletion tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_lm_heads(self):
def test_frozen_head(self):
# Ensure that all parameters of the `hydra_model.frozen_head` are actually frozen
for parameter in TestHydraHead.hydra_model.frozen_head.parameters():
self.assertTrue(parameter.requires_grad == False)
self.assertTrue(parameter.requires_grad is False)

def test_forward(self):
with torch.no_grad():
Expand Down
5 changes: 3 additions & 2 deletions trlx/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import random
from dataclasses import dataclass
from typing import Any, Callable, Iterable, List
from typing import Any, Iterable

from torchtyping import TensorType

from . import configs


@dataclass
class GeneralElement:
Expand Down
15 changes: 10 additions & 5 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional, Set
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Set, Tuple

import yaml

Expand Down Expand Up @@ -56,7 +56,7 @@ class OptimizerConfig:
"""

name: str
kwargs: Dict[str, Any] = None
kwargs: Dict[str, Any] = field(default_factory=dict)

@classmethod
def from_dict(cls, config: Dict[str, Any]):
Expand All @@ -76,7 +76,7 @@ class SchedulerConfig:
"""

name: str
kwargs: Dict[str, Any] = None
kwargs: Dict[str, Any] = field(default_factory=dict)

@classmethod
def from_dict(cls, config: Dict[str, Any]):
Expand All @@ -100,6 +100,9 @@ class TrainConfig:
:param batch_size: Batch size for training
:type batch_size: int

:param trackers: Tuple of trackers to use for logging. Default: ("wandb",)
:type trackers: Tuple[str]

:param checkpoint_interval: Save model every checkpoint_interval steps
:type checkpoint_interval: int

Expand All @@ -123,7 +126,8 @@ class TrainConfig:
:param checkpoint_dir: Directory to save checkpoints
:type checkpoint_dir: str

:param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation. Only used by AcceleratePPOTrainer.
:param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation.
Only used by AcceleratePPOTrainer.
:type rollout_logging_dir: Optional[str]

:param seed: Random seed
Expand All @@ -148,6 +152,7 @@ class TrainConfig:
checkpoint_dir: str = "ckpts"
rollout_logging_dir: Optional[str] = None

trackers: Tuple[str] = ("wandb",)
seed: int = 1000

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion trlx/data/method_configs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys
from dataclasses import dataclass
from typing import Any, Callable, Dict, List
from typing import Any, Dict

# specifies a dictionary of method configs
_METHODS: Dict[str, Any] = {} # registry
Expand Down
18 changes: 13 additions & 5 deletions trlx/data/ppo_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,27 @@
@dataclass
class PPORLElement:
"""
:param query_tensor: The query tensor i.e. the prompt tokens. Should be a long tensor.
:param query_tensor: The query tensor i.e. the prompt tokens.
Should be a long tensor.
:type query_tensor: torch.Tensor

:param response_tensor: The response tensor i.e. the output tokens. Should be a long tensor.
:param response_tensor: The response tensor i.e. the output tokens.
Should be a long tensor.
:type response_tensor: torch.Tensor

:param logprobs: The log probabilities over all tokens in the vocabulary for each token generated from the policy network (i.e. the autoregressive model). Should be a float tensor of same size as tokens, with a dimension across the vocabulary.
:param logprobs: The log probabilities over all tokens in the vocabulary for
each token generated from the policy network
(i.e. the autoregressive model).
Should be a float tensor of same size as tokens,
with a dimension across the vocabulary.
:type logprobs: torch.Tensor

:param values: The values for each token generated from the value network or value head. Should be a float tensor of same size as tokens.
:param values: The values for each token generated from the value network or value head.
Should be a float tensor of same size as tokens.
:type values: torch.Tensor

:param rewards: The rewards for each token outputted in response. Should be a float tensor of same size as tokens.
:param rewards: The rewards for each token outputted in response.
Should be a float tensor of same size as tokens.
:type rewards: torch.Tensor
"""

Expand Down
12 changes: 7 additions & 5 deletions trlx/orchestrator/ppo_orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from time import time
from typing import Callable
from typing import Callable, Optional

import ray
import torch
Expand All @@ -16,15 +16,16 @@
@register_orchestrator
class PPOOrchestrator(Orchestrator):
"""
Orchestrator that prepares data for PPO training: transforms samples from `pipeline` into `PPOBatch` and pushes them into trainer's `store`
Orchestrator prepares data for PPO training.
Transforms samples from `pipeline` into `PPOBatch` and pushes them into trainer's `store`
"""

def __init__(
self,
trainer: BaseRLTrainer,
pipeline: BasePipeline,
reward_fn: Callable,
metric_fn: Callable = None,
metric_fn: Optional[Callable] = None,
chunk_size: int = 512,
):
self.pipeline = pipeline
Expand Down Expand Up @@ -54,9 +55,10 @@ def score(self, samples):
"""
return self.trainer.reward_fn(samples)

def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noqa:
"""
Takes `num_rollouts` prompts from `pipeline`, samples model, computes KL againts a reference model appends PPOElements to trainer's `store`
Takes `num_rollouts` prompts from `pipeline`, samples model and computes the
KL againts a reference model. It then appends PPOElements to trainer's `store`
"""
ppo_rl_elements = []
stats = {}
Expand Down
8 changes: 5 additions & 3 deletions trlx/pipeline/ppo_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
import os
import time
from typing import Iterable, Optional
from typing import Iterable

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torchtyping import TensorType

from trlx.data.ppo_types import PPORLBatch, PPORLElement
from trlx.pipeline import BaseRolloutStore
Expand All @@ -32,7 +31,10 @@ def export_history(self, location: str):
assert os.path.exists(location)

fpath = os.path.join(location, f"epoch-{str(time.time())}.json")
exp_to_dict = lambda exp: {k: v.cpu().tolist() for k, v in exp.__dict__.items()}

def exp_to_dict(exp):
{k: v.cpu().tolist() for k, v in exp.__dict__.items()}

data = [exp_to_dict(exp) for exp in self.history]
with open(fpath, "w") as f:
f.write(json.dumps(data, indent=2))
Expand Down
2 changes: 1 addition & 1 deletion trlx/ray_tune/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ray import tune


def get_param_space(config: dict):
def get_param_space(config: dict): # noqa: C901
"""Get the param space from the config file."""

def get_strategy(value):
Expand Down
Loading