27
27
from ...modeling_flash_attention_utils import _flash_attention_forward
28
28
from ...modeling_outputs import (
29
29
BaseModelOutputWithPast ,
30
- CausalLMOutputWithPast ,
30
+ MoeCausalLMOutputWithPast ,
31
+ MoeModelOutputWithPast ,
31
32
)
32
33
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
33
34
from ...modeling_utils import PreTrainedModel
47
48
_CONFIG_FOR_DOC = "GraniteMoeConfig"
48
49
49
50
51
+ # Copied from transformers.models.granite.modeling_granite._prepare_4d_causal_attention_mask_with_cache_position with Granite->GraniteMoe
52
+ def _prepare_4d_causal_attention_mask_with_cache_position (
53
+ attention_mask : torch .Tensor ,
54
+ sequence_length : int ,
55
+ target_length : int ,
56
+ dtype : torch .dtype ,
57
+ device : torch .device ,
58
+ min_dtype : float ,
59
+ cache_position : torch .Tensor ,
60
+ batch_size : int ,
61
+ ):
62
+ """
63
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
64
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
65
+
66
+ Args:
67
+ attention_mask (`torch.Tensor`):
68
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
69
+ sequence_length (`int`):
70
+ The sequence length being processed.
71
+ target_length (`int`):
72
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
73
+ dtype (`torch.dtype`):
74
+ The dtype to use for the 4D attention mask.
75
+ device (`torch.device`):
76
+ The device to plcae the 4D attention mask on.
77
+ min_dtype (`float`):
78
+ The minimum value representable with the dtype `dtype`.
79
+ cache_position (`torch.Tensor`):
80
+ Indices depicting the position of the input sequence tokens in the sequence.
81
+ batch_size (`torch.Tensor`):
82
+ Batch size.
83
+ """
84
+ if attention_mask is not None and attention_mask .dim () == 4 :
85
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
86
+ causal_mask = attention_mask
87
+ else :
88
+ causal_mask = torch .full ((sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = device )
89
+ if sequence_length != 1 :
90
+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
91
+ causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
92
+ causal_mask = causal_mask [None , None , :, :].expand (batch_size , 1 , - 1 , - 1 )
93
+ if attention_mask is not None :
94
+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
95
+ mask_length = attention_mask .shape [- 1 ]
96
+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :]
97
+ padding_mask = padding_mask == 0
98
+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
99
+ padding_mask , min_dtype
100
+ )
101
+
102
+ return causal_mask
103
+
104
+
50
105
# Copied from transformers.models.jetmoe.modeling_jetmoe.load_balancing_loss_func
51
106
def load_balancing_loss_func (
52
107
gate_logits : torch .Tensor , num_experts : torch .Tensor = None , top_k = 2 , attention_mask : Optional [torch .Tensor ] = None
@@ -708,7 +763,6 @@ def forward(
708
763
}
709
764
710
765
711
- # Copied from transformers.models.granite.modeling_granite.GraniteDecoderLayer with GRANITE->GRANITEMOE,Granite->GraniteMoe
712
766
class GraniteMoeDecoderLayer (nn .Module ):
713
767
def __init__ (self , config : GraniteMoeConfig , layer_idx : int ):
714
768
super ().__init__ ()
@@ -821,7 +875,6 @@ def forward(
821
875
"The bare GraniteMoe Model outputting raw hidden-states without any specific head on top." ,
822
876
GRANITEMOE_START_DOCSTRING ,
823
877
)
824
- # Copied from transformers.models.granite.modeling_granite.GranitePreTrainedModel with Granite->GraniteMoe
825
878
class GraniteMoePreTrainedModel (PreTrainedModel ):
826
879
config_class = GraniteMoeConfig
827
880
base_model_prefix = "model"
@@ -929,7 +982,6 @@ def _init_weights(self, module):
929
982
"The bare GraniteMoe Model outputting raw hidden-states without any specific head on top." ,
930
983
GRANITEMOE_START_DOCSTRING ,
931
984
)
932
- # Copied from transformers.models.granite.modeling_granite.GraniteModel with GRANITE->GRANITEMOE,Granite->GraniteMoe
933
985
class GraniteMoeModel (GraniteMoePreTrainedModel ):
934
986
"""
935
987
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GraniteMoeDecoderLayer`]
@@ -1180,11 +1232,9 @@ def _update_causal_mask(
1180
1232
return causal_mask
1181
1233
1182
1234
1183
- # Copied from transformers.models.granite.modeling_granite.GraniteForCausalLM with GRANITE->GRANITEMOE,Granite->GraniteMoe,granite->granitemoe
1184
1235
class GraniteMoeForCausalLM (GraniteMoePreTrainedModel ):
1185
1236
_tied_weights_keys = ["lm_head.weight" ]
1186
1237
1187
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->GraniteMoe
1188
1238
def __init__ (self , config : GraniteMoeConfig ):
1189
1239
super ().__init__ (config )
1190
1240
self .model = GraniteMoeModel (config )
@@ -1217,7 +1267,7 @@ def get_decoder(self):
1217
1267
return self .model
1218
1268
1219
1269
@add_start_docstrings_to_model_forward (GRANITEMOE_INPUTS_DOCSTRING )
1220
- @replace_return_docstrings (output_type = CausalLMOutputWithPast , config_class = _CONFIG_FOR_DOC )
1270
+ @replace_return_docstrings (output_type = MoeCausalLMOutputWithPast , config_class = _CONFIG_FOR_DOC )
1221
1271
def forward (
1222
1272
self ,
1223
1273
input_ids : torch .LongTensor = None ,
@@ -1232,7 +1282,7 @@ def forward(
1232
1282
output_router_logits : Optional [bool ] = None ,
1233
1283
return_dict : Optional [bool ] = None ,
1234
1284
cache_position : Optional [torch .LongTensor ] = None ,
1235
- ) -> Union [Tuple , CausalLMOutputWithPast ]:
1285
+ ) -> Union [Tuple , MoeCausalLMOutputWithPast ]:
1236
1286
r"""
1237
1287
Args:
1238
1288
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
0 commit comments