Skip to content

Commit f1f0d2f

Browse files
authored
Revert "[Kernel] Add cuda kernel for gpt_oss activation" (vllm-project#22948)
1 parent 81f4b96 commit f1f0d2f

File tree

8 files changed

+24
-150
lines changed

8 files changed

+24
-150
lines changed

csrc/activation_kernels.cu

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -128,45 +128,6 @@ __global__ void act_and_mul_kernel_with_param(
128128
}
129129
}
130130

131-
template <typename T>
132-
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
133-
float alpha, float limit) {
134-
// clamp gate: min=None, max=limit
135-
const float gate_f = (float)gate;
136-
const float clamped_gate = gate_f > limit ? limit : gate_f;
137-
138-
// clamp up: min=-limit, max=limit
139-
const float up_f = (float)up;
140-
const float clamped_up =
141-
up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
142-
143-
// glu = gate * sigmoid(gate * alpha)
144-
const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
145-
const float glu = clamped_gate * sigmoid_val;
146-
147-
// (up + 1) * glu
148-
return (T)((clamped_up + 1.0f) * glu);
149-
}
150-
151-
template <typename scalar_t,
152-
scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
153-
const float)>
154-
__global__ void swigluoai_and_mul_kernel(
155-
scalar_t* __restrict__ out, // [..., d]
156-
const scalar_t* __restrict__ input, // [..., 2, d]
157-
const int d, const float alpha, const float limit) {
158-
const int64_t token_idx = blockIdx.x;
159-
// TODO: Vectorize loads and stores.
160-
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
161-
// gate = x[..., ::2] (even indices)
162-
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
163-
// up = x[..., 1::2] (odd indices)
164-
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);
165-
166-
out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit);
167-
}
168-
}
169-
170131
} // namespace vllm
171132

172133
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
@@ -184,31 +145,11 @@ __global__ void swigluoai_and_mul_kernel(
184145
PARAM); \
185146
});
186147

187-
#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \
188-
int d = input.size(-1) / 2; \
189-
int64_t num_tokens = input.numel() / input.size(-1); \
190-
dim3 grid(num_tokens); \
191-
dim3 block(std::min(d, 1024)); \
192-
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
193-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
194-
VLLM_DISPATCH_FLOATING_TYPES( \
195-
input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
196-
vllm::swigluoai_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
197-
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
198-
input.data_ptr<scalar_t>(), d, ALPHA, \
199-
LIMIT); \
200-
});
201-
202148
void fatrelu_and_mul(torch::Tensor& out, // [..., d],
203149
torch::Tensor& input, // [..., 2 * d]
204150
double threshold) {
205151
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
206152
}
207-
void swigluoai_and_mul(torch::Tensor& out, // [..., d]
208-
torch::Tensor& input, // [..., 2 * d]
209-
double alpha, double limit) {
210-
LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit);
211-
}
212153
namespace vllm {
213154

214155
// Element-wise activation kernel template.

csrc/ops.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,6 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
138138

139139
void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
140140
double threshold);
141-
void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input,
142-
double alpha = 1.702, double limit = 7.0);
143141

144142
void gelu_new(torch::Tensor& out, torch::Tensor& input);
145143

csrc/torch_bindings.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
130130
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
131131
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);
132132

133-
ops.def(
134-
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha, float limit) "
135-
"-> ()");
136-
ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul);
137-
138133
// GELU implementation used in GPT-2.
139134
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
140135
ops.impl("gelu_new", torch::kCUDA, &gelu_new);

tests/kernels/core/test_activation.py

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
1212
GeluAndMul, MulAndSilu,
1313
NewGELU, QuickGELU,
14-
SiluAndMul, SwigluOAIAndMul)
14+
SiluAndMul)
1515
from vllm.platforms import current_platform
1616

1717
DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -25,15 +25,7 @@
2525

