Skip to content

Commit 642d59c

Browse files
authored
[GLUON] Fix getting layout from a SwizzledSharedLayout (#8003)
`layoutToGluon` will seg fault when taking a `SwizzledSharedLayout` attribute. Found this issue while using `permute` on a shared memory with this attribute.
1 parent cfc0a9d commit 642d59c

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

python/src/gluon_ir.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,11 @@ py::object layoutToGluon(Attribute layout) {
191191
toStdVector(ctaLayout.getCTAOrder()));
192192
} else if (auto swizzled =
193193
dyn_cast<ttg::SwizzledSharedEncodingAttr>(layout)) {
194-
auto ctaLayout = nvmma.getCTALayout();
194+
auto ctaLayout = swizzled.getCTALayout();
195195
return layouts.SwizzledSharedLayout(
196196
swizzled.getVec(), swizzled.getPerPhase(), swizzled.getMaxPhase(),
197-
swizzled.getOrder(), toStdVector(ctaLayout.getCTAsPerCGA()),
197+
toStdVector(swizzled.getOrder()),
198+
toStdVector(ctaLayout.getCTAsPerCGA()),
198199
toStdVector(ctaLayout.getCTASplitNum()),
199200
toStdVector(ctaLayout.getCTAOrder()));
200201
} else if (auto autoEnc = dyn_cast<gluon::AutoEncodingAttr>(layout)) {

python/test/gluon/test_frontend.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,32 @@ def test_shared_memory_index(target):
317317
""")
318318

319319

320+
@gluon.jit
321+
def shared_memory_permute_kernel():
322+
layout: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0])
323+
smem = ttgl.allocate_shared_memory(ttgl.float16, [4, 128], layout)
324+
perm = smem.permute((1, 0))
325+
ttgl.static_assert(perm.layout == ttgl.SwizzledSharedLayout(1, 1, 1, [0, 1]))
326+
327+
328+
@pytest.mark.parametrize("target", ALL_TARGETS)
329+
def test_shared_memory_permute(target):
330+
mod = run_parser(shared_memory_permute_kernel, target=target)
331+
expecttest.assert_expected_inline(
332+
anonymize_ir(mod.str_nodebug()), """\
333+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
334+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
335+
#smem = #ttg.shared_memory
336+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
337+
tt.func public @shared_memory_permute_kernel() attributes {noinline = false} {
338+
%0 = ttg.local_alloc : () -> !ttg.memdesc<4x128xf16, #shared, #smem, mutable>
339+
%1 = ttg.memdesc_trans %0 {order = array<i32: 1, 0>} : !ttg.memdesc<4x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x4xf16, #shared1, #smem, mutable>
340+
tt.return
341+
}
342+
}
343+
""")
344+
345+
320346
@gluon.jit
321347
def shared_memory_cast_kernel():
322348
layout_a: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=False, element_bitwidth=8,

0 commit comments

Comments
 (0)