Skip to content

Commit cd2fc0d

Browse files
committed
[CIR][CIRGen] Fix intrinsic type signedness using frontend AST information
1 parent 3d04a3d commit cd2fc0d

File tree

2 files changed

+121
-9
lines changed

2 files changed

+121
-9
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -532,19 +532,56 @@ decodeFixedType(ArrayRef<llvm::Intrinsic::IITDescriptor> &infos,
532532
}
533533

534534
// llvm::Intrinsics accepts only LLVMContext. We need to reimplement it here.
535+
// Takes AST information to infer correct signedness when
536+
// IIT returns signed but AST shows unsigned in Function Decl (indicating
537+
// intrinsic expects unsigned)
535538
static cir::FuncType getIntrinsicType(mlir::MLIRContext *context,
536-
llvm::Intrinsic::ID id) {
539+
llvm::Intrinsic::ID id,
540+
const CallExpr *E) {
537541
using namespace llvm::Intrinsic;
538542

543+
// Get the FunctionDecl from the CallExpr
544+
const FunctionDecl *FD = nullptr;
545+
if (const auto *DRE =
546+
dyn_cast<DeclRefExpr>(E->getCallee()->IgnoreImpCasts())) {
547+
FD = dyn_cast<FunctionDecl>(DRE->getDecl());
548+
}
549+
539550
SmallVector<IITDescriptor, 8> table;
540551
getIntrinsicInfoTableEntries(id, table);
541552

542553
ArrayRef<IITDescriptor> tableRef = table;
543554
mlir::Type resultTy = decodeFixedType(tableRef, context);
544555

556+
// Use FunctionDecl return type if available
557+
if (auto intTy = dyn_cast<cir::IntType>(resultTy)) {
558+
if (FD && FD->getReturnType()->isUnsignedIntegerType()) {
559+
resultTy = IntType::get(context, intTy.getWidth(), /*signed=*/false);
560+
}
561+
// Otherwise keep IIT default (signed)
562+
}
563+
545564
SmallVector<mlir::Type, 8> argTypes;
546-
while (!tableRef.empty())
547-
argTypes.push_back(decodeFixedType(tableRef, context));
565+
unsigned argIndex = 0;
566+
while (!tableRef.empty()) {
567+
mlir::Type argTy = decodeFixedType(tableRef, context);
568+
569+
// Adjust argument type signedness based on FunctionDecl parameter
570+
// definition
571+
if (auto intTy = dyn_cast<cir::IntType>(argTy)) {
572+
if (FD && argIndex < FD->getNumParams()) {
573+
QualType paramType = FD->getParamDecl(argIndex)->getType();
574+
if (paramType->isUnsignedIntegerType()) {
575+
argTy = IntType::get(context, intTy.getWidth(), /*signed=*/false);
576+
}
577+
// Otherwise keep IIT default (signed)
578+
}
579+
// If no FunctionDecl, keep IIT default (signed)
580+
}
581+
582+
argTypes.push_back(argTy);
583+
argIndex++;
584+
}
548585

549586
return FuncType::get(argTypes, resultTy);
550587
}
@@ -2726,16 +2763,34 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
27262763
assert(false && "bad intrinsic name!");
27272764

27282765
cir::FuncType intrinsicType =
2729-
getIntrinsicType(&getMLIRContext(), intrinsicID);
2766+
getIntrinsicType(&getMLIRContext(), intrinsicID, E);
27302767

