@@ -142,9 +142,6 @@ class ModernBertConfig(PretrainedConfig):
142142 the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
143143 shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
144144 be faster in some scenarios.
145- repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
146- When True, ModernBertForMaskedLM keep track of the logits' gradient when repadding for output. This only
147- applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
148145
149146 Examples:
150147
@@ -200,7 +197,6 @@ def __init__(
200197 sparse_prediction = False ,
201198 sparse_pred_ignore_index = - 100 ,
202199 reference_compile = None ,
203- repad_logits_with_grad = False ,
204200 ** kwargs ,
205201 ):
206202 super ().__init__ (
@@ -240,7 +236,6 @@ def __init__(
240236 self .sparse_prediction = sparse_prediction
241237 self .sparse_pred_ignore_index = sparse_pred_ignore_index
242238 self .reference_compile = reference_compile
243- self .repad_logits_with_grad = repad_logits_with_grad
244239
245240 if self .classifier_pooling not in ["cls" , "mean" ]:
246241 raise ValueError (
@@ -1262,7 +1257,7 @@ def forward(
12621257 loss = self .loss_function (logits , labels , vocab_size = self .config .vocab_size )
12631258
12641259 if self .config ._attn_implementation == "flash_attention_2" :
1265- with nullcontext () if self .config . repad_logits_with_grad or labels is None else torch .no_grad ():
1260+ with nullcontext () if self .training or labels is None else torch .no_grad ():
12661261 logits = _pad_modernbert_output (inputs = logits , indices = indices , batch = batch_size , seqlen = seq_len )
12671262
12681263 if not return_dict :
0 commit comments