Skip to content

Commit 44b2316

Browse files
authored
[qwen2-vl] fix vision attention scaling (#39043)
scale lost its `-` when refactoring
1 parent ae15715 commit 44b2316

File tree

4 files changed

+4
-6
lines changed

4 files changed

+4
-6
lines changed

src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,7 @@ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None:
925925
self.k = nn.Linear(self.dim, self.dim, bias=True)
926926
self.v = nn.Linear(self.dim, self.dim, bias=True)
927927
self.proj = nn.Linear(self.dim, self.dim)
928-
self.scaling = math.sqrt(self.head_dim)
928+
self.scaling = self.head_dim**-0.5
929929
self.num_key_value_groups = 1 # needed for eager attention
930930
self.config = config
931931

src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1903,7 +1903,7 @@ def __init__(self, config: Qwen2_5OmniVisionEncoderConfig = None) -> None:
19031903
self.k = nn.Linear(self.dim, self.dim, bias=True)
19041904
self.v = nn.Linear(self.dim, self.dim, bias=True)
19051905
self.proj = nn.Linear(self.dim, self.dim)
1906-
self.scaling = math.sqrt(self.head_dim)
1906+
self.scaling = self.head_dim**-0.5
19071907
self.num_key_value_groups = 1 # needed for eager attention
19081908
self.config = config
19091909

src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
# See the License for the specific language governing permissions and
2525
# limitations under the License.
2626

27-
import math
2827
from dataclasses import dataclass
2928
from typing import Any, Callable, Optional, Union
3029

@@ -205,7 +204,7 @@ def __init__(self, config: Qwen2_5_VLVisionConfig) -> None:
205204
self.num_key_value_groups = 1 # needed for eager attention
206205
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
207206
self.proj = nn.Linear(self.dim, self.dim)
208-
self.scaling = math.sqrt(self.head_dim)
207+
self.scaling = self.head_dim**-0.5
209208
self.config = config
210209

211210
def forward(

src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
# limitations under the License.
2020
"""PyTorch Qwen2-VL model."""
2121

22-
import math
2322
from dataclasses import dataclass
2423
from typing import Any, Callable, Optional, Union
2524

@@ -323,7 +322,7 @@ def __init__(self, config: Qwen2VLVisionConfig) -> None:
323322
self.num_key_value_groups = 1 # needed for eager attention
324323
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
325324
self.proj = nn.Linear(self.dim, self.dim)
326-
self.scaling = math.sqrt(self.head_dim)
325+
self.scaling = self.head_dim**-0.5
327326
self.config = config
328327

329328
def forward(

0 commit comments

Comments
 (0)