Skip to content

Commit 90c6f1d

Browse files
authored
fix gpt with paddle.matmul (PaddlePaddle#3483)
1 parent f43cfd0 commit 90c6f1d

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

paddlenlp/transformers/gpt/modeling.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,9 @@ def forward(self,
198198
q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
199199
cache)
200200
# scale dot product attention
201-
product = layers.matmul(x=q,
201+
product = paddle.matmul(x=q * (self.head_dim**-0.5),
202202
y=k,
203-
transpose_y=True,
204-
alpha=self.head_dim**-0.5)
203+
transpose_y=True)
205204

206205
if attn_mask is not None:
207206
product = product + attn_mask

0 commit comments

Comments
 (0)