Skip to content

Commit b8a3340

Browse files
authored
refactor(dataset): return dict instead of tuple (#2106)
1 parent 0d8344c commit b8a3340

File tree

4 files changed

+19
-32
lines changed

4 files changed

+19
-32
lines changed

wenet/bin/train.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,16 @@ def main():
122122
lr = optimizer.param_groups[0]['lr']
123123
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(epoch, lr, rank))
124124

125-
device = model.local_rank if args.deepspeed else device
126-
127125
# NOTE(xcsong): Why we need a new group? see `train_utils.py::wenet_join`
128126
group_join = dist.new_group(backend="gloo",
129127
timeout=datetime.timedelta(seconds=30))
130128

131129
dist.barrier() # NOTE(xcsong): Ensure all ranks start Train at the same time.
132-
executor.train(model, optimizer, scheduler, train_data_loader, device,
130+
executor.train(model, optimizer, scheduler, train_data_loader,
133131
writer, configs, scaler, group_join)
134132

135133
dist.barrier() # NOTE(xcsong): Ensure all ranks start CV at the same time.
136-
total_loss, num_seen_utts = executor.cv(model, cv_data_loader, device,
137-
configs)
134+
total_loss, num_seen_utts = executor.cv(model, cv_data_loader, configs)
138135
cv_loss = total_loss / num_seen_utts
139136

140137
logging.info('Epoch {} CV info cv_loss {} rank {}'.format(epoch, cv_loss, rank))

wenet/dataset/processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -641,5 +641,5 @@ def padding(data):
641641
batch_first=True,
642642
padding_value=-1)
643643

644-
yield (sorted_keys, padded_feats, padding_labels, feats_lengths,
645-
label_lengths)
644+
yield {"keys": sorted_keys, "feats": padded_feats, "target": padding_labels,
645+
"feats_lengths": feats_lengths, "target_lengths": label_lengths}

wenet/utils/executor.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Executor:
2929
def __init__(self):
3030
self.step = 0
3131

32-
def train(self, model, optimizer, scheduler, data_loader, device, writer,
32+
def train(self, model, optimizer, scheduler, data_loader, writer,
3333
configs, scaler, group_join):
3434
''' Train one epoch
3535
'''
@@ -48,21 +48,13 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer,
4848
model_context = nullcontext
4949

5050
with model_context():
51-
for batch_idx, batch in enumerate(data_loader):
51+
for batch_idx, batch_dict in enumerate(data_loader):
5252
info_dict["step"] = self.step
5353
info_dict["batch_idx"] = batch_idx
5454
if wenet_join(group_join, info_dict):
5555
break
5656

57-
key, feats, target, feats_lengths, target_lengths = batch
58-
59-
batch_dict = {}
60-
batch_dict["feats"] = feats.to(device)
61-
batch_dict["target"] = target.to(device)
62-
batch_dict["feats_lengths"] = feats_lengths.to(device)
63-
batch_dict["target_lengths"] = target_lengths.to(device)
64-
65-
if target_lengths.size(0) == 0:
57+
if batch_dict["target_lengths"].size(0) == 0:
6658
continue
6759

6860
context = None
@@ -88,26 +80,19 @@ def train(self, model, optimizer, scheduler, data_loader, device, writer,
8880
log_per_step(writer, info_dict)
8981
self.step += 1
9082

91-
def cv(self, model, data_loader, device, configs):
83+
def cv(self, model, data_loader, configs):
9284
''' Cross validation on
9385
'''
9486
model.eval()
9587
info_dict = copy.deepcopy(configs)
9688
info_dict["tag"] = "CV"
9789
num_seen_utts, total_loss = 1, 0.0 # in order to avoid division by 0
9890
with torch.no_grad():
99-
for batch_idx, batch in enumerate(data_loader):
91+
for batch_idx, batch_dict in enumerate(data_loader):
10092
info_dict["step"] = self.step
10193
info_dict["batch_idx"] = batch_idx
102-
key, feats, target, feats_lengths, target_lengths = batch
103-
104-
batch_dict = {}
105-
batch_dict["feats"] = feats.to(device)
106-
batch_dict["target"] = target.to(device)
107-
batch_dict["feats_lengths"] = feats_lengths.to(device)
108-
batch_dict["target_lengths"] = target_lengths.to(device)
10994

110-
num_utts = target_lengths.size(0)
95+
num_utts = batch_dict["target_lengths"].size(0)
11196
if num_utts == 0:
11297
continue
11398

wenet/utils/train_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def wenet_join(group_join, info_dict):
416416

417417
def batch_forward(model, batch, scaler, info_dict):
418418
train_engine = info_dict.get('train_engine', "torch_ddp")
419+
device = int(os.environ.get('LOCAL_RANK', 0))
419420
accum_grad = info_dict.get('accum_grad', 1)
420421

421422
dtype = info_dict.get("dtype", "fp32")
@@ -431,16 +432,20 @@ def batch_forward(model, batch, scaler, info_dict):
431432
with torch.cuda.amp.autocast(
432433
enabled=dtype is not None, dtype=dtype, cache_enabled=False
433434
):
434-
loss_dict = model(batch["feats"], batch["feats_lengths"],
435-
batch["target"], batch["target_lengths"])
435+
loss_dict = model(batch["feats"].to(device),
436+
batch["feats_lengths"].to(device),
437+
batch["target"].to(device),
438+
batch["target_lengths"].to(device))
436439
else:
437440
# torch_ddp
438441
# autocast context
439442
# The more details about amp can be found in
440443
# https://pytorch.org/docs/stable/notes/amp_examples.html
441444
with torch.cuda.amp.autocast(scaler is not None):
442-
loss_dict = model(batch["feats"], batch["feats_lengths"],
443-
batch["target"], batch["target_lengths"])
445+
loss_dict = model(batch["feats"].to(device),
446+
batch["feats_lengths"].to(device),
447+
batch["target"].to(device),
448+
batch["target_lengths"].to(device))
444449
info_dict['loss_dict'] = loss_dict
445450

446451
return info_dict

0 commit comments

Comments
 (0)