|
16 | 16 |
|
17 | 17 | import paddle |
18 | 18 | import paddle.distributed as dist |
| 19 | +from paddle.base.libpaddle.pir import apply_dist2dense_pass |
19 | 20 | from paddle.distributed.auto_parallel.static.mix_to_dist_pass import ( |
20 | 21 | apply_mix2dist_pass, |
21 | 22 | ) |
@@ -228,6 +229,67 @@ def test_build_with_apply_mix2dist_pass(self): |
228 | 229 | self.assertEqual(input2_shape.dist_attr().process_mesh, mesh) |
229 | 230 | self.assertEqual(input2_shape.dist_attr().dims_mapping, [-1]) |
230 | 231 |
|
| 232 | + def test_build_with_apply_dist2dense_pass(self): |
| 233 | + paddle.enable_static() |
| 234 | + with paddle.pir_utils.IrGuard(): |
| 235 | + main_program = paddle.base.Program() |
| 236 | + with paddle.base.program_guard(main_program): |
| 237 | + mesh = dist.ProcessMesh([0, 1], dim_names=['dp']) |
| 238 | + input1 = paddle.randint(low=0, high=1000, shape=[8, 4]) |
| 239 | + output1 = dist.shard_tensor(input1, mesh, [dist.Shard(0)]) |
| 240 | + |
| 241 | + input2 = paddle.randn([4, 8]) |
| 242 | + output2 = dist.shard_tensor(input2, mesh, [dist.Shard(1)]) |
| 243 | + |
| 244 | + self.assertTrue(input1.is_dense_tensor_type()) |
| 245 | + self.assertTrue(input2.is_dense_tensor_type()) |
| 246 | + |
| 247 | + self.assertTrue(main_program.num_ops() == 6) |
| 248 | + |
| 249 | + self.assertFalse(input1.use_empty()) |
| 250 | + self.assertFalse(input2.use_empty()) |
| 251 | + |
| 252 | + self.assertTrue(output1.use_empty()) |
| 253 | + self.assertTrue(output2.use_empty()) |
| 254 | + |
| 255 | + self.assertFalse(input1.get_defining_op().has_attr("op_dist_attr")) |
| 256 | + self.assertFalse(input2.get_defining_op().has_attr("op_dist_attr")) |
| 257 | + |
| 258 | + # check dist type |
| 259 | + self.assertTrue(output1.is_dist_dense_tensor_type()) |
| 260 | + self.assertTrue(output2.is_dist_dense_tensor_type()) |
| 261 | + |
| 262 | + # run apply_mix2dist_pass and apply_dist2dense_pass |
| 263 | + apply_mix2dist_pass(main_program) |
| 264 | + apply_dist2dense_pass(main_program) |
| 265 | + |
| 266 | + # after apply_mix2dist_pass, the program changed |
| 267 | + # and after apply_dist2dense_pass, the operator in program do not have dist_attr |
| 268 | + self.assertTrue(main_program.num_ops() == 4) |
| 269 | + |
| 270 | + self.assertTrue(input1.is_dense_tensor_type()) |
| 271 | + self.assertTrue(input2.is_dense_tensor_type()) |
| 272 | + |
| 273 | + self.assertFalse(input1.get_defining_op().has_attr("op_dist_attr")) |
| 274 | + self.assertFalse(input2.get_defining_op().has_attr("op_dist_attr")) |
| 275 | + |
| 276 | + # check shape attribute of full_int_array op |
| 277 | + input1_shape = input1.get_defining_op().operand_source(0) |
| 278 | + input1_shape_op = input1_shape.get_defining_op() |
| 279 | + self.assertFalse(input1_shape_op.has_attr("op_dist_attr")) |
| 280 | + input1_shape_op_attr = input1_shape_op.attrs() |
| 281 | + self.assertEqual(input1_shape_op_attr['value'], [4, 4]) |
| 282 | + |
| 283 | + input2_shape = input2.get_defining_op().operand_source(0) |
| 284 | + input2_shape_op = input2_shape.get_defining_op() |
| 285 | + self.assertFalse(input2_shape_op.has_attr("op_dist_attr")) |
| 286 | + input2_shape_op_attr = input2_shape_op.attrs() |
| 287 | + self.assertEqual(input2_shape_op_attr['value'], [4, 4]) |
| 288 | + |
| 289 | + # check shape of input1 and input2 |
| 290 | + self.assertEqual(input1.shape, [4, 4]) |
| 291 | + self.assertEqual(input2.shape, [4, 4]) |
| 292 | + |
231 | 293 |
|
232 | 294 | if __name__ == "__main__": |
233 | 295 | unittest.main() |
0 commit comments