Skip to content

Commit b7d3c3f

Browse files
authored
Some readme improvements (#44)
1 parent 5b00cd9 commit b7d3c3f

12 files changed

+69
-32
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@ repos:
1313
- repo: https://github.com/psf/black
1414
rev: 22.10.0
1515
hooks:
16-
- id: black
16+
- id: black
1717
files: ^(trlx|examples|unittests|setup.py)/

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ check_dirs := trlx/
66

77
style:
88
black $(check_dirs)
9-
isort $(check_dirs)
9+
isort $(check_dirs) # see pyproject.toml for isort config
1010
flake8 $(check_dirs) --ignore=$(IGNORE_PEP)
1111

1212
quality:
13-
isort --check-only $(check_dirs)
13+
isort --check-only $(check_dirs) # see pyproject.toml for isort config
1414
flake8 $(check_dirs)

README.md

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Adding a task for RLHF training depends on the desired training method and pre-e
2929
git clone https://github.com/CarperAI/trlx.git
3030
cd trlx
3131
pip install -e ".[dev]"
32+
pre-commit install # see .pre-commit-config.yaml
3233
```
3334

3435
## Example: How to add a task
@@ -46,35 +47,52 @@ accelerate config
4647
```python
4748
@register_datapipeline
4849
class PPOPipeline(BasePipeline):
49-
def __init__(self, tokenizer, config, prompt_dataset_path = None):
50+
def __init__(self, tokenizer, config, prompt_dataset_path=None):
5051
super().__init__()
5152

52-
ds = load_dataset('imdb', split='test')
53-
ds = ds.rename_columns({'text': 'review', 'label': 'sentiment'})
54-
ds = ds.filter(lambda x: len(x["review"])<500, batched=False)
55-
56-
self.tokens = [tokenizer(text,
57-
truncation = True,
58-
padding = 'max_length',
59-
max_length = config.train.input_size,
60-
return_tensors = "pt"
61-
)['input_ids'].long().flatten() for text in ds['review']]
53+
ds = load_dataset("imdb", split="test")
54+
ds = ds.rename_columns({"text": "review", "label": "sentiment"})
55+
ds = ds.filter(lambda x: len(x["review"]) < 500, batched=False)
56+
57+
self.tokens = [
58+
tokenizer(
59+
text,
60+
truncation=True,
61+
padding="max_length",
62+
max_length=config.train.input_size,
63+
return_tensors="pt",
64+
)["input_ids"]
65+
.long()
66+
.flatten()
67+
for text in ds["review"]
68+
]
6269
self.text = [tokenizer.decode(tokens.tolist()) for tokens in self.tokens]
6370

64-
def __getitem__(self, index : int) -> PromptElement:
71+
def __getitem__(self, index: int) -> PromptElement:
6572
return PromptElement(self.text[index], self.tokens[index])
6673

6774
def __len__(self) -> int:
6875
return len(self.text)
6976

70-
def create_loader(self, batch_size : int, shuffle : bool, prep_fn : Callable = None, num_workers : int = 0) -> DataLoader:
71-
#TODO(dahoas): Decide how to support varying sizes of prompts without having to tokenize on fly
72-
def collate_fn(elems : Iterable[PromptElement]) -> PromptElement:
77+
def create_loader(
78+
self,
79+
batch_size: int,
80+
shuffle: bool,
81+
prep_fn: Callable = None,
82+
num_workers: int = 0,
83+
) -> DataLoader:
84+
# TODO(dahoas): Decide how to support varying sizes of prompts without having to tokenize on fly
85+
def collate_fn(elems: Iterable[PromptElement]) -> PromptElement:
7386
return PromptBatch(
74-
[elem.text for elem in elems], torch.stack([elem.tokens for elem in elems]) # Assumes token tensors all same size
87+
[elem.text for elem in elems],
88+
torch.stack(
89+
[elem.tokens for elem in elems]
90+
), # Assumes token tensors all same size
7591
)
7692

77-
return DataLoader(self, batch_size, shuffle, collate_fn = collate_fn, num_workers = num_workers)
93+
return DataLoader(
94+
self, batch_size, shuffle, collate_fn=collate_fn, num_workers=num_workers
95+
)
7896
```
7997

8098
### Launch training

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
[build-system]
22
requires = ["setuptools"]
33
build-backend = "setuptools.build_meta"
4+
5+
[tool.isort]
6+
multi_line_output = 3

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ install_requires =
2424
[options.extras_require]
2525
dev =
2626
black
27+
isort
2728
flake8
2829
pre-commit
2930
pytest

trlx/model/accelerate_base_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
from transformers import AutoConfig, AutoTokenizer
1212

1313
from trlx.data import BatchElement, RLElement
14-
from trlx.data.accelerate_base_datatypes import AccelerateRLBatchElement, PromptBatch
14+
from trlx.data.accelerate_base_datatypes import (
15+
AccelerateRLBatchElement,
16+
PromptBatch
17+
)
1518
from trlx.data.configs import TRLConfig
1619
from trlx.model import BaseRLModel, register_model
1720
from trlx.pipeline.accelerate_base_pipeline import AccelerateRolloutStorage

trlx/model/accelerate_ilql_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111

1212
from trlx.model import BaseRLModel, register_model
1313
from trlx.model.nn.ilql_models import CausalLMWithValueHeads
14-
from trlx.pipeline.offline_pipeline import OfflinePipeline, OfflineRolloutStorage
14+
from trlx.pipeline.offline_pipeline import (
15+
OfflinePipeline,
16+
OfflineRolloutStorage
17+
)
1518
from trlx.utils import Clock, rampup_decay, safe_mkdir, topk_mask
1619

1720
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))

trlx/model/accelerate_ppo_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,24 @@
22
from abc import abstractmethod
33
from typing import Dict, Iterable, Tuple
44

5+
import numpy as np
56
import torch
67
import torch.nn.functional as F
8+
import wandb
79
from accelerate import Accelerator
810
from torch.utils.data import DataLoader
911
from torchtyping import TensorType
1012
from tqdm import tqdm
1113
from transformers import AutoConfig, AutoTokenizer
12-
import numpy as np
1314

14-
import wandb
1515
from trlx.data.accelerate_base_datatypes import PromptBatch
1616
from trlx.data.configs import TRLConfig
1717
from trlx.model import BaseRLModel, register_model
1818
from trlx.model.accelerate_base_model import AccelerateRLModel
19-
from trlx.model.nn.ppo_models import GPTHeadWithValueModel, GPTHydraHeadWithValueModel
19+
from trlx.model.nn.ppo_models import (
20+
GPTHeadWithValueModel,
21+
GPTHydraHeadWithValueModel
22+
)
2023
from trlx.pipeline.ppo_pipeline import PPORolloutStorage
2124
from trlx.utils import Clock, rampup_decay, safe_mkdir, topk_mask
2225
from trlx.utils.modeling import clip_by_value, logprobs_from_logits, whiten

trlx/model/nn/ppo_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import inspect
2+
from copy import deepcopy
13
from dataclasses import dataclass
24
from typing import Optional, Tuple, Union
35

@@ -15,11 +17,9 @@
1517
GPTJModel,
1618
PretrainedConfig,
1719
PreTrainedModel,
18-
top_k_top_p_filtering,
20+
top_k_top_p_filtering
1921
)
2022
from transformers.modeling_outputs import ModelOutput
21-
from copy import deepcopy
22-
import inspect
2323

2424

2525
# Cell

trlx/orchestrator/offline_orchestrator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
from trlx.model import BaseRLModel
44
from trlx.orchestrator import Orchestrator, register_orchestrator
5-
from trlx.pipeline.offline_pipeline import OfflinePipeline, OfflineRolloutStorage
5+
from trlx.pipeline.offline_pipeline import (
6+
OfflinePipeline,
7+
OfflineRolloutStorage
8+
)
69

710

811
@register_orchestrator

0 commit comments

Comments
 (0)