Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 9 additions & 25 deletions examples/other/1xt2x/src_deepspeech2x/models/ds2/deepspeech2.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,34 +167,18 @@ def decode(self, audio, audio_len, vocab_list, decoding_method,
cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
batch_size = audio.shape[0]
self.decoder.init_decoder(batch_size, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n,
num_processes)

eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
print("probs.shape", probs.shape)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)

def decode_probs_split(self, probs_split, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size,
cutoff_prob, cutoff_top_n, num_processes):
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
return self.decoder.decode_probs_split(
probs_split, vocab_list, decoding_method, lang_model_path,
beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n,
num_processes)
self.decoder.next(probs, eouts_len)
trans_best, trans_beam = self.decoder.decode()
self.decoder.del_decoder()
return trans_best

@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
Expand Down
16 changes: 10 additions & 6 deletions paddlespeech/s2t/decoders/ctcdecoder/swig_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper for various CTC decoders in SWIG."""
import swig_decoders
import paddlespeech_ctcdecoders


class Scorer(swig_decoders.Scorer):
class Scorer(paddlespeech_ctcdecoders.Scorer):
"""Wrapper for Scorer.

:param alpha: Parameter associated with language model. Don't use
Expand All @@ -29,7 +29,7 @@ class Scorer(swig_decoders.Scorer):
"""

def __init__(self, alpha, beta, model_path, vocabulary):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
paddlespeech_ctcdecoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)


def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
Expand All @@ -44,7 +44,7 @@ def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
:return: Decoding result string.
:rtype: str
"""
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary,
result = paddlespeech_ctcdecoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary,
blank_id)
return result

