Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ cc_library(
deps = [
":ir_visitor",
":module",
":storage",
"//src/common/utils:fold",
"//src/common/utils:types",
],
Expand All @@ -146,6 +147,8 @@ cc_test(
":ir_traversing_visitor",
":module",
"//src/common/testing:gtest",
"//src/ir/types",
"@com_google_absl//absl/types:span",
],
)

Expand Down Expand Up @@ -244,6 +247,7 @@ cc_library(
name = "storage",
hdrs = ["storage.h"],
deps = [
":ir_visitor",
":value",
"//src/ir/types",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down
25 changes: 24 additions & 1 deletion src/ir/ir_traversing_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
#ifndef SRC_IR_IR_TRAVERSING_VISITOR_H_
#define SRC_IR_IR_TRAVERSING_VISITOR_H_

#include <memory>

#include "src/common/utils/fold.h"
#include "src/common/utils/types.h"
#include "src/ir/ir_visitor.h"
#include "src/ir/module.h"
#include "src/ir/storage.h"

namespace raksha::ir {

Expand Down Expand Up @@ -84,10 +87,25 @@ class IRTraversingVisitor
return in_order_result;
}

virtual Result PreVisit(const Storage& storage) { return GetDefaultValue(); }

virtual Result PostVisit(const Storage& storage, Result in_order_result) {
return in_order_result;
}

Result Visit(const Module& module) final override {
Result pre_visit_result = PreVisit(module);
// Visit storages before blocks
Result storage_fold_result = common::utils::fold(
module.named_storage_map(), std::move(pre_visit_result),
[this](Result acc,
const std::pair<const std::string, std::unique_ptr<Storage>>&
name_and_storage) {
return FoldResult(std::move(acc),
name_and_storage.second->Accept(*this));
});
Result fold_result = common::utils::fold(
module.blocks(), std::move(pre_visit_result),
module.blocks(), std::move(storage_fold_result),
[this](Result acc, const std::unique_ptr<Block>& block) {
return FoldResult(std::move(acc), block->Accept(*this));
});
Expand All @@ -108,6 +126,11 @@ class IRTraversingVisitor
Result pre_visit_result = PreVisit(operation);
return PostVisit(operation, std::move(pre_visit_result));
}

Result Visit(const Storage& storage) final override {
Result pre_visit_result = PreVisit(storage);
return PostVisit(storage, std::move(pre_visit_result));
}
};

} // namespace raksha::ir
Expand Down
194 changes: 138 additions & 56 deletions src/ir/ir_traversing_visitor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@
#include "src/ir/ir_traversing_visitor.h"

#include <memory>
#include <vector>

#include "absl/types/span.h"
#include "src/common/testing/gtest.h"
#include "src/ir/block_builder.h"
#include "src/ir/module.h"
#include "src/ir/types/type_factory.h"

