Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions speechx/examples/decoder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(offline_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_decoder_main.cc)
target_include_directories(offline_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})

add_executable(offline_wfst_decoder_main ${CMAKE_CURRENT_SOURCE_DIR}/offline_wfst_decoder_main.cc)
target_include_directories(offline_wfst_decoder_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(offline_wfst_decoder_main PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})
101 changes: 101 additions & 0 deletions speechx/examples/decoder/offline_wfst_decoder_main.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// todo refactor, repalce with gtest

#include "base/flags.h"
#include "base/log.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/raw_audio.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"

DEFINE_string(feature_respecifier, "", "test feature rspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "vocab.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");


using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;

int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);

kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_respecifier);
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string word_symbol_table = FLAGS_word_symbol_table;
std::string graph_path = FLAGS_graph_path;

int32 num_done = 0, num_err = 0;

ppspeech::TLGDecoderOptions opts;
opts.word_symbol_table = word_symbol_table;
opts.fst_path = graph_path;
ppspeech::TLGDecoder decoder(opts);

ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::RawDataCache> raw_data(
new ppspeech::RawDataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data));

int32 chunk_size = 35;
decoder.InitDecoder();

for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
const kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols());
int32 row_idx = 0;
int32 num_chunks = feature.NumRows() / chunk_size;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, row_idx);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
row_idx++;
}
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
decoder.AdvanceDecode(decodable);
}
std::string result;
result = decoder.GetFinalBestPath();
KALDI_LOG << " the result of " << utt << " is " << result;
decodable->Reset();
decoder.Reset();
++num_done;
}

KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}
3 changes: 2 additions & 1 deletion speechx/speechx/decoder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ add_library(decoder STATIC
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
ctc_tlg_decoder.cc
)
target_link_libraries(decoder PUBLIC kenlm utils fst)
target_link_libraries(decoder PUBLIC kenlm utils fst)
2 changes: 1 addition & 1 deletion speechx/speechx/decoder/ctc_beam_search_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "base/common.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "nnet/decodable-itf.h"
#include "kaldi/decoder/decodable-itf.h"
#include "util/parse-options.h"

#pragma once
Expand Down
50 changes: 50 additions & 0 deletions speechx/speechx/decoder/ctc_tlg_decoder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "decoder/ctc_tlg_decoder.h"
namespace ppspeech {

TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path));
CHECK(fst_ != nullptr);
word_symbol_table_.reset(fst::SymbolTable::ReadText(opts.word_symbol_table));
decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts));
decoder_->InitDecoding();
}

void TLGDecoder::InitDecoder() {
decoder_->InitDecoding();
}

void TLGDecoder::AdvanceDecode(const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (1) {
AdvanceDecoding(decodable.get());
if (decodable->IsLastFrame(num_frame_decoded_)) break;
}
}

void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
// skip blank frame?
decoder_->AdvanceDecoding(decodable, 1);
num_frame_decoded_++;
}

void TLGDecoder::Reset() {
decoder_->InitDecoding();
return;
}

std::string TLGDecoder::GetFinalBestPath() {
decoder_->FinalizeDecoding();
kaldi::Lattice lat;
kaldi::LatticeWeight weight;
std::vector<int> alignment;
std::vector<int> words_id;
decoder_->GetBestPath(&lat, true);
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
std::string words;
for (int32 idx = 0; idx < words_id.size(); ++idx) {
std::string word = word_symbol_table_->Find(words_id[idx]);
words += word;
}
return words;
}

}
47 changes: 47 additions & 0 deletions speechx/speechx/decoder/ctc_tlg_decoder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#pragma once

#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "kaldi/decoder/decodable-itf.h"
#include "util/parse-options.h"
#include "base/basic_types.h"

namespace ppspeech {

struct TLGDecoderOptions {
kaldi::LatticeFasterDecoderConfig opts;
// todo remove later, add into decode resource
std::string word_symbol_table;
std::string fst_path;

TLGDecoderOptions()
: word_symbol_table(""),
fst_path("") {}
};

class TLGDecoder {
public:
explicit TLGDecoder(TLGDecoderOptions opts);
void InitDecoder();
void Decode();
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();

private:
void AdvanceDecoding(kaldi::DecodableInterface* decodable);

std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::SymbolTable> word_symbol_table_;
int32 num_frame_decoded_;
};



} // namespace ppspeech
3 changes: 3 additions & 0 deletions speechx/speechx/kaldi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ add_subdirectory(base)
add_subdirectory(util)
add_subdirectory(feat)
add_subdirectory(matrix)
add_subdirectory(lat)
add_subdirectory(fstext)
add_subdirectory(decoder)
6 changes: 6 additions & 0 deletions speechx/speechx/kaldi/decoder/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

add_library(kaldi-decoder
lattice-faster-decoder.cc
lattice-faster-online-decoder.cc
)
target_link_libraries(kaldi-decoder PUBLIC kaldi-lat)
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class DecodableInterface {
/// decoding-from-matrix setting where we want to allow the last delta or
/// LDA
/// features to be flushed out for compatibility with the baseline setup.
virtual bool IsLastFrame(int32 frame) const = 0;
virtual bool IsLastFrame(int32 frame) = 0;

/// The call NumFramesReady() will return the number of frames currently
/// available
Expand Down
4 changes: 0 additions & 4 deletions speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1007,14 +1007,10 @@ template class LatticeFasterDecoderTpl<fst::Fst<fst::StdArc>, decoder::StdToken>
template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, decoder::StdToken >;
template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, decoder::StdToken >;

template class LatticeFasterDecoderTpl<fst::ConstGrammarFst, decoder::StdToken>;
template class LatticeFasterDecoderTpl<fst::VectorGrammarFst, decoder::StdToken>;

template class LatticeFasterDecoderTpl<fst::Fst<fst::StdArc> , decoder::BackpointerToken>;
template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, decoder::BackpointerToken >;
template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, decoder::BackpointerToken >;
template class LatticeFasterDecoderTpl<fst::ConstGrammarFst, decoder::BackpointerToken>;
template class LatticeFasterDecoderTpl<fst::VectorGrammarFst, decoder::BackpointerToken>;


} // end namespace kaldi.
3 changes: 1 addition & 2 deletions speechx/speechx/kaldi/decoder/lattice-faster-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@
#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#define KALDI_DECODER_LATTICE_FASTER_DECODER_H_

#include "decoder/grammar-fst.h"
#include "fst/fstlib.h"
#include "fst/memory.h"
#include "fstext/fstext-lib.h"
#include "itf/decodable-itf.h"
#include "decoder/decodable-itf.h"
#include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h"
#include "util/hash-list.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ bool LatticeFasterOnlineDecoderTpl<FST>::GetRawLatticePruned(
template class LatticeFasterOnlineDecoderTpl<fst::Fst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstGrammarFst >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorGrammarFst >;
//template class LatticeFasterOnlineDecoderTpl<fst::ConstGrammarFst >;
//template class LatticeFasterOnlineDecoderTpl<fst::VectorGrammarFst >;


} // end namespace kaldi.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include "util/stl-utils.h"
#include "util/hash-list.h"
#include "fst/fstlib.h"
#include "itf/decodable-itf.h"
#include "decoder/decodable-itf.h"
#include "fstext/fstext-lib.h"
#include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h"
Expand Down
5 changes: 5 additions & 0 deletions speechx/speechx/kaldi/fstext/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

add_library(kaldi-fstext
kaldi-fst-io.cc
)
target_link_libraries(kaldi-fstext PUBLIC kaldi-util)
Loading