@@ -106,7 +106,7 @@ class PPOConfig(MethodConfig):
106
106
:param vf_coef: Value loss scale w.r.t policy loss
107
107
:type vf_coef: float
108
108
109
- :param gen_kwargs: Additioanl kwargs for the generation
109
+ :param gen_kwargs: Additional kwargs for the generation
110
110
:type gen_kwargs: Dict[str, Any]
111
111
112
112
:param gen_experience_kwargs: if this is not None, then the experience is generated using this
@@ -445,7 +445,7 @@ def forward( # noqa: max-complexity
445
445
"""Reference:
446
446
https://github.com/huggingface/transformers/blob/2411f0e465e761790879e605a4256f3d4afb7f82/src/transformers/models/gpt2/modeling_gpt2.py#L743 # noqa: E501
447
447
"""
448
- batch_size = hidden_states .size ()[ 0 ]
448
+ batch_size , seq_length = hidden_states .shape [: 2 ]
449
449
450
450
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
451
451
output_hidden_states = (
@@ -457,7 +457,16 @@ def forward( # noqa: max-complexity
457
457
device = hidden_states .device
458
458
459
459
if past_key_values is None :
460
+ past_length = 0
460
461
past_key_values = tuple ([None ] * len (self .decoder_blocks ))
462
+ else :
463
+ past_length = past_key_values [0 ][0 ].size (- 2 )
464
+
465
+ if position_ids is None :
466
+ position_ids = torch .arange (past_length , seq_length + past_length , dtype = torch .long , device = device )
467
+ position_ids = position_ids .unsqueeze (0 ).view (- 1 , seq_length )
468
+ else :
469
+ position_ids = position_ids .view (- 1 , seq_length )
461
470
462
471
if attention_mask is not None :
463
472
if batch_size <= 0 :
@@ -498,28 +507,27 @@ def forward( # noqa: max-complexity
498
507
if output_hidden_states :
499
508
all_hidden_states = all_hidden_states + (hidden_states ,)
500
509
510
+ kwargs = dict (
511
+ layer_past = layer_past ,
512
+ attention_mask = attention_mask ,
513
+ position_ids = position_ids ,
514
+ head_mask = head_mask [i ],
515
+ encoder_hidden_states = encoder_hidden_states ,
516
+ encoder_attention_mask = encoder_attention_mask ,
517
+ use_cache = use_cache ,
518
+ output_attentions = output_attentions ,
519
+ )
520
+
501
521
# Assumes we are never training the branch
502
522
block_params = inspect .getfullargspec (block .forward ).args
503
- if "encoder_hidden_states" in block_params :
504
- outputs = block (
505
- hidden_states ,
506
- layer_past = layer_past ,
507
- attention_mask = attention_mask ,
508
- head_mask = head_mask [i ],
509
- encoder_hidden_states = encoder_hidden_states ,
510
- encoder_attention_mask = encoder_attention_mask ,
511
- use_cache = use_cache ,
512
- output_attentions = output_attentions ,
513
- )
514
- else :
515
- outputs = block (
516
- hidden_states ,
517
- layer_past = layer_past ,
518
- attention_mask = attention_mask ,
519
- head_mask = head_mask [i ],
520
- use_cache = use_cache ,
521
- output_attentions = output_attentions ,
522
- )
523
+ if "encoder_hidden_states" not in block_params :
524
+ kwargs .pop ("encoder_hidden_states" )
525
+ kwargs .pop ("encoder_attention_mask" )
526
+ # Remove position_ids for GPT2Block
527
+ if "position_ids" not in block_params :
528
+ kwargs .pop ("position_ids" )
529
+
530
+ outputs = block (hidden_states , ** kwargs )
523
531
524
532
hidden_states = outputs [0 ]
525
533
if use_cache is True :
@@ -594,10 +602,17 @@ def forward( # noqa: max-complexity
594
602
input_shape = hidden_states .size ()[:- 1 ]
595
603
combined_attention_mask = None
596
604
if input_shape [- 1 ] > 1 :
605
+ # `modeling_opt._make_causal_mask` @ transformers==4.27.1 doesn't have the `device` argument
606
+ if "device" in inspect .getfullargspec (modeling_opt ._make_causal_mask ).args :
607
+ kwargs = dict (device = hidden_state .device )
608
+ else :
609
+ kwargs = {}
610
+
597
611
combined_attention_mask = modeling_opt ._make_causal_mask (
598
612
input_shape ,
599
613
hidden_states .dtype ,
600
614
past_key_values_length = past_key_values_length ,
615
+ ** kwargs ,
601
616
).to (hidden_states .device )
602
617
603
618
if attention_mask is not None :
0 commit comments