@@ -326,6 +326,198 @@ void ReshardOp::Build(pir::Builder& builder,
326326 ::pir::PassStopGradientsDefaultly (argument);
327327}
328328
329+ void DtensorFromLocalOp::Build (pir::Builder& builder,
330+ pir::OperationArgument& argument,
331+ pir::Value input,
332+ TensorDistAttribute tensor_dist_attr) {
333+ VLOG (4 ) << " Start build DtensorFromLocalOp" ;
334+
335+ paddle::dialect::DenseTensorType local_tensor_type;
336+ if (input.type ().isa <paddle::dialect::DenseTensorType>()) {
337+ local_tensor_type =
338+ input.type ().dyn_cast <paddle::dialect::DenseTensorType>();
339+ } else {
340+ PADDLE_THROW (common::errors::Unimplemented (
341+ " Only support paddle::dialect::DenseTensorType" ));
342+ }
343+
344+ VLOG (4 ) << " Builder construction inputs" ;
345+ argument.AddInput (input);
346+
347+ VLOG (4 ) << " Builder construction attributes" ;
348+
349+ VLOG (4 ) << " Builder construction outputs" ;
350+
351+ auto global_ddim =
352+ InferGlobalDDim (local_tensor_type.dims (), tensor_dist_attr);
353+ auto global_tensor =
354+ dialect::DenseTensorType::get (pir::IrContext::Instance (),
355+ local_tensor_type.dtype (),
356+ global_ddim,
357+ local_tensor_type.data_layout (),
358+ local_tensor_type.lod (),
359+ local_tensor_type.offset ());
360+
361+ pir::Type out_dist_tensor_type =
362+ paddle::dialect::DistDenseTensorType::get (pir::IrContext::Instance (),
363+ global_tensor,
364+ tensor_dist_attr,
365+ local_tensor_type.dims ());
366+ argument.AddOutput (out_dist_tensor_type);
367+ ::pir::PassStopGradientsDefaultly (argument);
368+ }
369+
370+ OpInfoTuple DtensorFromLocalOp::GetOpInfo () {
371+ return OpInfoTuple ({OpInputInfo ()},
372+ {},
373+ {OpOutputInfo ()},
374+ OpRunTimeInfo (),
375+ " dtensor_from_local" );
376+ }
377+ std::vector<std::vector<pir::Value>> DtensorFromLocalOp::Vjp (
378+ pir::Operation* op,
379+ const std::vector<std::vector<pir::Value>>& inputs,
380+ const std::vector<std::vector<pir::Value>>& outputs,
381+ const std::vector<std::vector<pir::Value>>& out_grads,
382+ const std::vector<std::vector<bool >>& stop_gradients) {
383+ VLOG (6 ) << " Start call vjp for dtensor_from_local op." ;
384+ PADDLE_ENFORCE_EQ (inputs.size (),
385+ 1 ,
386+ common::errors::InvalidArgument (
387+ " dtensor_from_local op's inputs' size should be 1" ));
388+ PADDLE_ENFORCE_EQ (
389+ inputs[0 ].size (),
390+ 1 ,
391+ common::errors::InvalidArgument (
392+ " dtensor_from_local op's inputs[0]'s size should be 1" ));
393+
394+ PADDLE_ENFORCE_EQ (outputs.size (),
395+ 1 ,
396+ common::errors::InvalidArgument (
397+ " dtensor_from_local op's outputs' size should be 1" ));
398+ PADDLE_ENFORCE_EQ (
399+ outputs[0 ].size (),
400+ 1 ,
401+ common::errors::InvalidArgument (
402+ " dtensor_from_local op's outputs[0]'s size should be 1" ));
403+ auto dist_type = outputs[0 ][0 ].type ().dyn_cast <DistTypeInterface>();
404+
405+ PADDLE_ENFORCE_NOT_NULL (
406+ dist_type,
407+ common::errors::InvalidArgument (" Currently, dtensor_from_local op's "
408+ " outputs type must be dist type." ));
409+
410+ PADDLE_ENFORCE_EQ (
411+ out_grads.size (),
412+ 1 ,
413+ common::errors::InvalidArgument (
414+ " dtensor_from_local op's outputs grad size should be 1" ));
415+
416+ PADDLE_ENFORCE_EQ (
417+ out_grads[0 ].size (),
418+ 1 ,
419+ common::errors::InvalidArgument (
420+ " dtensor_from_local op's outputs grad[0] size should be 1" ));
421+
422+ auto & builder = *ApiBuilder::Instance ().GetBuilder ();
423+
424+ auto out_grad = out_grads[0 ][0 ];
425+
426+ if (out_grad.type () != outputs[0 ][0 ].type ()) {
427+ out_grad = builder.Build <ReshardOp>(out_grad, dist_type.tensor_dist_attr ())
428+ ->result (0 );
429+ }
430+
431+ auto grad_op = builder.Build <DtensorToLocalOp>(out_grad);
432+
433+ VLOG (6 ) << " End call vjp for dtensor_from_local op." ;
434+
435+ return {std::vector<pir::Value>{grad_op->result (0 )}};
436+ }
437+
438+ void DtensorToLocalOp::Build (pir::Builder& builder,
439+ pir::OperationArgument& argument,
440+ pir::Value input) {
441+ VLOG (4 ) << " Start build DtensorToLocalOp" ;
442+
443+ VLOG (4 ) << " Builder construction inputs" ;
444+ argument.AddInput (input);
445+
446+ VLOG (4 ) << " Builder construction attributes" ;
447+
448+ VLOG (4 ) << " Builder construction outputs" ;
449+
450+ auto dist_type = input.type ().dyn_cast <DistTypeInterface>();
451+ if (!dist_type) {
452+ PADDLE_THROW (common::errors::Unimplemented (
453+ " The input of DtensorToLocalOp must be dist type." ));
454+ }
455+
456+ argument.AddOutput (dist_type.local_type ());
457+ ::pir::PassStopGradientsDefaultly (argument);
458+ }
459+
460+ OpInfoTuple DtensorToLocalOp::GetOpInfo () {
461+ return OpInfoTuple ({OpInputInfo ()},
462+ {},
463+ {OpOutputInfo ()},
464+ OpRunTimeInfo (),
465+ " dtensor_to_local" );
466+ }
467+
468+ std::vector<std::vector<pir::Value>> DtensorToLocalOp::Vjp (
469+ pir::Operation* op,
470+ const std::vector<std::vector<pir::Value>>& inputs,
471+ const std::vector<std::vector<pir::Value>>& outputs,
472+ const std::vector<std::vector<pir::Value>>& out_grads,
473+ const std::vector<std::vector<bool >>& stop_gradients) {
474+ VLOG (6 ) << " Start call vjp for dtensor_to_local op." ;
475+ PADDLE_ENFORCE_EQ (inputs.size (),
476+ 1 ,
477+ common::errors::InvalidArgument (
478+ " dtensor_to_local op's inputs' size should be 1" ));
479+ PADDLE_ENFORCE_EQ (inputs[0 ].size (),
480+ 1 ,
481+ common::errors::InvalidArgument (
482+ " dtensor_to_local op's inputs[0]'s size should be 1" ));
483+
484+ PADDLE_ENFORCE_EQ (outputs.size (),
485+ 1 ,
486+ common::errors::InvalidArgument (
487+ " dtensor_to_local op's outputs' size should be 1" ));
488+ PADDLE_ENFORCE_EQ (outputs[0 ].size (),
489+ 1 ,
490+ common::errors::InvalidArgument (
491+ " dtensor_to_local op's outputs[0]'s size should be 1" ));
492+ auto dist_type = inputs[0 ][0 ].type ().dyn_cast <DistTypeInterface>();
493+
494+ PADDLE_ENFORCE_NOT_NULL (
495+ dist_type,
496+ common::errors::InvalidArgument (
497+ " Currently, dtensor_to_local op's inputs type must be dist type." ));
498+
499+ PADDLE_ENFORCE_EQ (
500+ out_grads.size (),
501+ 1 ,
502+ common::errors::InvalidArgument (
503+ " dtensor_from_local op's outputs grad size should be 1" ));
504+
505+ PADDLE_ENFORCE_EQ (
506+ out_grads[0 ].size (),
507+ 1 ,
508+ common::errors::InvalidArgument (
509+ " dtensor_from_local op's outputs grad[0] size should be 1" ));
510+
511+ auto & builder = *ApiBuilder::Instance ().GetBuilder ();
512+
513+ auto grad_op = builder.Build <DtensorFromLocalOp>(
514+ out_grads[0 ][0 ], dist_type.tensor_dist_attr ());
515+
516+ VLOG (6 ) << " End call vjp for dtensor_from_local op." ;
517+
518+ return {std::vector<pir::Value>{grad_op->result (0 )}};
519+ }
520+
329521TEST_API void paddle::dialect::MoESubMeshTensorsOp::Build (
330522 pir::Builder& builder,
331523 pir::OperationArgument& argument,
@@ -699,5 +891,7 @@ std::vector<std::vector<pir::Value>> MoEGlobalMeshTensorOp::Vjp(
699891
700892IR_DEFINE_EXPLICIT_TYPE_ID (paddle::dialect::ShardTensorOp)
701893IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ReshardOp)
894+ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DtensorFromLocalOp)
895+ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DtensorToLocalOp)
702896IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::MoESubMeshTensorsOp)
703897IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::MoEGlobalMeshTensorOp)
0 commit comments