Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions egs/swbd/s5c/local/xvector/prepare_perturbed_data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ stage=1
train_stage=-10
generate_alignments=true # false if doing ctc training
speed_perturb=true
mfcc_config=conf/mfcc_hires.conf
mfccdir=mfcc

. ./path.sh
. ./utils/parse_options.sh
Expand All @@ -27,13 +29,12 @@ if [ $stage -le 1 ]; then
if [ -f data/${datadir}_sp_hires/feats.scp ]; then
echo "$0: directory data/${datadir}_sp_hires/feats.scp already exists, skipping creating it."
else
mfccdir=mfcc
utils/copy_data_dir.sh data/${datadir}_sp data/${datadir}_sp_hires
steps/make_mfcc.sh --cmd "$train_cmd" --nj 50 \
steps/make_mfcc.sh --cmd "$train_cmd" --nj 50 --mfcc-config $mfcc_config \
data/${datadir}_sp_hires exp/make_mfcc/${datadir}_sp_hires $mfccdir || exit 1;
# we typically won't need the cmvn stats when using hires features-- it's
# mostly for neural nets.
utils/fix_data_dir.sh data/${dataset}_sp_hires # remove segments with problems
utils/fix_data_dir.sh data/${datadir}_sp_hires # remove segments with problems
fi
done
fi
Expand All @@ -50,7 +51,7 @@ if [ $stage -le 2 ]; then
echo "$0: data/${dataset}_hires/feats.scp already exists, skipping mfcc generation"
else
utils/copy_data_dir.sh data/$dataset data/${dataset}_hires
steps/make_mfcc.sh --cmd "$train_cmd" --nj 10 --mfcc-config conf/mfcc_hires.conf \
steps/make_mfcc.sh --cmd "$train_cmd" --nj 10 --mfcc-config $mfcc_config \
data/${dataset}_hires exp/make_hires/$dataset $mfccdir;
steps/compute_cmvn_stats.sh data/${dataset}_hires exp/make_hires/$dataset $mfccdir;
utils/fix_data_dir.sh data/${dataset}_hires # remove segments with problems
Expand Down
19 changes: 13 additions & 6 deletions egs/swbd/s5c/local/xvector/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@
set -e

stage=1
train_stage=1
train_stage=-10
generate_alignments=true # false if doing ctc training
speed_perturb=true

init_lr=0.003
final_lr=0.0003
max_change=2.0
use_gpu=true
feat_dim=40 # this is the MFCC dim we use in the hires features. you can't change it
# unless you change local/xvector/prepare_perturbed_data.sh to use a different
# MFCC config with a different dimension.
data=data/train_nodup_sp_hires # you can't change this without changing
# local/xvector/prepare_perturbed_data.sh
xvector_dim=200 # dimension of the xVector. configurable.
xvector_dir=exp/xvector_a
egs_dir=exp/xvector_a/egs


. ./path.sh
Expand All @@ -40,18 +44,21 @@ if [ $stage -le 3 ]; then
$xvector_dir/nnet.config
fi

if [ $stage -le 4 ]; then
if [ $stage -le 4 ] && [ -z "$egs_dir" ]; then
# dump egs.
steps/nnet3/xvector/get_egs.sh --cmd "$train_cmd" \
"$data" $xvector_dir/egs
"$data" $egs_dir
fi

if [ $stage -le 5 ]; then
# training for 4 epochs * 3 shifts means we see each eg 12
# times (3 different frame-shifts of the same eg are counted as different).
steps/nnet3/xvector/train.sh --cmd "$train_cmd" \
--num-epochs 4 --num-shifts 3 \
--num-jobs-initial 2 --num-jobs-final 8 \
--num-epochs 4 --num-shifts 3 --use-gpu $use_gpu --stage $train_stage \
--initial-effective-lrate $init_lr --final-effective-lrate $final_lr \
--num-jobs-initial 1 --num-jobs-final 8 \
--max-param-change $max_change \
--egs-dir $egs_dir \
$xvector_dir
fi

