Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ struct GemmFpAIntB {
static constexpr bool compile_needed =
platform::is_same<KernelArch, arch::Sm75>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 910)
static constexpr bool compile_needed =
platform::is_same<KernelArch, arch::Sm80>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ struct GemmFpAIntBSplitK {
gemm();
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
gemm();
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 910)
gemm();
#else
CUTLASS_NOT_IMPLEMENTED();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ void dispatch_gemm_to_cutlass(const T* A,
// fpA_intB. We also only instantiate configs here where threadblockShapeM ==
// warpShapeM since those usually perform the best for mixed type gemms.
switch (gemm_config.tile_config) {
#if defined(USE_FPAINTB_GEMM_WITH_SM80) || defined(USE_FPAINTB_GEMM_WITH_SM86)
#if defined(USE_FPAINTB_GEMM_WITH_SM80) || \
defined(USE_FPAINTB_GEMM_WITH_SM86) || defined(USE_FPAINTB_GEMM_WITH_SM90)
case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64:
dispatch_gemm_config<T,
WeightType,
Expand Down Expand Up @@ -259,7 +260,8 @@ void dispatch_gemm_to_cutlass(const T* A,
stream,
occupancy);
break;
#if defined(USE_FPAINTB_GEMM_WITH_SM80) || defined(USE_FPAINTB_GEMM_WITH_SM86)
#if defined(USE_FPAINTB_GEMM_WITH_SM80) || \
defined(USE_FPAINTB_GEMM_WITH_SM86) || defined(USE_FPAINTB_GEMM_WITH_SM90)
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64:
dispatch_gemm_config<T,
WeightType,
Expand Down Expand Up @@ -519,8 +521,8 @@ void CutlassFpAIntBGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag,
"[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for "
"CUTLASS mixed type GEMM");
#endif
} else if (sm_ >= 80 && sm_ < 90) {
#if defined(USE_FPAINTB_GEMM_WITH_SM80)
} else if (sm_ >= 80 && sm_ < 91) {
#if defined(USE_FPAINTB_GEMM_WITH_SM80) || defined(USE_FPAINTB_GEMM_WITH_SM90)
dispatch_gemm_to_cutlass<T,
WeightType,
cutlass::arch::Sm80,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ struct dispatch_stages<T,
}
};

#if defined(USE_FPAINTB_GEMM_WITH_SM80)
#if defined(USE_FPAINTB_GEMM_WITH_SM80) || defined(USE_FPAINTB_GEMM_WITH_SM90)
template <typename T,
typename WeightType,
typename EpilogueTag,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@