27312768
SmallVector<mlir::Value> args;
27322769
for (unsigned i = 0; i < E->getNumArgs(); i++) {
2733-
mlir::Value arg = emitScalarOrConstFoldImmArg(iceArguments, i, E);
2734-
mlir::Type argType = arg.getType();
2735-
if (argType != intrinsicType.getInput(i))
2736-
llvm_unreachable("NYI");
2770+
mlir::Value argValue = emitScalarOrConstFoldImmArg(iceArguments, i, E);
2771+
// If the intrinsic arg type is different from the builtin arg type
2772+
// we need to do a bit cast.
2773+
mlir::Type argType = argValue.getType();
2774+
mlir::Type expectedTy = intrinsicType.getInput(i);
2775+
if (argType != expectedTy) {
2776+
// XXX - vector of pointers?
2777+
if (cir::PointerType expectedPtrTy =
2778+
dyn_cast<cir::PointerType>(expectedTy)) {
2779+
if (cir::PointerType argPtrTy = dyn_cast<cir::PointerType>(argType)) {
2780+
if (expectedPtrTy.getAddrSpace() != argPtrTy.getAddrSpace()) {
2781+
argValue = builder.createAddrSpaceCast(
2782+
getLoc(E->getExprLoc()), argValue,
2783+
cir::PointerType::get(expectedPtrTy,
2784+
expectedPtrTy.getAddrSpace()));
2785+
}
2786+
}
2787+
}
2788+
// TODO(cir): Cast vector type (e.g., v256i32) to x86_amx, this only
2789+
// happens in amx intrinsics.
2790+
argValue = builder.createBitcast(argValue, expectedTy);
2791+
}
27372792

2738-
args.push_back(arg);
2793+
args.push_back(argValue);
27392794
}
27402795

27412796
auto intrinsicCall = builder.create<cir::LLVMIntrinsicCallOp>(
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_70 \
2+
// RUN: -fcuda-is-device -target-feature +ptx60 \
3+
// RUN: -emit-cir -o - -x cuda %s \
4+
// RUN: | FileCheck -check-prefix=CIR %s
5+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_80 \
6+
// RUN: -fcuda-is-device -target-feature +ptx65 \
7+
// RUN: -emit-cir -o - -x cuda %s \
8+
// RUN: | FileCheck -check-prefix=CIR %s
9+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_80 \
10+
// RUN: -fcuda-is-device -target-feature +ptx70 \
11+
// RUN: -emit-cir -o - -x cuda %s \
12+
// RUN: | FileCheck -check-prefix=CIR %s
13+
14+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_70 \
15+
// RUN: -fcuda-is-device -target-feature +ptx60 \
16+
// RUN: -emit-llvm -o - -x cuda %s \
17+
// RUN: | FileCheck -check-prefix=LLVM %s
18+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_80 \
19+
// RUN: -fcuda-is-device -target-feature +ptx65 \
20+
// RUN: -emit-llvm -o - -x cuda %s \
21+
// RUN: | FileCheck -check-prefix=LLVM %s
22+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -fclangir -target-cpu sm_80 \
23+
// RUN: -fcuda-is-device -target-feature +ptx70 \
24+
// RUN: -emit-llvm -o - -x cuda %s \
25+
// RUN: | FileCheck -check-prefix=LLVM %s
26+
27+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-cpu sm_70 \
28+
// RUN: -fcuda-is-device -target-feature +ptx60 \
29+
// RUN: -emit-llvm -o - -x cuda %s \
30+
// RUN: | FileCheck -check-prefix=OGCHECK %s
31+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-cpu sm_80 \
32+
// RUN: -fcuda-is-device -target-feature +ptx65 \
33+
// RUN: -emit-llvm -o - -x cuda %s \
34+
// RUN: | FileCheck -check-prefix=OGCHECK %s
35+
// RUN: %clang_cc1 -triple nvptx64-nvidia-cuda -target-cpu sm_80 \
36+
// RUN: -fcuda-is-device -target-feature +ptx70 \
37+
// RUN: -emit-llvm -o - -x cuda %s \
38+
// RUN: | FileCheck -check-prefix=OGCHECK %s
39+
40+
#define __device__ __attribute__((device))
41+
#define __global__ __attribute__((global))
42+
#define __shared__ __attribute__((shared))
43+
#define __constant__ __attribute__((constant))
44+
45+
typedef unsigned long long uint64_t;
46+
47+
__device__ void nvvm_sync(unsigned mask, int i, float f, int a, int b,
48+
bool pred, uint64_t i64) {
49+
50+
// CIR: cir.llvm.intrinsic "nvvm.bar.warp.sync" {{.*}} : (!u32i)
51+
// LLVM: call void @llvm.nvvm.bar.warp.sync(i32
52+
// OGCHECK: call void @llvm.nvvm.bar.warp.sync(i32
53+
__nvvm_bar_warp_sync(mask);
54+
55+
56+
// CHECK: ret void
57+
}

0 commit comments

Comments
 (0)