- 
                Notifications
    
You must be signed in to change notification settings  - Fork 5.9k
 
[Dist Dialect] Simple MoE training in PIR #66750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| 
           你的PR提交成功,感谢你对开源项目的贡献!  | 
    
089384d    to
    ff60be1      
    Compare
  
    45ae555    to
    d1496fa      
    Compare
  
    | 
           Sorry to inform you that d1496fa's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.  | 
    
24bdd87    to
    bf6ddb8      
    Compare
  
    caa33b6    to
    9a7205f      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
add sub_to_global_reshard function, the loss of PIR MoE demo is equal to dygraph auto parallel refine and rename the MoE apis
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
Auto Parallel
PR Types
New features
Description
Pcard-67164
Support simple MoE training in PIR:
Replace the MoE-related dist op, i.e.

moe_sub_mesh_tensorsandmoe_global_mesh_tensor([Dist Dialect] Add MoE-related api in PIR dist dialect #66462), with share_data op in remove_other_rank_op_pass.Add sub_to_global_mesh reshard function in PIR, now only supports replicated Tensor. This is needed in global_norm_clip, the l2 norm of each expert's parameter should be resharded to other ranks, e.g. reshard the l2 norm of the parameter on ProcessMesh([0]) to ProcessMesh([0,1]).
NOTE: this is a preliminary version, need further refine to make it consistent with the dygraph reshard function.
Refine the MoE-related apis for multi-dimension process mesh, e.g. mesh=[[0,1],[2,3]], one expert on ranks [0,2] and the other one on ranks [1,3].
Additoinal: add an unit test for reshape_grad when its output x_grad needs reshard, corresponding to #67729