Skip to content

Commit 656f66c

Browse files
authored
fix(llama): add position_ids to LlamaModelBranch (#418)
1 parent 3021f5d commit 656f66c

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

trlx/models/modeling_ppo.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -846,14 +846,17 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, hidden_st
846846
def forward(
847847
self,
848848
hidden_states: torch.Tensor,
849-
output_shape: Tuple[int, int],
850-
attention_mask: Optional[torch.Tensor] = None,
851-
past_key_values: Optional[List[torch.FloatTensor]] = None,
852-
inputs_embeds: Optional[torch.FloatTensor] = None,
849+
output_shape: torch.Tensor,
850+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
851+
attention_mask: Optional[torch.FloatTensor] = None,
852+
position_ids: Optional[torch.LongTensor] = None,
853+
head_mask: Optional[torch.FloatTensor] = None,
854+
encoder_hidden_states: Optional[torch.Tensor] = None,
855+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
853856
use_cache: Optional[bool] = None,
854857
output_attentions: Optional[bool] = None,
855858
output_hidden_states: Optional[bool] = None,
856-
return_dict: Optional[bool] = None,
859+
return_dict: Optional[bool] = False,
857860
) -> Union[Tuple, CausalLMOutputWithValue]:
858861
"""Reference:
859862
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L491
@@ -868,9 +871,20 @@ def forward(
868871
batch_size, seq_length = hidden_states.shape[:2]
869872
seq_length_with_past = seq_length
870873
past_key_values_length = 0
874+
871875
if past_key_values is not None:
872876
past_key_values_length = past_key_values[0][0].shape[2]
873877
seq_length_with_past = seq_length_with_past + past_key_values_length
878+
879+
if position_ids is None:
880+
device = hidden_states.device if hidden_states is not None else encoder_hidden_states.device
881+
position_ids = torch.arange(
882+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
883+
)
884+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
885+
else:
886+
position_ids = position_ids.view(-1, seq_length).long()
887+
874888
# embed positions
875889
if attention_mask is None:
876890
attention_mask = torch.ones(
@@ -894,6 +908,7 @@ def forward(
894908
layer_outputs = decoder_layer(
895909
hidden_states,
896910
attention_mask=attention_mask,
911+
position_ids=position_ids,
897912
past_key_value=past_key_value,
898913
output_attentions=output_attentions,
899914
use_cache=use_cache,
@@ -1253,6 +1268,7 @@ def hf_get_branch_class(
12531268
gpt_branch_supported_archs,
12541269
opt_branch_supported_archs,
12551270
bloom_branch_supported_archs,
1271+
llama_branch_supported_archs,
12561272
],
12571273
[],
12581274
)

0 commit comments

Comments
 (0)