@@ -846,14 +846,17 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, hidden_st
846
846
def forward (
847
847
self ,
848
848
hidden_states : torch .Tensor ,
849
- output_shape : Tuple [int , int ],
850
- attention_mask : Optional [torch .Tensor ] = None ,
851
- past_key_values : Optional [List [torch .FloatTensor ]] = None ,
852
- inputs_embeds : Optional [torch .FloatTensor ] = None ,
849
+ output_shape : torch .Tensor ,
850
+ past_key_values : Optional [Tuple [Tuple [torch .Tensor ]]] = None ,
851
+ attention_mask : Optional [torch .FloatTensor ] = None ,
852
+ position_ids : Optional [torch .LongTensor ] = None ,
853
+ head_mask : Optional [torch .FloatTensor ] = None ,
854
+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
855
+ encoder_attention_mask : Optional [torch .FloatTensor ] = None ,
853
856
use_cache : Optional [bool ] = None ,
854
857
output_attentions : Optional [bool ] = None ,
855
858
output_hidden_states : Optional [bool ] = None ,
856
- return_dict : Optional [bool ] = None ,
859
+ return_dict : Optional [bool ] = False ,
857
860
) -> Union [Tuple , CausalLMOutputWithValue ]:
858
861
"""Reference:
859
862
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L491
@@ -868,9 +871,20 @@ def forward(
868
871
batch_size , seq_length = hidden_states .shape [:2 ]
869
872
seq_length_with_past = seq_length
870
873
past_key_values_length = 0
874
+
871
875
if past_key_values is not None :
872
876
past_key_values_length = past_key_values [0 ][0 ].shape [2 ]
873
877
seq_length_with_past = seq_length_with_past + past_key_values_length
878
+
879
+ if position_ids is None :
880
+ device = hidden_states .device if hidden_states is not None else encoder_hidden_states .device
881
+ position_ids = torch .arange (
882
+ past_key_values_length , seq_length + past_key_values_length , dtype = torch .long , device = device
883
+ )
884
+ position_ids = position_ids .unsqueeze (0 ).view (- 1 , seq_length )
885
+ else :
886
+ position_ids = position_ids .view (- 1 , seq_length ).long ()
887
+
874
888
# embed positions
875
889
if attention_mask is None :
876
890
attention_mask = torch .ones (
@@ -894,6 +908,7 @@ def forward(
894
908
layer_outputs = decoder_layer (
895
909
hidden_states ,
896
910
attention_mask = attention_mask ,
911
+ position_ids = position_ids ,
897
912
past_key_value = past_key_value ,
898
913
output_attentions = output_attentions ,
899
914
use_cache = use_cache ,
@@ -1253,6 +1268,7 @@ def hf_get_branch_class(
1253
1268
gpt_branch_supported_archs ,
1254
1269
opt_branch_supported_archs ,
1255
1270
bloom_branch_supported_archs ,
1271
+ llama_branch_supported_archs ,
1256
1272
],
1257
1273
[],
1258
1274
)
0 commit comments