File tree Expand file tree Collapse file tree 2 files changed +17
-0
lines changed
torch/csrc/jit/codegen/cuda/test Expand file tree Collapse file tree 2 files changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -2976,6 +2976,7 @@ TEST_F(NVFuserTest, FusionConv2D_CUDA) {
29762976TEST_F (NVFuserTest, FusionConv2DNoPadding_CUDA) {
29772977 Fusion fusion;
29782978 FusionGuard fg (&fusion);
2979+ ContextCudnnTF32Disabled disabling_tf32_cudnn;
29792980
29802981 // Input: [C, H, W]
29812982 auto inp = makeSymbolicTensor (3 );
Original file line number Diff line number Diff line change 88#include < torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
99#include < torch/csrc/jit/codegen/cuda/transform_replay.h>
1010
11+ #include < ATen/Context.h>
1112#include < ATen/cuda/CUDAContext.h>
1213#include < c10/cuda/CUDACachingAllocator.h>
1314#include < torch/torch.h>
@@ -340,6 +341,21 @@ struct TransformPropagatorWithCheck : public TransformPropagator {
340341
341342} // namespace
342343
344+ class ContextCudnnTF32Disabled {
345+ public:
346+ ContextCudnnTF32Disabled () {
347+ flag_ = at::globalContext ().allowTF32CuDNN ();
348+ at::globalContext ().setAllowTF32CuDNN (false );
349+ }
350+
351+ ~ContextCudnnTF32Disabled () {
352+ at::globalContext ().setAllowTF32CuDNN (flag_);
353+ }
354+
355+ private:
356+ bool flag_;
357+ };
358+
343359// Fixture class must be uniquely identified, i.e., can't be in an
344360// anonymous namespace
345361class NVFuserTest : public ::testing::Test {
You can’t perform that action at this time.
0 commit comments