19
19
#include < string>
20
20
#include < vector>
21
21
#include " base/kaldi-common.h"
22
- #include " util/common-utils.h"
23
- #include " gmm/am-diag-gmm.h"
24
- #include " hmm/transition-model.h"
25
- #include " fstext/fstext-utils.h"
26
22
#include " decoder/decoder-wrappers.h"
23
+ #include " fstext/fstext-utils.h"
24
+ #include " gmm/am-diag-gmm.h"
27
25
#include " gmm/decodable-am-diag-gmm.h"
28
- #include " lat/kaldi-lattice.h"
29
26
#include " hmm/hmm-utils.h"
27
+ #include " hmm/transition-model.h"
28
+ #include " lat/kaldi-lattice.h"
29
+ #include " util/common-utils.h"
30
30
#include " gop/gmm-gop.h"
31
31
32
32
namespace kaldi {
@@ -35,21 +35,23 @@ typedef typename fst::StdArc Arc;
35
35
typedef typename Arc::StateId StateId;
36
36
typedef typename Arc::Weight Weight;
37
37
38
- void GmmGop::Init (std::string &tree_in_filename,
39
- std::string &model_in_filename,
40
- std::string &lex_in_filename) {
38
+ void GmmGop::Init (std::string &tree_in_filename, std::string &model_in_filename,
39
+ std::string &lex_in_filename) {
41
40
bool binary;
42
41
Input ki (model_in_filename, &binary);
43
42
tm_.Read (ki.Stream (), binary);
44
43
am_.Read (ki.Stream (), binary);
45
44
ReadKaldiObject (tree_in_filename, &ctx_dep_);
46
45
47
46
fst::VectorFst<fst::StdArc> *lex_fst = fst::ReadFstKaldi (lex_in_filename);
48
- std::vector<int32> disambig_syms;
47
+ std::vector<int32> disambig_syms;
49
48
TrainingGraphCompilerOptions gopts;
50
49
gc_ = new TrainingGraphCompiler (tm_, ctx_dep_, lex_fst, disambig_syms, gopts);
51
50
52
51
for (size_t i = 0 ; i < tm_.NumTransitionIds (); i++) {
52
+ // The transition-ids are only for building the denominator graph. Although
53
+ // one pdf-id may have multiple transition-ids, all those transitions-ids
54
+ // share the same HMM state (of course).
53
55
pdfid_to_tid[tm_.TransitionIdToPdf (i)] = i;
54
56
}
55
57
}
@@ -61,23 +63,22 @@ BaseFloat GmmGop::Decode(fst::VectorFst<fst::StdArc> &fst,
61
63
decode_opts.beam = 500 ;
62
64
FasterDecoder decoder (fst, decode_opts);
63
65
decoder.Decode (&decodable);
64
- if (! decoder.ReachedFinal ()) {
66
+ if (!decoder.ReachedFinal ()) {
65
67
KALDI_WARN << " Did not successfully decode." ;
66
68
}
67
69
fst::VectorFst<LatticeArc> decoded;
68
70
decoder.GetBestPath (&decoded);
69
71
std::vector<int32> osymbols;
70
72
LatticeWeight weight;
71
73
GetLinearSymbolSequence (decoded, align, &osymbols, &weight);
72
- BaseFloat likelihood = -(weight.Value1 ()+ weight.Value2 ());
74
+ BaseFloat likelihood = -(weight.Value1 () + weight.Value2 ());
73
75
74
76
return likelihood;
75
77
}
76
78
77
79
BaseFloat GmmGop::ComputeGopNumera (DecodableAmDiagGmmScaled &decodable,
78
80
std::vector<int32> &align,
79
- MatrixIndexT start_frame,
80
- int32 size) {
81
+ MatrixIndexT start_frame, int32 size) {
81
82
KALDI_ASSERT (start_frame + size <= align.size ());
82
83
BaseFloat likelihood = 0 ;
83
84
for (MatrixIndexT frame = start_frame; frame < start_frame + size; frame++) {
@@ -88,7 +89,8 @@ BaseFloat GmmGop::ComputeGopNumera(DecodableAmDiagGmmScaled &decodable,
88
89
}
89
90
90
91
BaseFloat GmmGop::ComputeGopNumeraViterbi (DecodableAmDiagGmmScaled &decodable,
91
- int32 phone_l, int32 phone, int32 phone_r) {
92
+ int32 phone_l, int32 phone,
93
+ int32 phone_r) {
92
94
KALDI_ASSERT (ctx_dep_.ContextWidth () == 3 );
93
95
KALDI_ASSERT (ctx_dep_.CentralPosition () == 1 );
94
96
std::vector<int32> phoneseq (3 );
@@ -101,7 +103,8 @@ BaseFloat GmmGop::ComputeGopNumeraViterbi(DecodableAmDiagGmmScaled &decodable,
101
103
fst.SetStart (cur_state);
102
104
for (size_t c = 0 ; c < tm_.GetTopo ().NumPdfClasses (phone); c++) {
103
105
int32 pdf_id;
104
- KALDI_ASSERT (ctx_dep_.Compute (phoneseq, c, &pdf_id));
106
+ if (!ctx_dep_.Compute (phoneseq, c, &pdf_id))
107
+ KALDI_ERR << " Failed to obtain pdf_id." ;
105
108
int32 tid = pdfid_to_tid[pdf_id];
106
109
107
110
StateId next_state = fst.AddState ();
@@ -137,7 +140,8 @@ BaseFloat GmmGop::ComputeGopDenomin(DecodableAmDiagGmmScaled &decodable,
137
140
StateId cur_state = start_state;
138
141
for (size_t c = 0 ; c < pdfclass_num; c++) {
139
142
int32 pdf_id;
140
- KALDI_ASSERT (ctx_dep_.Compute (phoneseq, c, &pdf_id));
143
+ if (!ctx_dep_.Compute (phoneseq, c, &pdf_id))
144
+ KALDI_ERR << " Failed to obtain pdf_id." ;
141
145
int32 tid = pdfid_to_tid[pdf_id];
142
146
143
147
StateId next_state = fst.AddState ();
@@ -157,11 +161,14 @@ BaseFloat GmmGop::ComputeGopDenomin(DecodableAmDiagGmmScaled &decodable,
157
161
}
158
162
159
163
void GmmGop::GetContextFromSplit (std::vector<std::vector<int32> > split,
160
- int32 index, int32 &phone_l, int32 &phone, int32 &phone_r) {
164
+ int32 index, int32 &phone_l, int32 &phone,
165
+ int32 &phone_r) {
161
166
KALDI_ASSERT (index < split.size ());
162
- phone_l = (index > 0 ) ? tm_.TransitionIdToPhone (split[index - 1 ][0 ]) : 1 ;
167
+ phone_l = (index > 0 ) ? tm_.TransitionIdToPhone (split[index - 1 ][0 ]) : 1 ;
163
168
phone = tm_.TransitionIdToPhone (split[index ][0 ]);
164
- phone_r = (index < split.size () - 1 ) ? tm_.TransitionIdToPhone (split[index +1 ][0 ]): 1 ;
169
+ phone_r = (index < split.size () - 1 )
170
+ ? tm_.TransitionIdToPhone (split[index + 1 ][0 ])
171
+ : 1 ;
165
172
}
166
173
167
174
void GmmGop::Compute (const Matrix<BaseFloat> &feats,
@@ -181,33 +188,31 @@ void GmmGop::Compute(const Matrix<BaseFloat> &feats,
181
188
phones_.resize (split.size ());
182
189
int32 frame_start_idx = 0 ;
183
190
for (MatrixIndexT i = 0 ; i < split.size (); i++) {
184
- SubMatrix<BaseFloat> feats_in_phone = feats. Range (frame_start_idx, split[i]. size (),
185
- 0 , feats.NumCols ());
191
+ SubMatrix<BaseFloat> feats_in_phone =
192
+ feats. Range (frame_start_idx, split[i]. size (), 0 , feats.NumCols ());
186
193
const Matrix<BaseFloat> features (feats_in_phone);
187
194
DecodableAmDiagGmmScaled split_decodable (am_, tm_, features, 1.0 );
188
195
189
196
int32 phone, phone_l, phone_r;
190
197
GetContextFromSplit (split, i, phone_l, phone, phone_r);
191
198
192
199
bool use_viterbi_numera = true ;
193
- BaseFloat gop_numerator = use_viterbi_numera ?
194
- ComputeGopNumeraViterbi (split_decodable, phone_l, phone, phone_r):
195
- ComputeGopNumera (ali_decodable, align,
196
- frame_start_idx, split[i].size ());
197
- BaseFloat gop_denominator = ComputeGopDenomin (split_decodable, phone_l, phone_r);
200
+ BaseFloat gop_numerator =
201
+ use_viterbi_numera
202
+ ? ComputeGopNumeraViterbi (split_decodable, phone_l, phone, phone_r)
203
+ : ComputeGopNumera (ali_decodable, align, frame_start_idx,
204
+ split[i].size ());
205
+ BaseFloat gop_denominator =
206
+ ComputeGopDenomin (split_decodable, phone_l, phone_r);
198
207
gop_result_ (i) = (gop_numerator - gop_denominator) / split[i].size ();
199
208
phones_[i] = phone;
200
209
201
210
frame_start_idx += split[i].size ();
202
211
}
203
212
}
204
213
205
- Vector<BaseFloat>& GmmGop::Result () {
206
- return gop_result_;
207
- }
214
+ Vector<BaseFloat> &GmmGop::Result () { return gop_result_; }
208
215
209
- std::vector<int32>& GmmGop::Phonemes () {
210
- return phones_;
211
- }
216
+ std::vector<int32> &GmmGop::Phonemes () { return phones_; }
212
217
213
218
} // End namespace kaldi
0 commit comments