Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions speechx/examples/aishell/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ wer=./aishell_wer
nj=40
export GLOG_logtostderr=1

./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
#./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj

data=$PWD/data
# 3. gen linear feat
Expand All @@ -72,10 +72,42 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$lm_model_dir/vocab.txt \
--lm_path=$lm_model_dir/avg_1.jit.klm \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result

cat $data/split${nj}/*/result > $label_file
cat $data/split${nj}/*/result > ${label_file}
local/compute-wer.py --char=1 --v=1 ${label_file} $text > ${wer}

# 4. decode with lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_lm \
offline_decoder_sliding_chunk_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$lm_model_dir/vocab.txt \
--lm_path=$lm_model_dir/avg_1.jit.klm \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_lm

cat $data/split${nj}/*/result_lm > ${label_file}_lm
local/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm

graph_dir=./aishell_graph
if [ ! -d $ ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
unzip -d aishell_graph.zip
fi

# 5. test TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/log_tlg \
offline_wfst_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$aishell_online_model/avg_1.jit.pdmodel \
--param_path=$aishell_online_model/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \
--acoustic_scale=1.2 \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_tlg

local/compute-wer.py --char=1 --v=1 $label_file $text > $wer
tail $wer
cat $data/split${nj}/*/result_tlg > ${label_file}_tlg
local/compute-wer.py --char=1 --v=1 ${label_file}_tlg $text > ${wer}_tlg
4 changes: 4 additions & 0 deletions speechx/examples/decoder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_
target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})

add_executable(offline_wfst_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_wfst_decoder_main.cc)
target_include_directories(offline_wfst_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_wfst_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})

add_executable(decoder_test_main ${CMAKE_CURRENT_SOURCE_DIR}/decoder_test_main.cc)
target_include_directories(decoder_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(decoder_test_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
Expand Down
14 changes: 9 additions & 5 deletions speechx/examples/decoder/offline_decoder_sliding_chunk_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(lm_path, "lm.klm", "language model");
DEFINE_string(lm_path, "", "language model");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=5) downsampling module.");
Expand All @@ -45,7 +45,6 @@ using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;


// test ds2 online decoder by feeding speech feature
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
Expand All @@ -63,7 +62,6 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;


int32 num_done = 0, num_err = 0;

ppspeech::CTCBeamSearchOptions opts;
Expand Down Expand Up @@ -138,10 +136,16 @@ int main(int argc, char* argv[]) {
}
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
decodable->Reset();
decoder.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}

Expand Down
158 changes: 158 additions & 0 deletions speechx/examples/decoder/offline_wfst_decoder_main.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// todo refactor, repalce with gtest

#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"

DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "decoder graph");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=5) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");

using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;

// test TLG decoder by feeding speech feature.
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);

kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string word_symbol_table = FLAGS_word_symbol_table;
std::string graph_path = FLAGS_graph_path;
LOG(INFO) << "model path: " << model_graph;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "word symbol path: " << word_symbol_table;
LOG(INFO) << "graph path: " << graph_path;

int32 num_done = 0, num_err = 0;

ppspeech::TLGDecoderOptions opts;
opts.word_symbol_table = word_symbol_table;
opts.fst_path = graph_path;
opts.opts.max_active = FLAGS_max_active;
opts.opts.beam = 15.0;
opts.opts.lattice_beam = 7.5;
ppspeech::TLGDecoder decoder(opts);

ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));

int32 chunk_size = FLAGS_receptive_field_length;
int32 chunk_stride = FLAGS_downsampling_rate;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder();

for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << feature.NumRows();
LOG(INFO) << "cols: " << feature.NumCols();

int32 row_idx = 0;
int32 padding_len = 0;
int32 ori_feature_len = feature.NumRows();
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
padding_len =
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
feature.Resize(feature.NumRows() + padding_len,
feature.NumCols(),
kaldi::kCopyData);
}
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
int32 feature_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (feature_chunk_size < receptive_field_length) break;

int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
++start;
}
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
decoder.AdvanceDecode(decodable);
}
std::string result;
result = decoder.GetFinalBestPath();
decodable->Reset();
decoder.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}

KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
3 changes: 2 additions & 1 deletion speechx/speechx/decoder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ add_library(decoder STATIC
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
ctc_tlg_decoder.cc
)
target_link_libraries(decoder PUBLIC kenlm utils fst)
target_link_libraries(decoder PUBLIC kenlm utils fst)
2 changes: 1 addition & 1 deletion speechx/speechx/decoder/ctc_beam_search_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void CTCBeamSearch::AdvanceDecode(
vector<vector<BaseFloat>> likelihood;
vector<BaseFloat> frame_prob;
bool flag =
decodable->FrameLogLikelihood(num_frame_decoded_, &frame_prob);
decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
if (flag == false) break;
likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood);
Expand Down
2 changes: 1 addition & 1 deletion speechx/speechx/decoder/ctc_beam_search_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "base/common.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "nnet/decodable-itf.h"
#include "kaldi/decoder/decodable-itf.h"
#include "util/parse-options.h"

#pragma once
Expand Down
66 changes: 66 additions & 0 deletions speechx/speechx/decoder/ctc_tlg_decoder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "decoder/ctc_tlg_decoder.h"
namespace ppspeech {

TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path));
CHECK(fst_ != nullptr);
word_symbol_table_.reset(
fst::SymbolTable::ReadText(opts.word_symbol_table));
decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts));
decoder_->InitDecoding();
frame_decoded_size_ = 0;
}

void TLGDecoder::InitDecoder() {
decoder_->InitDecoding();
frame_decoded_size_ = 0;
}

void TLGDecoder::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (!decodable->IsLastFrame(frame_decoded_size_)) {
LOG(INFO) << "num frame decode: " << frame_decoded_size_;
AdvanceDecoding(decodable.get());
}
}

void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
decoder_->AdvanceDecoding(decodable, 1);
frame_decoded_size_++;
}

void TLGDecoder::Reset() {
InitDecoder();
return;
}

std::string TLGDecoder::GetFinalBestPath() {
decoder_->FinalizeDecoding();
kaldi::Lattice lat;
kaldi::LatticeWeight weight;
std::vector<int> alignment;
std::vector<int> words_id;
decoder_->GetBestPath(&lat, true);
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
std::string words;
for (int32 idx = 0; idx < words_id.size(); ++idx) {
std::string word = word_symbol_table_->Find(words_id[idx]);
words += word;
}
return words;
}
}
Loading