Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 89 additions & 5 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1488,6 +1488,8 @@ struct ConvertCIRToLLVMPass
llvm::StringMap<mlir::LLVM::GlobalOp> &argStringGlobalsMap,
llvm::MapVector<mlir::ArrayAttr, mlir::LLVM::GlobalOp> &argsVarMap);

void resolveBlockAddressOp(LLVMBlockAddressInfo &blockInfoAddr);

void processCIRAttrs(mlir::ModuleOp moduleOp);

StringRef getDescription() const override {
Expand Down Expand Up @@ -4496,16 +4498,74 @@ mlir::LogicalResult CIRToLLVMLinkerOptionsOpLowering::matchAndRewrite(
return mlir::success();
}

mlir::LogicalResult CIRToLLVMLabelOpLowering::matchAndRewrite(
cir::LabelOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::MLIRContext *ctx = rewriter.getContext();
mlir::Block *block = op->getBlock();
// A BlockTagOp cannot reside in the entry block. The address of the entry
// block cannot be taken
if (block->isEntryBlock()) {
mlir::Block *newBlock =
rewriter.splitBlock(op->getBlock(), mlir::Block::iterator(op));
rewriter.setInsertionPointToEnd(block);
mlir::LLVM::BrOp::create(rewriter, op.getLoc(), newBlock);
}
auto tagAttr =
mlir::LLVM::BlockTagAttr::get(ctx, blockInfoAddr.getTagIndex());
rewriter.setInsertionPoint(op);

auto blockTagOp =
mlir::LLVM::BlockTagOp::create(rewriter, op->getLoc(), tagAttr);
auto func = op->getParentOfType<mlir::LLVM::LLVMFuncOp>();
blockInfoAddr.mapBlockTag(func.getSymName(), op.getLabel(), blockTagOp);
rewriter.eraseOp(op);

return mlir::success();
}

mlir::LogicalResult CIRToLLVMBlockAddressOpLowering::matchAndRewrite(
cir::BlockAddressOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::MLIRContext *ctx = rewriter.getContext();

mlir::LLVM::BlockTagOp matchLabel =
blockInfoAddr.lookupBlockTag(op.getFunc(), op.getLabel());
mlir::LLVM::BlockTagAttr tagAttr;
if (!matchLabel)
// If the BlockTagOp has not been emitted yet, use a placeholder.
// This will later be replaced with the correct tag index during
// `resolveBlockAddressOp`.
tagAttr = {};
else
tagAttr = matchLabel.getTag();

auto blkAddr = mlir::LLVM::BlockAddressAttr::get(rewriter.getContext(),
op.getFuncAttr(), tagAttr);
rewriter.setInsertionPoint(op);
auto newOp = mlir::LLVM::BlockAddressOp::create(
rewriter, op.getLoc(), mlir::LLVM::LLVMPointerType::get(ctx), blkAddr);
if (!matchLabel)
blockInfoAddr.addUnresolvedBlockAddress(newOp, op.getFunc(), op.getLabel());
rewriter.replaceOp(op, newOp);
return mlir::success();
}

void populateCIRToLLVMConversionPatterns(
mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter,
mlir::DataLayout &dataLayout, cir::LowerModule *lowerModule,
llvm::StringMap<mlir::LLVM::GlobalOp> &stringGlobalsMap,
llvm::StringMap<mlir::LLVM::GlobalOp> &argStringGlobalsMap,
llvm::MapVector<mlir::ArrayAttr, mlir::LLVM::GlobalOp> &argsVarMap) {
llvm::MapVector<mlir::ArrayAttr, mlir::LLVM::GlobalOp> &argsVarMap,
LLVMBlockAddressInfo &blockAddrInfo) {
patterns.add<CIRToLLVMReturnOpLowering>(patterns.getContext());
patterns.add<CIRToLLVMAllocaOpLowering>(converter, dataLayout,
stringGlobalsMap, argStringGlobalsMap,
argsVarMap, patterns.getContext());
patterns.add<CIRToLLVMBlockAddressOpLowering>(
converter, patterns.getContext(), blockAddrInfo);
patterns.add<CIRToLLVMLabelOpLowering>(converter, patterns.getContext(),
blockAddrInfo);
patterns.add<
// clang-format off
CIRToLLVMCastOpLowering,
Expand Down Expand Up @@ -4990,6 +5050,25 @@ void ConvertCIRToLLVMPass::buildGlobalAnnotationsVar(
}
}

void ConvertCIRToLLVMPass::resolveBlockAddressOp(
LLVMBlockAddressInfo &blockInfoAddr) {

mlir::ModuleOp module = getOperation();
mlir::OpBuilder opBuilder(module.getContext());
for (auto &[blockAddOp, blockInfo] :
blockInfoAddr.getUnresolvedBlockAddress()) {
mlir::LLVM::BlockTagOp resolvedLabel =
blockInfoAddr.lookupBlockTag(blockInfo.first, blockInfo.second);
assert(resolvedLabel && "expected BlockTagOp to already be emitted");
auto fnSym =
mlir::FlatSymbolRefAttr::get(module.getContext(), blockInfo.first);
auto blkAddTag = mlir::LLVM::BlockAddressAttr::get(
opBuilder.getContext(), fnSym, resolvedLabel.getTagAttr());
blockAddOp.setBlockAddrAttr(blkAddTag);
}
blockInfoAddr.clearUnresolvedMap();
}

void ConvertCIRToLLVMPass::processCIRAttrs(mlir::ModuleOp module) {
// Lower the module attributes to LLVM equivalents.
if (auto tripleAttr = module->getAttr(cir::CIRDialect::getTripleAttrName()))
Expand Down Expand Up @@ -5021,10 +5100,15 @@ void ConvertCIRToLLVMPass::runOnOperation() {
llvm::StringMap<mlir::LLVM::GlobalOp> argStringGlobalsMap;
// Track globals created for annotation args.
llvm::MapVector<mlir::ArrayAttr, mlir::LLVM::GlobalOp> argsVarMap;
/// Tracks the state required to lower CIR `LabelOp` and `BlockAddressOp`.
/// Maps labels to their corresponding `BlockTagOp` and keeps bookkeeping
/// of unresolved `BlockAddressOp`s until they are matched with the
/// corresponding `BlockTagOp` in `resolveBlockAddressOp`.
LLVMBlockAddressInfo blockInfoAddr;
populateCIRToLLVMConversionPatterns(
patterns, converter, dataLayout, lowerModule.get(), stringGlobalsMap,
argStringGlobalsMap, argsVarMap, blockInfoAddr);

populateCIRToLLVMConversionPatterns(patterns, converter, dataLayout,
lowerModule.get(), stringGlobalsMap,
argStringGlobalsMap, argsVarMap);
mlir::populateFuncToLLVMConversionPatterns(converter, patterns);

mlir::ConversionTarget target(getContext());
Expand Down Expand Up @@ -5078,7 +5162,7 @@ void ConvertCIRToLLVMPass::runOnOperation() {
dtorAttr.getPriority());
});
buildGlobalAnnotationsVar(stringGlobalsMap, argStringGlobalsMap, argsVarMap);

resolveBlockAddressOp(blockInfoAddr);
processCIRAttrs(module);
}

