Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,13 @@ def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
)
else:
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))


def attention_softmax(attn_weights: torch.Tensor, training: bool):
if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training:
import xe_addons
xe_addons.attn_softmax_inplaced(attn_weights)
else:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(attn_weights.dtype)
return attn_weights
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
from typing import Optional
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.common import attention_softmax
from transformers import AutoProcessor
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor

Expand Down Expand Up @@ -47,8 +48,7 @@ def siglip_attention_forward(
if attention_mask is not None:
attn_weights = attn_weights + attention_mask

import xe_addons
xe_addons.attn_softmax_inplaced(attn_weights)
attn_weights = attention_softmax(attn_weights, self.training)

attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import warnings
from torch import nn

from ipex_llm.transformers.models.common import attention_softmax
from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
Expand Down Expand Up @@ -184,8 +185,7 @@ def attention_forward(
if attention_mask is not None:
attn_weights = attn_weights + attention_mask

import xe_addons
xe_addons.attn_softmax_inplaced(attn_weights)
attn_weights = attention_softmax(attn_weights, self.training)

attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
Expand Down