2626
@pytest.mark.parametrize(
2727
"activation",
28-
[
29-
"silu_and_mul",
30-
"mul_and_silu",
31-
"gelu",
32-
"gelu_tanh",
33-
"fatrelu",
34-
"swigluoai_and_mul",
35-
],
36-
)
28+
["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"])
3729
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
3830
@pytest.mark.parametrize("d", D)
3931
@pytest.mark.parametrize("dtype", DTYPES)
@@ -67,43 +59,18 @@ def test_act_and_mul(
6759
threshold = random.uniform(0, 1)
6860
layer = FatreluAndMul(threshold)
6961
fn = torch.ops._C.fatrelu_and_mul
70-
elif activation == "swigluoai_and_mul":
71-
layer = SwigluOAIAndMul()
72-
fn = torch.ops._C.swigluoai_and_mul
7362
out = layer(x)
7463
ref_out = layer.forward_native(x)
75-
if activation == "swigluoai_and_mul":
76-
77-
rtol = {
78-
#For fp16, change the relative tolerance from 1e-3 to 2e-3
79-
torch.float16:
80-
2e-3,
81-
torch.bfloat16:
82-
2e-2,
83-
torch.float:
84-
1.3e-6
85-
}
86-
87-
def _get_rtol(output) -> float:
88-
return rtol[output.dtype]
89-
90-
torch.testing.assert_close(out,
91-
ref_out,
92-
atol=get_default_atol(out),
93-
rtol=_get_rtol(out))
94-
else:
95-
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
96-
# equivalent to the native PyTorch implementations, so we can do exact
97-
# comparison.
98-
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
64+
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
65+
# equivalent to the native PyTorch implementations, so we can do exact
66+
# comparison.
67+
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
9968

10069
d = x.shape[-1] // 2
10170
output_shape = (x.shape[:-1] + (d, ))
10271
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
10372
if activation == "fatrelu":
10473
opcheck(fn, (out, x, threshold))
105-
elif activation == "swigluoai_and_mul":
106-
opcheck(fn, (out, x, layer.alpha, layer.limit))
10774
else:
10875
opcheck(fn, (out, x))
10976

vllm/model_executor/layers/activation.py

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -239,35 +239,6 @@ def extra_repr(self) -> str:
239239
return f'approximate={repr(self.approximate)}'
240240

241241

242-
@CustomOp.register("swigluoai_and_mul")
243-
class SwigluOAIAndMul(CustomOp):
244-
# https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
245-
def __init__(self, alpha: float = 1.702, limit: float = 7.0):
246-
super().__init__()
247-
self.alpha = alpha
248-
self.limit = limit
249-
250-
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
251-
"""PyTorch-native implementation equivalent to forward()."""
252-
253-
gate, up = x[..., ::2], x[..., 1::2]
254-
gate = gate.clamp(min=None, max=self.limit)
255-
up = up.clamp(min=-self.limit, max=self.limit)
256-
glu = gate * torch.sigmoid(gate * self.alpha)
257-
gated_output = (up + 1) * glu
258-
return gated_output
259-
260-
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
261-
d = x.shape[-1] // 2
262-
output_shape = (x.shape[:-1] + (d, ))
263-
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
264-
torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
265-
return out
266-
267-
def extra_repr(self) -> str:
268-
return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"
269-
270-
271242
@CustomOp.register("gelu_new")
272243
class NewGELU(CustomOp):
273244

@@ -359,7 +330,6 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
359330
return torch.square(F.relu(x))
360331

361332
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
362-
#TODO : implement cuda kenrels
363333
return self.forward_native(x)
364334

365335

@@ -436,14 +406,9 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
436406

437407

438408
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
439-
"gelu":
440-
lambda: GeluAndMul(),
441-
"silu":
442-
lambda: SiluAndMul(),
443-
"geglu":
444-
lambda: GeluAndMul(),
445-
"swigluoai_and_mul":
446-
lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
409+
"gelu": lambda: GeluAndMul(),
410+
"silu": lambda: SiluAndMul(),
411+
"geglu": lambda: GeluAndMul(),
447412
})
448413

449414

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,23 +1633,31 @@ def fused_experts_impl(
16331633
block_shape=block_shape,
16341634
B_bias=w1_bias)
16351635

1636+
# TODO fused kernel
1637+
def swiglu_oai(gate_up):
1638+
alpha = 1.702
1639+
limit = 7.0
1640+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
1641+
gate = gate.clamp(min=None, max=limit)
1642+
up = up.clamp(min=-limit, max=limit)
1643+
glu = gate * torch.sigmoid(gate * alpha)
1644+
gated_output = (up + 1) * glu
1645+
return gated_output
1646+
16361647
# Activation function with multiplication
16371648
if activation == "silu" and is_act_and_mul:
16381649
torch.ops._C.silu_and_mul(intermediate_cache2,
16391650
intermediate_cache1.view(-1, N))
16401651
elif activation == "gelu" and is_act_and_mul:
16411652
torch.ops._C.gelu_and_mul(intermediate_cache2,
16421653
intermediate_cache1.view(-1, N))
1643-
elif activation == "swigluoai" and is_act_and_mul:
1644-
# alpha = 1.702, limit = 7.0
1645-
torch.ops._C.swigluoai_and_mul(intermediate_cache2,
1646-
intermediate_cache1.view(-1, N))
16471654
# Activation function without multiplication
16481655
elif activation == "silu":
16491656
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
16501657
elif activation == "gelu":
16511658
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
1652-
1659+
elif activation == "swiglu_oai":
1660+
intermediate_cache2 = swiglu_oai(intermediate_cache1.view(-1, N))
16531661
else:
16541662
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
16551663
f"with is_act_and_mul={is_act_and_mul}.")

vllm/model_executor/layers/quantization/utils/mxfp4_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _can_support_mxfp4(use_grouped_topk: bool = False,
6868
return not (use_grouped_topk or topk_group or num_expert_group
6969
or expert_map or custom_routing_function
7070
or e_score_correction_bias or apply_router_weight_on_input
71-
or scoring_func != "softmax" or activation != "swigluoai"
71+
or scoring_func != "softmax" or activation != "swiglu_oai"
7272
or expert_load_view or logical_to_physical_map
7373
or logical_replica_count)
7474

vllm/model_executor/models/gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __init__(
159159
prefix=f"{prefix}.experts",
160160
apply_router_weight_on_input=False,
161161
has_bias=True,
162-
activation="swigluoai")
162+
activation="swiglu_oai")
163163

164164
def forward(self, x: torch.Tensor) -> torch.Tensor:
165165
t = self.norm(x)

0 commit comments

Comments
 (0)