Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions training/HelloDeepSpeed/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python train_bert.py --checkpoint_dir ./experiment
1 change: 1 addition & 0 deletions training/HelloDeepSpeed/run_ds.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
deepspeed --bind_cores_to_rank train_bert_ds.py --checkpoint_dir experiment_deepspeed $@
6 changes: 4 additions & 2 deletions training/HelloDeepSpeed/train_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
RobertaPreTrainedModel,
)

from deepspeed.accelerator import get_accelerator

logger = loguru.logger

######################################################################
Expand Down Expand Up @@ -625,8 +627,8 @@ def train(
pathlib.Path: The final experiment directory

"""
device = (torch.device("cuda", local_rank) if (local_rank > -1)
and torch.cuda.is_available() else torch.device("cpu"))
device = (torch.device(get_accelerator().device_name(), local_rank) if (local_rank > -1)
and get_accelerator().is_available() else torch.device("cpu"))
################################
###### Create Exp. Dir #########
################################
Expand Down
9 changes: 6 additions & 3 deletions training/HelloDeepSpeed/train_bert_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
RobertaPreTrainedModel,
)

from deepspeed.accelerator import get_accelerator

def is_rank_0() -> bool:
return int(os.environ.get("RANK", "0")) == 0
Expand Down Expand Up @@ -612,6 +613,7 @@ def train(
checkpoint_every: int = 1000,
log_every: int = 10,
local_rank: int = -1,
dtype: str = "bf16",
) -> pathlib.Path:
"""Trains a [Bert style](https://arxiv.org/pdf/1810.04805.pdf)
(transformer encoder only) model for MLM Task
Expand Down Expand Up @@ -667,8 +669,8 @@ def train(
pathlib.Path: The final experiment directory

"""
device = (torch.device("cuda", local_rank) if (local_rank > -1)
and torch.cuda.is_available() else torch.device("cpu"))
device = (torch.device(get_accelerator().device_name(), local_rank) if (local_rank > -1)
and get_accelerator().is_available() else torch.device("cpu"))
################################
###### Create Exp. Dir #########
################################
Expand Down Expand Up @@ -777,6 +779,7 @@ def train(
###### DeepSpeed engine ########
################################
log_dist("Creating DeepSpeed engine", ranks=[0], level=logging.INFO)
assert (dtype == 'fp16' or dtype == 'bf16')
ds_config = {
"train_micro_batch_size_per_gpu": batch_size,
"optimizer": {
Expand All @@ -785,7 +788,7 @@ def train(
"lr": 1e-4
}
},
"fp16": {
dtype: {
"enabled": True
},
"zero_optimization": {
Expand Down
1 change: 1 addition & 0 deletions training/cifar/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ run_ds_moe.sh
* To run baseline CIFAR-10 model - "python cifar10_tutorial.py"
* To run DeepSpeed CIFAR-10 model - "bash run_ds.sh"
* To run DeepSpeed CIFAR-10 model with Mixture of Experts (MoE) - "bash run_ds_moe.sh"
* To run with different data type (default='fp16') and zero stages (default=0) - "bash run_ds.sh --dtype={fp16|bf16} --stage={0|1|2|3}"
109 changes: 89 additions & 20 deletions training/cifar/cifar10_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torchvision.transforms as transforms
import argparse
import deepspeed
from deepspeed.accelerator import get_accelerator


def add_argument():
Expand Down Expand Up @@ -88,6 +89,22 @@ def add_argument():
help=
'(moe) create separate moe param groups, required when using ZeRO w. MoE'
)
parser.add_argument(
'--dtype',
default='fp16',
type=str,
choices=['bf16', 'fp16'],
help=
'Datatype used for training'
)
parser.add_argument(
'--stage',
default=0,
type=int,
choices=[0, 1, 2, 3],
help=
'Datatype used for training'
)

# Include DeepSpeed configuration arguments
parser = deepspeed.add_config_arguments(parser)
Expand Down Expand Up @@ -243,11 +260,65 @@ def create_moe_param_groups(model):
# 1) Distributed model
# 2) Distributed data loader
# 3) DeepSpeed optimizer
ds_config = {
"train_batch_size": 16,
"steps_per_print": 2000,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
"betas": [
0.8,
0.999
],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
},
"gradient_clipping": 1.0,
"prescale_gradients": False,
args.dtype: {
"enabled": True,
"fp16_master_weights_and_grads": False,
"loss_scale": 0,
"loss_scale_window": 500,
"hysteresis": 2,
"min_loss_scale": 1,
"initial_scale_power": 15
},
"wall_clock_breakdown": False,
"zero_optimization": {
"stage": args.stage,
"allgather_partitions": True,
"reduce_scatter": True,
"allgather_bucket_size": 50000000,
"reduce_bucket_size": 50000000,
"overlap_comm": True,
"contiguous_gradients": True,
"cpu_offload": False
}
}

model_engine, optimizer, trainloader, __ = deepspeed.initialize(
args=args, model=net, model_parameters=parameters, training_data=trainset)
args=args, model=net, model_parameters=parameters, training_data=trainset, config=ds_config)

local_device = get_accelerator().device_name(model_engine.local_rank)
local_rank = model_engine.local_rank

fp16 = model_engine.fp16_enabled()
print(f'fp16={fp16}')
# For float32, target_dtype will be None so no datatype conversion needed
target_dtype = None
if model_engine.bfloat16_enabled():
target_dtype=torch.bfloat16
elif model_engine.fp16_enabled():
target_dtype=torch.half

#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#net.to(device)
Expand All @@ -274,10 +345,9 @@ def create_moe_param_groups(model):
running_loss = 0.0
for i, data in enumerate(trainloader):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data[0].to(model_engine.local_rank), data[1].to(
model_engine.local_rank)
if fp16:
inputs = inputs.half()
inputs, labels = data[0].to(local_device), data[1].to(local_device)
if target_dtype != None:
inputs = inputs.to(target_dtype)
outputs = model_engine(inputs)
loss = criterion(outputs, labels)

Expand All @@ -286,7 +356,7 @@ def create_moe_param_groups(model):

# print statistics
running_loss += loss.item()
if i % args.log_interval == (
if local_rank == 0 and i % args.log_interval == (
args.log_interval -
1): # print every log_interval mini-batches
print('[%d, %5d] loss: %.3f' %
Expand Down Expand Up @@ -317,9 +387,9 @@ def create_moe_param_groups(model):

########################################################################
# Okay, now let us see what the neural network thinks these examples above are:
if fp16:
images = images.half()
outputs = net(images.to(model_engine.local_rank))
if target_dtype != None:
images = images.to(target_dtype)
outputs = net(images.to(local_device))

########################################################################
# The outputs are energies for the 10 classes.
Expand All @@ -340,13 +410,12 @@ def create_moe_param_groups(model):
with torch.no_grad():
for data in testloader:
images, labels = data
if fp16:
images = images.half()
outputs = net(images.to(model_engine.local_rank))
if target_dtype != None:
images = images.to(target_dtype)
outputs = net(images.to(local_device))
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels.to(
model_engine.local_rank)).sum().item()
correct += (predicted == labels.to(local_device)).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' %
(100 * correct / total))
Expand All @@ -364,11 +433,11 @@ def create_moe_param_groups(model):
with torch.no_grad():
for data in testloader:
images, labels = data
if fp16:
images = images.half()
outputs = net(images.to(model_engine.local_rank))
if target_dtype != None:
images = images.to(target_dtype)
outputs = net(images.to(local_device))
_, predicted = torch.max(outputs, 1)
c = (predicted == labels.to(model_engine.local_rank)).squeeze()
c = (predicted == labels.to(local_device)).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
Expand Down
46 changes: 0 additions & 46 deletions training/cifar/ds_config.json

This file was deleted.

2 changes: 1 addition & 1 deletion training/cifar/run_ds.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash

deepspeed cifar10_deepspeed.py --deepspeed --deepspeed_config ds_config.json $@
deepspeed --bind_cores_to_rank cifar10_deepspeed.py --deepspeed $@
5 changes: 4 additions & 1 deletion training/cifar/run_ds_moe.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ EP_SIZE=2
# Number of total experts
EXPERTS=2

deepspeed --num_nodes=${NUM_NODES} --num_gpus=${NUM_GPUS} cifar10_deepspeed.py \
deepspeed --num_nodes=${NUM_NODES}\
--num_gpus=${NUM_GPUS} \
--bind_cores_to_rank \
cifar10_deepspeed.py \
--log-interval 100 \
--deepspeed \
--deepspeed_config ds_config.json \
Expand Down