namespace raksha::ir {
namespace {

using testing::ElementsAre;
using testing::UnorderedElementsAre;

class WrappedInt {
public:
WrappedInt() = delete;
Expand Down Expand Up @@ -62,9 +68,13 @@ TEST(IRTraversingVisitorTest,
global_module.AddBlock(std::make_unique<Block>());
global_module.AddBlock(std::make_unique<Block>());

types::TypeFactory type_factory;
global_module.CreateStorage("storage1", type_factory.MakePrimitiveType());
global_module.CreateStorage("storage2", type_factory.MakePrimitiveType());

NonDefaultConstructorVisitor counting_visitor;
const WrappedInt result = global_module.Accept(counting_visitor);
EXPECT_EQ(*result, 3);
EXPECT_EQ(*result, 5);
}

enum class TraversalType { kPre = 0x1, kPost = 0x2, kBoth = 0x3 };
Expand Down Expand Up @@ -107,6 +117,16 @@ class CollectingVisitor : public IRTraversingVisitor<CollectingVisitor> {
return result;
}

Unit PreVisit(const Storage& storage) override {
if (pre_visits_) nodes_.push_back(std::addressof(storage));
return Unit();
}

Unit PostVisit(const Storage& storage, Unit result) override {
if (post_visits_) nodes_.push_back(std::addressof(storage));
return result;
}

const std::vector<const void*>& nodes() const { return nodes_; }

private:
Expand All @@ -117,29 +137,56 @@ class CollectingVisitor : public IRTraversingVisitor<CollectingVisitor> {

TEST(IRTraversingVisitorTest, TraversesModuleAsExpected) {
Module global_module;
types::TypeFactory type_factory;
const Storage* storage1 = std::addressof(global_module.CreateStorage(
"storage1", type_factory.MakePrimitiveType()));
const Storage* storage2 = std::addressof(global_module.CreateStorage(
"storage2", type_factory.MakePrimitiveType()));
const Block* block1 =
std::addressof(global_module.AddBlock(std::make_unique<Block>()));
const Block* block2 =
std::addressof(global_module.AddBlock(std::make_unique<Block>()));

CollectingVisitor preorder_visitor(TraversalType::kPre);
global_module.Accept(preorder_visitor);
EXPECT_THAT(
preorder_visitor.nodes(),
testing::ElementsAre(std::addressof(global_module), block1, block2));
{
CollectingVisitor preorder_visitor(TraversalType::kPre);
global_module.Accept(preorder_visitor);

CollectingVisitor postorder_visitor(TraversalType::kPost);
global_module.Accept(postorder_visitor);
EXPECT_THAT(
postorder_visitor.nodes(),
testing::ElementsAre(block1, block2, std::addressof(global_module)));
absl::Span<const void* const> nodes = preorder_visitor.nodes();
EXPECT_EQ(nodes.size(), 5);
EXPECT_EQ(nodes.at(0), std::addressof(global_module));
EXPECT_THAT(nodes.subspan(1, 2), UnorderedElementsAre(storage1, storage2));
EXPECT_THAT(nodes.subspan(3, 2), ElementsAre(block1, block2));
}

CollectingVisitor all_order_visitor(TraversalType::kBoth);
global_module.Accept(all_order_visitor);
EXPECT_THAT(
all_order_visitor.nodes(),
testing::ElementsAre(std::addressof(global_module), block1, block1,
block2, block2, std::addressof(global_module)));
{
CollectingVisitor postorder_visitor(TraversalType::kPost);
global_module.Accept(postorder_visitor);

absl::Span<const void* const> nodes = postorder_visitor.nodes();
EXPECT_EQ(nodes.size(), 5);
EXPECT_THAT(nodes.subspan(0, 2), UnorderedElementsAre(storage1, storage2));
EXPECT_THAT(nodes.subspan(2, 2), ElementsAre(block1, block2));
EXPECT_EQ(nodes.at(4), std::addressof(global_module));
}

{
CollectingVisitor all_order_visitor(TraversalType::kBoth);
global_module.Accept(all_order_visitor);

absl::Span<const void* const> nodes = all_order_visitor.nodes();
EXPECT_EQ(nodes.size(), 10);
EXPECT_EQ(nodes.at(0), std::addressof(global_module));
EXPECT_EQ(nodes.at(9), std::addressof(global_module));
std::vector<const void*> deduped_nodes_vec;
for (auto iter = nodes.begin() + 1; iter + 1 != nodes.end(); iter += 2) {
EXPECT_EQ(*iter, *(iter + 1));
deduped_nodes_vec.push_back(*iter);
}
absl::Span<const void* const> deduped_nodes_span = deduped_nodes_vec;
EXPECT_THAT(deduped_nodes_span.subspan(0, 2),
UnorderedElementsAre(storage1, storage2));
EXPECT_THAT(deduped_nodes_span.subspan(2, 2), ElementsAre(block1, block2));
}
}

TEST(IRTraversingVisitorTest, TraversesBlockAsExpected) {
Expand All @@ -154,22 +201,19 @@ TEST(IRTraversingVisitorTest, TraversesBlockAsExpected) {

CollectingVisitor preorder_visitor(TraversalType::kPre);
block->Accept(preorder_visitor);
EXPECT_THAT(
preorder_visitor.nodes(),
testing::ElementsAre(block.get(), plus_op_instance, minus_op_instance));
EXPECT_THAT(preorder_visitor.nodes(),
ElementsAre(block.get(), plus_op_instance, minus_op_instance));

CollectingVisitor postorder_visitor(TraversalType::kPost);
block->Accept(postorder_visitor);
EXPECT_THAT(
postorder_visitor.nodes(),
testing::ElementsAre(plus_op_instance, minus_op_instance, block.get()));
EXPECT_THAT(postorder_visitor.nodes(),
ElementsAre(plus_op_instance, minus_op_instance, block.get()));

CollectingVisitor all_order_visitor(TraversalType::kBoth);
block->Accept(all_order_visitor);
EXPECT_THAT(
all_order_visitor.nodes(),
testing::ElementsAre(block.get(), plus_op_instance, plus_op_instance,
minus_op_instance, minus_op_instance, block.get()));
EXPECT_THAT(all_order_visitor.nodes(),
ElementsAre(block.get(), plus_op_instance, plus_op_instance,
minus_op_instance, minus_op_instance, block.get()));
}

TEST(IRTraversingVisitorTest, TraversesOperationAsExpected) {
Expand All @@ -182,18 +226,18 @@ TEST(IRTraversingVisitorTest, TraversesOperationAsExpected) {
CollectingVisitor preorder_visitor(TraversalType::kPre);
plus_op_instance.Accept(preorder_visitor);
EXPECT_THAT(preorder_visitor.nodes(),
testing::ElementsAre(std::addressof(plus_op_instance)));
ElementsAre(std::addressof(plus_op_instance)));

CollectingVisitor postorder_visitor(TraversalType::kPost);
plus_op_instance.Accept(postorder_visitor);
EXPECT_THAT(postorder_visitor.nodes(),
testing::ElementsAre(std::addressof(plus_op_instance)));
ElementsAre(std::addressof(plus_op_instance)));

CollectingVisitor all_order_visitor(TraversalType::kBoth);
plus_op_instance.Accept(all_order_visitor);
EXPECT_THAT(all_order_visitor.nodes(),
testing::ElementsAre(std::addressof(plus_op_instance),
std::addressof(plus_op_instance)));
ElementsAre(std::addressof(plus_op_instance),
std::addressof(plus_op_instance)));
}

using ResultType = std::vector<void*>;
Expand Down Expand Up @@ -246,6 +290,17 @@ class ReturningVisitor
return result;
}

ResultType PreVisit(const Storage& storage) override {
ResultType result;
if (pre_visits_) result.push_back((void*)std::addressof(storage));
return result;
}

ResultType PostVisit(const Storage& storage, ResultType result) override {
if (post_visits_) result.push_back((void*)std::addressof(storage));
return result;
}

private:
const bool pre_visits_;
const bool post_visits_;
Expand All @@ -258,22 +313,49 @@ TEST(IRTraversingVisitorTest, TraversesModuleAsExpectedUsingReturns) {
const Block* block2 =
std::addressof(global_module.AddBlock(std::make_unique<Block>()));

ReturningVisitor preorder_visitor(TraversalType::kPre);
ResultType nodes1 =
global_module.Accept<ReturningVisitor, ResultType>(preorder_visitor);
EXPECT_THAT(nodes1, testing::ElementsAre(std::addressof(global_module),
block1, block2));
types::TypeFactory type_factory;
const Storage* storage1 = std::addressof(global_module.CreateStorage(
"storage1", type_factory.MakePrimitiveType()));
const Storage* storage2 = std::addressof(global_module.CreateStorage(
"storage2", type_factory.MakePrimitiveType()));

{
ReturningVisitor preorder_visitor(TraversalType::kPre);
std::vector<void*> nodes_vec = global_module.Accept(preorder_visitor);
absl::Span<const void* const> nodes = nodes_vec;
EXPECT_EQ(nodes.size(), 5);
EXPECT_EQ(nodes.at(0), std::addressof(global_module));
EXPECT_THAT(nodes.subspan(1, 2), UnorderedElementsAre(storage1, storage2));
EXPECT_THAT(nodes.subspan(3, 2), ElementsAre(block1, block2));
}

ReturningVisitor postorder_visitor(TraversalType::kPost);
ResultType nodes2 = global_module.Accept(postorder_visitor);
EXPECT_THAT(nodes2, testing::ElementsAre(block1, block2,
std::addressof(global_module)));
{
ReturningVisitor postorder_visitor(TraversalType::kPost);
std::vector<void*> nodes_vec = global_module.Accept(postorder_visitor);
absl::Span<const void* const> nodes = nodes_vec;
EXPECT_EQ(nodes.size(), 5);
EXPECT_THAT(nodes.subspan(0, 2), UnorderedElementsAre(storage1, storage2));
EXPECT_THAT(nodes.subspan(2, 2), ElementsAre(block1, block2));
EXPECT_EQ(nodes.at(4), std::addressof(global_module));
}

ReturningVisitor all_order_visitor(TraversalType::kBoth);
ResultType nodes3 = global_module.Accept(all_order_visitor);
EXPECT_THAT(nodes3, testing::ElementsAre(std::addressof(global_module),
block1, block1, block2, block2,
std::addressof(global_module)));
{
ReturningVisitor all_order_visitor(TraversalType::kBoth);
std::vector<void*> nodes_vec = global_module.Accept(all_order_visitor);
absl::Span<const void* const> nodes = nodes_vec;
EXPECT_EQ(nodes.size(), 10);
EXPECT_EQ(nodes.at(0), std::addressof(global_module));
EXPECT_EQ(nodes.at(9), std::addressof(global_module));
std::vector<const void*> deduped_nodes_vec;
for (auto iter = nodes.begin() + 1; iter + 1 != nodes.end(); iter += 2) {
EXPECT_EQ(*iter, *(iter + 1));
deduped_nodes_vec.push_back(*iter);
}
absl::Span<const void* const> deduped_nodes_span = deduped_nodes_vec;
EXPECT_THAT(deduped_nodes_span.subspan(0, 2),
UnorderedElementsAre(storage1, storage2));
EXPECT_THAT(deduped_nodes_span.subspan(2, 2), ElementsAre(block1, block2));
}
}

TEST(IRTraversingVisitorTest, TraversesBlockAsExpectedUsingReturns) {
Expand All @@ -289,19 +371,19 @@ TEST(IRTraversingVisitorTest, TraversesBlockAsExpectedUsingReturns) {
ReturningVisitor preorder_visitor(TraversalType::kPre);
ResultType nodes1 =
block->Accept<ReturningVisitor, ResultType>(preorder_visitor);
EXPECT_THAT(nodes1, testing::ElementsAre(block.get(), plus_op_instance,
minus_op_instance));
EXPECT_THAT(nodes1,
ElementsAre(block.get(), plus_op_instance, minus_op_instance));

ReturningVisitor postorder_visitor(TraversalType::kPost);
ResultType nodes2 = block->Accept(postorder_visitor);
EXPECT_THAT(nodes2, testing::ElementsAre(plus_op_instance, minus_op_instance,
block.get()));
EXPECT_THAT(nodes2,
ElementsAre(plus_op_instance, minus_op_instance, block.get()));

ReturningVisitor all_order_visitor(TraversalType::kBoth);
ResultType nodes3 = block->Accept(all_order_visitor);
EXPECT_THAT(nodes3, testing::ElementsAre(block.get(), plus_op_instance,
plus_op_instance, minus_op_instance,
minus_op_instance, block.get()));
EXPECT_THAT(nodes3,
ElementsAre(block.get(), plus_op_instance, plus_op_instance,
minus_op_instance, minus_op_instance, block.get()));
}

TEST(IRTraversingVisitorTest, TraversesOperationAsExpectedUsingReturns) {
Expand All @@ -313,16 +395,16 @@ TEST(IRTraversingVisitorTest, TraversesOperationAsExpectedUsingReturns) {

ReturningVisitor preorder_visitor(TraversalType::kPre);
ResultType nodes1 = plus_op_instance.Accept(preorder_visitor);
EXPECT_THAT(nodes1, testing::ElementsAre(std::addressof(plus_op_instance)));
EXPECT_THAT(nodes1, ElementsAre(std::addressof(plus_op_instance)));

ReturningVisitor postorder_visitor(TraversalType::kPost);
ResultType nodes2 = plus_op_instance.Accept(postorder_visitor);
EXPECT_THAT(nodes2, testing::ElementsAre(std::addressof(plus_op_instance)));
EXPECT_THAT(nodes2, ElementsAre(std::addressof(plus_op_instance)));

ReturningVisitor all_order_visitor(TraversalType::kBoth);
ResultType nodes3 = plus_op_instance.Accept(all_order_visitor);
EXPECT_THAT(nodes3, testing::ElementsAre(std::addressof(plus_op_instance),
std::addressof(plus_op_instance)));
EXPECT_THAT(nodes3, ElementsAre(std::addressof(plus_op_instance),
std::addressof(plus_op_instance)));
}

} // namespace
Expand Down
Loading