Skip to content

Commit 70ca6c6

Browse files
Pad prompts to the right in T5 examples and add EOS token to seq2seq prompts (#422)
* Pad prompts to the right in T5 examples. Add EOS token to prompts for seq2seq models like T5 * style: satisfy black * fix(offline_pipeline): default `add_special_tokens` to `False` * style: satisfy flake --------- Co-authored-by: reciprocated <[email protected]>
1 parent ec75e99 commit 70ca6c6

File tree

7 files changed

+33
-9
lines changed

7 files changed

+33
-9
lines changed

examples/ilql_sentiments_t5.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def get_positive_score(scores):
4141
),
4242
tokenizer=TokenizerConfig(
4343
tokenizer_path="lvwerra/t5-imdb",
44+
padding_side="right",
4445
truncation_side="right",
4546
),
4647
optimizer=OptimizerConfig(

examples/ppo_sentiments_t5.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def get_positive_score(scores):
4343
),
4444
tokenizer=TokenizerConfig(
4545
tokenizer_path="lvwerra/t5-imdb",
46+
padding_side="right",
4647
truncation_side="right",
4748
),
4849
optimizer=OptimizerConfig(

examples/ppo_translation_t5.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
),
5555
tokenizer=TokenizerConfig(
5656
tokenizer_path="t5-large",
57+
padding_side="right",
5758
truncation_side="right",
5859
),
5960
optimizer=OptimizerConfig(

trlx/pipeline/offline_pipeline.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,16 @@ class PromptPipeline(BasePipeline):
117117
max_prompt_length (`int`): max length of the prompt, if exceeded the prompt will be truncated according to
118118
tokenizer's truncation setting.
119119
tokenizer (`transformers.PreTrainedTokenizer`): a tokenizer to tokenize prompts with.
120+
add_special_tokens (`bool`): whether to encode prompts with tokenizer's special tokens (passed directly
121+
into `tokenizer.encode`)
120122
"""
121123

122124
def __init__(
123-
self, prompts: Union[Dict[str, Any], List[str]], max_prompt_length: int, tokenizer: PreTrainedTokenizer
125+
self,
126+
prompts: Union[Dict[str, Any], List[str]],
127+
max_prompt_length: int,
128+
tokenizer: PreTrainedTokenizer,
129+
add_special_tokens: bool = False,
124130
):
125131
super().__init__()
126132

@@ -131,7 +137,7 @@ def __init__(
131137
metadata = [{}] * len(prompts)
132138

133139
model_inputs = tokenizer(
134-
prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=False
140+
prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=add_special_tokens
135141
)
136142

137143
prompts_tokens = model_inputs["input_ids"]

trlx/pipeline/ppo_pipeline.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ class PPORolloutStorage(BaseRolloutStore):
1515
Rollout storage for training PPO
1616
"""
1717

18-
def __init__(self, pad_token_id):
18+
def __init__(self, pad_token_id, padding_side):
1919
super().__init__()
2020

2121
self.pad_token_id = pad_token_id
22+
self.padding_side = padding_side
2223
self.history: Iterable[PPORLElement] = [None]
2324

2425
def push(self, exps: Iterable[PPORLElement]):
@@ -51,13 +52,23 @@ def create_loader(
5152
shuffle: bool,
5253
) -> DataLoader:
5354
def collate_fn(elems: Iterable[PPORLElement]):
54-
return PPORLBatch(
55+
if self.padding_side == "right":
56+
# Right padding of already right-padded queries
57+
query_tensors = pad_sequence(
58+
[elem.query_tensor for elem in elems],
59+
padding_value=self.pad_token_id,
60+
batch_first=True,
61+
)
62+
else:
5563
# Left padding of already left-padded queries
56-
pad_sequence(
64+
query_tensors = pad_sequence(
5765
[elem.query_tensor.flip(0) for elem in elems],
5866
padding_value=self.pad_token_id,
5967
batch_first=True,
60-
).flip(1),
68+
).flip(1)
69+
70+
return PPORLBatch(
71+
query_tensors,
6172
# Right pad the rest, to have a single horizontal query/response split
6273
pad_sequence(
6374
[elem.response_tensor for elem in elems],

trlx/trainer/accelerate_ppo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self, config: TRLConfig, **kwargs):
5454

5555
# Setup the rollout store
5656
# Rollouts contain the prompt & response, log probs, values and rewards - from each rollout
57-
self.store = PPORolloutStorage(self.tokenizer.pad_token_id)
57+
self.store = PPORolloutStorage(self.tokenizer.pad_token_id, self.tokenizer.padding_side)
5858

5959
# Create the rollout store dataloader (for batching up rollouts)
6060
# TODO (jon-tow): This is only used to satisfy to `accelerator.prepare` call constraint below - remove in future

trlx/trlx.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ def train( # noqa: C901
9494
if eval_prompts is None:
9595
eval_prompts = prompts[:batch_size]
9696

97-
pipeline = get_pipeline(config.train.pipeline)(prompts, max_prompt_length, trainer.tokenizer)
97+
pipeline = get_pipeline(config.train.pipeline)(
98+
prompts, max_prompt_length, trainer.tokenizer, add_special_tokens=config.model.model_arch_type == "seq2seq"
99+
)
98100
trainer.add_prompt_pipeline(pipeline)
99101

100102
if eval_prompts is None:
@@ -118,7 +120,9 @@ def train( # noqa: C901
118120
else:
119121
raise ValueError("Either `samples` or `reward_fn` should be given for training")
120122

121-
eval_pipeline = get_pipeline(config.train.pipeline)(eval_prompts, max_prompt_length, trainer.tokenizer)
123+
eval_pipeline = get_pipeline(config.train.pipeline)(
124+
eval_prompts, max_prompt_length, trainer.tokenizer, add_special_tokens=config.model.model_arch_type == "seq2seq"
125+
)
122126
trainer.add_eval_pipeline(eval_pipeline)
123127

124128
trainer.learn()

0 commit comments

Comments
 (0)