Skip to content

Commit 7117a7e

Browse files
authored
patching nvfuser conv cudnn test numerics mismatch (#2048)
Tests failed on upstream, not yet in our devel branch. Disabling TF32 in the test, which creates numerical issue when validating outputs.
1 parent 65af1a4 commit 7117a7e

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2976,6 +2976,7 @@ TEST_F(NVFuserTest, FusionConv2D_CUDA) {
29762976
TEST_F(NVFuserTest, FusionConv2DNoPadding_CUDA) {
29772977
Fusion fusion;
29782978
FusionGuard fg(&fusion);
2979+
ContextCudnnTF32Disabled disabling_tf32_cudnn;
29792980

29802981
// Input: [C, H, W]
29812982
auto inp = makeSymbolicTensor(3);

torch/csrc/jit/codegen/cuda/test/test_utils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
99
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>
1010

11+
#include <ATen/Context.h>
1112
#include <ATen/cuda/CUDAContext.h>
1213
#include <c10/cuda/CUDACachingAllocator.h>
1314
#include <torch/torch.h>
@@ -340,6 +341,21 @@ struct TransformPropagatorWithCheck : public TransformPropagator {
340341

341342
} // namespace
342343

344+
class ContextCudnnTF32Disabled {
345+
public:
346+
ContextCudnnTF32Disabled() {
347+
flag_ = at::globalContext().allowTF32CuDNN();
348+
at::globalContext().setAllowTF32CuDNN(false);
349+
}
350+
351+
~ContextCudnnTF32Disabled() {
352+
at::globalContext().setAllowTF32CuDNN(flag_);
353+
}
354+
355+
private:
356+
bool flag_;
357+
};
358+
343359
// Fixture class must be uniquely identified, i.e., can't be in an
344360
// anonymous namespace
345361
class NVFuserTest : public ::testing::Test {

0 commit comments

Comments
 (0)