Skip to content

Commit c78e320

Browse files
Error (also in original Facebook XLM) model, scaling only q matrix not qk.T dot product (qk.T/sqrt(dim_per_head))
As per Vaswani et al, 2017 p.4 - https://arxiv.org/pdf/1912.05372.pdf Is torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim_per_head) not q / math.sqrt(dim_per_head) This effectively scales queries only and not the queries-keys dot product as should be Mentioned in - original facebookresearch/XLM#357 - dependent original FlauBERT getalp/Flaubert@6d17688 - dependent Huggingface FlauBERT huggingface#21627
1 parent 22888d3 commit c78e320

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/transformers/models/xlm/modeling_xlm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def unshape(x):
176176
k, v = cache[self.layer_id]
177177
cache[self.layer_id] = (k, v)
178178

179-
q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head)
180179
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
180+
scores = scores / math.sqrt(dim_per_head) # (bs, n_heads, qlen, klen)
181181
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
182182
scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
183183

0 commit comments

Comments
 (0)