Expand Down
7 changes: 3 additions & 4 deletions egs/wsj/s5/steps/nnet3/xvector/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
cmd=run.pl
num_epochs=4 # Number of epochs of training;
# the number of iterations is worked out from this.
diss_scale=1.0 # scale value used to scale the dissimalarity part in objective function.
num_shifts=3
initial_effective_lrate=0.003
final_effective_lrate=0.0003
Expand Down Expand Up @@ -134,10 +133,10 @@ while [ $x -lt $num_iters ]; do
# Set off jobs doing some diagnostics, in the background.
# Use the egs dir from the previous iteration for the diagnostics
$cmd JOB=1:$num_diagnostic_archives $dir/log/compute_prob_valid.$x.JOB.log \
nnet3-xvector-compute-prob $dir/$x.raw \
nnet3-xvector-compute-prob --compute-accuracy=true $dir/$x.raw \
"ark:nnet3-merge-egs --measure-output-frames=false ark:$egs_dir/valid_diagnostic_egs.JOB.ark ark:- |" &
$cmd JOB=1:$num_diagnostic_archives $dir/log/compute_prob_train.$x.JOB.log \
nnet3-xvector-compute-prob $dir/$x.raw \
nnet3-xvector-compute-prob --compute-accuracy=true $dir/$x.raw \
"ark:nnet3-merge-egs --measure-output-frames=false ark:$egs_dir/train_diagnostic_egs.JOB.ark ark:- |" &

if [ $x -gt 0 ]; then
Expand Down Expand Up @@ -175,7 +174,7 @@ while [ $x -lt $num_iters ]; do

$cmd $train_queue_opt $dir/log/train.$x.$n.log \
nnet3-xvector-train $parallel_train_opts --print-interval=10 \
--max-param-change=$max_param_change --diss-scale=$diss_scale "$raw" \
--max-param-change=$max_param_change "$raw" \
"ark:nnet3-copy-egs ark:$egs_dir/egs.$archive.ark ark:- | nnet3-shuffle-egs --buffer-size=$shuffle_buffer_size --srand=$x ark:- ark:-| nnet3-merge-egs --measure-output-frames=false --minibatch-size=$minibatch_size --discard-partial-minibatches=true ark:- ark:- |" \
$dir/$[$x+1].$n.raw || touch $dir/.error &
done
Expand Down
130 changes: 116 additions & 14 deletions src/xvector/nnet-xvector-diagnostics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ NnetXvectorComputeProb::NnetXvectorComputeProb(const NnetComputeProbOptions &con
bool is_gradient = true; // force simple update
SetZero(is_gradient, deriv_nnet_);
}
if (config_.compute_accuracy)
need_eer_threshold_ = true;
else
need_eer_threshold_ = false;
}

