Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions dlinfer/vendor/maca/maca_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,28 @@
from dlinfer.vendor import vendor_ops_registry
from dlinfer.utils.registry import register_ops
from dlinfer.utils.type_annotation import Tensor, Optional, Sequence, Tuple

from .fused_moe import fused_experts
from .maca_extension import ops as maca_ext_ops
from mcoplib import mcoplib_ops
from mcoplib import op as op_origin

# Environment variable to choose between mcoplib_ops and maca_ext_ops
# Default: False (use maca_ext_ops)
# When MACA_LMDEPLOY_MCOPLIB_OPS=true, use mcoplib_ops
env_value = os.getenv("MACA_LMDEPLOY_MCOPLIB_OPS", "false")
USE_MCOPLIB_OPS = env_value.lower() in ("true", "1", "yes", "on")

# Select the ops library based on environment variable
if USE_MCOPLIB_OPS:
ops = mcoplib_ops
ops_name = "mcoplib_ops"
else:
ops = maca_ext_ops
ops_name = "maca_ext_ops"

# Print environment variable value and selected ops library
print(f"[DLInfer] MACA_LMDEPLOY_MCOPLIB_OPS environment variable: {env_value}")
print(f"[DLInfer] Using ops library: {ops_name}")

__all__ = [
"add_rms_norm",
Expand Down Expand Up @@ -58,7 +77,7 @@ def add_rms_norm(
weight: Tensor,
epsilon: float,
) -> Tuple[Tensor, Tensor]:
maca_ext_ops.fused_add_rms_norm(hidden_states, residual, weight, epsilon)
ops.fused_add_rms_norm(hidden_states, residual, weight, epsilon)
return hidden_states, residual


Expand All @@ -77,15 +96,11 @@ def apply_rotary_pos_emb(
key = key.flatten(-2, -1)
rot_dim = cos.size(-1)

maca_ext_ops.rotary_embedding(
position_ids_1d,
query,
key,
head_size,
cos.view(-1, rot_dim),
sin.view(-1, rot_dim),
True,
)
if USE_MCOPLIB_OPS:
ops.lmdeploy_rotary_embedding(query, key, cos, sin, position_ids_1d, rot_dim)
else:
ops.rotary_embedding(query, key, cos, sin, position_ids_1d, rot_dim)

return query, key


Expand Down Expand Up @@ -200,7 +215,7 @@ def fill_kv_cache(
quant_bits: int,
) -> Tuple[Tensor, Tensor]:
kv_indices = kv_indices.squeeze(-1)
maca_ext_ops.reshape_and_cache_new(
ops.reshape_and_cache_new(
key, value, key_cache, value_cache, kv_indices, "auto", 1.0, 1.0
)
return key_cache, value_cache
Expand Down Expand Up @@ -239,7 +254,7 @@ def paged_decode_attention(
value_cache = key_cache.transpose(2, 3).reshape(
-1, num_kv_heads, 576, block_size
)
maca_ext_ops.paged_attention_v1(
ops.paged_attention_v1(
output,
query,
key_cache,
Expand Down Expand Up @@ -349,7 +364,10 @@ def rms_norm(
hidden_states = hidden_states.to(torch.float32)
weight = weight.to(torch.float32)
output = torch.empty_like(hidden_states)
maca_ext_ops.rms_norm(output, hidden_states, weight, epsilon)
if USE_MCOPLIB_OPS:
op_origin.rms_norm(output, hidden_states, weight, epsilon,None, None, False)
else:
ops.rms_norm(output, hidden_states, weight, epsilon)

return output.to(input_dtype)

Expand All @@ -367,13 +385,19 @@ def moe_gating_topk_softmax(
topk_ids = torch.empty(N, topk, dtype=torch.int32, device=router_logits.device)

token_expert_indicies = torch.empty_like(topk_ids)

maca_ext_ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
router_logits.float(),
)
if USE_MCOPLIB_OPS:
op_origin.topk_softmax(
topk_weights,
topk_ids,
router_logits.float(),
)
else:
ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
router_logits.float(),
)

del token_expert_indicies # Not used. Will be used in the future.

Expand All @@ -390,7 +414,7 @@ def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
maca_ext_ops.silu_and_mul(out, x)
ops.silu_and_mul(out, x)
return out


Expand Down
Loading