Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,57 @@ bool ReadFileOpInferSymbolicShape(
return true;
}

// bool RecvV2OpInferSymbolicShape(pir::Operation *op,
// pir::InferSymbolicShapeContext
// *infer_context) {
// // pass
// return true;
// }
bool RecvV2OpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
const int ring_id = op->attribute<pir::Int32Attribute>("ring_id").data();
const bool dynamic_shape =
op->attribute<pir::BoolAttribute>("dynamic_shape").data();
const int peer = op->attribute<pir::Int32Attribute>("peer").data();

PADDLE_ENFORCE_GE(
peer,
0,
common::errors::InvalidArgument(
"The peer (%d) for recv_v2 op must be non-negative.", peer));

PADDLE_ENFORCE_GE(
ring_id,
0,
common::errors::InvalidArgument(
"The ring_id (%d) for recv_v2 op must be non-negative.", ring_id));

const std::vector<int> out_shape =
paddle::dialect::details::GetVectorAttr<int>(op, "out_shape");
if (!dynamic_shape) {
PADDLE_ENFORCE_GE(out_shape.size(),
1,
common::errors::InvalidArgument(
"The size of the output shape must be greater than 0 "
"but the value given is %d.",
out_shape.size()));

std::vector<symbol::DimExpr> output_shape;
for (size_t i = 0; i < out_shape.size(); ++i) {
PADDLE_ENFORCE_GE(out_shape[i],
1,
common::errors::InvalidArgument(
"The shape attribute for recv_v2 must be set "
"explicitly, but the %dth element is %d which "
"is less than 1. Or dynamic_shape should be set to "
"True for both send_v2 and recv_v2.",
i,
out_shape[i]));
output_shape.push_back(symbol::DimExpr(out_shape[i]));
}

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(output_shape)});
}

return true;
}

bool SeedOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Randint)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Randperm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReadFile)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Seed)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(RecvV2)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(RecvV2)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TrilIndices)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TriuIndices)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TruncatedGaussianRandom)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/inconsistent/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@
func : recv_v2
param : [ring_id, dynamic_shape, peer, out_shape, dtype, use_calc_stream]
data_type : dtype
# interfaces : paddle::dialect::InferSymbolicShapeInterface
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : remainder
args : (Tensor x, Tensor y)
Expand Down