Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions paddlespeech/s2t/models/u2/u2.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _calc_att_loss(self,
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
loss_att = loss_att * (1 - reverse_weight) + r_loss_att * reverse_weight
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
decoder_out.reshape([-1, self.vocab_size]),
ys_out_pad,
ignore_label=self.ignore_id, )
return loss_att, acc_att
Expand Down Expand Up @@ -271,11 +271,13 @@ def recognize(
maxlen = encoder_out.shape[1]
encoder_dim = encoder_out.shape[2]
running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
encoder_out = encoder_out.unsqueeze(1).repeat(
1, beam_size, 1, 1).reshape(
[running_size, maxlen,
encoder_dim]) # (B*N, maxlen, encoder_dim)
encoder_mask = encoder_mask.unsqueeze(1).repeat(
1, beam_size, 1, 1).view(running_size, 1,
maxlen) # (B*N, 1, max_len)
1, beam_size, 1, 1).reshape([running_size, 1,
maxlen]) # (B*N, 1, max_len)

hyps = paddle.ones(
[running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1)
Expand Down Expand Up @@ -305,34 +307,35 @@ def recognize(

# 2.3 Seconde beam prune: select topk score with history
scores = scores + top_k_logp # (B*N, N), broadcast add
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
scores = scores.reshape(
[batch_size, beam_size * beam_size]) # (B, N*N)
scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
scores = scores.view(-1, 1) # (B*N, 1)
scores = scores.reshape([-1, 1]) # (B*N, 1)

# 2.4. Compute base index in top_k_index,
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
# then find offset_k_index in top_k_index
base_k_index = paddle.arange(batch_size).view(-1, 1).repeat(
base_k_index = paddle.arange(batch_size).reshape([-1, 1]).repeat(
1, beam_size) # (B, N)
base_k_index = base_k_index * beam_size * beam_size
best_k_index = base_k_index.view(-1) + offset_k_index.view(
-1) # (B*N)
best_k_index = base_k_index.reshape([-1]) + offset_k_index.reshape(
[-1]) # (B*N)

# 2.5 Update best hyps
best_k_pred = paddle.index_select(
top_k_index.view(-1), index=best_k_index, axis=0) # (B*N)
top_k_index.reshape([-1]), index=best_k_index, axis=0) # (B*N)
best_hyps_index = best_k_index // beam_size
last_best_k_hyps = paddle.index_select(
hyps, index=best_hyps_index, axis=0) # (B*N, i)
hyps = paddle.cat(
(last_best_k_hyps, best_k_pred.view(-1, 1)),
(last_best_k_hyps, best_k_pred.reshape([-1, 1])),
dim=1) # (B*N, i+1)

# 2.6 Update end flag
end_flag = paddle.equal(hyps[:, -1], self.eos).view(-1, 1)
end_flag = paddle.equal(hyps[:, -1], self.eos).reshape([-1, 1])

# 3. Select best of best
scores = scores.view(batch_size, beam_size)
scores = scores.reshape([batch_size, beam_size])
# TODO: length normalization
best_index = paddle.argmax(scores, axis=-1).long() # (B)
best_hyps_index = best_index + paddle.arange(
Expand Down Expand Up @@ -379,7 +382,7 @@ def ctc_greedy_search(
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)

topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
topk_index = topk_index.reshape([batch_size, maxlen]) # (B, maxlen)
pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen)
topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen)

Expand Down
11 changes: 6 additions & 5 deletions paddlespeech/s2t/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def forward_attention(

p_attn = self.dropout(attn)
x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k)
x = x.transpose([0, 2, 1, 3]).reshape([n_batch, -1, self.h *
self.d_k]) # (batch, time1, d_model)
x = x.transpose([0, 2, 1, 3]).reshape(
[n_batch, -1, self.h * self.d_k]) # (batch, time1, d_model)

return self.linear_out(x) # (batch, time1, d_model)

Expand Down Expand Up @@ -280,8 +280,8 @@ def rel_shift(self, x, zero_triu: bool=False):
(x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1)

x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1,
x.shape[2])
x_padded = x_padded.reshape(
[x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]])
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]

if zero_triu:
Expand Down Expand Up @@ -349,7 +349,8 @@ def forward(self,
new_cache = paddle.concat((k, v), axis=-1)

n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).reshape([n_batch_pos, -1, self.h, self.d_k])
p = self.linear_pos(pos_emb).reshape(
[n_batch_pos, -1, self.h, self.d_k])
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)

# (batch, head, time1, d_k)
Expand Down