"""

DefaultArch = [70, 75, 80]
DefaultArch = [70, 75, 80, 90]
epilogue_tags = ["bias", "biasFtGelu", "biasReLU", "noBias"]

WeightTypes = ["uint8_t", "cutlass::uint4b_t"]
Expand Down Expand Up @@ -110,13 +110,14 @@
"cutlass::gemm::GemmShape<32, 32, 64>",
"cutlass::gemm::GemmShape<64, 64, 64>",
]
StagesList = {70: [2], 75: [2], 80: [2, 3, 4, 5]}
StagesList = {70: [2], 75: [2], 80: [2, 3, 4, 5], 90: [2, 3, 4, 5]}

ElementTypes = {"fp16": "half", "bf16": "__nv_bfloat16"}
Archs = {
70: "cutlass::arch::Sm70",
75: "cutlass::arch::Sm75",
80: "cutlass::arch::Sm80",
90: "cutlass::arch::Sm80",
}
EpilogueTags = {
"bias": "EpilogueOpBias",
Expand Down Expand Up @@ -152,6 +153,8 @@ def find_arch_range(archs):
compile_archs.append(75)
elif arch >= 80 and arch < 90:
compile_archs.append(80)
elif arch >= 90 and arch < 91:
compile_archs.append(90)
compile_archs = list(set(compile_archs))
compile_archs.sort()
return compile_archs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,10 @@ struct MoeFCGemm {
static constexpr bool compile_needed =
platform::is_same<KernelArch, arch::Sm80>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDA_ARCH__ < 910)
static constexpr bool compile_needed =
platform::is_same<KernelArch, arch::Sm80>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
cudaStream_t stream,
int* kernel_occupancy = nullptr) {
if (gemm_config.split_k_style != SplitKStyle::NO_SPLIT_K) {
throw std::runtime_error("[MoeGemm] Grouped gemm does not support split-k");
PADDLE_FATAL("[MoeGemm] Grouped gemm does not support split-k");
}

#ifdef PADDLE_CUDA_BF16
Expand Down Expand Up @@ -171,7 +171,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
int occupancy = std::min(2, GemmGrouped::maximum_active_blocks());

if (occupancy == 0) {
throw std::runtime_error(
PADDLE_FATAL(
"[MoE Runner] GPU lacks the shared memory resources to run "
"GroupedGEMM kernel");
}
Expand Down Expand Up @@ -199,23 +199,23 @@ void generic_moe_gemm_kernelLauncher(const T* A,
if (can_implement != cutlass::Status::kSuccess) {
std::string err_msg = "MoEFC kernel will fail for params. Error: " +
std::string(cutlassGetStatusString(can_implement));
throw std::runtime_error("[MoE Runner] " + err_msg);
PADDLE_FATAL("[MoE Runner] " + err_msg);
}

auto init_status = gemm.initialize(args);
if (init_status != cutlass::Status::kSuccess) {
std::string err_msg =
"Failed to initialize cutlass variable batched gemm. Error: " +
std::string(cutlassGetStatusString(init_status));
throw std::runtime_error("[MoE Runner] " + err_msg);
PADDLE_FATAL("[MoE Runner] " + err_msg);
}

auto run_status = gemm.run(stream);
if (run_status != cutlass::Status::kSuccess) {
std::string err_msg =
"Failed to run cutlass variable batched gemm. Error: " +
std::string(cutlassGetStatusString(run_status));
throw std::runtime_error("[MoE Runner] " + err_msg);
PADDLE_FATAL("[MoE Runner] " + err_msg);
}
}

Expand Down Expand Up @@ -245,7 +245,7 @@ struct dispatch_stages {
std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " +
std::to_string(arch::kMinComputeCapability) +
" with stages set to " + std::to_string(Stages);
throw std::runtime_error("[dispatch_stages::dispatch] " + err_msg);
PADDLE_FATAL("[dispatch_stages::dispatch] " + err_msg);
}
};

Expand Down Expand Up @@ -457,7 +457,7 @@ void dispatch_gemm_config(const T* A,
default:
std::string err_msg = "dispatch_gemm_config does not support stages " +
std::to_string(gemm_config.stages);
throw std::runtime_error("[MoE][dispatch_gemm_config] " + err_msg);
PADDLE_FATAL("[MoE][dispatch_gemm_config] " + err_msg);
break;
}
}
Expand Down Expand Up @@ -572,16 +572,15 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
occupancy);
break;
case CutlassTileConfig::Undefined:
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
PADDLE_FATAL("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
break;
case CutlassTileConfig::ChooseWithHeuristic:
throw std::runtime_error(
PADDLE_FATAL(
"[dispatch_moe_gemm_to_cutlass] gemm config should have "
"already been set by heuristic.");
break;
default:
throw std::runtime_error(
PADDLE_FATAL(
"[dispatch_moe_gemm_to_cutlass] Config is invalid for same "
"type MoE tensorop GEMM.");
break;
Expand Down Expand Up @@ -762,16 +761,15 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
occupancy);
break;
case CutlassTileConfig::Undefined:
throw std::runtime_error(
"[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
PADDLE_FATAL("[dispatch_moe_gemm_to_cutlass] gemm config undefined.");
break;
case CutlassTileConfig::ChooseWithHeuristic:
throw std::runtime_error(
PADDLE_FATAL(
"[dispatch_moe_gemm_to_cutlass] gemm config should have "
"already been set by heuristic.");
break;
default:
throw std::runtime_error(
PADDLE_FATAL(
"[dispatch_moe_gemm_to_cutlass] Config is invalid for "
"mixed type tensorop GEMM.");
break;
Expand Down Expand Up @@ -823,17 +821,17 @@ void dispatch_moe_gemm_to_cutlass(const T* A,
occupancy);
break;
case CutlassTileConfig::Undefined:
throw std::runtime_error(
PADDLE_FATAL(
"[dispatch_moe_gemm_to_cutlass][SIMT] gemm config "
"undefined.");
break;
case CutlassTileConfig::ChooseWithHeuristic:
throw std::runtime_error(
PADDLE_FATAL(
"[dispatch_moe_gemm_to_cutlass][SIMT] gemm config should "
"have already been set by heuristic.");
break;
default:
throw std::runtime_error(
PADDLE_FATAL(
"[dispatch_moe_gemm_to_cutlass][SIMT] Unsupported config "
"for float MoE gemm.");
break;
Expand Down Expand Up @@ -903,7 +901,7 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
multi_processor_count_,
stream,
occupancy);
} else if (sm_ >= 80 && sm_ < 90) {
} else if (sm_ >= 80 && sm_ < 91) {
dispatch_moe_gemm_to_cutlass<T,
WeightType,
cutlass::arch::Sm80,
Expand All @@ -923,8 +921,7 @@ void MoeGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
stream,
occupancy);
} else {
throw std::runtime_error(
"[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM");
PADDLE_FATAL("[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM");
}
}

Expand Down