Skip to content

Pipeline failed while executing TritonGPUHoistTMEMAlloc #7837

@yyinter

Description

@yyinter

Describe the bug

UNREACHABLE executed at lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp:418!
module {
  tt.func public @bwd_kv_kernel(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg9: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg10: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg11: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: f32) attributes {noinline = false} {
    %c64_i64 = arith.constant 64 : i64
    %cst = arith.constant dense<0.000000e+00> : tensor<64xf32>
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x32xbf16>
    %cst_1 = arith.constant dense<0> : tensor<64x1xi64>
    %cst_2 = arith.constant dense<0> : tensor<64xi32>
    %cst_3 = arith.constant dense<0> : tensor<64xi64>
    %cst_4 = arith.constant dense<0.000000e+00> : tensor<32x64xbf16>
    %cst_5 = arith.constant dense<0> : tensor<1x64xi64>
    %c0_i64 = arith.constant 0 : i64
    %cst_6 = arith.constant dense<0> : tensor<64x64xi32>
    %cst_7 = arith.constant dense<1> : tensor<1x64xi32>
    %cst_8 = arith.constant dense<0.000000e+00> : tensor<64x32xf32>
    %cst_9 = arith.constant dense<1.44269502> : tensor<64x1xf32>
    %cst_10 = arith.constant dense<0.000000e+00> : tensor<64x64xf32>
    %c32_i64 = arith.constant 32 : i64
    %c32_i32 = arith.constant 32 : i32
    %cst_11 = arith.constant 1.44269502 : f32
    %c0_i32 = arith.constant 0 : i32
    %c64_i32 = arith.constant 64 : i32
    %c1_i32 = arith.constant 1 : i32
    %0 = tt.get_program_id x : i32
    %1 = tt.get_program_id y : i32
    %2 = tt.get_program_id z : i32
    %3 = arith.divsi %arg12, %arg13 : i32
    %4 = arith.divsi %1, %3 : i32
    %5 = tt.addptr %arg11, %2 : !tt.ptr<i32>, i32
    %6 = tt.load %5 : !tt.ptr<i32>
    %7 = tt.addptr %5, %c1_i32 : !tt.ptr<i32>, i32
    %8 = tt.load %7 : !tt.ptr<i32>
    %9 = arith.subi %8, %6 : i32
    %10 = arith.muli %0, %c64_i32 : i32
    %11 = arith.cmpi sge, %10, %9 : i32
    cf.cond_br %11, ^bb1, ^bb2
  ^bb1:  // pred: ^bb0
    tt.return
  ^bb2:  // pred: ^bb0
    %12 = tt.addptr %arg10, %2 : !tt.ptr<i32>, i32
    %13 = tt.load %12 : !tt.ptr<i32>
    %14 = tt.addptr %12, %c1_i32 : !tt.ptr<i32>, i32
    %15 = tt.load %14 : !tt.ptr<i32>
    %16 = arith.subi %15, %13 : i32
    %17 = arith.mulf %arg14, %cst_11 : f32
    %18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
    %19 = tt.splat %10 : i32 -> tensor<64xi32>
    %20 = arith.addi %19, %18 : tensor<64xi32>
    %21 = arith.extsi %13 : i32 to i64
    %22 = arith.extsi %6 : i32 to i64
    %23 = arith.extsi %arg12 : i32 to i64
    %24 = arith.muli %21, %23 : i64
    %25 = arith.muli %24, %c32_i64 : i64
    %26 = tt.addptr %arg0, %25 : !tt.ptr<bf16>, i64
    %27 = arith.muli %1, %c32_i32 : i32
    %28 = tt.addptr %26, %27 : !tt.ptr<bf16>, i32
    %29 = arith.muli %arg12, %c32_i32 : i32
    %30 = arith.extsi %16 : i32 to i64
    %31 = arith.extsi %29 : i32 to i64
    %32 = arith.extsi %arg13 : i32 to i64
    %33 = arith.muli %22, %32 : i64
    %34 = arith.muli %33, %c32_i64 : i64
    %35 = tt.addptr %arg1, %34 : !tt.ptr<bf16>, i64
    %36 = arith.muli %4, %c32_i32 : i32
    %37 = tt.addptr %35, %36 : !tt.ptr<bf16>, i32
    %38 = arith.muli %arg13, %c32_i32 : i32
    %39 = arith.extsi %9 : i32 to i64
    %40 = arith.extsi %38 : i32 to i64
    %41 = arith.extsi %10 : i32 to i64
    %42 = tt.addptr %arg2, %34 : !tt.ptr<bf16>, i64
    %43 = tt.addptr %42, %36 : !tt.ptr<bf16>, i32
    %44 = arith.muli %22, %23 : i64
    %45 = arith.muli %44, %c32_i64 : i64
    %46 = tt.addptr %arg3, %45 : !tt.ptr<bf16>, i64
    %47 = tt.addptr %46, %27 : !tt.ptr<bf16>, i32
    %48 = tt.addptr %arg4, %45 : !tt.ptr<bf16>, i64
    %49 = tt.addptr %48, %27 : !tt.ptr<bf16>, i32
    %50 = tt.addptr %arg5, %25 : !tt.ptr<bf16>, i64
    %51 = tt.addptr %50, %27 : !tt.ptr<bf16>, i32
    %52 = tt.addptr %arg6, %24 : !tt.ptr<f32>, i64
    %53 = tt.addptr %52, %1 : !tt.ptr<f32>, i32
    %54 = tt.addptr %arg7, %24 : !tt.ptr<f32>, i64
    %55 = tt.addptr %54, %1 : !tt.ptr<f32>, i32
    %56 = tt.addptr %arg8, %21 : !tt.ptr<i32>, i64
    %57 = tt.addptr %arg9, %22 : !tt.ptr<i32>, i64
    %58 = tt.splat %37 : !tt.ptr<bf16> -> tensor<32x64x!tt.ptr<bf16>>
    %59 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
    %60 = arith.extsi %59 : tensor<32xi32> to tensor<32xi64>
    %61 = tt.expand_dims %60 {axis = 1 : i32} : tensor<32xi64> -> tensor<32x1xi64>
    %62 = tt.broadcast %61 : tensor<32x1xi64> -> tensor<32x64xi64>
    %63 = tt.splat %41 : i64 -> tensor<64xi64>
    %64 = arith.extsi %18 : tensor<64xi32> to tensor<64xi64>
    %65 = arith.addi %63, %64 : tensor<64xi64>
    %66 = tt.expand_dims %65 {axis = 0 : i32} : tensor<64xi64> -> tensor<1x64xi64>
    %67 = tt.splat %40 : i64 -> tensor<1x64xi64>
    %68 = arith.muli %66, %67 : tensor<1x64xi64>
    %69 = tt.broadcast %68 : tensor<1x64xi64> -> tensor<32x64xi64>
    %70 = arith.addi %62, %69 : tensor<32x64xi64>
    %71 = tt.addptr %58, %70 : tensor<32x64x!tt.ptr<bf16>>, tensor<32x64xi64>
    %72 = arith.cmpi sge, %66, %cst_5 : tensor<1x64xi64>
    %73 = tt.splat %39 : i64 -> tensor<1x64xi64>
    %74 = arith.cmpi slt, %66, %73 : tensor<1x64xi64>
    %75 = arith.andi %72, %74 : tensor<1x64xi1>
    %76 = tt.broadcast %75 : tensor<1x64xi1> -> tensor<32x64xi1>
    %77 = tt.load %71, %76, %cst_4 : tensor<32x64x!tt.ptr<bf16>>
    %78 = tt.splat %43 : !tt.ptr<bf16> -> tensor<32x64x!tt.ptr<bf16>>
    %79 = tt.addptr %78, %70 : tensor<32x64x!tt.ptr<bf16>>, tensor<32x64xi64>
    %80 = tt.load %79, %76, %cst_4 : tensor<32x64x!tt.ptr<bf16>>
    %81 = tt.splat %57 : !tt.ptr<i32> -> tensor<64x!tt.ptr<i32>>
    %82 = tt.addptr %81, %65 : tensor<64x!tt.ptr<i32>>, tensor<64xi64>
    %83 = arith.cmpi sge, %65, %cst_3 : tensor<64xi64>
    %84 = tt.splat %39 : i64 -> tensor<64xi64>
    %85 = arith.cmpi slt, %65, %84 : tensor<64xi64>
    %86 = arith.andi %83, %85 : tensor<64xi1>
    %87 = tt.load %82, %86, %cst_2 : tensor<64x!tt.ptr<i32>>
    %88:7 = scf.for %arg15 = %c0_i32 to %16 step %c64_i32 iter_args(%arg16 = %c0_i64, %arg17 = %c0_i64, %arg18 = %c0_i64, %arg19 = %c0_i64, %arg20 = %c0_i64, %arg21 = %cst_8, %arg22 = %cst_8) -> (i64, i64, i64, i64, i64, tensor<64x32xf32>, tensor<64x32xf32>)  : i32 {
      %109 = tt.splat %56 : !tt.ptr<i32> -> tensor<64x!tt.ptr<i32>>
      %110 = tt.splat %arg20 : i64 -> tensor<64xi64>
      %111 = arith.addi %110, %64 : tensor<64xi64>
      %112 = tt.addptr %109, %111 : tensor<64x!tt.ptr<i32>>, tensor<64xi64>
      %113 = arith.cmpi sge, %111, %cst_3 : tensor<64xi64>
      %114 = tt.splat %30 : i64 -> tensor<64xi64>
      %115 = arith.cmpi slt, %111, %114 : tensor<64xi64>
      %116 = arith.andi %113, %115 : tensor<64xi1>
      %117 = tt.load %112, %116, %cst_2 : tensor<64x!tt.ptr<i32>>
      %118 = tt.expand_dims %117 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32>
      %119 = tt.expand_dims %20 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
      %120 = tt.broadcast %118 : tensor<64x1xi32> -> tensor<64x64xi32>
      %121 = tt.broadcast %119 : tensor<1x64xi32> -> tensor<64x64xi32>
      %122 = arith.subi %120, %121 : tensor<64x64xi32>
      %123 = arith.cmpi sgt, %122, %cst_6 : tensor<64x64xi32>
      %124 = tt.expand_dims %87 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32>
      %125 = arith.cmpi eq, %124, %cst_7 : tensor<1x64xi32>
      %126 = tt.broadcast %125 : tensor<1x64xi1> -> tensor<64x64xi1>
      %127 = arith.andi %123, %126 : tensor<64x64xi1>
      %128 = arith.extui %127 : tensor<64x64xi1> to tensor<64x64xi32>
      %129 = tt.reshape %128 allow_reorder : tensor<64x64xi32> -> tensor<4096xi32>
      %130 = "tt.reduce"(%129) <{axis = 0 : i32}> ({
      ^bb0(%arg23: i32, %arg24: i32):
        %138 = arith.addi %arg23, %arg24 : i32
        tt.reduce.return %138 : i32
      }) : (tensor<4096xi32>) -> i32
      %131 = arith.cmpi ne, %130, %c0_i32 : i32
      %132:2 = scf.if %131 -> (tensor<64x32xf32>, tensor<64x32xf32>) {
        %138 = tt.splat %28 : !tt.ptr<bf16> -> tensor<64x32x!tt.ptr<bf16>>
        %139 = tt.splat %arg16 : i64 -> tensor<64xi64>
        %140 = arith.addi %139, %64 : tensor<64xi64>
        %141 = tt.expand_dims %140 {axis = 1 : i32} : tensor<64xi64> -> tensor<64x1xi64>
        %142 = tt.splat %31 : i64 -> tensor<64x1xi64>
        %143 = arith.muli %141, %142 : tensor<64x1xi64>
        %144 = tt.broadcast %143 : tensor<64x1xi64> -> tensor<64x32xi64>
        %145 = tt.expand_dims %60 {axis = 0 : i32} : tensor<32xi64> -> tensor<1x32xi64>
        %146 = tt.broadcast %145 : tensor<1x32xi64> -> tensor<64x32xi64>
        %147 = arith.addi %144, %146 : tensor<64x32xi64>
        %148 = tt.addptr %138, %147 : tensor<64x32x!tt.ptr<bf16>>, tensor<64x32xi64>
        %149 = arith.cmpi sge, %141, %cst_1 : tensor<64x1xi64>
        %150 = tt.splat %30 : i64 -> tensor<64x1xi64>
        %151 = arith.cmpi slt, %141, %150 : tensor<64x1xi64>
        %152 = arith.andi %149, %151 : tensor<64x1xi1>
        %153 = tt.broadcast %152 : tensor<64x1xi1> -> tensor<64x32xi1>
        %154 = tt.load %148, %153, %cst_0 : tensor<64x32x!tt.ptr<bf16>>
        %155 = tt.dot %154, %77, %cst_10 : tensor<64x32xbf16> * tensor<32x64xbf16> -> tensor<64x64xf32>
        %156 = tt.splat %53 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
        %157 = tt.splat %arg18 : i64 -> tensor<64xi64>
        %158 = arith.addi %157, %64 : tensor<64xi64>
        %159 = tt.splat %23 : i64 -> tensor<64xi64>
        %160 = arith.muli %158, %159 : tensor<64xi64>
        %161 = tt.addptr %156, %160 : tensor<64x!tt.ptr<f32>>, tensor<64xi64>
        %162 = arith.cmpi sge, %158, %cst_3 : tensor<64xi64>
        %163 = arith.cmpi slt, %158, %114 : tensor<64xi64>
        %164 = arith.andi %162, %163 : tensor<64xi1>
        %165 = tt.load %161, %164, %cst : tensor<64x!tt.ptr<f32>>
        %166 = tt.splat %17 : f32 -> tensor<64x64xf32>
        %167 = arith.mulf %155, %166 : tensor<64x64xf32>
        %168 = tt.expand_dims %165 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32>
        %169 = arith.mulf %168, %cst_9 : tensor<64x1xf32>
        %170 = tt.broadcast %169 : tensor<64x1xf32> -> tensor<64x64xf32>
        %171 = arith.subf %167, %170 : tensor<64x64xf32>
        %172 = math.exp2 %171 : tensor<64x64xf32>
        %173 = arith.select %127, %172, %cst_10 : tensor<64x64xi1>, tensor<64x64xf32>
        %174 = tt.splat %51 : !tt.ptr<bf16> -> tensor<64x32x!tt.ptr<bf16>>
        %175 = tt.splat %arg17 : i64 -> tensor<64xi64>
        %176 = arith.addi %175, %64 : tensor<64xi64>
        %177 = tt.expand_dims %176 {axis = 1 : i32} : tensor<64xi64> -> tensor<64x1xi64>
        %178 = arith.muli %177, %142 : tensor<64x1xi64>
        %179 = tt.broadcast %178 : tensor<64x1xi64> -> tensor<64x32xi64>
        %180 = arith.addi %179, %146 : tensor<64x32xi64>
        %181 = tt.addptr %174, %180 : tensor<64x32x!tt.ptr<bf16>>, tensor<64x32xi64>
        %182 = arith.cmpi sge, %177, %cst_1 : tensor<64x1xi64>
        %183 = arith.cmpi slt, %177, %150 : tensor<64x1xi64>
        %184 = arith.andi %182, %183 : tensor<64x1xi1>
        %185 = tt.broadcast %184 : tensor<64x1xi1> -> tensor<64x32xi1>
        %186 = tt.load %181, %185, %cst_0 : tensor<64x32x!tt.ptr<bf16>>
        %187 = tt.trans %173 {order = array<i32: 1, 0>} : tensor<64x64xf32> -> tensor<64x64xf32>
        %188 = arith.truncf %187 : tensor<64x64xf32> to tensor<64x64xbf16>
        %189 = tt.dot %188, %186, %arg22 : tensor<64x64xbf16> * tensor<64x32xbf16> -> tensor<64x32xf32>
        %190 = tt.splat %55 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
        %191 = tt.splat %arg19 : i64 -> tensor<64xi64>
        %192 = arith.addi %191, %64 : tensor<64xi64>
        %193 = arith.muli %192, %159 : tensor<64xi64>
        %194 = tt.addptr %190, %193 : tensor<64x!tt.ptr<f32>>, tensor<64xi64>
        %195 = arith.cmpi sge, %192, %cst_3 : tensor<64xi64>
        %196 = arith.cmpi slt, %192, %114 : tensor<64xi64>
        %197 = arith.andi %195, %196 : tensor<64xi1>
        %198 = tt.load %194, %197, %cst : tensor<64x!tt.ptr<f32>>
        %199 = tt.dot %186, %80, %cst_10 : tensor<64x32xbf16> * tensor<32x64xbf16> -> tensor<64x64xf32>
        %200 = tt.expand_dims %198 {axis = 1 : i32} : tensor<64xf32> -> tensor<64x1xf32>
        %201 = tt.broadcast %200 : tensor<64x1xf32> -> tensor<64x64xf32>
        %202 = arith.subf %199, %201 : tensor<64x64xf32>
        %203 = arith.mulf %173, %202 : tensor<64x64xf32>
        %204 = arith.select %127, %203, %cst_10 : tensor<64x64xi1>, tensor<64x64xf32>
        %205 = tt.trans %204 {order = array<i32: 1, 0>} : tensor<64x64xf32> -> tensor<64x64xf32>
        %206 = arith.truncf %205 : tensor<64x64xf32> to tensor<64x64xbf16>
        %207 = tt.dot %206, %154, %arg21 : tensor<64x64xbf16> * tensor<64x32xbf16> -> tensor<64x32xf32>
        scf.yield %207, %189 : tensor<64x32xf32>, tensor<64x32xf32>
      } else {
        scf.yield %arg21, %arg22 : tensor<64x32xf32>, tensor<64x32xf32>
      }
      %133 = arith.addi %arg16, %c64_i64 : i64
      %134 = arith.addi %arg17, %c64_i64 : i64
      %135 = arith.addi %arg18, %c64_i64 : i64
      %136 = arith.addi %arg19, %c64_i64 : i64
      %137 = arith.addi %arg20, %c64_i64 : i64
      scf.yield %133, %134, %135, %136, %137, %132#0, %132#1 : i64, i64, i64, i64, i64, tensor<64x32xf32>, tensor<64x32xf32>
    }
    %89 = tt.splat %arg14 : f32 -> tensor<64x32xf32>
    %90 = arith.mulf %88#5, %89 : tensor<64x32xf32>
    %91 = arith.truncf %90 : tensor<64x32xf32> to tensor<64x32xbf16>
    %92 = tt.splat %47 : !tt.ptr<bf16> -> tensor<64x32x!tt.ptr<bf16>>
    %93 = tt.expand_dims %65 {axis = 1 : i32} : tensor<64xi64> -> tensor<64x1xi64>
    %94 = tt.splat %31 : i64 -> tensor<64x1xi64>
    %95 = arith.muli %93, %94 : tensor<64x1xi64>
    %96 = tt.broadcast %95 : tensor<64x1xi64> -> tensor<64x32xi64>
    %97 = tt.expand_dims %60 {axis = 0 : i32} : tensor<32xi64> -> tensor<1x32xi64>
    %98 = tt.broadcast %97 : tensor<1x32xi64> -> tensor<64x32xi64>
    %99 = arith.addi %96, %98 : tensor<64x32xi64>
    %100 = tt.addptr %92, %99 : tensor<64x32x!tt.ptr<bf16>>, tensor<64x32xi64>
    %101 = arith.cmpi sge, %93, %cst_1 : tensor<64x1xi64>
    %102 = tt.splat %39 : i64 -> tensor<64x1xi64>
    %103 = arith.cmpi slt, %93, %102 : tensor<64x1xi64>
    %104 = arith.andi %101, %103 : tensor<64x1xi1>
    %105 = tt.broadcast %104 : tensor<64x1xi1> -> tensor<64x32xi1>
    tt.store %100, %91, %105 : tensor<64x32x!tt.ptr<bf16>>
    %106 = arith.truncf %88#6 : tensor<64x32xf32> to tensor<64x32xbf16>
    %107 = tt.splat %49 : !tt.ptr<bf16> -> tensor<64x32x!tt.ptr<bf16>>
    %108 = tt.addptr %107, %99 : tensor<64x32x!tt.ptr<bf16>>, tensor<64x32xi64>
    tt.store %108, %106, %105 : tensor<64x32x!tt.ptr<bf16>>
    tt.return
  }
}