Expand Down
84 changes: 83 additions & 1 deletion clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,60 @@ void buildCtorDtorList(
llvm::function_ref<std::pair<mlir::StringRef, int>(mlir::Attribute)>
createXtor);

struct LLVMBlockAddressInfo {
// Get the next tag index
uint32_t getTagIndex() { return blockTagOpIndex++; }

void mapBlockTag(llvm::StringRef func, llvm::StringRef label,
mlir::LLVM::BlockTagOp tagOp) {
auto result = blockInfoToTagOp.try_emplace({func, label}, tagOp);
assert(result.second &&
"attempting to map a BlockTag operation that is already mapped");
}

// Lookup a BlockTagOp, may return nullptr if not yet registered.
mlir::LLVM::BlockTagOp lookupBlockTag(llvm::StringRef func,
llvm::StringRef label) const {
return blockInfoToTagOp.lookup({func, label});
}

// Record an unresolved BlockAddressOp that needs patching later.
void addUnresolvedBlockAddress(mlir::LLVM::BlockAddressOp op,
llvm::StringRef func, llvm::StringRef label) {
unresolvedBlockAddressOp.try_emplace(op, std::make_pair(func, label));
}

void clearUnresolvedMap() { unresolvedBlockAddressOp.clear(); }

llvm::DenseMap<mlir::LLVM::BlockAddressOp,
std::pair<llvm::StringRef, llvm::StringRef>> &
getUnresolvedBlockAddress() {
return unresolvedBlockAddressOp;
}

private:
// Maps a (function name, label name) pair to the corresponding BlockTagOp.
// Used to resolve CIR LabelOps into their LLVM BlockTagOp.
llvm::DenseMap<std::pair<llvm::StringRef, llvm::StringRef>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of std::pair<llvm::StringRef, llvm::StringRef>, it might be more MLIR and clean to create an attribute that wraps these two, just like LLVM dialect does. Can you add the attribute? It's fine if you do that in a follow up PR though, just let me know.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’d prefer to handle this in a future PR. Also, what do you think about changing LabelOp to always use this attribute? I think that would be better, because right now I’m looking at IndirectBrOpand with what we have now, I imagine it could look something like this:

cir.blockaddress(@B, "B") -> !cir.ptr<!void> // here we just have a string, not the actual label

^bb2
  cir.label "A"
  ...

^bb3
  cir.label "B"
  ...

^indirectBr(%addr : !cir.ptr<!void>):  // this would act like a PHI node for each block address
  cir.indirectBr %addr : !cir.ptr<!void> [ ^bb2, ^bb3 ] // we don’t have enough info to get the block of the label

(This is just a rough idea for now.)

The problem is that each blockaddress only has a string, so we don’t actually know the block associated with the label. I think with an attribute we could link the label to the block—but I don’t fully understand how that would work yet.
(@andykaylor, I’d appreciate your input on this as welll)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about changing LabelOp to always use this attribute?

Sure thing, once you introduce new attributes they should be used wherever it makes sense.

I imagine it could look something like this:

What's the C/C++ code that might lead to this? Note that in LLVM if you don't have the label hints it's undefined behavior (https://llvm.org/docs/LangRef.html#indirectbr-instruction) so I wouldn't worry about imaginary use cases.

The problem is that each blockaddress only has a string, so we don’t actually know the block associated with the label

For the case where this isn't undefined behavior: the new attribute has to be within the same function and part of one of the cir.labels used. (a) you can walk the blocks and find out (the slow options) the cir.labels, or (b) you can book keep the blocks as you go about emitting the "cir.label" instructions. If it's expensive to find the all block sources, we could just use the superset of all the ones we've seen cir.labels in (but I'd like to see a C/C++ source exercising that)

