Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/deepseek_v2/train_deepseekv2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
# See LICENSE for license information.
#################################################################################
set -e
# set -e
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vidushi8 I think there is a better way to do this. You mentioned that there is a command to detect for multi-node that would fail on single node. You can change it to the following:

<command to detect multi-node> || true

This way, it won't fail and break the run.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the multi-node NCCL_IB_HCA logic under the if condition NNODES>1


TOKENIZER_MODEL="deepseek-ai/DeepSeek-V2-Lite"

Expand Down
2 changes: 1 addition & 1 deletion examples/deepseek_v3/train_deepseekv3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# See LICENSE for license information.
#################################################################################

set -e
# set -e

# exp
EXPERIMENT="deepseek_v3"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
# in torch_dist format. This is a workaround to fix the issue.
from torch.distributed.checkpoint.filesystem import _StorageWriterTransforms
from functools import partial
_write_item = partial(_write_item, _StorageWriterTransforms())
import inspect
if "serialization_format" in inspect.signature(_write_item).parameters:
from torch.distributed.checkpoint.filesystem import SerializationFormat
_write_item = partial(_write_item, _StorageWriterTransforms(), serialization_format=SerializationFormat.TORCH_SAVE)
else:
_write_item = partial(_write_item, _StorageWriterTransforms())
except ImportError:
pass

Expand Down