Skip to content

Commit 1f629f2

Browse files
authored
[Bug Fixes] fix batch_norm default stream bug && apply igemm sp_conv pass to trt (#67443)
1 parent bca78dd commit 1f629f2

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,11 @@ const std::vector<std::string> kTRTSubgraphPasses({
108108
"trt_multihead_matmul_fuse_pass_v2", //
109109
"trt_multihead_matmul_fuse_pass_v3", //
110110
"multihead_matmul_roformer_fuse_pass", //
111-
"constant_folding_pass", //
111+
#if defined _WIN32 // Windows does not support sparse_conv3d_implicit_gemm
112+
#else
113+
"sparse_conv_optim_pass", //
114+
#endif
115+
"constant_folding_pass", //
112116
#ifdef PADDLE_WITH_TENSORRT
113117
#if !IS_TRT_VERSION_GE(8610)
114118
"trt_flash_multihead_matmul_fuse_pass", //

paddle/phi/kernels/gpu/batch_norm_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ void BatchNormKernel(const Context &ctx,
825825
auto *inv_var_ptr = inv_var.data<BatchNormParamType<T>>();
826826
const int threads = 512 > C ? C : 512;
827827
const int blocks = (C + 511) / 512;
828-
InverseVariance<T><<<blocks, threads>>>(
828+
InverseVariance<T><<<blocks, threads, 0, ctx.stream()>>>(
829829
est_var->template data<BatchNormParamType<T>>(),
830830
epsilon,
831831
C,

paddle/phi/kernels/sparse/gpu/conv_kernel_igemm.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ void Conv3dImplicitGemmKernel(const Context& dev_ctx,
186186
SparseCooTensor* out) {
187187
#ifdef PADDLE_WITH_CUDA
188188
PD_VISIT_BASE_INTEGRAL_TYPES(
189-
x.indices().dtype(), "Conv3dImplicitGemmGPUKernel", ([&] {
189+
x.indices().dtype(), "Conv3dImplicitGemmGPUKernel's indices", ([&] {
190190
// Conv3dImplicitGemmGPUKernel<T, data_t>(dev_ctx,
191191
Conv3dImplicitGemmGPUKernel<T, int64_t>(dev_ctx,
192192
x,

0 commit comments

Comments
 (0)