@@ -128,45 +128,6 @@ __global__ void act_and_mul_kernel_with_param(
128
128
}
129
129
}
130
130
131
- template <typename T>
132
- __device__ __forceinline__ T swigluoai_and_mul (const T& gate, const T& up,
133
- float alpha, float limit) {
134
- // clamp gate: min=None, max=limit
135
- const float gate_f = (float )gate;
136
- const float clamped_gate = gate_f > limit ? limit : gate_f;
137
-
138
- // clamp up: min=-limit, max=limit
139
- const float up_f = (float )up;
140
- const float clamped_up =
141
- up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
142
-
143
- // glu = gate * sigmoid(gate * alpha)
144
- const float sigmoid_val = 1 .0f / (1 .0f + expf (-clamped_gate * alpha));
145
- const float glu = clamped_gate * sigmoid_val;
146
-
147
- // (up + 1) * glu
148
- return (T)((clamped_up + 1 .0f ) * glu);
149
- }
150
-
151
- template <typename scalar_t ,
152
- scalar_t (*ACT_FN)(const scalar_t &, const scalar_t &, const float ,
153
- const float )>
154
- __global__ void swigluoai_and_mul_kernel (
155
- scalar_t * __restrict__ out, // [..., d]
156
- const scalar_t * __restrict__ input, // [..., 2, d]
157
- const int d, const float alpha, const float limit) {
158
- const int64_t token_idx = blockIdx .x ;
159
- // TODO: Vectorize loads and stores.
160
- for (int64_t idx = threadIdx .x ; idx < d; idx += blockDim .x ) {
161
- // gate = x[..., ::2] (even indices)
162
- const scalar_t gate = VLLM_LDG (&input[token_idx * 2 * d + 2 * idx]);
163
- // up = x[..., 1::2] (odd indices)
164
- const scalar_t up = VLLM_LDG (&input[token_idx * 2 * d + 2 * idx + 1 ]);
165
-
166
- out[token_idx * d + idx] = ACT_FN (gate, up, alpha, limit);
167
- }
168
- }
169
-
170
131
} // namespace vllm
171
132
172
133
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM (KERNEL, PARAM ) \
@@ -184,31 +145,11 @@ __global__ void swigluoai_and_mul_kernel(
184
145
PARAM); \
185
146
});
186
147
187
- #define LAUNCH_SIGLUOAI_AND_MUL (KERNEL, ALPHA, LIMIT ) \
188
- int d = input.size(-1 ) / 2 ; \
189
- int64_t num_tokens = input.numel() / input.size(-1 ); \
190
- dim3 grid (num_tokens); \
191
- dim3 block (std::min(d, 1024 )); \
192
- const at::cuda::OptionalCUDAGuard device_guard (device_of(input)); \
193
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
194
- VLLM_DISPATCH_FLOATING_TYPES ( \
195
- input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
196
- vllm::swigluoai_and_mul_kernel<scalar_t , KERNEL<scalar_t >> \
197
- <<<grid, block, 0 , stream>>> (out.data_ptr <scalar_t >(), \
198
- input.data_ptr <scalar_t >(), d, ALPHA, \
199
- LIMIT); \
200
- });
201
-
202
148
void fatrelu_and_mul (torch::Tensor& out, // [..., d],
203
149
torch::Tensor& input, // [..., 2 * d]
204
150
double threshold) {
205
151
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM (vllm::fatrelu_kernel, threshold);
206
152
}
207
- void swigluoai_and_mul (torch::Tensor& out, // [..., d]
208
- torch::Tensor& input, // [..., 2 * d]
209
- double alpha, double limit) {
210
- LAUNCH_SIGLUOAI_AND_MUL (vllm::swigluoai_and_mul, alpha, limit);
211
- }
212
153
namespace vllm {
213
154
214
155
// Element-wise activation kernel template.
0 commit comments