Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions paddle/fluid/framework/data_layout_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#include "paddle/fluid/framework/data_layout_transform.h"

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace paddle {
Expand Down Expand Up @@ -61,6 +61,18 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
platform::errors::PreconditionNotMet(
"TransDataLayout only support DataLayout transform on same place."));

TransDataLayout(kernel_type_for_var.layout(),
expected_kernel_type.layout(),
place,
in,
out);
}

void TransDataLayout(DataLayout from_layout,
DataLayout to_layout,
phi::Place place,
const phi::DenseTensor& in,
phi::DenseTensor* out) {
PADDLE_ENFORCE_EQ(
arity(in.dims()),
4,
Expand All @@ -73,8 +85,7 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
auto src_dim = in.dims();
std::vector<int64_t> dst_dim;

auto axis =
GetAxis(kernel_type_for_var.layout(), expected_kernel_type.layout());
auto axis = GetAxis(from_layout, to_layout);
dst_dim.resize(axis.size());
for (size_t i = 0; i < axis.size(); i++) {
dst_dim[i] = src_dim[axis[i]];
Expand All @@ -83,10 +94,11 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
out->Resize(phi::make_ddim(dst_dim));
out->mutable_data(place, in.dtype());

framework::VisitDataType(framework::TransToProtoVarType(in.dtype()),
CastDataLayout(pool.Get(place), axis, in, out));
framework::VisitDataType(
static_cast<proto::VarType::Type>(phi::TransToProtoVarType(in.dtype())),
CastDataLayout(pool.Get(place), axis, in, out));

out->set_layout(expected_kernel_type.layout());
out->set_layout(to_layout);
}

} // namespace framework
Expand Down
21 changes: 10 additions & 11 deletions paddle/fluid/framework/data_layout_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,14 @@

#pragma once

#include <map>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/kernels/funcs/data_layout_transform.h"

namespace paddle {
namespace framework {
class OpKernelType;
} // namespace framework
} // namespace paddle

#ifdef PADDLE_WITH_MKLDNN
#include "paddle/phi/backends/onednn/onednn_helper.h"
#endif
Expand Down Expand Up @@ -60,5 +53,11 @@ void TransDataLayout(const phi::KernelKey& kernel_type_for_var,
phi::DenseTensor* out,
const phi::Place& place);

void TransDataLayout(phi::DataLayout from_layout,
phi::DataLayout to_layout,
phi::Place place,
const phi::DenseTensor& in,
phi::DenseTensor* out);

} // namespace framework
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/framework/data_layout_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/fluid/framework/data_layout_transform.h"

#include "gtest/gtest.h"
#include "paddle/fluid/platform/bfloat16.h"

