Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ if(NOT CINN_ONLY)
op_dialect.cc
${cinn_op_source_file}
${cinn_op_info_file}
generate_shape_util.cc
manual_op.cc
op_attribute.cc
DEPS
op_dialect_vjp)
op_dialect_vjp
pir)

target_include_directories(cinn_op_dialect PRIVATE ${CINN_DIALECT_SOURCE_DIR})
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pir/dialect/shape/utils/dim_expr_util.h"
#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/builtin_attribute.h"

namespace symbol {
namespace cinn::dialect {
using namespace symbol; // NOLINT

namespace {

Expand Down Expand Up @@ -58,71 +59,71 @@ std::string GetSerializedTag<Broadcast<DimExpr>>() {
return "Broadcast";
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder,
::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const std::int64_t& dim_expr) {
return builder->int64_attr(dim_expr);
return pir::Int64Attribute::get(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder,
::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const std::string& dim_expr) {
return builder->str_attr(dim_expr);
return pir::StrAttribute::get(ctx, dim_expr);
}

template <typename T>
::pir::Attribute ConvertUnaryDimExprToAttributeImpl(::pir::Builder* builder,
::pir::Attribute ConvertUnaryDimExprToAttributeImpl(::pir::IrContext* ctx,
const T& dim_expr) {
std::vector<::pir::Attribute> attr_vecs{};
attr_vecs.push_back(builder->str_attr(GetSerializedTag<T>()));
attr_vecs.push_back(pir::StrAttribute::get(ctx, GetSerializedTag<T>()));
const auto& operand = dim_expr->data;
attr_vecs.push_back(ConvertDimExprToAttribute(builder, operand));
return builder->array_attr(attr_vecs);
attr_vecs.push_back(ConvertDimExprToAttribute(ctx, operand));
return pir::ArrayAttribute::get(ctx, attr_vecs);
}

::pir::Attribute ConvertDimExprToAttributeImpl(
::pir::Builder* builder, const Negative<DimExpr>& dim_expr) {
return ConvertUnaryDimExprToAttributeImpl(builder, dim_expr);
::pir::IrContext* ctx, const Negative<DimExpr>& dim_expr) {
return ConvertUnaryDimExprToAttributeImpl(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(
::pir::Builder* builder, const Reciprocal<DimExpr>& dim_expr) {
return ConvertUnaryDimExprToAttributeImpl(builder, dim_expr);
::pir::IrContext* ctx, const Reciprocal<DimExpr>& dim_expr) {
return ConvertUnaryDimExprToAttributeImpl(ctx, dim_expr);
}

template <typename T>
::pir::Attribute ConvertVariadicDimExprToAttribute(::pir::Builder* builder,
::pir::Attribute ConvertVariadicDimExprToAttribute(::pir::IrContext* ctx,
const T& dim_expr) {
std::vector<::pir::Attribute> attr_vecs{};
attr_vecs.push_back(builder->str_attr(GetSerializedTag<T>()));
attr_vecs.push_back(pir::StrAttribute::get(ctx, GetSerializedTag<T>()));
const auto& operands = *(dim_expr.operands);
for (const auto& operand : operands) {
attr_vecs.push_back(ConvertDimExprToAttribute(builder, operand));
attr_vecs.push_back(ConvertDimExprToAttribute(ctx, operand));
}
return builder->array_attr(attr_vecs);
return pir::ArrayAttribute::get(ctx, attr_vecs);
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder,
::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const Add<DimExpr>& dim_expr) {
return ConvertVariadicDimExprToAttribute(builder, dim_expr);
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder,
::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const Mul<DimExpr>& dim_expr) {
return ConvertVariadicDimExprToAttribute(builder, dim_expr);
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder,
::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const Max<DimExpr>& dim_expr) {
return ConvertVariadicDimExprToAttribute(builder, dim_expr);
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(::pir::Builder* builder,
::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
const Min<DimExpr>& dim_expr) {
return ConvertVariadicDimExprToAttribute(builder, dim_expr);
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
}

::pir::Attribute ConvertDimExprToAttributeImpl(
::pir::Builder* builder, const Broadcast<DimExpr>& dim_expr) {
return ConvertVariadicDimExprToAttribute(builder, dim_expr);
::pir::IrContext* ctx, const Broadcast<DimExpr>& dim_expr) {
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
}

std::optional<DimExpr> ConvertInt64AttributeToDimExpr(
Expand Down Expand Up @@ -211,11 +212,11 @@ std::optional<DimExpr> ConvertArrayAttributeToDimExpr(

} // namespace

::pir::Attribute ConvertDimExprToAttribute(::pir::Builder* builder,
::pir::Attribute ConvertDimExprToAttribute(pir::IrContext* ctx,
const DimExpr& dim_expr) {
return std::visit(
[&](const auto& impl) {
return ConvertDimExprToAttributeImpl(builder, impl);
return ConvertDimExprToAttributeImpl(ctx, impl);
},
dim_expr.variant());
}
Expand Down Expand Up @@ -359,4 +360,66 @@ MakeGetterDimExpr4SymbolName(
};
}

} // namespace symbol
namespace {

std::optional<DimExpr> GetDimExprBySymbolBindingImpl(
const GenerateShapeOp::DataSymbolBinding& symbol_binding,
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim) {
const symbol::ShapeOrDataDimExprs& shape_or_data_dim_expr =
DimExpr4InputDim(symbol_binding.input_tensor_idx);
if (!shape_or_data_dim_expr.data().has_value()) return std::nullopt;
int dim_idx = symbol_binding.input_tensor_dim_idx;
if (dim_idx >= shape_or_data_dim_expr.data().value().size())
return std::nullopt;
return shape_or_data_dim_expr.data().value().at(dim_idx);
}

std::optional<DimExpr> GetDimExprBySymbolBindingImpl(
const GenerateShapeOp::ShapeSymbolBinding& symbol_binding,
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim) {
const symbol::ShapeOrDataDimExprs& shape_or_data_dim_expr =
DimExpr4InputDim(symbol_binding.input_tensor_idx);
int dim_idx = symbol_binding.input_tensor_dim_idx;
if (dim_idx >= shape_or_data_dim_expr.shape().size()) return std::nullopt;
return shape_or_data_dim_expr.shape().at(dim_idx);
}

} // namespace

std::function<std::optional<DimExpr>(const std::string& symbol_name)>
MakeGetterDimExpr4SymbolName(
const GenerateShapeOp::SymbolBindings& symbol_bindings,
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim) {
std::unordered_map<std::string, std::vector<GenerateShapeOp::SymbolBinding>>
symbol_name2symbol_bindins{};
const auto& GetDimExpr =
[&](const GenerateShapeOp::SymbolBinding& symbol_binding) {
return std::visit(
[&](const auto& impl) {
return GetDimExprBySymbolBindingImpl(impl, DimExpr4InputDim);
},
symbol_binding);
};
return [map = std::move(symbol_name2symbol_bindins), GetDimExpr](
const std::string& symbol_name) -> std::optional<DimExpr> {
const auto& iter = map.find(symbol_name);
if (iter == map.end()) return std::nullopt;
std::optional<DimExpr> ret = std::nullopt;
for (const auto& symbol_binding : iter->second) {
const auto& current = GetDimExpr(symbol_binding);
if (!current.has_value()) return std::nullopt;
if (ret.has_value()) {
// Same names, same DimExprs.
if (ret.value() != current.value()) return std::nullopt;
} else {
ret = current;
}
}
return ret;
};
}

} // namespace cinn::dialect
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,35 @@
#pragma once

#include <optional>
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/dll_decl.h"
#include "paddle/pir/dialect/shape/utils/dim_expr.h"

namespace symbol {
namespace cinn::dialect {

IR_API ::pir::Attribute ConvertDimExprToAttribute(::pir::Builder* builder,
const DimExpr& dim_expr);
IR_API std::optional<DimExpr> ConvertAttributeToDimExpr(
::pir::Attribute ConvertDimExprToAttribute(pir::IrContext* ctx,
const symbol::DimExpr& dim_expr);

std::optional<symbol::DimExpr> ConvertAttributeToDimExpr(
::pir::Attribute attribute);

IR_API std::optional<DimExpr> SubstituteDimExpr(
const DimExpr& dim_expr,
const std::function<std::optional<DimExpr>(const std::string& symbol_name)>&
DimExpr4SymbolName);
std::optional<symbol::DimExpr> SubstituteDimExpr(
const symbol::DimExpr& dim_expr,
const std::function<std::optional<symbol::DimExpr>(
const std::string& symbol_name)>& DimExpr4SymbolName);

IR_API std::function<std::optional<DimExpr>(const std::string& symbol_name)>
std::function<std::optional<symbol::DimExpr>(const std::string& symbol_name)>
MakeGetterDimExpr4SymbolName(
const std::vector<std::tuple<std::string /*symbol_name*/,
int /*in_tensor_idx*/,
int /*in_tensor_dim_idx*/>>& symbol_bindings,
const std::function<std::optional<DimExpr>(
const std::function<std::optional<symbol::DimExpr>(
int in_tensor_idx, int in_tensor_dim_idx)>& DimExpr4InputDim);

} // namespace symbol
std::function<std::optional<symbol::DimExpr>(const std::string& symbol_name)>
MakeGetterDimExpr4SymbolName(
const GenerateShapeOp::SymbolBindings& symbol_bindings,
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim);

} // namespace cinn::dialect
Loading