mlir::LLVM::BlockTagOp>
blockInfoToTagOp;
// Tracks BlockAddressOps that could not yet be fully resolved because
// their BlockTagOp was not available at the time of lowering. The map
// stores the unresolved BlockAddressOp along with its (function name, label
// name) pair so it can be patched later.
llvm::DenseMap<mlir::LLVM::BlockAddressOp,
std::pair<llvm::StringRef, llvm::StringRef>>
unresolvedBlockAddressOp;
int32_t blockTagOpIndex;
};

void populateCIRToLLVMConversionPatterns(
mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter,
mlir::DataLayout &dataLayout,
llvm::StringMap<mlir::LLVM::GlobalOp> &stringGlobalsMap,
llvm::StringMap<mlir::LLVM::GlobalOp> &argStringGlobalsMap,
llvm::MapVector<mlir::ArrayAttr, mlir::LLVM::GlobalOp> &argsVarMap);
llvm::MapVector<mlir::ArrayAttr, mlir::LLVM::GlobalOp> &argsVarMap,
LLVMBlockAddressInfo &blockAddrInfo);

std::unique_ptr<cir::LowerModule> prepareLowerModule(mlir::ModuleOp module);

Expand Down Expand Up @@ -1329,6 +1377,40 @@ class CIRToLLVMLinkerOptionsOpLowering
mlir::ConversionPatternRewriter &rewriter) const override;
};

class CIRToLLVMLabelOpLowering
: public mlir::OpConversionPattern<cir::LabelOp> {
LLVMBlockAddressInfo &blockInfoAddr;

public:
CIRToLLVMLabelOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
LLVMBlockAddressInfo &blockInfoAddr)
: OpConversionPattern<cir::LabelOp>(typeConverter, context),
blockInfoAddr(blockInfoAddr) {}
using mlir::OpConversionPattern<cir::LabelOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::LabelOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMBlockAddressOpLowering
: public mlir::OpConversionPattern<cir::BlockAddressOp> {
LLVMBlockAddressInfo &blockInfoAddr;

public:
CIRToLLVMBlockAddressOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
LLVMBlockAddressInfo &blockInfoAddr)
: OpConversionPattern<cir::BlockAddressOp>(typeConverter, context),
blockInfoAddr(blockInfoAddr) {}
using mlir::OpConversionPattern<cir::BlockAddressOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::BlockAddressOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

mlir::ArrayAttr lowerCIRTBAAAttr(mlir::Attribute tbaa,
mlir::ConversionPatternRewriter &rewriter,
cir::LowerModule *lowerMod);
Expand Down
84 changes: 84 additions & 0 deletions clang/test/CIR/CodeGen/label-values.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t-cir.ll
// RUN: FileCheck --input-file=%t-cir.ll %s --check-prefix=LLVM
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=OGCG