Expand Down Expand Up @@ -81,7 +81,7 @@ def ctc_beam_search_decoder(probs_seq,
results, in descending order of the probability.
:rtype: list
"""
beam_results = swig_decoders.ctc_beam_search_decoder(
beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoder(
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n,
ext_scoring_func, blank_id)
beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results]
Expand Down Expand Up @@ -126,9 +126,13 @@ def ctc_beam_search_decoder_batch(probs_split,
"""
probs_split = [probs_seq.tolist() for probs_seq in probs_split]

batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch(
batch_beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoder_batch(
probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func, blank_id)
batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results]
return batch_beam_results


def get_ctc_beam_search_decoder_batch_class():
return paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch
64 changes: 50 additions & 14 deletions paddlespeech/s2t/exps/deepspeech2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,26 +402,42 @@ def test(self):
def compute_result_transcripts(self, audio, audio_len, vocab_list,
decode_cfg):
if self.args.model_type == "online":
output_probs, output_lens = self.static_forward_online(audio,
audio_len)
output_probs, output_lens, trans_process_list = self.static_forward_online(
audio,
audio_len,
vocab_list,
decode_cfg.decoding_method,
decode_cfg.lang_model_path,
decode_cfg.alpha,
decode_cfg.beta,
decode_cfg.beam_size,
decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n,
decode_cfg.num_proc_bsearch,
decoder_chunk_size=1)
result_transcripts = trans_process_list[-1:]
elif self.args.model_type == "offline":
output_probs, output_lens = self.static_forward_offline(audio,
audio_len)

self.model.decoder.init_decoder(
1, vocab_list, decode_cfg.decoding_method,
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
decode_cfg.beam_size, decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)

self.model.decoder.next(output_probs, output_lens)

trans_best, trans_beam = self.model.decoder.decode()
self.model.decoder.del_decoder()
result_transcripts = trans_best

else:
raise Exception("wrong model type")

self.predictor.clear_intermediate_tensor()
self.predictor.try_shrink_memory()

self.model.decoder.init_decode(decode_cfg.alpha, decode_cfg.beta,
decode_cfg.lang_model_path, vocab_list,
decode_cfg.decoding_method)

result_transcripts = self.model.decoder.decode_probs(
output_probs, output_lens, vocab_list, decode_cfg.decoding_method,
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
decode_cfg.beam_size, decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)
#replace the <space> with ' '
result_transcripts = [
self._text_featurizer.detokenize(sentence)
Expand All @@ -439,7 +455,18 @@ def run_test(self):
except KeyboardInterrupt:
exit(-1)

def static_forward_online(self, audio, audio_len,
def static_forward_online(self,
audio,
audio_len,
vocab_list,
decoding_method,
lang_model_path,
alpha,
beta,
beam_size,
cutoff_prob,
cutoff_top_n,
num_proc_bsearch,
decoder_chunk_size: int=1):
"""
Parameters
Expand Down Expand Up @@ -472,6 +499,7 @@ def static_forward_online(self, audio, audio_len,
x_list = np.split(x_batch, batch_size, axis=0)
x_len_list = np.split(x_len_batch, batch_size, axis=0)

trans_process_list = []
for x, x_len in zip(x_list, x_len_list):
if self.args.enable_auto_log is True:
self.autolog.times.start()
Expand Down Expand Up @@ -504,12 +532,16 @@ def static_forward_online(self, audio, audio_len,
h_box_handle = self.predictor.get_input_handle(input_names[2])
c_box_handle = self.predictor.get_input_handle(input_names[3])

trans_chunk_list = []
probs_chunk_list = []
probs_chunk_lens_list = []
if self.args.enable_auto_log is True:
# record the model preprocessing time
self.autolog.times.stamp()

self.model.decoder.init_decoder(
1, vocab_list, decoding_method, lang_model_path, alpha, beta,
beam_size, cutoff_prob, cutoff_top_n, num_proc_bsearch)
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
Expand Down Expand Up @@ -549,9 +581,13 @@ def static_forward_online(self, audio, audio_len,
output_chunk_lens = output_lens_handle.copy_to_cpu()
chunk_state_h_box = output_state_h_handle.copy_to_cpu()
chunk_state_c_box = output_state_c_handle.copy_to_cpu()

self.model.decoder.next(output_chunk_probs, output_chunk_lens)
probs_chunk_list.append(output_chunk_probs)
probs_chunk_lens_list.append(output_chunk_lens)
trans_best, trans_beam = self.model.decoder.decode()
trans_process_list.append(trans_best[0])
self.model.decoder.del_decoder()

output_probs = np.concatenate(probs_chunk_list, axis=1)
output_lens = np.sum(probs_chunk_lens_list, axis=0)
vocab_size = output_probs.shape[2]
Expand All @@ -573,7 +609,7 @@ def static_forward_online(self, audio, audio_len,
self.autolog.times.end()
output_probs = np.concatenate(output_probs_list, axis=0)
output_lens = np.concatenate(output_lens_list, axis=0)
return output_probs, output_lens
return output_probs, output_lens, trans_process_list

def static_forward_offline(self, audio, audio_len):
"""
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/s2t/models/ds2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from paddlespeech.s2t.utils import dynamic_pip_install

try:
import swig_decoders
import paddlespeech_ctcdecoders
except ImportError:
try:
package_name = 'paddlespeech_ctcdecoders'
Expand Down
19 changes: 9 additions & 10 deletions paddlespeech/s2t/models/ds2/deepspeech2.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,19 +169,18 @@ def decode(self, audio, audio_len, vocab_list, decoding_method,
cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
batch_size = audio.shape[0]
self.decoder.init_decoder(batch_size, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n,
num_processes)

eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
self.decoder.next(probs, eouts_len)
trans_best, trans_beam = self.decoder.decode()
self.decoder.del_decoder()
return trans_best

@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/s2t/models/ds2_online/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from paddlespeech.s2t.utils import dynamic_pip_install

try:
import swig_decoders
import paddlespeech_ctcdecoders
except ImportError:
try:
package_name = 'paddlespeech_ctcdecoders'
Expand Down
19 changes: 9 additions & 10 deletions paddlespeech/s2t/models/ds2_online/deepspeech2.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,20 +298,19 @@ def decode(self, audio, audio_len, vocab_list, decoding_method,
cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
batch_size = audio.shape[0]
self.decoder.init_decoder(batch_size, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n,
num_processes)

eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder(
audio, audio_len, None, None)
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
self.decoder.next(probs, eouts_len)
trans_best, trans_beam = self.decoder.decode()
self.decoder.del_decoder()
return trans_best

@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
Expand Down
6 changes: 3 additions & 3 deletions paddlespeech/s2t/models/u2/u2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.models.asr_interface import ASRInterface
from paddlespeech.s2t.modules.cmvn import GlobalCMVN
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.modules.ctc import CTCDecoderBase
from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(self,
vocab_size: int,
encoder: TransformerEncoder,
decoder: TransformerDecoder,
ctc: CTCDecoder,
ctc: CTCDecoderBase,
ctc_weight: float=0.5,
ignore_id: int=IGNORE_ID,
lsm_weight: float=0.0,
Expand Down Expand Up @@ -840,7 +840,7 @@ def _init_from_config(cls, configs: dict):
model_conf = configs.get('model_conf', dict())
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)
grad_norm_type = model_conf.get('ctc_grad_norm_type', None)
ctc = CTCDecoder(
ctc = CTCDecoderBase(
odim=vocab_size,
enc_n_units=encoder.output_size(),
blank_id=0,
Expand Down
9 changes: 4 additions & 5 deletions paddlespeech/s2t/models/u2_st/u2_st.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from paddlespeech.s2t.frontend.utility import IGNORE_ID
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.modules.cmvn import GlobalCMVN
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.modules.ctc import CTCDecoderBase
from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(self,
encoder: TransformerEncoder,
st_decoder: TransformerDecoder,
decoder: TransformerDecoder=None,
ctc: CTCDecoder=None,
ctc: CTCDecoderBase=None,
ctc_weight: float=0.0,
asr_weight: float=0.0,
ignore_id: int=IGNORE_ID,
Expand Down Expand Up @@ -313,8 +313,7 @@ def translate(
cache = [
paddle.ones(
(len(hyps), i - 1, hyp_cache.shape[-1]),
dtype=paddle.float32)
for hyp_cache in hyps[0]["cache"]
dtype=paddle.float32) for hyp_cache in hyps[0]["cache"]
]
for j, hyp in enumerate(hyps):
ys[j, :] = paddle.to_tensor(hyp["yseq"])
Expand Down Expand Up @@ -596,7 +595,7 @@ def _init_from_config(cls, configs: dict):
model_conf = configs['model_conf']
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)
grad_norm_type = model_conf.get('ctc_grad_norm_type', None)
ctc = CTCDecoder(
ctc = CTCDecoderBase(
odim=vocab_size,
enc_n_units=encoder.output_size(),
blank_id=0,
Expand Down
Loading