Skip to content

Commit de02e91

Browse files
committed
cleanup
1 parent 569a8c1 commit de02e91

File tree

1 file changed

+58
-8
lines changed

1 file changed

+58
-8
lines changed

src/transformers/models/granitemoe/modeling_granitemoe.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from ...modeling_flash_attention_utils import _flash_attention_forward
2828
from ...modeling_outputs import (
2929
BaseModelOutputWithPast,
30-
CausalLMOutputWithPast,
30+
MoeCausalLMOutputWithPast,
31+
MoeModelOutputWithPast,
3132
)
3233
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
3334
from ...modeling_utils import PreTrainedModel
@@ -47,6 +48,60 @@
4748
_CONFIG_FOR_DOC = "GraniteMoeConfig"
4849

4950

51+
# Copied from transformers.models.granite.modeling_granite._prepare_4d_causal_attention_mask_with_cache_position with Granite->GraniteMoe
52+
def _prepare_4d_causal_attention_mask_with_cache_position(
53+
attention_mask: torch.Tensor,
54+
sequence_length: int,
55+
target_length: int,
56+
dtype: torch.dtype,
57+
device: torch.device,
58+
min_dtype: float,
59+
cache_position: torch.Tensor,
60+
batch_size: int,
61+
):
62+
"""
63+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
64+
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
65+
66+
Args:
67+
attention_mask (`torch.Tensor`):
68+
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
69+
sequence_length (`int`):
70+
The sequence length being processed.
71+
target_length (`int`):
72+
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
73+
dtype (`torch.dtype`):
74+
The dtype to use for the 4D attention mask.
75+
device (`torch.device`):
76+
The device to plcae the 4D attention mask on.
77+
min_dtype (`float`):
78+
The minimum value representable with the dtype `dtype`.
79+
cache_position (`torch.Tensor`):
80+
Indices depicting the position of the input sequence tokens in the sequence.
81+
batch_size (`torch.Tensor`):
82+
Batch size.
83+
"""
84+
if attention_mask is not None and attention_mask.dim() == 4:
85+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
86+
causal_mask = attention_mask
87+
else:
88+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
89+
if sequence_length != 1:
90+
causal_mask = torch.triu(causal_mask, diagonal=1)
91+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
92+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
93+
if attention_mask is not None:
94+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
95+
mask_length = attention_mask.shape[-1]
96+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
97+
padding_mask = padding_mask == 0
98+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
99+
padding_mask, min_dtype
100+
)
101+
102+
return causal_mask
103+
104+
50105
# Copied from transformers.models.jetmoe.modeling_jetmoe.load_balancing_loss_func
51106
def load_balancing_loss_func(
52107
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
@@ -708,7 +763,6 @@ def forward(
708763
}
709764

710765

711-
# Copied from transformers.models.granite.modeling_granite.GraniteDecoderLayer with GRANITE->GRANITEMOE,Granite->GraniteMoe
712766
class GraniteMoeDecoderLayer(nn.Module):
713767
def __init__(self, config: GraniteMoeConfig, layer_idx: int):
714768
super().__init__()
@@ -821,7 +875,6 @@ def forward(
821875
"The bare GraniteMoe Model outputting raw hidden-states without any specific head on top.",
822876
GRANITEMOE_START_DOCSTRING,
823877
)
824-
# Copied from transformers.models.granite.modeling_granite.GranitePreTrainedModel with Granite->GraniteMoe
825878
class GraniteMoePreTrainedModel(PreTrainedModel):
826879
config_class = GraniteMoeConfig
827880
base_model_prefix = "model"
@@ -929,7 +982,6 @@ def _init_weights(self, module):
929982
"The bare GraniteMoe Model outputting raw hidden-states without any specific head on top.",
930983
GRANITEMOE_START_DOCSTRING,
931984
)
932-
# Copied from transformers.models.granite.modeling_granite.GraniteModel with GRANITE->GRANITEMOE,Granite->GraniteMoe
933985
class GraniteMoeModel(GraniteMoePreTrainedModel):
934986
"""
935987
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GraniteMoeDecoderLayer`]
@@ -1180,11 +1232,9 @@ def _update_causal_mask(
11801232
return causal_mask
11811233

11821234

1183-
# Copied from transformers.models.granite.modeling_granite.GraniteForCausalLM with GRANITE->GRANITEMOE,Granite->GraniteMoe,granite->granitemoe
11841235
class GraniteMoeForCausalLM(GraniteMoePreTrainedModel):
11851236
_tied_weights_keys = ["lm_head.weight"]
11861237

1187-
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->GraniteMoe
11881238
def __init__(self, config: GraniteMoeConfig):
11891239
super().__init__(config)
11901240
self.model = GraniteMoeModel(config)
@@ -1217,7 +1267,7 @@ def get_decoder(self):
12171267
return self.model
12181268

12191269
@add_start_docstrings_to_model_forward(GRANITEMOE_INPUTS_DOCSTRING)
1220-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1270+
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
12211271
def forward(
12221272
self,
12231273
input_ids: torch.LongTensor = None,
@@ -1232,7 +1282,7 @@ def forward(
12321282
output_router_logits: Optional[bool] = None,
12331283
return_dict: Optional[bool] = None,
12341284
cache_position: Optional[torch.LongTensor] = None,
1235-
) -> Union[Tuple, CausalLMOutputWithPast]:
1285+
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
12361286
r"""
12371287
Args:
12381288
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):

0 commit comments

Comments
 (0)