TEST(DataTransform, DataLayoutFunction) {
auto place = paddle::platform::CPUPlace();
Expand Down
4 changes: 1 addition & 3 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,9 @@ if(WITH_TENSORRT)
pass_library(layernorm_shift_partition_fuse_pass inference)
pass_library(reverse_roll_fuse_pass inference)
pass_library(preln_layernorm_x_fuse_pass inference)
pass_library(trt_support_nhwc_pass inference)
pass_library(elementwise_groupnorm_act_pass inference)
pass_library(preln_elementwise_groupnorm_act_pass inference)
endif()

if(WITH_TENSORRT)
pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
endif()
Expand Down
72 changes: 21 additions & 51 deletions paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <unordered_map>
#include <unordered_set>

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/phi/common/layout.h"
Expand All @@ -30,43 +29,11 @@ namespace framework {
namespace ir {
namespace {

void TransDataLayout(DataLayout from_layout,
DataLayout to_layout,
const phi::DenseTensor &in,
phi::DenseTensor *out) {
PADDLE_ENFORCE_EQ(
arity(in.dims()),
4,
platform::errors::InvalidArgument(
"Input dimension arity only can be 4, the input dimension is %s.",
in.dims()));

auto &pool = platform::DeviceContextPool::Instance();

auto src_dim = in.dims();
std::vector<int64_t> dst_dim;

auto axis = GetAxis(from_layout, to_layout);
dst_dim.resize(axis.size());
for (size_t i = 0; i < axis.size(); i++) {
dst_dim[i] = src_dim[axis[i]];
}

out->Resize(phi::make_ddim(dst_dim));
out->mutable_data(phi::CPUPlace(), in.dtype());

framework::VisitDataType(
framework::TransToProtoVarType(in.dtype()),
CastDataLayout(pool.Get(phi::CPUPlace()), axis, in, out));

out->set_layout(to_layout);
}

void InsertLayoutTransOp(ir::Graph *graph,
ir::Node *prev_node,
ir::Node *next_node,
DataLayout from_layout,
DataLayout to_layout,
phi::DataLayout from_layout,
phi::DataLayout to_layout,
framework::BlockDesc *block_desc,
std::unordered_map<ir::Node *, ir::Node *> *cache) {
auto do_insert = [&](const std::string &in_var_name,
Expand All @@ -91,7 +58,7 @@ void InsertLayoutTransOp(ir::Graph *graph,
op_out_var_desc->SetPersistable(false);
op_out_var_desc->SetDataType(prev_node->Var()->GetDataType());
auto to_shape = prev_node->Var()->GetShape();
if (from_layout == DataLayout::kNCHW) {
if (from_layout == phi::DataLayout::kNCHW) {
auto n = to_shape[0];
auto c = to_shape[1];
auto h = to_shape[2];
Expand All @@ -117,12 +84,13 @@ void InsertLayoutTransOp(ir::Graph *graph,
IR_NODE_UNLINK(prev_node, next_node);
};

if (from_layout == DataLayout::kNCHW && to_layout == DataLayout::kNHWC) {
if (from_layout == phi::DataLayout::kNCHW &&
to_layout == phi::DataLayout::kNHWC) {
auto in_var_name = prev_node->Var()->Name();
auto out_var_name = in_var_name + "_nchw_to_nhwc";
do_insert(in_var_name, out_var_name);
} else if (from_layout == DataLayout::kNHWC &&
to_layout == DataLayout::kNCHW) {
} else if (from_layout == phi::DataLayout::kNHWC &&
to_layout == phi::DataLayout::kNCHW) {
auto in_var_name = prev_node->Var()->Name();
auto out_var_name = in_var_name + "_nhwc_to_nchw";
do_insert(in_var_name, out_var_name);
Expand All @@ -135,7 +103,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph,
platform::errors::PreconditionNotMet("graph should not be nullptr."));
FusePassBase::Init("data_layout_transfer", graph);
FusePassBase::Init("conv2d_fusion_layout_transfer", graph);
auto *scope = param_scope();

// only float16 compute precision need insert transfer_layout.
Expand Down Expand Up @@ -170,7 +138,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {

// Not support multiple block now.
std::unordered_map<ir::Node *, ir::Node *> cache;
auto op_nodes = ir::TopologySortOperations(*graph);
auto op_nodes = TopologySortOperations(*graph);
auto iter = op_nodes.cbegin();
auto *block_desc = (*iter)->Op()->Block();

Expand All @@ -186,7 +154,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
op_node->Op()->GetAttrIfExists<std::string>("data_format");
if (data_format != "NCHW") return false;
auto filter_names = op_node->Op()->Input("Filter");
constexpr int CUTLASS_NHWC_ALIGNMENT = 8;
constexpr int NHWC_ALIGNMENT = 8;
// If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
for (const auto &filter_name : filter_names) {
auto *filter_var = scope->FindLocalVar(filter_name);
Expand All @@ -195,7 +163,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
int oc = filter_tensor.dims()[0];
int ic = filter_tensor.dims()[1];
bool cutlass_can_support =
oc % CUTLASS_NHWC_ALIGNMENT == 0 && ic % CUTLASS_NHWC_ALIGNMENT == 0;
oc % NHWC_ALIGNMENT == 0 && ic % NHWC_ALIGNMENT == 0;
if (!cutlass_can_support) {
return false;
}
Expand Down Expand Up @@ -229,8 +197,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
if (cuDNNIsValid(op_node)) {
valid_ops.insert(op_node);
auto *op_desc = op_node->Op();
auto nhwc_attr = framework::Attribute(std::string("NHWC"));
op_desc->SetAttr("data_format", nhwc_attr);
op_desc->SetAttr("data_format", std::string{"NHWC"});
if (cutlass_enable && CutlassIsValid(op_node)) {
op_desc->SetType("conv2d_fusion_cutlass");
}
Expand All @@ -244,8 +211,11 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
phi::DenseTensor temp_tensor = *filter_tensor;
filter_tensor->clear();

TransDataLayout(
DataLayout::kNCHW, DataLayout::kNHWC, temp_tensor, filter_tensor);
framework::TransDataLayout(phi::DataLayout::kNCHW,
phi::DataLayout::kNHWC,
phi::CPUPlace{},
temp_tensor,
filter_tensor);
}
auto op_inputs = op_node->inputs;
for (auto *in_var_node : op_inputs) {
Expand Down Expand Up @@ -290,8 +260,8 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
InsertLayoutTransOp(graph,
in_var_node,
op_node,
DataLayout::kNCHW,
DataLayout::kNHWC,
phi::DataLayout::kNCHW,
phi::DataLayout::kNHWC,
block_desc,
&cache);
}
Expand All @@ -304,8 +274,8 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
InsertLayoutTransOp(graph,
in_var_node,
op_node,
DataLayout::kNHWC,
DataLayout::kNCHW,
phi::DataLayout::kNHWC,
phi::DataLayout::kNCHW,
block_desc,
&cache);
}
Expand Down
Loading