const Nnet &NnetXvectorComputeProb::GetDeriv() const {
Expand All @@ -51,6 +55,7 @@ NnetXvectorComputeProb::~NnetXvectorComputeProb() {
void NnetXvectorComputeProb::Reset() {
num_minibatches_processed_ = 0;
objf_info_.clear();
acc_info_.clear();
if (deriv_nnet_) {
bool is_gradient = true;
SetZero(is_gradient, deriv_nnet_);
Expand Down Expand Up @@ -80,46 +85,65 @@ void NnetXvectorComputeProb::ProcessOutputs(NnetComputer *computer) {
if (nnet_.IsOutputNode(node_index)) {
std::string xvector_name = nnet_.GetNodeName(node_index),
s_name = "s", b_name = "b";
if (nnet_.GetNodeIndex(s_name) == -1 || nnet_.GetNodeIndex(b_name) == -1)
KALDI_ERR << "The nnet expected to have two output nodes with name s and b.";
if (nnet_.GetNodeIndex(s_name) == -1
|| nnet_.GetNodeIndex(b_name) == -1)
KALDI_ERR << "Expected the nnet to have two output nodes with name "
<< "s and b.";

if (xvector_name != s_name && xvector_name != b_name) {
const CuMatrixBase<BaseFloat> &xvector_pairs = computer->GetOutput(xvector_name),
&xvec_s = computer->GetOutput(s_name),
&xvec_b = computer->GetOutput(b_name);
CuMatrix<BaseFloat> xvector_deriv(xvector_pairs.NumRows(), xvector_pairs.NumCols(),
kUndefined);
int32 s_dim = xvector_pairs.NumCols() * (xvector_pairs.NumCols() + 1) / 2;
const CuMatrixBase<BaseFloat> &xvector_pairs = computer->GetOutput(
xvector_name),
&xvec_s = computer->GetOutput(
s_name),
&xvec_b = computer->GetOutput(
b_name);
int32 num_rows = xvector_pairs.NumRows(),
dim_xvector = xvector_pairs.NumCols();
int32 s_dim = dim_xvector * (dim_xvector + 1) / 2;

CuMatrix<BaseFloat> xvector_deriv(num_rows, dim_xvector,
kUndefined),
raw_scores(num_rows, num_rows, kUndefined);

// convert CuVector to CuSpMatrix
CuSpMatrix<BaseFloat> xvec_s_sp(xvector_pairs.NumCols());
CuSpMatrix<BaseFloat> xvec_s_sp(dim_xvector);
xvec_s_sp.CopyFromVec(xvec_s.Row(0));

CuVector<BaseFloat> deriv_s(s_dim);
BaseFloat xvec_b_val = xvec_b(0,0), deriv_b;
BaseFloat tot_weight, tot_objf;
bool supply_deriv = config_.compute_deriv;
bool compute_accuracy = config_.compute_accuracy;
ComputeXvectorObjfAndDeriv(xvector_pairs, xvec_s_sp, xvec_b_val,
(supply_deriv ? &xvector_deriv : NULL),
(supply_deriv ? &deriv_s : NULL),
(supply_deriv ? &deriv_b : NULL),
(compute_accuracy ? &raw_scores : NULL),
&tot_objf,
&tot_weight);
if (supply_deriv) {
CuMatrix<BaseFloat> deriv_s_mat(1, s_dim),
deriv_b_mat(1,1);
deriv_b_mat(1,1);
deriv_b_mat(0,0) = deriv_b;
deriv_s_mat.CopyRowsFromVec(deriv_s);
computer->AcceptOutputDeriv(xvector_name, &xvector_deriv);
computer->AcceptOutputDeriv(s_name, &deriv_s_mat);
computer->AcceptOutputDeriv(b_name, &deriv_b_mat);

}

SimpleObjectiveInfo &totals = objf_info_[xvector_name];
totals.tot_weight += tot_weight;
totals.tot_objective += tot_objf;

if (compute_accuracy) {
BaseFloat tot_acc, tot_weight_acc;
SimpleObjectiveInfo &acc_totals = acc_info_[xvector_name];
ComputeAccuracy(raw_scores, &tot_weight_acc, &tot_acc);
acc_totals.tot_objective += tot_weight_acc * tot_acc;
acc_totals.tot_weight += tot_weight_acc;
}
num_minibatches_processed_++;
}
num_minibatches_processed_++;
}
}
}
Expand All @@ -140,15 +164,70 @@ bool NnetXvectorComputeProb::PrintTotalStats() const {
KALDI_LOG << "Overall "
<< (obj_type == kLinear ? "log-likelihood" : "objective")
<< " for '" << name << "' is "
<< (info.tot_objective / info.tot_weight) << " per frame"
<< ", over " << info.tot_weight << " frames.";
<< (info.tot_objective / info.tot_weight) << " per chunk"
<< ", over " << info.tot_weight << " chunks.";
if (info.tot_weight > 0)
ans = true;
}
}
if (config_.compute_accuracy) { // Now print the accuracy.
iter = acc_info_.begin();
end = acc_info_.end();
for (; iter != end; ++iter) {
const std::string &name = iter->first;
const SimpleObjectiveInfo &info = iter->second;
KALDI_LOG << "Overall accuracy for '" << name << "' is "
<< (info.tot_objective / info.tot_weight)
<< " per pair of chunks"
<< ", over " << info.tot_weight << " pairs of chunks.";
}
}
return ans;
}

void NnetXvectorComputeProb::ComputeAccuracy(
const CuMatrixBase<BaseFloat> &raw_scores,
BaseFloat *tot_weight_out,
BaseFloat *tot_accuracy_out) {
int32 num_rows = raw_scores.NumCols();
// The accuracy uses the EER threshold, which is calculated
// on the first minibatch.
if (need_eer_threshold_) {
std::vector<BaseFloat> target_scores;
std::vector<BaseFloat> nontarget_scores;
for (int32 i = 0; i < num_rows; i++) {
for (int32 j = 0; j < num_rows; j++) {
if (i + 1 == j && i % 2 == 0) {
target_scores.push_back(raw_scores(i, j));
} else if (i < j) {
nontarget_scores.push_back(raw_scores(i, j));
}
}
}
(*tot_accuracy_out) = 1.0 - ComputeEer(&target_scores, &nontarget_scores);
(*tot_weight_out) = target_scores.size() + nontarget_scores.size();
need_eer_threshold_ = false;
} else {
int32 count = 0,
error = 0;
for (int32 i = 0; i < num_rows; i++) {
for (int32 j = 0; j < num_rows; j++) {
if (i + 1 == j && i % 2 == 0) {
if (raw_scores(i, j) < eer_threshold_)
error++;
count++;
} else if (i < j) {
if (raw_scores(i, j) >= eer_threshold_)
error++;
count++;
}
}
}
(*tot_accuracy_out) = 1.0 - static_cast<BaseFloat>(error) / count;
(*tot_weight_out) = count;
}
}

const SimpleObjectiveInfo* NnetXvectorComputeProb::GetObjective(
const std::string &output_name) const {
unordered_map<std::string, SimpleObjectiveInfo, StringHasher>::const_iterator
Expand All @@ -159,5 +238,28 @@ const SimpleObjectiveInfo* NnetXvectorComputeProb::GetObjective(
return NULL;
}

BaseFloat NnetXvectorComputeProb::ComputeEer(
std::vector<BaseFloat> *target_scores,
std::vector<BaseFloat> *nontarget_scores) {
KALDI_ASSERT(!target_scores->empty() && !nontarget_scores->empty());
std::sort(target_scores->begin(), target_scores->end());
std::sort(nontarget_scores->begin(), nontarget_scores->end());
int32 target_position = 0,
target_size = target_scores->size();
for (; target_position + 1 < target_size; target_position++) {
int32 nontarget_size = nontarget_scores->size(),
nontarget_n = nontarget_size * target_position * 1.0 / target_size,
nontarget_position = nontarget_size - 1 - nontarget_n;
if (nontarget_position < 0)
nontarget_position = 0;
if ((*nontarget_scores)[nontarget_position] <
(*target_scores)[target_position])
break;
}
eer_threshold_ = (*target_scores)[target_position];
BaseFloat eer = target_position * 1.0 / target_size;
return eer;
}

} // namespace nnet3
} // namespace kaldi
13 changes: 10 additions & 3 deletions src/xvector/nnet-xvector-diagnostics.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ class NnetXvectorComputeProb {
~NnetXvectorComputeProb();
private:
void ProcessOutputs(NnetComputer *computer);
// Returns the Equal Error Rate (EER) and sets the threshold.
BaseFloat ComputeEer(std::vector<BaseFloat> *target_scores,
std::vector<BaseFloat> *nontarget_scores);
// Computes the accuracy for this minibatch.
void ComputeAccuracy(const CuMatrixBase<BaseFloat> &raw_scores,
BaseFloat *tot_weight_out,
BaseFloat *tot_accuracy_out);

NnetComputeProbOptions config_;
const Nnet &nnet_;
Expand All @@ -80,12 +87,12 @@ class NnetXvectorComputeProb {

// this is only for diagnostics.
int32 num_minibatches_processed_;

bool need_eer_threshold_;
BaseFloat eer_threshold_;
unordered_map<std::string, SimpleObjectiveInfo, StringHasher> objf_info_;

unordered_map<std::string, SimpleObjectiveInfo, StringHasher> acc_info_;
};


} // namespace nnet3
} // namespace kaldi

Expand Down
Loading