void A(void) {
void *ptr = &&A;
Expand All @@ -14,6 +18,22 @@ void A(void) {
// CIR: ^bb1: // pred: ^bb0
// CIR: cir.label "A"
// CIR: cir.return
//
// LLVM: define dso_local void @A()
// LLVM: [[PTR:%.*]] = alloca ptr, i64 1, align 8
// LLVM: store ptr blockaddress(@A, %[[A:.*]]), ptr [[PTR]], align 8
// LLVM: br label %[[A]]
// LLVM: [[A]]: ; preds = %0
// LLVM: ret void

// OGCG: define dso_local void @A()
// OGCG: [[PTR:%.*]] = alloca ptr, align 8
// OGCG: store ptr blockaddress(@A, %A), ptr [[PTR]], align 8
// OGCG: br label %A
// OGCG: A: ; preds = %entry, %indirectgoto
// OGCG: ret void
// OGCG: indirectgoto: ; No predecessors!
// OGCG: indirectbr ptr poison, [label %A]

void B(void) {
B:
Expand All @@ -27,6 +47,22 @@ void B(void) {
// CIR: cir.store align(8) [[BLOCK]], [[PTR]] : !cir.ptr<!void>, !cir.ptr<!cir.ptr<!void>>
// CIR: cir.return

// LLVM: define dso_local void @B
// LLVM: br label %[[B:.*]]
// LLVM: [[B]]:
// LLVM: %[[PTR:.*]] = alloca ptr, i64 1, align 8
// LLVM: store ptr blockaddress(@B, %[[B]]), ptr %[[PTR]], align 8
// LLVM: ret void

// OGCG: define dso_local void @B
// OGCG: [[PTR:%.*]] = alloca ptr, align 8
// OGCG: br label %B
// OGCG: B: ; preds = %indirectgoto, %entry
// OGCG: store ptr blockaddress(@B, %B), ptr [[PTR]], align 8
// OGCG: ret void
// OGCG: indirectgoto: ; No predecessors!
// OGCG: indirectbr ptr poison, [label %B]

void C(int x) {
void *ptr = (x == 0) ? &&A : &&B;
A:
Expand All @@ -50,6 +86,30 @@ void C(int x) {
// CIR: cir.label "B"
// CIR: cir.br ^bb1

// LLVM: define dso_local void @C(i32 %0)
// LLVM: [[COND:%.*]] = select i1 [[CMP:%.*]], ptr blockaddress(@C, %[[A:.*]]), ptr blockaddress(@C, %[[B:.*]])
// LLVM: store ptr [[COND]], ptr [[PTR:%.*]], align 8
// LLVM: br label %[[A]]
// LLVM: [[RET:.*]]:
// LLVM: ret void
// LLVM: [[A]]:
// LLVM: br label %[[RET]]
// LLVM: [[B]]:
// LLVM: br label %[[RET]]

// OGCG: define dso_local void @C
// OGCG: [[COND:%.*]] = select i1 [[CMP:%.*]], ptr blockaddress(@C, %A), ptr blockaddress(@C, %B)
// OGCG: store ptr [[COND]], ptr [[PTR:%.*]], align 8
// OGCG: br label %A
// OGCG: A: ; preds = %entry, %indirectgoto
// OGCG: br label %return
// OGCG: B: ; preds = %indirectgoto
// OGCG: br label %return
// OGCG: return: ; preds = %B, %A
// OGCG: ret void
// OGCG: indirectgoto: ; No predecessors!
// OGCG: indirectbr ptr poison, [label %A, label %B]

void D(void) {
void *ptr = &&A;
void *ptr2 = &&A;
Expand All @@ -72,3 +132,27 @@ void D(void) {
// CIR: %[[BLK3:.*]] = cir.blockaddress(@D, "A") -> !cir.ptr<!void>
// CIR: cir.store align(8) %[[BLK3]], %[[PTR3]] : !cir.ptr<!void>, !cir.ptr<!cir.ptr<!void>>
// CIR: cir.return

// LLVM: define dso_local void @D
// LLVM: %[[PTR:.*]] = alloca ptr, i64 1, align 8
// LLVM: %[[PTR2:.*]] = alloca ptr, i64 1, align 8
// LLVM: %[[PTR3:.*]] = alloca ptr, i64 1, align 8
// LLVM: store ptr blockaddress(@D, %[[A:.*]]), ptr %[[PTR]], align 8
// LLVM: store ptr blockaddress(@D, %[[A]]), ptr %[[PTR2]], align 8
// LLVM: br label %[[A]]
// LLVM: [[A]]:
// LLVM: store ptr blockaddress(@D, %[[A]]), ptr %[[PTR3]], align 8
// LLVM: ret void

// OGCG: define dso_local void @D
// OGCG: %[[PTR:.*]] = alloca ptr, align 8
// OGCG: %[[PTR2:.*]] = alloca ptr, align 8
// OGCG: %[[PTR3:.*]] = alloca ptr, align 8
// OGCG: store ptr blockaddress(@D, %A), ptr %[[PTR]], align 8
// OGCG: store ptr blockaddress(@D, %A), ptr %[[PTR2]], align 8
// OGCG: br label %A
// OGCG: A:
// OGCG: store ptr blockaddress(@D, %A), ptr %[[PTR3]], align 8
// OGCG: ret void
// OGCG: indirectgoto:
// OGCG: indirectbr ptr poison, [label %A, label %A, label %A]
Loading