@@ -532,19 +532,56 @@ decodeFixedType(ArrayRef<llvm::Intrinsic::IITDescriptor> &infos,
532532}
533533
534534// llvm::Intrinsics accepts only LLVMContext. We need to reimplement it here.
535+ // Takes AST information to infer correct signedness when
536+ // IIT returns signed but AST shows unsigned in Function Decl (indicating
537+ // intrinsic expects unsigned)
535538static cir::FuncType getIntrinsicType (mlir::MLIRContext *context,
536- llvm::Intrinsic::ID id) {
539+ llvm::Intrinsic::ID id,
540+ const CallExpr *E) {
537541 using namespace llvm ::Intrinsic;
538542
543+ // Get the FunctionDecl from the CallExpr
544+ const FunctionDecl *FD = nullptr ;
545+ if (const auto *DRE =
546+ dyn_cast<DeclRefExpr>(E->getCallee ()->IgnoreImpCasts ())) {
547+ FD = dyn_cast<FunctionDecl>(DRE->getDecl ());
548+ }
549+
539550 SmallVector<IITDescriptor, 8 > table;
540551 getIntrinsicInfoTableEntries (id, table);
541552
542553 ArrayRef<IITDescriptor> tableRef = table;
543554 mlir::Type resultTy = decodeFixedType (tableRef, context);
544555
556+ // Use FunctionDecl return type if available
557+ if (auto intTy = dyn_cast<cir::IntType>(resultTy)) {
558+ if (FD && FD->getReturnType ()->isUnsignedIntegerType ()) {
559+ resultTy = IntType::get (context, intTy.getWidth (), /* signed=*/ false );
560+ }
561+ // Otherwise keep IIT default (signed)
562+ }
563+
545564 SmallVector<mlir::Type, 8 > argTypes;
546- while (!tableRef.empty ())
547- argTypes.push_back (decodeFixedType (tableRef, context));
565+ unsigned argIndex = 0 ;
566+ while (!tableRef.empty ()) {
567+ mlir::Type argTy = decodeFixedType (tableRef, context);
568+
569+ // Adjust argument type signedness based on FunctionDecl parameter
570+ // definition
571+ if (auto intTy = dyn_cast<cir::IntType>(argTy)) {
572+ if (FD && argIndex < FD->getNumParams ()) {
573+ QualType paramType = FD->getParamDecl (argIndex)->getType ();
574+ if (paramType->isUnsignedIntegerType ()) {
575+ argTy = IntType::get (context, intTy.getWidth (), /* signed=*/ false );
576+ }
577+ // Otherwise keep IIT default (signed)
578+ }
579+ // If no FunctionDecl, keep IIT default (signed)
580+ }
581+
582+ argTypes.push_back (argTy);
583+ argIndex++;
584+ }
548585
549586 return FuncType::get (argTypes, resultTy);
550587}
@@ -2726,16 +2763,34 @@ RValue CIRGenFunction::emitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
27262763 assert (false && " bad intrinsic name!" );
27272764
27282765 cir::FuncType intrinsicType =
2729- getIntrinsicType (&getMLIRContext (), intrinsicID);
2766+ getIntrinsicType (&getMLIRContext (), intrinsicID, E );
27302767
27312768 SmallVector<mlir::Value> args;
27322769 for (unsigned i = 0 ; i < E->getNumArgs (); i++) {
2733- mlir::Value arg = emitScalarOrConstFoldImmArg (iceArguments, i, E);
2734- mlir::Type argType = arg.getType ();
2735- if (argType != intrinsicType.getInput (i))
2736- llvm_unreachable (" NYI" );
2770+ mlir::Value argValue = emitScalarOrConstFoldImmArg (iceArguments, i, E);
2771+ // If the intrinsic arg type is different from the builtin arg type
2772+ // we need to do a bit cast.
2773+ mlir::Type argType = argValue.getType ();
2774+ mlir::Type expectedTy = intrinsicType.getInput (i);
2775+ if (argType != expectedTy) {
2776+ // XXX - vector of pointers?
2777+ if (cir::PointerType expectedPtrTy =
2778+ dyn_cast<cir::PointerType>(expectedTy)) {
2779+ if (cir::PointerType argPtrTy = dyn_cast<cir::PointerType>(argType)) {
2780+ if (expectedPtrTy.getAddrSpace () != argPtrTy.getAddrSpace ()) {
2781+ argValue = builder.createAddrSpaceCast (
2782+ getLoc (E->getExprLoc ()), argValue,
2783+ cir::PointerType::get (expectedPtrTy,
2784+ expectedPtrTy.getAddrSpace ()));
2785+ }
2786+ }
2787+ }
2788+ // TODO(cir): Cast vector type (e.g., v256i32) to x86_amx, this only
2789+ // happens in amx intrinsics.
2790+ argValue = builder.createBitcast (argValue, expectedTy);
2791+ }
27372792
2738- args.push_back (arg );
2793+ args.push_back (argValue );
27392794 }
27402795
27412796 auto intrinsicCall = builder.create <cir::LLVMIntrinsicCallOp>(
0 commit comments