Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,25 +163,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
"Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
"-> Tensor");
"-> Tensor",
{at::Tag::needs_fixed_stride_order});
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);

// Decompression method for AQLM.
ops.def(
"aqlm_dequant(Tensor codes, Tensor codebooks, "
"int[] codebook_partition_sizes) -> Tensor");
"int[] codebook_partition_sizes) -> Tensor",
{at::Tag::needs_fixed_stride_order});
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);

// Quantized GEMM for AWQ.
ops.def(
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, SymInt split_k_iters) -> Tensor");
"Tensor _zeros, SymInt split_k_iters) -> Tensor",
{at::Tag::needs_fixed_stride_order});
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);

// Dequantization for AWQ.
ops.def(
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor");
"Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor",
{at::Tag::needs_fixed_stride_order});
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);

// Note about marlin kernel 'workspace' arguments:
Expand All @@ -202,15 +206,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
"Tensor");
"Tensor",
{at::Tag::needs_fixed_stride_order});
// conditionally compiled so impl in source file

// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
ops.def(
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
"Tensor b_scales, Tensor workspace, "
"int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor");
"SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor",
{at::Tag::needs_fixed_stride_order});
// conditionally compiled so impl in source file

// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
Expand All @@ -236,7 +242,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor? channel_scales,"
" Tensor? token_scales,"
" str? schedule"
") -> Tensor");
") -> Tensor",
{at::Tag::needs_fixed_stride_order});
ops.def(
"machete_prepack_B("
" Tensor B,"
Expand All @@ -255,7 +262,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
{at::Tag::needs_fixed_stride_order});
// conditionally compiled so impl registration is in source file

// gptq_marlin repack from GPTQ.
Expand Down Expand Up @@ -291,30 +299,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
"SymInt size_k) -> Tensor");
"SymInt size_k) -> Tensor",
{at::Tag::needs_fixed_stride_order});
// conditionally compiled so impl registration is in source file

// marlin_qqq_gemm for QQQ.
ops.def(
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
"Tensor! workspace, SymInt size_m, SymInt size_n, "
"SymInt size_k) -> Tensor");
"SymInt size_k) -> Tensor",
{at::Tag::needs_fixed_stride_order});
// conditionally compiled so impl registration is in source file

// CUTLASS nvfp4 block scaled GEMM
ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
" Tensor alpha) -> ()",
{at::Tag::needs_fixed_stride_order});
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);

// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
" Tensor b_scales, Tensor? bias) -> ()",
{at::Tag::needs_fixed_stride_order});
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);

// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
Expand All @@ -323,7 +335,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
" Tensor? azp, Tensor? bias) -> ()",
{at::Tag::needs_fixed_stride_order});
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);

// Check if cutlass scaled_mm is supported for CUDA devices of the given
Expand Down Expand Up @@ -351,7 +364,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
" Tensor bt_nzs,"
" Tensor bt_meta, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
" Tensor b_scales, Tensor? bias) -> ()",
{at::Tag::needs_fixed_stride_order});
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);

// CUTLASS sparse matrix compressor
Expand Down Expand Up @@ -407,7 +421,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
"-> Tensor");
"-> Tensor",
{at::Tag::needs_fixed_stride_order});
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);

// Post processing for GPTQ.
Expand Down