-
Notifications
You must be signed in to change notification settings - Fork 5.9k
【auto_parallel】upgrade load_state_dict #67427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【auto_parallel】upgrade load_state_dict #67427
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
| import copy | ||
| import os | ||
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么删除了 TYPE_CHECKING
| ] | ||
|
|
||
| if offload: | ||
| storage_local_tensor = storage_local_tensor.cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该要考虑到各个平台的问题,建议通过 _current_expected_place 判断一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| path, | ||
| process_group=None, | ||
| coordinator_rank=0, | ||
| offload=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认为 False 吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| return rank_to_files, missing_keys | ||
|
|
||
|
|
||
| def get_local_load_files_for_multiple_node( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个名字建议换一下,可以用 rank_to_file 表示 rank 可见的文件,rank_to_read_files 表示rank要读的文件
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for API, adding parameter with default value is compatibility upgrade
PR Category
Auto Parallel
PR Types
Improvements
Description
修改load_state_dict