Skip to content

Commit cd7c853

Browse files
authored
dist2dense_pass fix shape errors in shard randomly sampled data (#68067)
* dist2dense_pass fix shape errors in shard randomly sampled data * add unit test case * fix coverage issues
1 parent 754516d commit cd7c853

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

paddle/fluid/pir/dialect/distributed/transforms/dist_to_dense_pass.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,35 @@ void ProcessDistBlock(pir::Block* block) {
103103
new_dims.push_back(pir::Int64Attribute::get(ctx, local_dims[index]));
104104
}
105105
prev_op->set_attribute("value", pir::ArrayAttribute::get(ctx, new_dims));
106+
} else if (op_item->isa<RandintOp>() || op_item->isa<GaussianOp>() ||
107+
op_item->isa<UniformOp>()) {
108+
auto local_dims =
109+
op_item->result_type(0).dyn_cast<pir::DenseTensorType>().dims();
110+
auto shape_value = op_item->operand_source(0);
111+
auto prev_op = shape_value.defining_op();
112+
PADDLE_ENFORCE((prev_op != nullptr),
113+
common::errors::PreconditionNotMet(
114+
"The shape of randint, gaussian and uniform mush be "
115+
"the result of "
116+
"FullIntArrayOp, not null"));
117+
PADDLE_ENFORCE_EQ(
118+
prev_op->isa<FullIntArrayOp>(),
119+
true,
120+
common::errors::PreconditionNotMet(
121+
"The shape of randint, gaussian and uniform mush be the result "
122+
"of FullIntArrayOp."));
123+
auto array_attr = prev_op->attribute<pir::ArrayAttribute>("value");
124+
PADDLE_ENFORCE_EQ(
125+
array_attr.size(),
126+
local_dims.size(),
127+
common::errors::PreconditionNotMet(
128+
"The shape of randint, gaussian and uniform element's size must "
129+
"equal to result's dim size."));
130+
std::vector<pir::Attribute> new_dims;
131+
for (int index = 0; index < local_dims.size(); ++index) {
132+
new_dims.push_back(pir::Int64Attribute::get(ctx, local_dims[index]));
133+
}
134+
prev_op->set_attribute("value", pir::ArrayAttribute::get(ctx, new_dims));
106135
}
107136
// TODO(2024-Q2) not all op are dist type
108137
// PADDLE_ENFORCE_EQ(

test/auto_parallel/pir/test_static_pir_program.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import paddle
1818
import paddle.distributed as dist
19+
from paddle.base.libpaddle.pir import apply_dist2dense_pass
1920
from paddle.distributed.auto_parallel.static.mix_to_dist_pass import (
2021
apply_mix2dist_pass,
2122
)
@@ -228,6 +229,67 @@ def test_build_with_apply_mix2dist_pass(self):
228229
self.assertEqual(input2_shape.dist_attr().process_mesh, mesh)
229230
self.assertEqual(input2_shape.dist_attr().dims_mapping, [-1])
230231

232+
def test_build_with_apply_dist2dense_pass(self):
233+
paddle.enable_static()
234+
with paddle.pir_utils.IrGuard():
235+
main_program = paddle.base.Program()
236+
with paddle.base.program_guard(main_program):
237+
mesh = dist.ProcessMesh([0, 1], dim_names=['dp'])
238+
input1 = paddle.randint(low=0, high=1000, shape=[8, 4])
239+
output1 = dist.shard_tensor(input1, mesh, [dist.Shard(0)])
240+
241+
input2 = paddle.randn([4, 8])
242+
output2 = dist.shard_tensor(input2, mesh, [dist.Shard(1)])
243+
244+
self.assertTrue(input1.is_dense_tensor_type())
245+
self.assertTrue(input2.is_dense_tensor_type())
246+
247+
self.assertTrue(main_program.num_ops() == 6)
248+
249+
self.assertFalse(input1.use_empty())
250+
self.assertFalse(input2.use_empty())
251+
252+
self.assertTrue(output1.use_empty())
253+
self.assertTrue(output2.use_empty())
254+
255+
self.assertFalse(input1.get_defining_op().has_attr("op_dist_attr"))
256+
self.assertFalse(input2.get_defining_op().has_attr("op_dist_attr"))
257+
258+
# check dist type
259+
self.assertTrue(output1.is_dist_dense_tensor_type())
260+
self.assertTrue(output2.is_dist_dense_tensor_type())
261+
262+
# run apply_mix2dist_pass and apply_dist2dense_pass
263+
apply_mix2dist_pass(main_program)
264+
apply_dist2dense_pass(main_program)
265+
266+
# after apply_mix2dist_pass, the program changed
267+
# and after apply_dist2dense_pass, the operator in program do not have dist_attr
268+
self.assertTrue(main_program.num_ops() == 4)
269+
270+
self.assertTrue(input1.is_dense_tensor_type())
271+
self.assertTrue(input2.is_dense_tensor_type())
272+
273+
self.assertFalse(input1.get_defining_op().has_attr("op_dist_attr"))
274+
self.assertFalse(input2.get_defining_op().has_attr("op_dist_attr"))
275+
276+
# check shape attribute of full_int_array op
277+
input1_shape = input1.get_defining_op().operand_source(0)
278+
input1_shape_op = input1_shape.get_defining_op()
279+
self.assertFalse(input1_shape_op.has_attr("op_dist_attr"))
280+
input1_shape_op_attr = input1_shape_op.attrs()
281+
self.assertEqual(input1_shape_op_attr['value'], [4, 4])
282+
283+
input2_shape = input2.get_defining_op().operand_source(0)
284+
input2_shape_op = input2_shape.get_defining_op()
285+
self.assertFalse(input2_shape_op.has_attr("op_dist_attr"))
286+
input2_shape_op_attr = input2_shape_op.attrs()
287+
self.assertEqual(input2_shape_op_attr['value'], [4, 4])
288+
289+
# check shape of input1 and input2
290+
self.assertEqual(input1.shape, [4, 4])
291+
self.assertEqual(input2.shape, [4, 4])
292+
231293

232294
if __name__ == "__main__":
233295
unittest.main()

0 commit comments

Comments
 (0)