Skip to content

Commit 7331d63

Browse files
fix(ppo_gpt): prevent position_ids being None (#451)
* fix(ppo_gpt): prevent position_ids being None * fix(ppo_modeling): pop `position_ids` argument if not required * fix(ppo_modeling): add `device` argument for `OPTModelBranch` * fix(modeling_ppo): de-complement if-condition * fix(ppo_modeling): condition passing `device` in `OPTModelBranch` --------- Co-authored-by: reciprocated <[email protected]>
1 parent fa3e13e commit 7331d63

File tree

2 files changed

+38
-23
lines changed

2 files changed

+38
-23
lines changed

examples/hh/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Launch training of [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6B) on 7 GPUs
66
```sh
77
accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml ppo_hh.py
88
```
9-
Or if you want to train a smaller model or start from a supervised checkpoint, you can use one of the [configs](./configs)
9+
Or if you want to train a smaller model or start from a supervised checkpoint, you can use one of the [configs](../../configs)
1010
```sh
1111
CONFIG_NAME=125M accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml ppo_hh.py
1212
```

trlx/models/modeling_ppo.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class PPOConfig(MethodConfig):
106106
:param vf_coef: Value loss scale w.r.t policy loss
107107
:type vf_coef: float
108108
109-
:param gen_kwargs: Additioanl kwargs for the generation
109+
:param gen_kwargs: Additional kwargs for the generation
110110
:type gen_kwargs: Dict[str, Any]
111111
112112
:param gen_experience_kwargs: if this is not None, then the experience is generated using this
@@ -445,7 +445,7 @@ def forward( # noqa: max-complexity
445445
"""Reference:
446446
https://github.com/huggingface/transformers/blob/2411f0e465e761790879e605a4256f3d4afb7f82/src/transformers/models/gpt2/modeling_gpt2.py#L743 # noqa: E501
447447
"""
448-
batch_size = hidden_states.size()[0]
448+
batch_size, seq_length = hidden_states.shape[:2]
449449

450450
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
451451
output_hidden_states = (
@@ -457,7 +457,16 @@ def forward( # noqa: max-complexity
457457
device = hidden_states.device
458458

459459
if past_key_values is None:
460+
past_length = 0
460461
past_key_values = tuple([None] * len(self.decoder_blocks))
462+
else:
463+
past_length = past_key_values[0][0].size(-2)
464+
465+
if position_ids is None:
466+
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
467+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
468+
else:
469+
position_ids = position_ids.view(-1, seq_length)
461470

462471
if attention_mask is not None:
463472
if batch_size <= 0:
@@ -498,28 +507,27 @@ def forward( # noqa: max-complexity
498507
if output_hidden_states:
499508
all_hidden_states = all_hidden_states + (hidden_states,)
500509

510+
kwargs = dict(
511+
layer_past=layer_past,
512+
attention_mask=attention_mask,
513+
position_ids=position_ids,
514+
head_mask=head_mask[i],
515+
encoder_hidden_states=encoder_hidden_states,
516+
encoder_attention_mask=encoder_attention_mask,
517+
use_cache=use_cache,
518+
output_attentions=output_attentions,
519+
)
520+
501521
# Assumes we are never training the branch
502522
block_params = inspect.getfullargspec(block.forward).args
503-
if "encoder_hidden_states" in block_params:
504-
outputs = block(
505-
hidden_states,
506-
layer_past=layer_past,
507-
attention_mask=attention_mask,
508-
head_mask=head_mask[i],
509-
encoder_hidden_states=encoder_hidden_states,
510-
encoder_attention_mask=encoder_attention_mask,
511-
use_cache=use_cache,
512-
output_attentions=output_attentions,
513-
)
514-
else:
515-
outputs = block(
516-
hidden_states,
517-
layer_past=layer_past,
518-
attention_mask=attention_mask,
519-
head_mask=head_mask[i],
520-
use_cache=use_cache,
521-
output_attentions=output_attentions,
522-
)
523+
if "encoder_hidden_states" not in block_params:
524+
kwargs.pop("encoder_hidden_states")
525+
kwargs.pop("encoder_attention_mask")
526+
# Remove position_ids for GPT2Block
527+
if "position_ids" not in block_params:
528+
kwargs.pop("position_ids")
529+
530+
outputs = block(hidden_states, **kwargs)
523531

524532
hidden_states = outputs[0]
525533
if use_cache is True:
@@ -594,10 +602,17 @@ def forward( # noqa: max-complexity
594602
input_shape = hidden_states.size()[:-1]
595603
combined_attention_mask = None
596604
if input_shape[-1] > 1:
605+
# `modeling_opt._make_causal_mask` @ transformers==4.27.1 doesn't have the `device` argument
606+
if "device" in inspect.getfullargspec(modeling_opt._make_causal_mask).args:
607+
kwargs = dict(device=hidden_state.device)
608+
else:
609+
kwargs = {}
610+
597611
combined_attention_mask = modeling_opt._make_causal_mask(
598612
input_shape,
599613
hidden_states.dtype,
600614
past_key_values_length=past_key_values_length,
615+
**kwargs,
601616
).to(hidden_states.device)
602617

603618
if attention_mask is not None:

0 commit comments

Comments
 (0)