{-#
  external_resources: {
    mlir_reproducer: {
      pipeline: "builtin.module(convert-triton-to-tritongpu{enable-source-remat=false num-ctas=1 num-warps=4 target=cuda:100 threads-per-warp=32}, tritongpu-coalesce, tritongpu-F32DotTC, triton-nvidia-gpu-plan-cta, tritongpu-remove-layout-conversions, tritongpu-optimize-thread-locality, tritongpu-accelerate-matmul, tritongpu-remove-layout-conversions, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, triton-nvidia-optimize-descriptor-encoding, triton-loop-aware-cse, tritongpu-fuse-nested-loops, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-licm, tritongpu-optimize-accumulator-init, tritongpu-hoist-tmem-alloc, tritongpu-promote-lhs-to-tmem, tritongpu-assign-latencies{num-stages=1}, tritongpu-schedule-loops, tritongpu-automatic-warp-specialization{num-stages=1}, tritongpu-pipeline{dump-intermediate-steps=false num-stages=1}, tritongpu-combine-tensor-select-and-if, triton-nvidia-gpu-remove-tmem-tokens, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-loop-aware-cse, tritongpu-prefetch, tritongpu-optimize-dot-operands{hoist-layout-conversion=true}, tritongpu-coalesce-async-copy, triton-nvidia-optimize-tmem-layouts, tritongpu-remove-layout-conversions, triton-nvidia-interleave-tmem, tritongpu-reduce-data-duplication, tritongpu-reorder-instructions, triton-loop-aware-cse, symbol-dce, triton-nvidia-tma-lowering, triton-nvidia-gpu-fence-insertion{compute-capability=100}, sccp, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true})",
      disable_threading: false,
      verify_each: true
    }
  }
#-}
pipeline failed while executing [`TritonGPUHoistTMEMAlloc` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`

Environment details

Triton: 3.4.0
GPU: b200

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions