Skip to content

Commit 43a94bb

Browse files
committed
[CIR][Transforms] Introduce StdVectorCtorOp & StdVectorDtorOp
1 parent b647f4b commit 43a94bb

File tree

7 files changed

+142
-24
lines changed

7 files changed

+142
-24
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,16 +1288,19 @@ def CIR_CXXCtorAttr : CIR_Attr<"CXXCtor", "cxx_ctor"> {
12881288
The `copy` kind is used if the constructor is a copy constructor.
12891289
}];
12901290
let parameters = (ins "mlir::Type":$type,
1291-
EnumParameter<CIR_CtorKind>:$ctorKind);
1291+
EnumParameter<CIR_CtorKind>:$ctorKind,
1292+
OptionalParameter<"std::optional<const clang::CXXRecordDecl *>">:$recordDecl);
12921293

12931294
let assemblyFormat = [{
12941295
`<` $type `,` $ctorKind `>`
12951296
}];
12961297

12971298
let builders = [
12981299
AttrBuilderWithInferredContext<(ins "mlir::Type":$type,
1299-
CArg<"CtorKind", "cir::CtorKind::Custom">:$ctorKind), [{
1300-
return $_get(type.getContext(), type, ctorKind);
1300+
CArg<"CtorKind", "cir::CtorKind::Custom">:$ctorKind,
1301+
CArg<"std::optional<const clang::CXXRecordDecl *>",
1302+
"std::nullopt">:$recordDecl), [{
1303+
return $_get(type.getContext(), type, ctorKind, recordDecl);
13011304
}]>
13021305
];
13031306
}
@@ -1307,15 +1310,18 @@ def CIR_CXXDtorAttr : CIR_Attr<"CXXDtor", "cxx_dtor"> {
13071310
let description = [{
13081311
Functions with this attribute are CXX destructors
13091312
}];
1310-
let parameters = (ins "mlir::Type":$type);
1313+
let parameters = (ins "mlir::Type":$type,
1314+
OptionalParameter<"std::optional<const clang::CXXRecordDecl *>">:$recordDecl);
13111315

13121316
let assemblyFormat = [{
13131317
`<` $type `>`
13141318
}];
13151319

13161320
let builders = [
1317-
AttrBuilderWithInferredContext<(ins "mlir::Type":$type), [{
1318-
return $_get(type.getContext(), type);
1321+
AttrBuilderWithInferredContext<(ins "mlir::Type":$type,
1322+
CArg<"std::optional<const clang::CXXRecordDecl *>",
1323+
"std::nullopt">:$recordDecl), [{
1324+
return $_get(type.getContext(), type, recordDecl);
13191325
}]>
13201326
];
13211327
}

clang/include/clang/CIR/Dialect/IR/CIRStdOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,11 @@ def CIR_IterBeginOp: CIR_StdOp<"begin",
6161
def CIR_IterEndOp: CIR_StdOp<"end",
6262
(ins CIR_AnyType:$container),
6363
(outs CIR_AnyType:$result)>;
64+
def CIR_StdVectorCtorOp: CIR_StdOp<"vector_cxx_ctor",
65+
(ins CIR_AnyType:$first),
66+
(outs Optional<CIR_AnyType>:$result)>;
67+
def CIR_StdVectorDtorOp: CIR_StdOp<"vector_cxx_dtor",
68+
(ins CIR_AnyType:$first),
69+
(outs Optional<CIR_AnyType>:$result)>;
6470

6571
#endif

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,8 @@ cir::FuncOp CIRGenFunction::generateCode(clang::GlobalDecl gd, cir::FuncOp fn,
766766
assert(!cir::MissingFeatures::shouldInstrumentFunction());
767767
if (auto dtor = dyn_cast<CXXDestructorDecl>(fd)) {
768768
auto cxxDtor = cir::CXXDtorAttr::get(
769-
convertType(getContext().getRecordType(dtor->getParent())));
769+
convertType(getContext().getRecordType(dtor->getParent())),
770+
dtor->getParent());
770771
fn.setCxxSpecialMemberAttr(cxxDtor);
771772

772773
emitDestructorBody(args);
@@ -778,7 +779,8 @@ cir::FuncOp CIRGenFunction::generateCode(clang::GlobalDecl gd, cir::FuncOp fn,
778779
ctorKind = cir::CtorKind::Copy;
779780

780781
auto cxxCtor = cir::CXXCtorAttr::get(
781-
convertType(getContext().getRecordType(ctor->getParent())), ctorKind);
782+
convertType(getContext().getRecordType(ctor->getParent())), ctorKind,
783+
ctor->getParent());
782784
fn.setCxxSpecialMemberAttr(cxxCtor);
783785

784786
emitConstructorBody(args);

clang/lib/CIR/CodeGen/CIRGenModule.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2738,7 +2738,8 @@ cir::FuncOp CIRGenModule::createCIRFunction(mlir::Location loc, StringRef name,
27382738
if (fd) {
27392739
if (auto dtor = dyn_cast<CXXDestructorDecl>(fd)) {
27402740
auto cxxDtor = cir::CXXDtorAttr::get(
2741-
convertType(getASTContext().getRecordType(dtor->getParent())));
2741+
convertType(getASTContext().getRecordType(dtor->getParent())),
2742+
dtor->getParent());
27422743
f.setCxxSpecialMemberAttr(cxxDtor);
27432744
}
27442745

@@ -2751,7 +2752,7 @@ cir::FuncOp CIRGenModule::createCIRFunction(mlir::Location loc, StringRef name,
27512752

27522753
auto cxxCtor = cir::CXXCtorAttr::get(
27532754
convertType(getASTContext().getRecordType(ctor->getParent())),
2754-
ctorKind);
2755+
ctorKind, ctor->getParent());
27552756
f.setCxxSpecialMemberAttr(cxxCtor);
27562757
}
27572758
}

clang/lib/CIR/Dialect/Transforms/IdiomRecognizer.cpp

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/ADT/StringMap.h"
2222
#include "llvm/ADT/StringRef.h"
2323
#include "llvm/ADT/Twine.h"
24+
#include "llvm/ADT/TypeSwitch.h"
2425
#include "llvm/Support/ErrorHandling.h"
2526
#include "llvm/Support/Path.h"
2627

@@ -42,24 +43,66 @@ template <typename TargetOp> class StdRecognizer {
4243
template <size_t... Indices>
4344
static TargetOp buildCall(CIRBaseBuilderTy &builder, CallOp call,
4445
std::index_sequence<Indices...>) {
45-
return builder.create<TargetOp>(call.getLoc(), call.getResult().getType(),
46-
call.getCalleeAttr(),
47-
call.getOperand(Indices)...);
46+
return builder.create<TargetOp>(
47+
call.getLoc(),
48+
(call.getResult() ? call.getResult().getType() : mlir::TypeRange{}),
49+
call.getCalleeAttr(), call.getOperand(Indices)...);
4850
}
4951

5052
public:
51-
static bool raise(CallOp call, mlir::MLIRContext &context, bool remark) {
53+
static FuncOp getCalleeFromSymbol(mlir::ModuleOp theModule,
54+
llvm::StringRef name) {
55+
auto global = mlir::SymbolTable::lookupSymbolIn(theModule, name);
56+
assert(global && "expected to find symbol for function");
57+
return dyn_cast<FuncOp>(global);
58+
}
59+
60+
static std::optional<StringRef>
61+
getRecordName(const clang::CXXRecordDecl *rd) {
62+
if (!rd || !rd->getDeclContext()->isStdNamespace())
63+
return std::nullopt;
64+
65+
if (rd->getDeclName().isIdentifier())
66+
return rd->getName();
67+
68+
return std::nullopt;
69+
}
70+
71+
static std::optional<std::string>
72+
resolveSpecialMember(mlir::Attribute specialMember) {
73+
return TypeSwitch<Attribute, std::optional<std::string>>(specialMember)
74+
.Case<CXXCtorAttr, CXXDtorAttr>(
75+
[](auto attr) -> std::optional<std::string> {
76+
if (!attr.getRecordDecl())
77+
return std::nullopt;
78+
if (auto recordName = getRecordName(*attr.getRecordDecl()))
79+
return recordName->str() + "_" + attr.getMnemonic().str();
80+
return std::nullopt;
81+
})
82+
.Default([](Attribute) { return std::nullopt; });
83+
}
84+
85+
static bool raise(mlir::ModuleOp theModule, CallOp call,
86+
mlir::MLIRContext &context, bool remark) {
5287
constexpr int numArgs = TargetOp::getNumArgs();
5388
if (call.getNumOperands() != numArgs)
5489
return false;
5590

56-
auto callExprAttr = call.getAstAttr();
5791
llvm::StringRef stdFuncName = TargetOp::getFunctionName();
58-
if (!callExprAttr || !callExprAttr.isStdFunctionCall(stdFuncName))
59-
return false;
60-
61-
if (!checkArguments(call.getArgOperands()))
62-
return false;
92+
auto calleeFunc = getCalleeFromSymbol(theModule, *call.getCallee());
93+
94+
if (auto specialMember = calleeFunc.getCxxSpecialMemberAttr()) {
95+
auto resolved = resolveSpecialMember(specialMember);
96+
if (!resolved || *resolved != stdFuncName.str())
97+
return false;
98+
} else {
99+
auto callExprAttr = call.getAstAttr();
100+
if (!callExprAttr || !callExprAttr.isStdFunctionCall(stdFuncName))
101+
return false;
102+
103+
if (!checkArguments(call.getArgOperands()))
104+
return false;
105+
}
63106

64107
if (remark)
65108
mlir::emitRemark(call.getLoc())
@@ -194,12 +237,16 @@ void IdiomRecognizerPass::recognizeCall(CallOp call) {
194237

195238
bool remark = opts.emitRemarkFoundCalls();
196239

197-
using StdFunctionsRecognizer = std::tuple<StdRecognizer<StdFindOp>>;
240+
using StdFunctionsRecognizer =
241+
std::tuple<StdRecognizer<StdFindOp>, StdRecognizer<StdVectorCtorOp>,
242+
StdRecognizer<StdVectorDtorOp>>;
198243

199244
// MSVC requires explicitly capturing these variables.
200245
std::apply(
201246
[&, call, remark, this](auto... recognizers) {
202-
(decltype(recognizers)::raise(call, this->getContext(), remark) || ...);
247+
(decltype(recognizers)::raise(theModule, call, this->getContext(),
248+
remark) ||
249+
...);
203250
},
204251
StdFunctionsRecognizer());
205252
}

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
8484
void lowerGlobalOp(GlobalOp op);
8585
void lowerDynamicCastOp(DynamicCastOp op);
8686
void lowerStdFindOp(StdFindOp op);
87+
void lowerStdVectorCtorOp(StdVectorCtorOp op);
88+
void lowerStdVectorDtorOp(StdVectorDtorOp op);
8789
void lowerIterBeginOp(IterBeginOp op);
8890
void lowerIterEndOp(IterEndOp op);
8991
void lowerToMemCpy(StoreOp op);
@@ -1474,6 +1476,28 @@ void LoweringPreparePass::lowerStdFindOp(StdFindOp op) {
14741476
op.erase();
14751477
}
14761478

1479+
void LoweringPreparePass::lowerStdVectorCtorOp(StdVectorCtorOp op) {
1480+
CIRBaseBuilderTy builder(getContext());
1481+
builder.setInsertionPointAfter(op.getOperation());
1482+
auto call =
1483+
builder.createCallOp(op.getLoc(), op.getOriginalFnAttr(), mlir::Type{},
1484+
mlir::ValueRange{op.getOperand()});
1485+
1486+
op.replaceAllUsesWith(call);
1487+
op.erase();
1488+
}
1489+
1490+
void LoweringPreparePass::lowerStdVectorDtorOp(StdVectorDtorOp op) {
1491+
CIRBaseBuilderTy builder(getContext());
1492+
builder.setInsertionPointAfter(op.getOperation());
1493+
auto call =
1494+
builder.createCallOp(op.getLoc(), op.getOriginalFnAttr(), mlir::Type{},
1495+
mlir::ValueRange{op.getOperand()});
1496+
1497+
op.replaceAllUsesWith(call);
1498+
op.erase();
1499+
}
1500+
14771501
void LoweringPreparePass::lowerIterBeginOp(IterBeginOp op) {
14781502
CIRBaseBuilderTy builder(getContext());
14791503
builder.setInsertionPointAfter(op.getOperation());
@@ -1585,6 +1609,10 @@ void LoweringPreparePass::runOnOp(Operation *op) {
15851609
lowerDynamicCastOp(dynamicCast);
15861610
} else if (auto stdFind = dyn_cast<StdFindOp>(op)) {
15871611
lowerStdFindOp(stdFind);
1612+
} else if (auto stdVectorCtorOp = dyn_cast<StdVectorCtorOp>(op)) {
1613+
lowerStdVectorCtorOp(stdVectorCtorOp);
1614+
} else if (auto stdVectorDtorOp = dyn_cast<StdVectorDtorOp>(op)) {
1615+
lowerStdVectorDtorOp(stdVectorDtorOp);
15881616
} else if (auto iterBegin = dyn_cast<IterBeginOp>(op)) {
15891617
lowerIterBeginOp(iterBegin);
15901618
} else if (auto iterEnd = dyn_cast<IterEndOp>(op)) {
@@ -1630,8 +1658,9 @@ void LoweringPreparePass::runOnOperation() {
16301658

16311659
op->walk([&](Operation *op) {
16321660
if (isa<UnaryOp, BinOp, CastOp, ComplexBinOp, CmpThreeWayOp, VAArgOp,
1633-
GlobalOp, DynamicCastOp, StdFindOp, IterEndOp, IterBeginOp,
1634-
ArrayCtor, ArrayDtor, cir::FuncOp, StoreOp, ThrowOp, CallOp>(op))
1661+
GlobalOp, DynamicCastOp, StdFindOp, StdVectorCtorOp,
1662+
StdVectorDtorOp, IterEndOp, IterBeginOp, ArrayCtor, ArrayDtor,
1663+
cir::FuncOp, StoreOp, ThrowOp, CallOp>(op))
16351664
opsToTransform.push_back(op);
16361665
});
16371666

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fclangir-idiom-recognizer -emit-cir -I%S/../Inputs -mmlir --mlir-print-ir-after-all %s -o - 2>&1 | FileCheck %s -check-prefix=PASS_ENABLED
2+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -fclangir -emit-cir -I%S/../Inputs -fclangir-idiom-recognizer="remarks=found-calls" -clangir-verify-diagnostics %s -o %t.cir
3+
4+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -fclangir -fclangir-idiom-recognizer -emit-cir -I%S/../Inputs -mmlir --mlir-print-ir-before=cir-idiom-recognizer %s -o - 2>&1 | FileCheck %s -check-prefix=BEFORE-IDIOM
5+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -fclangir -fclangir-idiom-recognizer -emit-cir -I%S/../Inputs -mmlir --mlir-print-ir-after=cir-idiom-recognizer %s -o - 2>&1 | FileCheck %s -check-prefix=AFTER-IDIOM
6+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -fclangir -fclangir-idiom-recognizer -emit-cir -I%S/../Inputs -mmlir --mlir-print-ir-after=cir-lowering-prepare %s -o - 2>&1 | FileCheck %s -check-prefix=AFTER-LOWERING-PREPARE
7+
8+
// PASS_ENABLED: IR Dump After IdiomRecognizer (cir-idiom-recognizer)
9+
10+
namespace std {
11+
template <typename T> class vector {
12+
public:
13+
vector() {} // expected-remark {{found call to std::vector_cxx_ctor()}}
14+
~vector() {}; // expected-remark{{found call to std::vector_cxx_dtor()}}
15+
};
16+
}; // namespace std
17+
18+
void vector_test() {
19+
std::vector<int> v; // expected-remark {{found call to std::vector_cxx_ctor()}}
20+
21+
// BEFORE-IDIOM: cir.call @_ZNSt6vectorIiEC1Ev(
22+
// BEFORE-IDIOM: cir.call @_ZNSt6vectorIiED1Ev(
23+
// AFTER-IDIOM: cir.std.vector_cxx_ctor(
24+
// AFTER-IDIOM: cir.std.vector_cxx_dtor(
25+
// AFTER-LOWERING-PREPARE: cir.call @_ZNSt6vectorIiEC1Ev(
26+
// AFTER-LOWERING-PREPARE: cir.call @_ZNSt6vectorIiED1Ev(
27+
}

0 commit comments

Comments
 (0)