diff --git a/egs/ami/s5b/local/prepare_parallel_train_data.sh b/egs/ami/s5b/local/prepare_parallel_train_data.sh index b049c906c3b..b551bacfb92 100755 --- a/egs/ami/s5b/local/prepare_parallel_train_data.sh +++ b/egs/ami/s5b/local/prepare_parallel_train_data.sh @@ -5,6 +5,10 @@ # but the wav data is copied from data/ihm. This is a little tricky because the # utterance ids are different between the different mics +train_set=train + +. utils/parse_options.sh + if [ $# != 1 ]; then echo "Usage: $0 [sdm1|mdm8]" @@ -18,12 +22,10 @@ if [ $mic == "ihm" ]; then exit 1; fi -train_set=train - . cmd.sh . ./path.sh -for f in data/ihm/train/utt2spk data/$mic/train/utt2spk; do +for f in data/ihm/${train_set}/utt2spk data/$mic/${train_set}/utt2spk; do if [ ! -f $f ]; then echo "$0: expected file $f to exist" exit 1 @@ -32,12 +34,12 @@ done set -e -o pipefail -mkdir -p data/$mic/train_ihmdata +mkdir -p data/$mic/${train_set}_ihmdata # the utterance-ids and speaker ids will be from the SDM or MDM data -cp data/$mic/train/{spk2utt,text,utt2spk} data/$mic/train_ihmdata/ +cp data/$mic/${train_set}/{spk2utt,text,utt2spk} data/$mic/${train_set}_ihmdata/ # the recording-ids will be from the IHM data. -cp data/ihm/train/{wav.scp,reco2file_and_channel} data/$mic/train_ihmdata/ +cp data/ihm/${train_set}/{wav.scp,reco2file_and_channel} data/$mic/${train_set}_ihmdata/ # map sdm/mdm segments to the ihm segments @@ -47,19 +49,17 @@ mic_base_upcase=$(echo $mic | sed 's/[0-9]//g' | tr 'a-z' 'A-Z') # It has lines like: # AMI_EN2001a_H02_FEO065_0021133_0021442 AMI_EN2001a_SDM_FEO065_0021133_0021442 -tmpdir=data/$mic/train_ihmdata/ +tmpdir=data/$mic/${train_set}_ihmdata/ -awk '{print $1, $1}' $tmpdir/ihmutt2utt # Map the 1st field of the segments file from the ihm data (the 1st field being # the utterance-id) to the corresponding SDM or MDM utterance-id. The other # fields remain the same (e.g. we want the recording-ids from the IHM data). -utils/apply_map.pl -f 1 $tmpdir/ihmutt2utt data/$mic/train_ihmdata/segments - -utils/fix_data_dir.sh data/$mic/train_ihmdata +utils/apply_map.pl -f 1 $tmpdir/ihmutt2utt data/$mic/${train_set}_ihmdata/segments -rm $tmpdir/ihmutt2utt +utils/fix_data_dir.sh data/$mic/${train_set}_ihmdata exit 0; diff --git a/egs/aspire/s5/conf/mfcc_hires_bp.conf b/egs/aspire/s5/conf/mfcc_hires_bp.conf new file mode 100644 index 00000000000..64292e8b489 --- /dev/null +++ b/egs/aspire/s5/conf/mfcc_hires_bp.conf @@ -0,0 +1,13 @@ +# config for high-resolution MFCC features, intended for neural network training. +# Note: we keep all cepstra, so it has the same info as filterbank features, +# but MFCC is more easily compressible (because less correlated) which is why +# we prefer this method. +--use-energy=false # use average of log energy, not energy. +--sample-frequency=8000 # Switchboard is sampled at 8kHz +--num-mel-bins=28 +--num-ceps=28 +--cepstral-lifter=0 +--low-freq=330 # low cutoff frequency for mel bins +--high-freq=-1000 # high cutoff frequently, relative to Nyquist of 4000 (=3000) + + diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh new file mode 100755 index 00000000000..1bfa08370e7 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir.sh @@ -0,0 +1,136 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +set -e +set -u +set -o pipefail + +. path.sh + +stage=0 +corruption_stage=-10 +corrupt_only=false + +# Data options +data_dir=data/train_si284 # Expecting whole data directory. +speed_perturb=true +num_data_reps=5 # Number of corrupted versions +snrs="20:10:15:5:0:-5" +foreground_snrs="20:10:15:5:0:-5" +background_snrs="20:10:15:5:0:-5" +base_rirs=simulated + +# Parallel options +reco_nj=40 +cmd=queue.pl + +# Options for feature extraction +mfcc_config=conf/mfcc_hires_bp.conf +feat_suffix=hires_bp + +reco_vad_dir= # Output of prepare_unsad_data.sh. + # If provided, the speech labels and deriv weights will be + # copied into the output data directory. + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + exit 1 +fi + +data_id=`basename ${data_dir}` + +rvb_opts=() +if [ "$base_rirs" == "simulated" ]; then + # This is the config for the system using simulated RIRs and point-source noises + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") + rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") + rvb_opts+=(--noise-set-parameters RIRS_NOISES/pointsource_noises/noise_list) +else + # This is the config for the JHU ASpIRE submission system + rvb_opts+=(--rir-set-parameters "1.0, RIRS_NOISES/real_rirs_isotropic_noises/rir_list") + rvb_opts+=(--noise-set-parameters RIRS_NOISES/real_rirs_isotropic_noises/noise_list) +fi + +corrupted_data_id=${data_id}_corrupted + +if [ $stage -le 1 ]; then + python steps/data/reverberate_data_dir.py \ + "${rvb_opts[@]}" \ + --prefix="rev" \ + --foreground-snrs=$foreground_snrs \ + --background-snrs=$background_snrs \ + --speech-rvb-probability=1 \ + --pointsource-noise-addition-probability=1 \ + --isotropic-noise-addition-probability=1 \ + --num-replications=$num_data_reps \ + --max-noises-per-minute=1 \ + data/${data_id} data/${corrupted_data_id} +fi + +corrupted_data_dir=data/${corrupted_data_id} + +if $speed_perturb; then + if [ $stage -le 2 ]; then + ## Assuming whole data directories + for x in $clean_data_dir $corrupted_data_dir $noise_data_dir; do + cp $x/reco2dur $x/utt2dur + utils/data/perturb_data_dir_speed_3way.sh $x ${x}_sp + done + fi + + corrupted_data_dir=${corrupted_data_dir}_sp + corrupted_data_id=${corrupted_data_id}_sp + + if [ $stage -le 3 ]; then + utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 \ + ${corrupted_data_dir} + fi +fi + +if $corrupt_only; then + echo "$0: Got corrupted data directory in ${corrupted_data_dir}" + exit 0 +fi + +mfccdir=`basename $mfcc_config` +mfccdir=${mfccdir%%.conf} + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage +fi + +if [ $stage -le 4 ]; then + utils/copy_data_dir.sh $corrupted_data_dir ${corrupted_data_dir}_$feat_suffix + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$cmd" --nj $reco_nj \ + $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir + steps/compute_cmvn_stats.sh --fake \ + $corrupted_data_dir exp/make_${feat_suffix}/${corrupted_data_id} $mfccdir +else + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix +fi + +if [ $stage -le 8 ]; then + if [ ! -z "$reco_vad_dir" ]; then + if [ ! -f $reco_vad_dir/speech_feat.scp ]; then + echo "$0: Could not find file $reco_vad_dir/speech_feat.scp" + exit 1 + fi + + cat $reco_vad_dir/speech_feat.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps | \ + sort -k1,1 > ${corrupted_data_dir}/speech_feat.scp + + cat $reco_vad_dir/deriv_weights.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps | \ + sort -k1,1 > ${corrupted_data_dir}/deriv_weights.scp + fi +fi + +exit 0 diff --git a/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh new file mode 100755 index 00000000000..214cba347da --- /dev/null +++ b/egs/aspire/s5/local/segmentation/do_corruption_data_dir_music.sh @@ -0,0 +1,203 @@ +#!/bin/bash +set -e +set -u +set -o pipefail + +. path.sh +. cmd.sh + +num_data_reps=5 +data_dir=data/train_si284 + +nj=40 +reco_nj=40 + +stage=0 +corruption_stage=-10 + +pad_silence=false + +mfcc_config=conf/mfcc_hires_bp_vh.conf +feat_suffix=hires_bp_vh +mfcc_irm_config=conf/mfcc_hires_bp.conf + +dry_run=false +corrupt_only=false +speed_perturb=true + +reco_vad_dir= + +max_jobs_run=20 + +foreground_snrs="5:2:1:0:-2:-5:-10:-20" +background_snrs="5:2:1:0:-2:-5:-10:-20" + +. utils/parse_options.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + exit 1 +fi + +data_id=`basename ${data_dir}` + +rvb_opts=() +# This is the config for the system using simulated RIRs and point-source noises +rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/smallroom/rir_list") +rvb_opts+=(--rir-set-parameters "0.5, RIRS_NOISES/simulated_rirs/mediumroom/rir_list") +rvb_opts+=(--noise-set-parameters RIRS_NOISES/music/music_list) + +music_utt2num_frames=RIRS_NOISES/music/split_utt2num_frames + +corrupted_data_id=${data_id}_music_corrupted +orig_corrupted_data_id=$corrupted_data_id + +if [ $stage -le 1 ]; then + python steps/data/reverberate_data_dir.py \ + "${rvb_opts[@]}" \ + --prefix="music" \ + --foreground-snrs=$foreground_snrs \ + --background-snrs=$background_snrs \ + --speech-rvb-probability=1 \ + --pointsource-noise-addition-probability=1 \ + --isotropic-noise-addition-probability=1 \ + --num-replications=$num_data_reps \ + --max-noises-per-minute=5 \ + data/${data_id} data/${corrupted_data_id} +fi + +if $dry_run; then + exit 0 +fi + +corrupted_data_dir=data/${corrupted_data_id} +orig_corrupted_data_dir=$corrupted_data_dir + +if $speed_perturb; then + if [ $stage -le 2 ]; then + ## Assuming whole data directories + for x in $corrupted_data_dir; do + cp $x/reco2dur $x/utt2dur + utils/data/perturb_data_dir_speed_3way.sh $x ${x}_sp + done + fi + + corrupted_data_dir=${corrupted_data_dir}_sp + corrupted_data_id=${corrupted_data_id}_sp + + if [ $stage -le 3 ]; then + utils/data/perturb_data_dir_volume.sh --scale-low 0.03125 --scale-high 2 \ + ${corrupted_data_dir} + fi +fi + +if $corrupt_only; then + echo "$0: Got corrupted data directory in ${corrupted_data_dir}" + exit 0 +fi + +mfccdir=`basename $mfcc_config` +mfccdir=${mfccdir%%.conf} + +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage +fi + +if [ $stage -le 4 ]; then + if [ ! -z $feat_suffix ]; then + utils/copy_data_dir.sh $corrupted_data_dir ${corrupted_data_dir}_$feat_suffix + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix + fi + steps/make_mfcc.sh --mfcc-config $mfcc_config \ + --cmd "$train_cmd" --nj $reco_nj \ + $corrupted_data_dir exp/make_${mfccdir}/${corrupted_data_id} $mfccdir + steps/compute_cmvn_stats.sh --fake \ + $corrupted_data_dir exp/make_${mfccdir}/${corrupted_data_id} $mfccdir +else + if [ ! -z $feat_suffix ]; then + corrupted_data_dir=${corrupted_data_dir}_$feat_suffix + fi +fi + +if [ $stage -le 8 ]; then + if [ ! -z "$reco_vad_dir" ]; then + if [ ! -f $reco_vad_dir/speech_feat.scp ]; then + echo "$0: Could not find file $reco_vad_dir/speech_feat.scp" + exit 1 + fi + + cat $reco_vad_dir/speech_feat.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps "music" | \ + sort -k1,1 > ${corrupted_data_dir}/speech_feat.scp + + cat $reco_vad_dir/deriv_weights.scp | \ + steps/segmentation/get_reverb_scp.pl -f 1 $num_data_reps "music" | \ + sort -k1,1 > ${corrupted_data_dir}/deriv_weights.scp + fi +fi + +# music_dir is without speed perturbation +music_dir=exp/make_music_labels/${orig_corrupted_data_id} +music_data_dir=$music_dir/music_data + +mkdir -p $music_data_dir + +if [ $stage -le 10 ]; then + utils/data/get_utt2num_frames.sh $corrupted_data_dir + utils/split_data.sh --per-reco ${orig_corrupted_data_dir} $reco_nj + + cp $orig_corrupted_data_dir/wav.scp $music_data_dir + + # Combine the VAD from the base recording and the VAD from the overlapping segments + # to create per-frame labels of the number of overlapping speech segments + # Unreliable segments are regions where no VAD labels were available for the + # overlapping segments. These can be later removed by setting deriv weights to 0. + $train_cmd JOB=1:$reco_nj $music_dir/log/get_music_seg.JOB.log \ + segmentation-init-from-additive-signals-info --lengths-rspecifier=ark,t:$corrupted_data_dir/utt2num_frames \ + --additive-signals-segmentation-rspecifier="ark:segmentation-init-from-lengths ark:$music_utt2num_frames ark:- |" \ + "ark:utils/filter_scp.pl ${orig_corrupted_data_dir}/split${reco_nj}reco/JOB/utt2spk $corrupted_data_dir/utt2num_frames | segmentation-init-from-lengths --label=1 ark:- ark:- | segmentation-post-process --remove-labels=1 ark:- ark:- |" \ + ark,t:$orig_corrupted_data_dir/additive_signals_info.txt \ + ark:- \| \ + segmentation-post-process --merge-adjacent-segments ark:- \ + ark:- \| \ + segmentation-to-segments ark:- ark:$music_data_dir/utt2spk.JOB \ + $music_data_dir/segments.JOB + + for n in `seq $reco_nj`; do cat $music_data_dir/utt2spk.$n; done > $music_data_dir/utt2spk + for n in `seq $reco_nj`; do cat $music_data_dir/segments.$n; done > $music_data_dir/segments + + utils/fix_data_dir.sh $music_data_dir + + if $speed_perturb; then + utils/data/perturb_data_dir_speed_3way.sh $music_data_dir ${music_data_dir}_sp + fi +fi + +if $speed_perturb; then + music_data_dir=${music_data_dir}_sp +fi + +label_dir=music_labels + +mkdir -p $label_dir +label_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $label_dir ${PWD}` + +if [ $stage -le 11 ]; then + utils/split_data.sh --per-reco ${music_data_dir} $reco_nj + + $train_cmd JOB=1:$reco_nj $music_dir/log/get_music_labels.JOB.log \ + utils/data/get_reco2utt.sh ${music_data_dir}/split${reco_nj}reco/JOB '&&' \ + segmentation-init-from-segments --shift-to-zero=false \ + ${music_data_dir}/split${reco_nj}reco/JOB/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:${music_data_dir}/split${reco_nj}reco/JOB/reco2utt \ + ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${corrupted_data_dir}/utt2num_frames ark:- \ + ark,scp:$label_dir/music_labels_${corrupted_data_id}.JOB.ark,$label_dir/music_labels_${corrupted_data_id}.JOB.scp +fi + +for n in `seq $reco_nj`; do + cat $label_dir/music_labels_${corrupted_data_id}.$n.scp +done > ${corrupted_data_dir}/music_labels.scp + +exit 0 diff --git a/egs/aspire/s5/local/segmentation/make_musan_music.py b/egs/aspire/s5/local/segmentation/make_musan_music.py new file mode 100755 index 00000000000..5d13078de63 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/make_musan_music.py @@ -0,0 +1,69 @@ +#! /usr/bin/env python + +from __future__ import print_function +import argparse +import os + + +def _get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--use-vocals", type=str, default="false", + choices=["true", "false"], + help="If true, also add music with vocals in the " + "output music-set-parameters") + parser.add_argument("root_dir", type=str, + help="Root directory of MUSAN corpus") + parser.add_argument("music_list", type=argparse.FileType('w'), + help="Convert music list into noise-set-paramters " + "for steps/data/reverberate_data_dir.py") + + args = parser.parse_args() + + args.use_vocals = True if args.use_vocals == "true" else False + return args + + +def read_vocals(annotations): + vocals = {} + for line in open(annotations): + parts = line.strip().split() + if parts[2] == "Y": + vocals[parts[0]] = True + return vocals + + +def write_music(utt, file_path, music_list): + print ('{utt} {file_path}'.format( + utt=utt, file_path=file_path), file=music_list) + + +def prepare_music_set(root_dir, use_vocals, music_list): + vocals = {} + music_dir = os.path.join(root_dir, "music") + for root, dirs, files in os.walk(music_dir): + if os.path.exists(os.path.join(root, "ANNOTATIONS")): + vocals = read_vocals(os.path.join(root, "ANNOTATIONS")) + + for f in files: + file_path = os.path.join(root, f) + if f.endswith(".wav"): + utt = str(f).replace(".wav", "") + if not use_vocals and utt in vocals: + continue + write_music(utt, file_path, music_list) + music_list.close() + + +def main(): + args = _get_args() + + try: + prepare_music_set(args.root_dir, args.use_vocals, + args.music_list) + finally: + args.music_list.close() + + +if __name__ == '__main__': + main() diff --git a/egs/aspire/s5/local/segmentation/make_sad_tdnn_configs.py b/egs/aspire/s5/local/segmentation/make_sad_tdnn_configs.py new file mode 100755 index 00000000000..e859a3593ce --- /dev/null +++ b/egs/aspire/s5/local/segmentation/make_sad_tdnn_configs.py @@ -0,0 +1,616 @@ +#!/usr/bin/env python + +# we're using python 3.x style print but want it to work in python 2.x, +from __future__ import print_function +import os +import argparse +import shlex +import sys +import warnings +import copy +import imp +import ast + +nodes = imp.load_source('', 'steps/nnet3/components.py') +import libs.common as common_lib + +def GetArgs(): + # we add compulsary arguments as named arguments for readability + parser = argparse.ArgumentParser(description="Writes config files and variables " + "for TDNNs creation and training", + epilog="See steps/nnet3/tdnn/train.sh for example.") + + # Only one of these arguments can be specified, and one of them has to + # be compulsarily specified + feat_group = parser.add_mutually_exclusive_group(required = True) + feat_group.add_argument("--feat-dim", type=int, + help="Raw feature dimension, e.g. 13") + feat_group.add_argument("--feat-dir", type=str, + help="Feature directory, from which we derive the feat-dim") + + # only one of these arguments can be specified + ivector_group = parser.add_mutually_exclusive_group(required = False) + ivector_group.add_argument("--ivector-dim", type=int, + help="iVector dimension, e.g. 100", default=0) + ivector_group.add_argument("--ivector-dir", type=str, + help="iVector dir, which will be used to derive the ivector-dim ", default=None) + + num_target_group = parser.add_mutually_exclusive_group(required = True) + num_target_group.add_argument("--num-targets", type=int, + help="number of network targets (e.g. num-pdf-ids/num-leaves)") + num_target_group.add_argument("--ali-dir", type=str, + help="alignment directory, from which we derive the num-targets") + num_target_group.add_argument("--tree-dir", type=str, + help="directory with final.mdl, from which we derive the num-targets") + num_target_group.add_argument("--output-node-parameters", type=str, action='append', + dest='output_node_para_array', + help = "Define output nodes' and their parameters like output-suffix, dim, objective-type etc") + # CNN options + parser.add_argument('--cnn.layer', type=str, action='append', dest = "cnn_layer", + help="CNN parameters at each CNN layer, e.g. --filt-x-dim=3 --filt-y-dim=8 " + "--filt-x-step=1 --filt-y-step=1 --num-filters=256 --pool-x-size=1 --pool-y-size=3 " + "--pool-z-size=1 --pool-x-step=1 --pool-y-step=3 --pool-z-step=1, " + "when CNN layers are used, no LDA will be added", default = None) + parser.add_argument("--cnn.bottleneck-dim", type=int, dest = "cnn_bottleneck_dim", + help="Output dimension of the linear layer at the CNN output " + "for dimension reduction, e.g. 256." + "The default zero means this layer is not needed.", default=0) + + # General neural network options + parser.add_argument("--splice-indexes", type=str, required = True, + help="Splice indexes at each layer, e.g. '-3,-2,-1,0,1,2,3' " + "If CNN layers are used the first set of splice indexes will be used as input " + "to the first CNN layer and later splice indexes will be interpreted as indexes " + "for the TDNNs.") + parser.add_argument("--add-lda", type=str, action=common_lib.StrToBoolAction, + help="If \"true\" an LDA matrix computed from the input features " + "(spliced according to the first set of splice-indexes) will be used as " + "the first Affine layer. This affine layer's parameters are fixed during training. " + "This variable needs to be set to \"false\" when using dense-targets.\n" + "If --cnn.layer is specified this option will be forced to \"false\".", + default=True, choices = ["false", "true"]) + + parser.add_argument("--include-log-softmax", type=str, action=common_lib.StrToBoolAction, + help="add the final softmax layer ", default=True, choices = ["false", "true"]) + parser.add_argument("--add-final-sigmoid", type=str, action=common_lib.StrToBoolAction, + help="add a final sigmoid layer as alternate to log-softmax-layer. " + "Can only be used if include-log-softmax is false. " + "This is useful in cases where you want the output to be " + "like probabilities between 0 and 1. Typically the nnet " + "is trained with an objective such as quadratic", + default=False, choices = ["false", "true"]) + + parser.add_argument("--objective-type", type=str, + help = "the type of objective; i.e. quadratic or linear", + default="linear", choices = ["linear", "quadratic"]) + parser.add_argument("--xent-regularize", type=float, + help="For chain models, if nonzero, add a separate output for cross-entropy " + "regularization (with learning-rate-factor equal to the inverse of this)", + default=0.0) + parser.add_argument("--final-layer-normalize-target", type=float, + help="RMS target for final layer (set to <1 if final layer learns too fast", + default=1.0) + parser.add_argument("--subset-dim", type=int, default=0, + help="dimension of the subset of units to be sent to the central frame") + parser.add_argument("--pnorm-input-dim", type=int, + help="input dimension to p-norm nonlinearities") + parser.add_argument("--pnorm-output-dim", type=int, + help="output dimension of p-norm nonlinearities") + relu_dim_group = parser.add_mutually_exclusive_group(required = False) + relu_dim_group.add_argument("--relu-dim", type=int, + help="dimension of all ReLU nonlinearity layers") + relu_dim_group.add_argument("--relu-dim-final", type=int, + help="dimension of the last ReLU nonlinearity layer. Dimensions increase geometrically from the first through the last ReLU layer.", default=None) + parser.add_argument("--relu-dim-init", type=int, + help="dimension of the first ReLU nonlinearity layer. Dimensions increase geometrically from the first through the last ReLU layer.", default=None) + + parser.add_argument("--self-repair-scale-nonlinearity", type=float, + help="A non-zero value activates the self-repair mechanism in the sigmoid and tanh non-linearities of the LSTM", default=None) + + + parser.add_argument("--use-presoftmax-prior-scale", type=str, action=common_lib.StrToBoolAction, + help="if true, a presoftmax-prior-scale is added", + choices=['true', 'false'], default = True) + + # Options to convert input MFCC into Fbank features. This is useful when a + # LDA layer is not added (such as when using dense targets) + parser.add_argument("--cnn.cepstral-lifter", type=float, dest = "cepstral_lifter", + help="The factor used for determining the liftering vector in the production of MFCC. " + "User has to ensure that it matches the lifter used in MFCC generation, " + "e.g. 22.0", default=22.0) + + parser.add_argument("config_dir", + help="Directory to write config files and variables") + + print(' '.join(sys.argv)) + + args = parser.parse_args() + args = CheckArgs(args) + + return args + +def CheckArgs(args): + if not os.path.exists(args.config_dir): + os.makedirs(args.config_dir) + + ## Check arguments. + if args.feat_dir is not None: + args.feat_dim = common_lib.get_feat_dim(args.feat_dir) + + if args.ivector_dir is not None: + args.ivector_dim = common_lib.get_ivector_dim(args.ivector_dir) + + if not args.feat_dim > 0: + raise Exception("feat-dim has to be postive") + + if len(args.output_node_para_array) == 0: + if args.ali_dir is not None: + args.num_targets = common_lib.get_number_of_leaves_from_tree(args.ali_dir) + elif args.tree_dir is not None: + args.num_targets = common_lib.get_number_of_leaves_from_tree(args.tree_dir) + if not args.num_targets > 0: + print(args.num_targets) + raise Exception("num_targets has to be positive") + args.output_node_para_array.append( + "--dim={0} --objective-type={1} --include-log-softmax={2} --add-final-sigmoid={3} --xent-regularize={4}".format( + args.num_targets, args.objective_type, + "true" if args.include_log_softmax else "false", + "true" if args.add_final_sigmoid else "false", + args.xent_regularize)) + + if not args.ivector_dim >= 0: + raise Exception("ivector-dim has to be non-negative") + + if (args.subset_dim < 0): + raise Exception("--subset-dim has to be non-negative") + + if not args.relu_dim is None: + if not args.pnorm_input_dim is None or not args.pnorm_output_dim is None or not args.relu_dim_init is None: + raise Exception("--relu-dim argument not compatible with " + "--pnorm-input-dim or --pnorm-output-dim or --relu-dim-init options"); + args.nonlin_input_dim = args.relu_dim + args.nonlin_output_dim = args.relu_dim + args.nonlin_output_dim_final = None + args.nonlin_output_dim_init = None + args.nonlin_type = 'relu' + + elif not args.relu_dim_final is None: + if not args.pnorm_input_dim is None or not args.pnorm_output_dim is None: + raise Exception("--relu-dim-final argument not compatible with " + "--pnorm-input-dim or --pnorm-output-dim options") + if args.relu_dim_init is None: + raise Exception("--relu-dim-init argument should also be provided with --relu-dim-final") + if args.relu_dim_init > args.relu_dim_final: + raise Exception("--relu-dim-init has to be no larger than --relu-dim-final") + args.nonlin_input_dim = None + args.nonlin_output_dim = None + args.nonlin_output_dim_final = args.relu_dim_final + args.nonlin_output_dim_init = args.relu_dim_init + args.nonlin_type = 'relu' + + else: + if not args.relu_dim_init is None: + raise Exception("--relu-dim-final argument not compatible with " + "--pnorm-input-dim or --pnorm-output-dim options") + if not args.pnorm_input_dim > 0 or not args.pnorm_output_dim > 0: + raise Exception("--relu-dim not set, so expected --pnorm-input-dim and " + "--pnorm-output-dim to be provided."); + args.nonlin_input_dim = args.pnorm_input_dim + args.nonlin_output_dim = args.pnorm_output_dim + if (args.nonlin_input_dim < args.nonlin_output_dim) or (args.nonlin_input_dim % args.nonlin_output_dim != 0): + raise Exception("Invalid --pnorm-input-dim {0} and --pnorm-output-dim {1}".format(args.nonlin_input_dim, args.nonlin_output_dim)) + args.nonlin_output_dim_final = None + args.nonlin_output_dim_init = None + args.nonlin_type = 'pnorm' + + if args.add_lda and args.cnn_layer is not None: + args.add_lda = False + warnings.warn("--add-lda is set to false as CNN layers are used.") + + return args + +def AddConvMaxpLayer(config_lines, name, input, args): + if '3d-dim' not in input: + raise Exception("The input to AddConvMaxpLayer() needs '3d-dim' parameters.") + + input = nodes.AddConvolutionLayer(config_lines, name, input, + input['3d-dim'][0], input['3d-dim'][1], input['3d-dim'][2], + args.filt_x_dim, args.filt_y_dim, + args.filt_x_step, args.filt_y_step, + args.num_filters, input['vectorization']) + + if args.pool_x_size > 1 or args.pool_y_size > 1 or args.pool_z_size > 1: + input = nodes.AddMaxpoolingLayer(config_lines, name, input, + input['3d-dim'][0], input['3d-dim'][1], input['3d-dim'][2], + args.pool_x_size, args.pool_y_size, args.pool_z_size, + args.pool_x_step, args.pool_y_step, args.pool_z_step) + + return input + +# The ivectors are processed through an affine layer parallel to the CNN layers, +# then concatenated with the CNN output and passed to the deeper part of the network. +def AddCnnLayers(config_lines, cnn_layer, cnn_bottleneck_dim, cepstral_lifter, config_dir, feat_dim, splice_indexes=[0], ivector_dim=0): + cnn_args = ParseCnnString(cnn_layer) + num_cnn_layers = len(cnn_args) + # We use an Idct layer here to convert MFCC to FBANK features + common_lib.write_idct_matrix(feat_dim, cepstral_lifter, config_dir.strip() + "/idct.mat") + prev_layer_output = {'descriptor': "input", + 'dimension': feat_dim} + prev_layer_output = nodes.AddFixedAffineLayer(config_lines, "Idct", prev_layer_output, config_dir.strip() + '/idct.mat') + + list = [('Offset({0}, {1})'.format(prev_layer_output['descriptor'],n) if n != 0 else prev_layer_output['descriptor']) for n in splice_indexes] + splice_descriptor = "Append({0})".format(", ".join(list)) + cnn_input_dim = len(splice_indexes) * feat_dim + prev_layer_output = {'descriptor': splice_descriptor, + 'dimension': cnn_input_dim, + '3d-dim': [len(splice_indexes), feat_dim, 1], + 'vectorization': 'yzx'} + + for cl in range(0, num_cnn_layers): + prev_layer_output = AddConvMaxpLayer(config_lines, "L{0}".format(cl), prev_layer_output, cnn_args[cl]) + + if cnn_bottleneck_dim > 0: + prev_layer_output = nodes.AddAffineLayer(config_lines, "cnn-bottleneck", prev_layer_output, cnn_bottleneck_dim, "") + + if ivector_dim > 0: + iv_layer_output = {'descriptor': 'ReplaceIndex(ivector, t, 0)', + 'dimension': ivector_dim} + iv_layer_output = nodes.AddAffineLayer(config_lines, "ivector", iv_layer_output, ivector_dim, "") + prev_layer_output['descriptor'] = 'Append({0}, {1})'.format(prev_layer_output['descriptor'], iv_layer_output['descriptor']) + prev_layer_output['dimension'] = prev_layer_output['dimension'] + iv_layer_output['dimension'] + + return prev_layer_output + +def PrintConfig(file_name, config_lines): + f = open(file_name, 'w') + f.write("\n".join(config_lines['components'])+"\n") + f.write("\n#Component nodes\n") + f.write("\n".join(config_lines['component-nodes'])+"\n") + f.close() + +def ParseCnnString(cnn_param_string_list): + cnn_parser = argparse.ArgumentParser(description="cnn argument parser") + + cnn_parser.add_argument("--filt-x-dim", required=True, type=int) + cnn_parser.add_argument("--filt-y-dim", required=True, type=int) + cnn_parser.add_argument("--filt-x-step", type=int, default = 1) + cnn_parser.add_argument("--filt-y-step", type=int, default = 1) + cnn_parser.add_argument("--num-filters", required=True, type=int) + cnn_parser.add_argument("--pool-x-size", type=int, default = 1) + cnn_parser.add_argument("--pool-y-size", type=int, default = 1) + cnn_parser.add_argument("--pool-z-size", type=int, default = 1) + cnn_parser.add_argument("--pool-x-step", type=int, default = 1) + cnn_parser.add_argument("--pool-y-step", type=int, default = 1) + cnn_parser.add_argument("--pool-z-step", type=int, default = 1) + + cnn_args = [] + for cl in range(0, len(cnn_param_string_list)): + cnn_args.append(cnn_parser.parse_args(shlex.split(cnn_param_string_list[cl]))) + + return cnn_args + +def ParseSpliceString(splice_indexes): + splice_array = [] + left_context = 0 + right_context = 0 + split_on_spaces = splice_indexes.split(); # we already checked the string is nonempty. + if len(split_on_spaces) < 1: + raise Exception("invalid splice-indexes argument, too short: " + + splice_indexes) + try: + for string in split_on_spaces: + this_splices = string.split(",") + if len(this_splices) < 1: + raise Exception("invalid splice-indexes argument, too-short element: " + + splice_indexes) + # the rest of this block updates left_context and right_context, and + # does some checking. + leftmost_splice = 10000 + rightmost_splice = -10000 + + int_list = [] + for s in this_splices: + try: + n = int(s) + if n < leftmost_splice: + leftmost_splice = n + if n > rightmost_splice: + rightmost_splice = n + int_list.append(n) + except ValueError: + #if len(splice_array) == 0: + # raise Exception("First dimension of splicing array must not have averaging [yet]") + try: + x = nodes.StatisticsConfig(s, { 'dimension':100, + 'descriptor': 'foo'} ) + int_list.append(s) + except Exception as e: + raise Exception("The following element of the splicing array is not a valid specifier " + "of statistics: {0}\nGot {1}".format(s, str(e))) + splice_array.append(int_list) + + if leftmost_splice == 10000 or rightmost_splice == -10000: + raise Exception("invalid element of --splice-indexes: " + string) + left_context += -leftmost_splice + right_context += rightmost_splice + except ValueError as e: + raise Exception("invalid --splice-indexes argument " + args.splice_indexes + " " + str(e)) + + left_context = max(0, left_context) + right_context = max(0, right_context) + + return {'left_context':left_context, + 'right_context':right_context, + 'splice_indexes':splice_array, + 'num_hidden_layers':len(splice_array) + } + +def AddPriorsAccumulator(config_lines, name, input): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + components.append("component name={0}_softmax type=SoftmaxComponent dim={1}".format(name, input['dimension'])) + component_nodes.append("component-node name={0}_softmax component={0}_softmax input={1}".format(name, input['descriptor'])) + + return {'descriptor': '{0}_softmax'.format(name), + 'dimension': input['dimension']} + +def AddFinalLayer(config_lines, input, output_dim, + ng_affine_options = " param-stddev=0 bias-stddev=0 ", + label_delay=None, + use_presoftmax_prior_scale = False, + prior_scale_file = None, + include_log_softmax = True, + add_final_sigmoid = False, + name_affix = None, + objective_type = "linear", + objective_scale = 1.0, + objective_scales_vec = None): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + if name_affix is not None: + final_node_prefix = 'Final-' + str(name_affix) + else: + final_node_prefix = 'Final' + + prev_layer_output = nodes.AddAffineLayer(config_lines, + final_node_prefix , input, output_dim, + ng_affine_options) + if include_log_softmax: + if use_presoftmax_prior_scale : + components.append('component name={0}-fixed-scale type=FixedScaleComponent scales={1}'.format(final_node_prefix, prior_scale_file)) + component_nodes.append('component-node name={0}-fixed-scale component={0}-fixed-scale input={1}'.format(final_node_prefix, + prev_layer_output['descriptor'])) + prev_layer_output['descriptor'] = "{0}-fixed-scale".format(final_node_prefix) + prev_layer_output = nodes.AddSoftmaxLayer(config_lines, final_node_prefix, prev_layer_output) + + elif add_final_sigmoid: + # Useful when you need the final outputs to be probabilities + # between 0 and 1. + # Usually used with an objective-type such as "quadratic" + prev_layer_output = nodes.AddSigmoidLayer(config_lines, final_node_prefix, prev_layer_output) + + # we use the same name_affix as a prefix in for affine/scale nodes but as a + # suffix for output node + if (objective_scale != 1.0 or objective_scales_vec is not None): + prev_layer_output = nodes.AddGradientScaleLayer(config_lines, final_node_prefix, prev_layer_output, objective_scale, objective_scales_vec) + + nodes.AddOutputLayer(config_lines, prev_layer_output, label_delay, suffix = name_affix, objective_type = objective_type) + +def AddOutputLayers(config_lines, prev_layer_output, output_nodes, + ng_affine_options = "", label_delay = 0): + + for o in output_nodes: + # make the intermediate config file for layerwise discriminative + # training + AddFinalLayer(config_lines, prev_layer_output, o.dim, + ng_affine_options, label_delay = label_delay, + include_log_softmax = o.include_log_softmax, + add_final_sigmoid = o.add_final_sigmoid, + objective_type = o.objective_type, + name_affix = o.output_suffix) + + if o.xent_regularize != 0.0: + nodes.AddFinalLayer(config_lines, prev_layer_output, o.dim, + include_log_softmax = True, + label_delay = label_delay, + name_affix = o.output_suffix + '_xent') + +# The function signature of MakeConfigs is changed frequently as it is intended for local use in this script. +def MakeConfigs(config_dir, splice_indexes_string, + cnn_layer, cnn_bottleneck_dim, cepstral_lifter, + feat_dim, ivector_dim, add_lda, + nonlin_type, nonlin_input_dim, nonlin_output_dim, subset_dim, + nonlin_output_dim_init, nonlin_output_dim_final, + use_presoftmax_prior_scale, final_layer_normalize_target, + output_nodes, self_repair_scale): + + parsed_splice_output = ParseSpliceString(splice_indexes_string.strip()) + + left_context = parsed_splice_output['left_context'] + right_context = parsed_splice_output['right_context'] + num_hidden_layers = parsed_splice_output['num_hidden_layers'] + splice_indexes = parsed_splice_output['splice_indexes'] + input_dim = len(parsed_splice_output['splice_indexes'][0]) + feat_dim + ivector_dim + + prior_scale_file = '{0}/presoftmax_prior_scale.vec'.format(config_dir) + + config_lines = {'components':[], 'component-nodes':[]} + + config_files={} + prev_layer_output = nodes.AddInputLayer(config_lines, feat_dim, splice_indexes[0], + ivector_dim) + + # Add the init config lines for estimating the preconditioning matrices + init_config_lines = copy.deepcopy(config_lines) + init_config_lines['components'].insert(0, '# Config file for initializing neural network prior to') + init_config_lines['components'].insert(0, '# preconditioning matrix computation') + + for o in output_nodes: + nodes.AddOutputLayer(init_config_lines, prev_layer_output, + objective_type = o.objective_type, suffix = o.output_suffix) + + config_files[config_dir + '/init.config'] = init_config_lines + + if cnn_layer is not None: + prev_layer_output = AddCnnLayers(config_lines, cnn_layer, cnn_bottleneck_dim, cepstral_lifter, config_dir, + feat_dim, splice_indexes[0], ivector_dim) + + # add_lda needs to be set "false" when using dense targets, + # or if the task is not a simple classification task + # (e.g. regression, multi-task) + if add_lda: + prev_layer_output = nodes.AddLdaLayer(config_lines, "L0", prev_layer_output, config_dir + '/lda.mat') + + left_context = 0 + right_context = 0 + # we moved the first splice layer to before the LDA.. + # so the input to the first affine layer is going to [0] index + splice_indexes[0] = [0] + + if not nonlin_output_dim is None: + nonlin_output_dims = [nonlin_output_dim] * num_hidden_layers + elif nonlin_output_dim_init < nonlin_output_dim_final and num_hidden_layers == 1: + raise Exception("num-hidden-layers has to be greater than 1 if relu-dim-init and relu-dim-final is different.") + else: + # computes relu-dim for each hidden layer. They increase geometrically across layers + factor = pow(float(nonlin_output_dim_final) / nonlin_output_dim_init, 1.0 / (num_hidden_layers - 1)) if num_hidden_layers > 1 else 1 + nonlin_output_dims = [int(round(nonlin_output_dim_init * pow(factor, i))) for i in range(0, num_hidden_layers)] + assert(nonlin_output_dims[-1] >= nonlin_output_dim_final - 1 and nonlin_output_dims[-1] <= nonlin_output_dim_final + 1) # due to rounding error + nonlin_output_dims[-1] = nonlin_output_dim_final # It ensures that the dim of the last hidden layer is exactly the same as what is specified + + for i in range(0, num_hidden_layers): + # make the intermediate config file for layerwise discriminative training + + # prepare the spliced input + if not (len(splice_indexes[i]) == 1 and splice_indexes[i][0] == 0): + try: + zero_index = splice_indexes[i].index(0) + except ValueError: + zero_index = None + # I just assume the prev_layer_output_descriptor is a simple forwarding descriptor + prev_layer_output_descriptor = prev_layer_output['descriptor'] + subset_output = prev_layer_output + if subset_dim > 0: + # if subset_dim is specified the script expects a zero in the splice indexes + assert(zero_index is not None) + subset_node_config = ("dim-range-node name=Tdnn_input_{0} " + "input-node={1} dim-offset={2} dim={3}".format( + i, prev_layer_output_descriptor, 0, subset_dim)) + subset_output = {'descriptor' : 'Tdnn_input_{0}'.format(i), + 'dimension' : subset_dim} + config_lines['component-nodes'].append(subset_node_config) + appended_descriptors = [] + appended_dimension = 0 + for j in range(len(splice_indexes[i])): + if j == zero_index: + appended_descriptors.append(prev_layer_output['descriptor']) + appended_dimension += prev_layer_output['dimension'] + continue + try: + offset = int(splice_indexes[i][j]) + # it's an integer offset. + appended_descriptors.append('Offset({0}, {1})'.format( + subset_output['descriptor'], splice_indexes[i][j])) + appended_dimension += subset_output['dimension'] + except ValueError: + # it's not an integer offset, so assume it specifies the + # statistics-extraction. + stats = nodes.StatisticsConfig(splice_indexes[i][j], prev_layer_output) + stats_layer = stats.AddLayer(config_lines, "Tdnn_stats_{0}".format(i)) + appended_descriptors.append(stats_layer['descriptor']) + appended_dimension += stats_layer['dimension'] + + prev_layer_output = {'descriptor' : "Append({0})".format(" , ".join(appended_descriptors)), + 'dimension' : appended_dimension} + else: + # this is a normal affine node + pass + + if nonlin_type == "relu": + prev_layer_output = nodes.AddAffRelNormLayer(config_lines, "Tdnn_{0}".format(i), + prev_layer_output, nonlin_output_dims[i], + self_repair_scale=self_repair_scale, + norm_target_rms=1.0 if i < num_hidden_layers -1 else final_layer_normalize_target) + elif nonlin_type == "pnorm": + prev_layer_output = nodes.AddAffPnormLayer(config_lines, "Tdnn_{0}".format(i), + prev_layer_output, nonlin_input_dim, nonlin_output_dim, + norm_target_rms=1.0 if i < num_hidden_layers -1 else final_layer_normalize_target) + else: + raise Exception("Unknown nonlinearity type") + # a final layer is added after each new layer as we are generating + # configs for layer-wise discriminative training + + AddOutputLayers(config_lines, prev_layer_output, output_nodes) + + config_files['{0}/layer{1}.config'.format(config_dir, i + 1)] = config_lines + config_lines = {'components':[], 'component-nodes':[]} + + left_context += int(parsed_splice_output['left_context']) + right_context += int(parsed_splice_output['right_context']) + + # write the files used by other scripts like steps/nnet3/get_egs.sh + f = open(config_dir + "/vars", "w") + print('model_left_context=' + str(left_context), file=f) + print('model_right_context=' + str(right_context), file=f) + print('num_hidden_layers=' + str(num_hidden_layers), file=f) + print('add_lda=' + ('true' if add_lda else 'false'), file=f) + f.close() + + # printing out the configs + # init.config used to train lda-mllt train + for key in config_files.keys(): + PrintConfig(key, config_files[key]) + +def ParseOutputNodesParameters(para_array): + output_parser = argparse.ArgumentParser() + output_parser.add_argument('--output-suffix', type=str, action=common_lib.NullstrToNoneAction, + help = "Name of the output node. e.g. output-xent") + output_parser.add_argument('--dim', type=int, required=True, + help = "Dimension of the output node") + output_parser.add_argument("--include-log-softmax", type=str, action=common_lib.StrToBoolAction, + help="add the final softmax layer ", + default=True, choices = ["false", "true"]) + output_parser.add_argument("--add-final-sigmoid", type=str, action=common_lib.StrToBoolAction, + help="add a sigmoid layer as the final layer. Applicable only if skip-final-softmax is true.", + choices=['true', 'false'], default = False) + output_parser.add_argument("--objective-type", type=str, default="linear", + choices = ["linear", "quadratic","xent-per-dim"], + help = "the type of objective; i.e. quadratic or linear") + output_parser.add_argument("--xent-regularize", type=float, + help="For chain models, if nonzero, add a separate output for cross-entropy " + "regularization (with learning-rate-factor equal to the inverse of this)", + default=0.0) + + output_nodes = [ output_parser.parse_args(shlex.split(x)) for x in para_array ] + + return output_nodes + +def Main(): + args = GetArgs() + + output_nodes = ParseOutputNodesParameters(args.output_node_para_array) + + MakeConfigs(config_dir = args.config_dir, + feat_dim = args.feat_dim, ivector_dim = args.ivector_dim, + add_lda = args.add_lda, + cepstral_lifter = args.cepstral_lifter, + splice_indexes_string = args.splice_indexes, + cnn_layer = args.cnn_layer, + cnn_bottleneck_dim = args.cnn_bottleneck_dim, + nonlin_type = args.nonlin_type, + nonlin_input_dim = args.nonlin_input_dim, + nonlin_output_dim = args.nonlin_output_dim, + subset_dim = args.subset_dim, + nonlin_output_dim_init = args.nonlin_output_dim_init, + nonlin_output_dim_final = args.nonlin_output_dim_final, + use_presoftmax_prior_scale = args.use_presoftmax_prior_scale, + final_layer_normalize_target = args.final_layer_normalize_target, + output_nodes = output_nodes, + self_repair_scale = args.self_repair_scale_nonlinearity) + +if __name__ == "__main__": + Main() + + diff --git a/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh b/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh new file mode 100644 index 00000000000..1344e185a02 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_fisher_data.sh @@ -0,0 +1,88 @@ +#! /bin/bash + +# This script prepares Fisher data for training a speech activity detection +# and music detection system + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +. path.sh +. cmd.sh + +if [ $# -ne 0 ]; then + echo "Usage: $0" + echo "This script is to serve as an example recipe." + echo "Edit the script to change variables if needed." + exit 1 +fi + +dir=exp/unsad/make_unsad_fisher_train_100k # Work dir +subset=150 + +# All the paths below can be modified to any absolute path. + +# The original data directory which will be converted to a whole (recording-level) directory. +train_data_dir=data/fisher_train_100k + +model_dir=exp/tri3a # Model directory used for decoding +sat_model_dir=exp/tri4a # Model directory used for getting alignments +lang=data/lang # Language directory +lang_test=data/lang_test # Language directory used to build graph + +# Hard code the mapping from phones to SAD labels +# 0 for silence, 1 for speech, 2 for noise, 3 for unk +cat < $dir/fisher_sad.map +sil 0 +sil_B 0 +sil_E 0 +sil_I 0 +sil_S 0 +laughter 2 +laughter_B 2 +laughter_E 2 +laughter_I 2 +laughter_S 2 +noise 2 +noise_B 2 +noise_E 2 +noise_I 2 +noise_S 2 +oov 3 +oov_B 3 +oov_E 3 +oov_I 3 +oov_S 3 +EOF + +# Expecting the user to have done run.sh to have $model_dir, +# $sat_model_dir, $lang, $lang_test, $train_data_dir +local/segmentation/prepare_unsad_data.sh \ + --sad-map $dir/fisher_sad.map \ + --config-dir conf \ + --reco-nj 40 --nj 100 --cmd "$train_cmd" \ + --sat-model $sat_model_dir \ + --lang-test $lang_test \ + $train_data_dir $lang $model_dir $dir + +data_dir=${train_data_dir}_whole + +if [ ! -z $subset ]; then + # Work on a subset + utils/subset_data_dir.sh ${data_dir} $subset \ + ${data_dir}_$subset + data_dir=${data_dir}_$subset +fi + +reco_vad_dir=$dir/`basename $model_dir`_reco_vad_`basename $train_data_dir`_sp + +# Add noise from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir.sh + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf + +# Add music from MUSAN corpus to data directory and create a new data directory +local/segmentation/do_corruption_data_dir_music.sh + --data-dir $data_dir \ + --reco-vad-dir $reco_vad_dir + --feat-suffix hires_bp --mfcc-config conf/mfcc_hires_bp.conf diff --git a/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh b/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh new file mode 100755 index 00000000000..12097811ec9 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/prepare_unsad_data.sh @@ -0,0 +1,537 @@ +#!/bin/bash + +# This script prepares speech labels and deriv weights for +# training unsad network for speech activity detection and music detection. + +set -u +set -o pipefail +set -e + +. path.sh + +stage=-2 +cmd=queue.pl +reco_nj=40 +nj=100 + +# Options to be passed to get_sad_map.py +map_noise_to_sil=true # Map noise phones to silence label (0) +map_unk_to_speech=true # Map unk phones to speech label (1) +sad_map= # Initial mapping from phones to speech/non-speech labels. + # Overrides the default mapping using phones/silence.txt + # and phones/nonsilence.txt + +# Options for feature extraction +feat_type=mfcc # mfcc or plp +add_pitch=false # Add pitch features + +config_dir=conf +feat_config= +pitch_config= + +mfccdir=mfcc +plpdir=plp + +speed_perturb=true + +sat_model_dir= # Model directory used for getting alignments +lang_test= # Language directory used to build graph. + # If its not provided, $lang will be used instead. + +. utils/parse_options.sh + +if [ $# -ne 5 ]; then + echo "This script takes a data directory and creates a new data directory " + echo "and speech activity labels" + echo "for the purpose of training a Universal Speech Activity Detector." + echo "Usage: $0 [options] " + echo " e.g.: $0 data/train_100k data/lang exp/tri4a exp/vad_data_prep" + echo "" + echo "Main options (for others, see top of script file)" + echo " --config # config file containing options" + echo " --cmd (run.pl|/queue.pl ) # how to run jobs." + echo " --reco-nj <#njobs|4> # Split a whole data directory into these many pieces" + echo " --nj <#njobs|4> # Split a segmented data directory into these many pieces" + exit 1 +fi + +data_dir=$1 +lang=$2 +model_dir=$3 +dir=$4 + +if [ $feat_type != "plp" ] && [ $feat_type != "mfcc" ]; then + echo "$0: --feat-type must be plp or mfcc. Must match the model_dir used." + exit 1 +fi + +[ -z "$feat_config" ] && feat_config=$config_dir/$feat_type.conf +[ -z "$pitch_config" ] && pitch_config=$config_dir/pitch.conf + +extra_files= + +if $add_pitch; then + extra_files="$extra_files $pitch_config" +fi + +for f in $feat_config $extra_files; do + if [ ! -f $f ]; then + echo "$f could not be found" + exit 1 + fi +done + +mkdir -p $dir + +function make_mfcc { + local nj=$nj + local mfcc_config=$feat_config + local add_pitch=$add_pitch + local cmd=$cmd + local pitch_config=$pitch_config + + while [ $# -gt 0 ]; do + if [ $1 == "--nj" ]; then + nj=$2 + shift; shift; + elif [ $1 == "--mfcc-config" ]; then + mfcc_config=$2 + shift; shift; + elif [ $1 == "--add-pitch" ]; then + add_pitch=$2 + shift; shift; + elif [ $1 == "--cmd" ]; then + cmd=$2 + shift; shift; + elif [ $1 == "--pitch-config" ]; then + pitch_config=$2 + shift; shift; + else + break + fi + done + + if [ $# -ne 3 ]; then + echo "Usage: make_mfcc " + exit 1 + fi + + if $add_pitch; then + steps/make_mfcc_pitch.sh --cmd "$cmd" --nj $nj \ + --mfcc-config $mfcc_config --pitch-config $pitch_config $1 $2 $3 || exit 1 + else + steps/make_mfcc.sh --cmd "$cmd" --nj $nj \ + --mfcc-config $mfcc_config $1 $2 $3 || exit 1 + fi + +} + +function make_plp { + local nj=$nj + local mfcc_config=$feat_config + local add_pitch=$add_pitch + local cmd=$cmd + local pitch_config=$pitch_config + + while [ $# -gt 0 ]; do + if [ $1 == "--nj" ]; then + nj=$2 + shift; shift; + elif [ $1 == "--plp-config" ]; then + plp_config=$2 + shift; shift; + elif [ $1 == "--add-pitch" ]; then + add_pitch=$2 + shift; shift; + elif [ $1 == "--cmd" ]; then + cmd=$2 + shift; shift; + elif [ $1 == "--pitch-config" ]; then + pitch_config=$2 + shift; shift; + else + break + fi + done + + if [ $# -ne 3 ]; then + echo "Usage: make_plp " + exit 1 + fi + + if $add_pitch; then + steps/make_plp_pitch.sh --cmd "$cmd" --nj $nj \ + --plp-config $plp_config --pitch-config $pitch_config $1 $2 $3 || exit 1 + else + steps/make_plp.sh --cmd "$cmd" --nj $nj \ + --plp-config $plp_config $1 $2 $3 || exit 1 + fi +} + +frame_shift_info=`cat $feat_config | steps/segmentation/get_frame_shift_info_from_config.pl` || exit 1 + +frame_shift=`echo $frame_shift_info | awk '{print $1}'` +frame_overlap=`echo $frame_shift_info | awk '{print $2}'` + +data_id=$(basename $data_dir) +whole_data_dir=${data_dir}_whole +whole_data_id=${data_id}_whole + +if [ $stage -le -2 ]; then + steps/segmentation/get_sad_map.py \ + --init-sad-map="$sad_map" \ + --map-noise-to-sil=$map_noise_to_sil \ + --map-unk-to-speech=$map_unk_to_speech \ + $lang | utils/sym2int.pl -f 1 $lang/phones.txt > $dir/sad_map + + utils/data/convert_data_dir_to_whole.sh ${data_dir} ${whole_data_dir} + utils/data/get_utt2dur.sh ${whole_data_dir} +fi + +if $speed_perturb; then + plpdir=${plpdir}_sp + mfccdir=${mfccdir}_sp + + + if [ $stage -le -1 ]; then + utils/data/perturb_data_dir_speed_3way.sh ${whole_data_dir} ${whole_data_dir}_sp + utils/data/perturb_data_dir_speed_3way.sh ${data_dir} ${data_dir}_sp + + if [ $feat_type == "mfcc" ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $mfccdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$mfccdir/storage $mfccdir/storage + fi + make_mfcc --cmd "$cmd --max-jobs-run 40" --nj $nj \ + --mfcc-config $feat_config \ + --add-pitch $add_pitch --pitch-config $pitch_config \ + ${whole_data_dir}_sp exp/make_mfcc $mfccdir || exit 1 + steps/compute_cmvn_stats.sh \ + ${whole_data_dir}_sp exp/make_mfcc $mfccdir || exit 1 + elif [ $feat_type == "plp" ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $plpdir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$plpdir/storage $plpdir/storage + fi + + make_plp --cmd "$cmd --max-jobs-run 40" --nj $nj \ + --plp-config $feat_config \ + --add-pitch $add_pitch --pitch-config $pitch_config \ + ${whole_data_dir}_sp exp/make_plp $plpdir || exit 1 + steps/compute_cmvn_stats.sh \ + ${whole_data_dir}_sp exp/make_plp $plpdir || exit 1 + else + echo "$0: Unknown feat-type $feat_type. Must be mfcc or plp." + exit 1 + fi + + utils/fix_data_dir.sh ${whole_data_dir}_sp + fi + + data_dir=${data_dir}_sp + whole_data_dir=${whole_data_dir}_sp + data_id=${data_id}_sp +fi + + +############################################################################### +# Compute length of recording +############################################################################### + +utils/data/get_reco2utt.sh $data_dir + +if [ $stage -le 0 ]; then + steps/segmentation/get_utt2num_frames.sh \ + --frame-shift $frame_shift --frame-overlap $frame_overlap \ + --cmd "$cmd" --nj $reco_nj $whole_data_dir + + awk '{print $1" "$2}' ${data_dir}/segments | utils/apply_map.pl -f 2 ${whole_data_dir}/utt2num_frames > $data_dir/utt2max_frames + utils/data/subsegment_feats.sh ${whole_data_dir}/feats.scp \ + $frame_shift $frame_overlap ${data_dir}/segments | \ + utils/data/fix_subsegmented_feats.pl $data_dir/utt2max_frames \ + > ${data_dir}/feats.scp + + if [ $feat_type == mfcc ]; then + steps/compute_cmvn_stats.sh ${data_dir} exp/make_mfcc/${data_id} $mfccdir + else + steps/compute_cmvn_stats.sh ${data_dir} exp/make_plp/${data_id} $plpdir + fi + + utils/fix_data_dir.sh $data_dir +fi + +if [ -z "$sat_model_dir" ]; then + ali_dir=${model_dir}_ali_${data_id} + if [ $stage -le 2 ]; then + steps/align_si.sh --nj $nj --cmd "$cmd" \ + ${data_dir} ${lang} ${model_dir} ${model_dir}_ali_${data_id} || exit 1 + fi +else + ali_dir=${sat_model_dir}_ali_${data_id} + #obtain the alignment of the perturbed data + if [ $stage -le 2 ]; then + steps/align_fmllr.sh --nj $nj --cmd "$cmd" \ + ${data_dir} ${lang} ${sat_model_dir} ${sat_model_dir}_ali_${data_id} || exit 1 + fi +fi + + +# All the data from this point is speed perturbed. + +data_id=$(basename $data_dir) +utils/split_data.sh $data_dir $nj + +############################################################################### +# Convert alignment for the provided segments into +# initial SAD labels at utterance-level in segmentation format +############################################################################### + +vad_dir=$dir/`basename ${ali_dir}`_vad_${data_id} +if [ $stage -le 3 ]; then + steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$cmd" \ + $data_dir $ali_dir \ + $dir/sad_map $vad_dir +fi + +[ ! -s $vad_dir/sad_seg.scp ] && echo "$0: $vad_dir/vad.scp is empty" && exit 1 + +if [ $stage -le 4 ]; then + utils/copy_data_dir.sh $data_dir $dir/${data_id}_manual_segments + + awk '{print $1" "$2}' $dir/${data_id}_manual_segments/segments | sort -k1,1 > $dir/${data_id}_manual_segments/utt2spk + utils/utt2spk_to_spk2utt.pl $dir/${data_id}_manual_segments/utt2spk | sort -k1,1 > $dir/${data_id}_manual_segments/spk2utt + + if [ $feat_type == mfcc ]; then + steps/compute_cmvn_stats.sh $dir/${data_id}_manual_segments exp/make_mfcc/${data_id}_manual_segments $mfccdir + else + steps/compute_cmvn_stats.sh $dir/${data_id}_manual_segments exp/make_plp/${data_id}_manual_segments $plpdir + fi + + utils/fix_data_dir.sh $dir/${data_id}_manual_segments || true # Might fail because utt2spk will be not sorted on both utts and spks +fi + + +#utils/split_data.sh --per-reco $data_dir $reco_nj +#segmentation-combine-segments ark,s:$vad_dir/sad_seg.scp +# "ark,s:segmentation-init-from-segments --shift-to-zero=false --frame-shift=$ali_frame_shift --frame-overlap=$ali_frame_overlap ${data}/split${reco_nj}reco/JOB/segments ark:- |" \ +# "ark:cat ${data}/split${reco_nj}reco/JOB/segments | cut -d ' ' -f 1,2 | utils/utt2spk_to_spk2utt.pl | sort -k1,1 |" ark:- + +############################################################################### + + +# Create extended data directory that consists of the provided +# segments along with the segments outside it. +# This is basically dividing the whole recording into pieces +# consisting of pieces corresponding to the provided segments +# and outside the provided segments. + +############################################################################### +# Create segments outside of the manual segments +############################################################################### + +outside_data_dir=$dir/${data_id}_outside +if [ $stage -le 5 ]; then + rm -rf $outside_data_dir + mkdir -p $outside_data_dir/split${reco_nj}reco + + for f in wav.scp reco2file_and_channel stm glm; do + [ -f ${data_dir}/$f ] && cp ${data_dir}/$f $outside_data_dir + done + + steps/segmentation/split_data_on_reco.sh $data_dir $whole_data_dir $reco_nj + + for n in `seq $reco_nj`; do + dsn=$whole_data_dir/split${reco_nj}reco/$n + awk '{print $2}' $dsn/segments | \ + utils/filter_scp.pl /dev/stdin $whole_data_dir/utt2num_frames > \ + $dsn/utt2num_frames + mkdir -p $outside_data_dir/split${reco_nj}reco/$n + done + + $cmd JOB=1:$reco_nj $outside_data_dir/log/get_empty_segments.JOB.log \ + segmentation-init-from-segments --frame-shift=$frame_shift \ + --frame-overlap=$frame_overlap --shift-to-zero=false \ + ${data_dir}/split${reco_nj}reco/JOB/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- \ + "ark,t:cut -d ' ' -f 1,2 ${data_dir}/split${reco_nj}reco/JOB/segments | utils/utt2spk_to_spk2utt.pl |" ark:- \| \ + segmentation-create-subsegments --filter-label=1 --subsegment-label=0 \ + "ark:segmentation-init-from-lengths --label=1 ark,t:${whole_data_dir}/split${reco_nj}reco/JOB/utt2num_frames ark:- |" \ + ark:- ark:- \| \ + segmentation-post-process --remove-labels=0 --max-segment-length=1000 \ + --post-process-label=1 --overlap-length=50 \ + ark:- ark:- \| segmentation-to-segments --single-speaker=true \ + --frame-shift=$frame_shift --frame-overlap=$frame_overlap \ + ark:- ark,t:$outside_data_dir/split${reco_nj}reco/JOB/utt2spk \ + $outside_data_dir/split${reco_nj}reco/JOB/segments || exit 1 + + for n in `seq $reco_nj`; do + cat $outside_data_dir/split${reco_nj}reco/$n/utt2spk + done | sort -k1,1 > $outside_data_dir/utt2spk + + for n in `seq $reco_nj`; do + cat $outside_data_dir/split${reco_nj}reco/$n/segments + done | sort -k1,1 > $outside_data_dir/segments + + utils/fix_data_dir.sh $outside_data_dir + +fi + + +if [ $stage -le 6 ]; then + utils/data/get_reco2utt.sh $outside_data_dir + awk '{print $1" "$2}' $outside_data_dir/segments | utils/apply_map.pl -f 2 $whole_data_dir/utt2num_frames > $outside_data_dir/utt2max_frames + + utils/data/subsegment_feats.sh ${whole_data_dir}/feats.scp \ + $frame_shift $frame_overlap ${outside_data_dir}/segments | \ + utils/data/fix_subsegmented_feats.pl $outside_data_dir/utt2max_framres \ + > ${outside_data_dir}/feats.scp + +fi + +extended_data_dir=$dir/${data_id}_extended +if [ $stage -le 7 ]; then + cp $dir/${data_id}_manual_segments/cmvn.scp ${outside_data_dir} || exit 1 + utils/fix_data_dir.sh $outside_data_dir + + utils/combine_data.sh $extended_data_dir $data_dir $outside_data_dir + + steps/segmentation/split_data_on_reco.sh $data_dir $extended_data_dir $reco_nj +fi + +############################################################################### +# Create graph for decoding +############################################################################### + +# TODO: By default, we use word LM. If required, we can think +# consider phone LM. +graph_dir=$model_dir/graph +if [ $stage -le 8 ]; then + if [ ! -d $graph_dir ]; then + utils/mkgraph.sh ${lang_test} $model_dir $graph_dir || exit 1 + fi +fi + +############################################################################### +# Decode extended data directory +############################################################################### + + +# Decode without lattice (get only best path) +if [ $stage -le 8 ]; then + steps/decode_nolats.sh --cmd "$cmd --mem 2G" --nj $nj \ + --max-active 1000 --beam 10.0 --write-words false \ + --write-alignments true \ + $graph_dir ${extended_data_dir} \ + ${model_dir}/decode_${data_id}_extended || exit 1 + cp ${model_dir}/final.mdl ${model_dir}/decode_${data_id}_extended +fi + +model_id=`basename $model_dir` + +# Get VAD based on the decoded best path +decode_vad_dir=$dir/${model_id}_decode_vad_${data_id} +if [ $stage -le 9 ]; then + steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$cmd" \ + $extended_data_dir ${model_dir}/decode_${data_id}_extended \ + $dir/sad_map $decode_vad_dir +fi + +[ ! -s $decode_vad_dir/sad_seg.scp ] && echo "$0: $decode_vad_dir/vad.scp is empty" && exit 1 + +vad_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $vad_dir ${PWD}` + +if [ $stage -le 10 ]; then + segmentation-init-from-segments --frame-shift=$frame_shift \ + --frame-overlap=$frame_overlap --segment-label=0 \ + $outside_data_dir/segments \ + ark,scp:$vad_dir/outside_sad_seg.ark,$vad_dir/outside_sad_seg.scp +fi + +reco_vad_dir=$dir/${model_id}_reco_vad_${data_id} +mkdir -p $reco_vad_dir +if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $reco_vad_dir/storage ]; then + utils/create_split_dir.pl \ + /export/b0{3,4,5,6}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$reco_vad_dir/storage $reco_vad_dir/storage +fi + +reco_vad_dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $reco_vad_dir ${PWD}` + +echo $reco_nj > $reco_vad_dir/num_jobs + +if [ $stage -le 11 ]; then + $cmd JOB=1:$reco_nj $reco_vad_dir/log/intersect_vad.JOB.log \ + segmentation-intersect-segments --mismatch-label=10 \ + "scp:cat $vad_dir/sad_seg.scp $vad_dir/outside_sad_seg.scp | sort -k1,1 | utils/filter_scp.pl $extended_data_dir/split${reco_nj}reco/JOB/utt2spk |" \ + "scp:utils/filter_scp.pl $extended_data_dir/split${reco_nj}reco/JOB/utt2spk $decode_vad_dir/sad_seg.scp |" \ + ark:- \| segmentation-post-process --remove-labels=10 \ + --merge-adjacent-segments --max-intersegment-length=10 ark:- ark:- \| \ + segmentation-combine-segments ark:- "ark:segmentation-init-from-segments --shift-to-zero=false $extended_data_dir/split${reco_nj}reco/JOB/segments ark:- |" \ + ark,t:$extended_data_dir/split${reco_nj}reco/JOB/reco2utt \ + ark,scp:$reco_vad_dir/sad_seg.JOB.ark,$reco_vad_dir/sad_seg.JOB.scp + for n in `seq $reco_nj`; do + cat $reco_vad_dir/sad_seg.$n.scp + done > $reco_vad_dir/sad_seg.scp +fi + +set +e +for n in `seq $reco_nj`; do + utils/create_data_link.pl $reco_vad_dir/deriv_weights.$n.ark + utils/create_data_link.pl $reco_vad_dir/deriv_weights_for_uncorrupted.$n.ark + utils/create_data_link.pl $reco_vad_dir/speech_feat.$n.ark +done +set -e + +if [ $stage -le 12 ]; then + $cmd JOB=1:$reco_nj $reco_vad_dir/log/get_deriv_weights.JOB.log \ + segmentation-post-process --merge-labels=0:1:2:3 --merge-dst-label=1 \ + scp:$reco_vad_dir/sad_seg.JOB.scp ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${whole_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$reco_vad_dir/deriv_weights.JOB.ark,$reco_vad_dir/deriv_weights.JOB.scp + + for n in `seq $reco_nj`; do + cat $reco_vad_dir/deriv_weights.$n.scp + done > $reco_vad_dir/deriv_weights.scp +fi + +if [ $stage -le 13 ]; then + $cmd JOB=1:$reco_nj $reco_vad_dir/log/get_deriv_weights_for_uncorrupted.JOB.log \ + segmentation-post-process --remove-labels=1:2:3 scp:$reco_vad_dir/sad_seg.JOB.scp \ + ark:- \| segmentation-post-process --merge-labels=0 --merge-dst-label=1 ark:- ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${whole_data_dir}/utt2num_frames ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$reco_vad_dir/deriv_weights_for_uncorrupted.JOB.ark,$reco_vad_dir/deriv_weights_for_uncorrupted.JOB.scp + for n in `seq $reco_nj`; do + cat $reco_vad_dir/deriv_weights_for_uncorrupted.$n.scp + done > $reco_vad_dir/deriv_weights_for_uncorrupted.scp +fi + +if [ $stage -le 14 ]; then + $cmd JOB=1:$reco_nj $reco_vad_dir/log/get_speech_labels.JOB.log \ + segmentation-post-process --keep-label=1 scp:$reco_vad_dir/sad_seg.JOB.scp ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${whole_data_dir}/utt2num_frames \ + ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| vector-to-feat ark:- ark:- \| copy-feats --compress \ + ark:- ark,scp:$reco_vad_dir/speech_feat.JOB.ark,$reco_vad_dir/speech_feat.JOB.scp + for n in `seq $reco_nj`; do + cat $reco_vad_dir/speech_feat.$n.scp + done > $reco_vad_dir/speech_feat.scp +fi + +if [ $stage -le 15 ]; then + $cmd JOB=1:$reco_nj $reco_vad_dir/log/convert_manual_segments_to_deriv_weights.JOB.log \ + segmentation-init-from-segments --shift-to-zero=false \ + $data_dir/split${reco_nj}reco/JOB/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- \ + ark:$data_dir/split${reco_nj}reco/JOB/reco2utt ark:- \| \ + segmentation-to-ali --lengths-rspecifier=ark,t:${whole_data_dir}/utt2num_frames \ + ark:- ark,t:- \| \ + steps/segmentation/convert_ali_to_vec.pl \| copy-vector ark,t:- \ + ark,scp:$reco_vad_dir/deriv_weights_manual_seg.JOB.ark,$reco_vad_dir/deriv_weights_manual_seg.JOB.scp + + for n in `seq $reco_nj`; do + cat $reco_vad_dir/deriv_weights_manual_seg.$n.scp + done > $reco_vad_dir/deriv_weights_manual_seg.scp +fi + +echo "$0: Finished creating corpus for training Universal SAD with data in $whole_data_dir and labels in $reco_vad_dir" diff --git a/egs/aspire/s5/local/segmentation/run_fisher.sh b/egs/aspire/s5/local/segmentation/run_fisher.sh new file mode 100644 index 00000000000..e39ef5f3a91 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/run_fisher.sh @@ -0,0 +1,23 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +local/segmentation/prepare_fisher_data.sh + +utils/combine_data.sh --extra-files "speech_feat.scp deriv_weights.scp deriv_weights_manual_seg.scp music_labels.scp" \ + data/fisher_train_100k_whole_all_corrupted_sp_hires_bp \ + data/fisher_train_100k_whole_corrupted_sp_hires_bp \ + data/fisher_train_100k_whole_music_corrupted_sp_hires_bp + +local/segmentation/train_stats_sad_music.sh \ + --train-data-dir data/fisher_train_100k_whole_all_corrupted_sp_hires_bp \ + --speech-feat-scp data/fisher_train_100k_whole_corrupted_sp_hires_bp/speech_feat.scp \ + --deriv-weights-scp data/fisher_train_100k_whole_corrupted_sp_hires_bp/deriv_weights.scp \ + --music-labels-scp data/fisher_train-100k_whole_music_corrupted_sp_hires_bp/music_labels.scp \ + --max-param-change 0.2 \ + --num-epochs 2 --affix k \ + --splice-indexes "-3,-2,-1,0,1,2,3 -6,0,mean+count(-99:3:9:99) -9,0,3 0" + +local/segmentation/run_segmentation_ami.sh \ + --nnet-dir exp/nnet3_sad_snr/nnet_tdnn_k_n4 diff --git a/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh new file mode 100755 index 00000000000..46ebf013b82 --- /dev/null +++ b/egs/aspire/s5/local/segmentation/run_segmentation_ami.sh @@ -0,0 +1,128 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +. cmd.sh +. path.sh + +set -e +set -o pipefail +set -u + +stage=-1 +nnet_dir=exp/nnet3_sad_snr/nnet_tdnn_k_n4 + +. utils/parse_options.sh + +export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH + +src_dir=/export/a09/vmanoha1/workspace_asr_diarization/egs/ami/s5b # AMI src_dir +dir=exp/sad_ami_sdm1_dev/ref + +mkdir -p $dir + +# Expecting user to have done run.sh to run the AMI recipe in $src_dir for +# both sdm and ihm microphone conditions + +if [ $stage -le 1 ]; then + ( + cd $src_dir + local/prepare_parallel_train_data.sh --train-set dev sdm1 + + awk '{print $1" "$2}' $src_dir/data/ihm/dev/segments > \ + $src_dir/data/ihm/dev/utt2reco + awk '{print $1" "$2}' $src_dir/data/sdm1/dev/segments > \ + $src_dir/data/sdm1/dev/utt2reco + + cat $src_dir/data/sdm1/dev_ihmdata/ihmutt2utt | \ + utils/apply_map.pl -f 1 $src_dir/data/ihm/dev/utt2reco | \ + utils/apply_map.pl -f 2 $src_dir/data/sdm1/dev/utt2reco | \ + sort -u > $src_dir/data/sdm1/dev_ihmdata/ihm2sdm_reco + ) +fi + +if [ $stage -le 2 ]; then + ( + cd $src_dir + utils/data/get_reco2utt.sh $src_dir/data/sdm1/dev + ) + + phone_map=$dir/phone_map + steps/segmentation/get_sad_map.py \ + $src_dir/data/lang | utils/sym2int.pl -f 1 $src_dir/data/lang/phones.txt > \ + $phone_map +fi + +if [ $stage -le 3 ]; then + # Expecting user to have run local/run_cleanup_segmentation.sh in $src_dir + ( + cd $src_dir + steps/align_fmllr.sh --nj 18 --cmd "$train_cmd" \ + data/sdm1/dev_ihmdata data/lang \ + exp/ihm/tri3_cleaned \ + exp/sdm1/tri3_cleaned_dev_ihmdata + ) +fi + +if [ $stage -le 4 ]; then + steps/segmentation/internal/convert_ali_to_vad.sh --cmd "$train_cmd" \ + $src_dir/exp/sdm1/tri3_cleaned_dev_ihmdata $phone_map $dir +fi + +echo "A 1" > $dir/channel_map +cat $src_dir/data/sdm1/dev/reco2file_and_channel | \ + utils/apply_map.pl -f 3 $dir/channel_map > $dir/reco2file_and_channel + +if [ $stage -le 5 ]; then + $train_cmd $dir/log/get_ref_rttm.log \ + segmentation-combine-segments scp:$dir/sad_seg.scp \ + "ark:segmentation-init-from-segments --shift-to-zero=false $src_dir/data/sdm1/dev_ihmdata/segments ark:- |" \ + ark,t:$src_dir/data/sdm1/dev_ihmdata/reco2utt ark:- \| \ + segmentation-merge-recordings \ + "ark,t:utils/utt2spk_to_spk2utt.pl $src_dir/data/sdm1/dev_ihmdata/ihm2sdm_reco |" \ + ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + ark:- $dir/ref.rttm +fi + +if [ $stage -le 6 ]; then + $train_cmd $dir/log/get_uem.log \ + segmentation-init-from-segments --shift-to-zero=false $src_dir/data/sdm1/dev/segments ark:- \| \ + segmentation-combine-segments-to-recordings ark:- ark,t:$src_dir/data/sdm1/dev/reco2utt ark:- \| \ + segmentation-post-process --remove-labels=0 --merge-adjacent-segments \ + --max-intersegment-length=10000 ark:- ark:- \| \ + segmentation-to-rttm --reco2file-and-channel=$dir/reco2file_and_channel \ + ark:- - \| grep SPEECH \| grep SPEAKER \| \ + rttmSmooth.pl -s 0 \| awk '{ print $2" "$3" "$4" "$5+$4 }' '>' $dir/uem +fi + +hyp_dir=$nnet_dir/segmentation_ami_sdm1_dev_whole_bp + +if [ $stage -le 7 ]; then + steps/segmentation/do_segmentation_data_dir.sh --reco-nj 18 \ + --mfcc-config conf/mfcc_hires_bp.conf --feat-affix bp --do-downsampling true \ + --extra-left-context 100 --extra-right-context 20 \ + --output-name output-speech --frame-subsampling-factor 6 \ + $src_dir/data/sdm1/dev data/ami_sdm1_dev $nnet_dir +fi + + +if [ $stage -le 8 ]; then + utils/data/get_reco2utt.sh $src_dir/data/sdm1/dev_ihmdata + + steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ + $hyp_dir/ami_sdm1_dev_seg/utt2spk \ + $hyp_dir/ami_sdm1_dev_seg/segments \ + $dir/reco2file_and_channel \ + /dev/stdout | spkr2sad.pl > $hyp_dir/sys.rttm +fi + +if [ $stage -le 9 ]; then + md-eval.pl -s <(cat $hyp_dir/sys.rttm | grep speech | rttmSmooth.pl -s 0) \ + -r <(cat $dir/ref.rttm | grep SPEECH | rttmSmooth.pl -s 0 ) \ + -u $dir/uem -c 0.25 +fi + +#md-eval.pl -s <( segmentation-init-from-segments --shift-to-zero=false exp/nnet3_sad_snr/nnet_tdnn_j_n4/segmentation_ami_sdm1_dev_whole_bp/ami_sdm1_dev_seg/segments ark:- | segmentation-combine-segments-to-recordings ark:- ark,t:exp/nnet3_sad_snr/nnet_tdnn_j_n4/segmentation_ami_sdm1_dev_whole_bp/ami_sdm1_dev_seg/reco2utt ark:- | segmentation-to-ali --length-tolerance=1000 --lengths-rspecifier=ark,t:data/ami_sdm1_dev_whole_bp_hires/utt2num_frames ark:- ark:- | +#segmentation-init-from-ali ark:- ark:- | segmentation-to-rttm ark:- - | grep SPEECH | rttmSmooth.pl -s 0) diff --git a/egs/aspire/s5/local/segmentation/run_train_sad.sh b/egs/aspire/s5/local/segmentation/run_train_sad.sh new file mode 100755 index 00000000000..9b1f104939a --- /dev/null +++ b/egs/aspire/s5/local/segmentation/run_train_sad.sh @@ -0,0 +1,150 @@ +#!/bin/bash + +# this is the standard "tdnn" system, built in nnet3; it's what we use to +# call multi-splice. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= + +splice_indexes="-3,-2,-1,0,1,2,3 -6,0 -9,0,3 0" +relu_dim=256 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=1 +extra_egs_copy_cmd= + +num_utts_subset_valid=40 +num_utts_subset_train=40 +add_idct=true + +# target options +train_data_dir=data/train_azteec_whole_sp_corrupted_hires + +snr_scp= +speech_feat_scp= + +deriv_weights_scp= +deriv_weights_for_irm_scp= + +egs_dir= +nj=40 +feat_type=raw +config_dir= +compute_objf_opts= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +num_hidden_layers=`echo $splice_indexes | perl -ane 'print scalar @F'` || exit 1 +if [ -z "$dir" ]; then + dir=exp/nnet3_sad_snr/nnet_tdnn +fi + +dir=$dir${affix:+_$affix}_n${num_hidden_layers} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$train_data_dir/feats.scp -` name=input + + # please note that it is important to have input layer with the name=input + # as the layer immediately preceding the fixed-affine-layer to enable + # the use of short notation for the descriptor + # This is disabled for now. + # fixed-affine-layer name=lda input=Append(-3,-2,-1,0,1,2,3) affine-transform-file=$dir/configs/lda.mat + # the first splicing is moved before the lda layer, so no splicing here + # relu-renorm-layer name=tdnn1 dim=625 + + relu-renorm-layer name=tdnn1 input=Append(-3,-2,-1,0,1,2,3) dim=256 + stats-layer name=tdnn2.stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(Offset(tdnn1, -6), tdnn1, tdnn2.stats) dim=256 + relu-renorm-layer name=tdnn3 input=Append(-9,0,3) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 input=tdnn4 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn4 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ +fi + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$train_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=20000 \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$speech_feat_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_labels_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --dir=$dir/egs + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=20 \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=64 \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$train_data_dir \ + --targets-scp="$speech_feat_scp" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 6 ]; then + $train_cmd JOB=1:100 $dir/log/compute_post_output-speech.JOB.log \ + extract-column "scp:utils/split_scp.pl -j 100 \$[JOB-1] $speech_feat_scp |" ark,t:- \| \ + steps/segmentation/quantize_vector.pl \| \ + ali-to-post ark,t:- ark:- \| \ + weight-post ark:- scp:$deriv_weights_scp ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-speech.vec.JOB + eval vector-sum $dir/post_output-speech.vec.{`seq -s, 100`} $dir/post_output-speech.vec + + $train_cmd JOB=1:100 $dir/log/compute_post_output-music.JOB.log \ + ali-to-post "scp:utils/split_scp.pl -j 100 \$[JOB-1] $music_labels_scp |" ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-music.vec.JOB + eval vector-sum $dir/post_output-music.vec.{`seq -s, 100`} $dir/post_output-music.vec +fi diff --git a/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1c.sh b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1c.sh new file mode 100644 index 00000000000..163ea6df14d --- /dev/null +++ b/egs/aspire/s5/local/segmentation/tuning/train_stats_sad_music_1c.sh @@ -0,0 +1,185 @@ +#!/bin/bash + +# This is a script to train a time-delay neural network for speech activity detection (SAD) and +# music-id using statistic pooling component for long-context information. + +set -o pipefail +set -e +set -u + +. cmd.sh + +# At this script level we don't support not running on GPU, as it would be painfully slow. +# If you want to run without GPU you'd have to call train_tdnn.sh with --gpu false, +# --num-threads 16 and --minibatch-size 128. + +stage=0 +train_stage=-10 +get_egs_stage=-10 +egs_opts= # Directly passed to get_egs_multiple_targets.py + +# TDNN options +relu_dim=256 +chunk_width=20 # We use chunk training for training TDNN +extra_left_context=100 # Maximum left context in egs apart from TDNN's left context +extra_right_context=20 # Maximum right context in egs apart from TDNN's right context + +# We randomly select an extra {left,right} context for each job between +# min_extra_*_context and extra_*_context so that the network can get used +# to different contexts used to compute statistics. +min_extra_left_context=20 +min_extra_right_context=0 + +# training options +num_epochs=2 +initial_effective_lrate=0.0003 +final_effective_lrate=0.00003 +num_jobs_initial=3 +num_jobs_final=8 +remove_egs=false +max_param_change=0.2 # Small max-param change for small network +extra_egs_copy_cmd= # Used if you want to do some weird stuff to egs + # such as removing one of the targets + +num_utts_subset_valid=50 # "utts" is actually recording. So this is prettly small. +num_utts_subset_train=50 + +# target options +train_data_dir=data/train_azteec_whole_sp_corrupted_hires + +speech_feat_scp= +music_labels_scp= + +deriv_weights_scp= + +egs_dir= +nj=40 +feat_type=raw +config_dir= + +dir= +affix=a + +. cmd.sh +. ./path.sh +. ./utils/parse_options.sh + +if [ -z "$dir" ]; then + dir=exp/nnet3_stats_sad_music/nnet_tdnn +fi + +dir=$dir${affix:+_$affix} + +if ! cuda-compiled; then + cat < $dir/configs/network.xconfig + input dim=`feat-to-dim scp:$train_data_dir/feats.scp -` name=input + output name=output-temp input=Append(-3,-2,-1,0,1,2,3) + + relu-renorm-layer name=tdnn1 input=Append(input@-3, input@-2, input@-1, input, input@1, input@2, input@3) dim=256 + stats-layer name=tdnn2_stats config=mean+count(-99:3:9:99) + relu-renorm-layer name=tdnn2 input=Append(tdnn1@-9, tdnn1@-3, tdnn1, tdnn1@3, tdnn2_stats) dim=256 + stats-layer name=tdnn3_stats config=mean+count(-108:9:27:108) + relu-renorm-layer name=tdnn3 input=Append(tdnn2@-27, tdnn2@-9, tdnn2, tdnn2@9, tdnn3_stats) dim=256 + relu-renorm-layer name=tdnn4 dim=256 + + output-layer name=output-speech include-log-softmax=true dim=2 input=tdnn4 + output-layer name=output-music include-log-softmax=true dim=2 input=tdnn4 +EOF + steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig \ + --config-dir $dir/configs/ \ + --nnet-edits="rename-node old-name=output-speech new-name=output" + + cat <> $dir/configs/vars +add_lda=false +EOF +fi + +if [ -z "$egs_dir" ]; then + egs_dir=$dir/egs + if [ $stage -le 4 ]; then + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then + utils/create_split_dir.pl \ + /export/b{03,04,05,06}/$USER/kaldi-data/egs/aspire-$(date +'%m_%d_%H_%M')/s5/$dir/egs/storage $dir/egs/storage + fi + + . $dir/configs/vars + + steps/nnet3/get_egs_multiple_targets.py --cmd="$decode_cmd" \ + $egs_opts \ + --feat.dir="$train_data_dir" \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --frames-per-eg=$chunk_width \ + --left-context=$[model_left_context + extra_left_context] \ + --right-context=$[model_right_context + extra_right_context] \ + --num-utts-subset-train=$num_utts_subset_train \ + --num-utts-subset-valid=$num_utts_subset_valid \ + --samples-per-iter=20000 \ + --stage=$get_egs_stage \ + --targets-parameters="--output-name=output-speech --target-type=sparse --dim=2 --targets-scp=$speech_feat_scp --deriv-weights-scp=$deriv_weights_scp --scp2ark-cmd=\"extract-column --column-index=0 scp:- ark,t:- | steps/segmentation/quantize_vector.pl | ali-to-post ark,t:- ark:- |\" --compress=true" \ + --targets-parameters="--output-name=output-music --target-type=sparse --dim=2 --targets-scp=$music_labels_scp --scp2ark-cmd=\"ali-to-post scp:- ark:- |\" --compress=true" \ + --dir=$dir/egs + fi +fi + +if [ $stage -le 5 ]; then + steps/nnet3/train_raw_rnn.py --stage=$train_stage \ + --feat.cmvn-opts="--norm-means=false --norm-vars=false" \ + --egs.chunk-width=20 \ + --egs.dir="$egs_dir" --egs.stage=$get_egs_stage \ + --egs.chunk-left-context=$extra_left_context \ + --egs.chunk-right-context=$extra_right_context \ + ${extra_egs_copy_cmd:+--egs.extra-copy-cmd="$extra_egs_copy_cmd"} \ + --trainer.min-chunk-left-context=$min_extra_left_context \ + --trainer.min-chunk-right-context=$min_extra_right_context \ + --trainer.num-epochs=$num_epochs \ + --trainer.samples-per-iter=20000 \ + --trainer.optimization.num-jobs-initial=$num_jobs_initial \ + --trainer.optimization.num-jobs-final=$num_jobs_final \ + --trainer.optimization.initial-effective-lrate=$initial_effective_lrate \ + --trainer.optimization.final-effective-lrate=$final_effective_lrate \ + --trainer.optimization.shrink-value=1.0 \ + --trainer.rnn.num-chunk-per-minibatch=64 \ + --trainer.deriv-truncate-margin=8 \ + --trainer.max-param-change=$max_param_change \ + --cmd="$decode_cmd" --nj 40 \ + --cleanup=true \ + --cleanup.remove-egs=$remove_egs \ + --cleanup.preserve-model-interval=10 \ + --use-gpu=true \ + --use-dense-targets=false \ + --feat-dir=$train_data_dir \ + --targets-scp="$speech_feat_scp" \ + --dir=$dir || exit 1 +fi + +if [ $stage -le 6 ]; then + $train_cmd JOB=1:100 $dir/log/compute_post_output-speech.JOB.log \ + extract-column "scp:utils/split_scp.pl -j 100 \$[JOB-1] $speech_feat_scp |" ark,t:- \| \ + steps/segmentation/quantize_vector.pl \| \ + ali-to-post ark,t:- ark:- \| \ + weight-post ark:- scp:$deriv_weights_scp ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-speech.vec.JOB + eval vector-sum $dir/post_output-speech.vec.{`seq -s, 100`} $dir/post_output-speech.vec + + $train_cmd JOB=1:100 $dir/log/compute_post_output-music.JOB.log \ + ali-to-post "scp:utils/split_scp.pl -j 100 \$[JOB-1] $music_labels_scp |" ark:- \| \ + post-to-feats --post-dim=2 ark:- ark:- \| \ + matrix-sum-rows ark:- ark:- \| \ + vector-sum ark:- $dir/post_output-music.vec.JOB + eval vector-sum $dir/post_output-music.vec.{`seq -s, 100`} $dir/post_output-music.vec +fi diff --git a/egs/aspire/s5/path.sh b/egs/aspire/s5/path.sh index 1a6fb5f891b..5c0d3a92f19 100755 --- a/egs/aspire/s5/path.sh +++ b/egs/aspire/s5/path.sh @@ -2,4 +2,8 @@ export KALDI_ROOT=`pwd`/../../.. export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present -> Exit!" && exit 1 . $KALDI_ROOT/tools/config/common_path.sh +export PATH=/home/vmanoha1/kaldi-raw-signal/src/segmenterbin:$PATH +export PATH=$KALDI_ROOT/tools/sph2pipe_v2.5:$PATH +export PATH=$KALDI_ROOT/tools/sctk/bin:$PATH +export PYTHONPATH=steps:${PYTHONPATH} export LC_ALL=C diff --git a/egs/babel/s5c/local/resegment/evaluate_segmentation.pl b/egs/babel/s5c/local/resegment/evaluate_segmentation.pl deleted file mode 100755 index 06a762d7762..00000000000 --- a/egs/babel/s5c/local/resegment/evaluate_segmentation.pl +++ /dev/null @@ -1,198 +0,0 @@ -#!/usr/bin/env perl - -# Copyright 2014 Johns Hopkins University (Author: Sanjeev Khudanpur), Vimal Manohar -# Apache 2.0 - -################################################################################ -# -# This script was written to check the goodness of automatic segmentation tools -# It assumes input in the form of two Kaldi segments files, i.e. a file each of -# whose lines contain four space-separated values: -# -# UtteranceID FileID StartTime EndTime -# -# It computes # missed frames, # false positives and # overlapping frames. -# -################################################################################ - -if ($#ARGV == 1) { - $ReferenceSegmentation = $ARGV[0]; - $HypothesizedSegmentation = $ARGV[1]; - printf STDERR ("Comparing reference segmentation\n\t%s\nwith proposed segmentation\n\t%s\n", - $ReferenceSegmentation, - $HypothesizedSegmentation); -} else { - printf STDERR "This program compares the reference segmenation with the proposted segmentation\n"; - printf STDERR "Usage: $0 reference_segments_filename proposed_segments_filename\n"; - printf STDERR "e.g. $0 data/dev10h/segments data/dev10h.seg/segments\n"; - exit (0); -} - -################################################################################ -# First read the reference segmentation, and -# store the start- and end-times of all segments in each file. -################################################################################ - -open (SEGMENTS, "cat $ReferenceSegmentation | sort -k2,2 -k3n,3 -k4n,4 |") - || die "Unable to open $ReferenceSegmentation"; -$numLines = 0; -while ($line=) { - chomp $line; - @field = split("[ \t]+", $line); - unless ($#field == 3) { - exit (1); - printf STDERR "Skipping unparseable line in file $ReferenceSegmentation\n\t$line\n"; - next; - } - $fileID = $field[1]; - unless (exists $firstSeg{$fileID}) { - $firstSeg{$fileID} = $numLines; - $actualSpeech{$fileID} = 0.0; - $hypothesizedSpeech{$fileID} = 0.0; - $foundSpeech{$fileID} = 0.0; - $falseAlarm{$fileID} = 0.0; - $minStartTime{$fileID} = 0.0; - $maxEndTime{$fileID} = 0.0; - } - $refSegName[$numLines] = $field[0]; - $refSegStart[$numLines] = $field[2]; - $refSegEnd[$numLines] = $field[3]; - $actualSpeech{$fileID} += ($field[3]-$field[2]); - $minStartTime{$fileID} = $field[2] if ($minStartTime{$fileID}>$field[2]); - $maxEndTime{$fileID} = $field[3] if ($maxEndTime{$fileID}<$field[3]); - $lastSeg{$fileID} = $numLines; - ++$numLines; -} -close(SEGMENTS); -print STDERR "Read $numLines segments from $ReferenceSegmentation\n"; - -################################################################################ -# Process hypothesized segments sequentially, and gather speech/nonspeech stats -################################################################################ - -open (SEGMENTS, "cat $HypothesizedSegmentation | sort -k2,2 -k1,1 |") - # Kaldi segments files are sorted by UtteranceID, but we re-sort them here - # so that all segments of a file are read together, sorted by start-time. - || die "Unable to open $HypothesizedSegmentation"; -$numLines = 0; -$totalHypSpeech = 0.0; -$totalFoundSpeech = 0.0; -$totalFalseAlarm = 0.0; -$numShortSegs = 0; -$numLongSegs = 0; -while ($line=) { - chomp $line; - @field = split("[ \t]+", $line); - unless ($#field == 3) { - exit (1); - printf STDERR "Skipping unparseable line in file $HypothesizedSegmentation\n\t$line\n"; - next; - } - $fileID = $field[1]; - $segStart = $field[2]; - $segEnd = $field[3]; - if (exists $firstSeg{$fileID}) { - # This FileID exists in the reference segmentation - # So gather statistics for this UtteranceID - $hypothesizedSpeech{$fileID} += ($segEnd-$segStart); - $totalHypSpeech += ($segEnd-$segStart); - if (($segStart>=$maxEndTime{$fileID}) || ($segEnd<=$minStartTime{$fileID})) { - # This entire segment is a false alarm - $falseAlarm{$fileID} += ($segEnd-$segStart); - $totalFalseAlarm += ($segEnd-$segStart); - } else { - # This segment may overlap one or more reference segments - $p = $firstSeg{$fileID}; - while ($refSegEnd[$p]<=$segStart) { - ++$p; - } - # The overlap, if any, begins at the reference segment p - $q = $lastSeg{$fileID}; - while ($refSegStart[$q]>=$segEnd) { - --$q; - } - # The overlap, if any, ends at the reference segment q - if ($q<$p) { - # This segment sits entirely in the nonspeech region - # between the two reference speech segments q and p - $falseAlarm{$fileID} += ($segEnd-$segStart); - $totalFalseAlarm += ($segEnd-$segStart); - } else { - if (($segEnd-$segStart)<0.20) { - # For diagnosing Pascal's VAD segmentation - print STDOUT "Found short speech region $line\n"; - ++$numShortSegs; - } elsif (($segEnd-$segStart)>60.0) { - ++$numLongSegs; - # For diagnosing Pascal's VAD segmentation - print STDOUT "Found long speech region $line\n"; - } - # There is some overlap with segments p through q - for ($s=$p; $s<=$q; ++$s) { - if ($segStart<$refSegStart[$s]) { - # There is a leading false alarm portion before s - $falseAlarm{$fileID} += ($refSegStart[$s]-$segStart); - $totalFalseAlarm += ($refSegStart[$s]-$segStart); - $segStart=$refSegStart[$s]; - } - $speechPortion = ($refSegEnd[$s]<$segEnd) ? - ($refSegEnd[$s]-$segStart) : ($segEnd-$segStart); - $foundSpeech{$fileID} += $speechPortion; - $totalFoundSpeech += $speechPortion; - $segStart=$refSegEnd[$s]; - } - if ($segEnd>$segStart) { - # There is a trailing false alarm portion after q - $falseAlarm{$fileID} += ($segEnd-$segStart); - $totalFalseAlarm += ($segEnd-$segStart); - } - } - } - } else { - # This FileID does not exist in the reference segmentation - # So all this speech counts as a false alarm - exit (1); - printf STDERR ("Unexpected fileID in hypothesized segments: %s", $fileID); - $totalFalseAlarm += ($segEnd-$segStart); - } - ++$numLines; -} -close(SEGMENTS); -print STDERR "Read $numLines segments from $HypothesizedSegmentation\n"; - -################################################################################ -# Now that all hypothesized segments have been processed, compute needed stats -################################################################################ - -$totalActualSpeech = 0.0; -$totalNonSpeechEst = 0.0; # This is just a crude estimate of total nonspeech. -foreach $fileID (sort keys %actualSpeech) { - $totalActualSpeech += $actualSpeech{$fileID}; - $totalNonSpeechEst += $maxEndTime{$fileID} - $actualSpeech{$fileID}; - ####################################################################### - # Print file-wise statistics to STDOUT; can pipe to /dev/null is needed - ####################################################################### - printf STDOUT ("%s: %.2f min actual speech, %.2f min hypothesized: %.2f min overlap (%d\%), %.2f min false alarm (~%d\%)\n", - $fileID, - ($actualSpeech{$fileID}/60.0), - ($hypothesizedSpeech{$fileID}/60.0), - ($foundSpeech{$fileID}/60.0), - ($foundSpeech{$fileID}*100/($actualSpeech{$fileID}+0.01)), - ($falseAlarm{$fileID}/60.0), - ($falseAlarm{$fileID}*100/($maxEndTime{$fileID}-$actualSpeech{$fileID}+0.01))); -} - -################################################################################ -# Finally, we have everything needed to report the segmentation statistics. -################################################################################ - -printf STDERR ("------------------------------------------------------------------------\n"); -printf STDERR ("TOTAL: %.2f hrs actual speech, %.2f hrs hypothesized: %.2f hrs overlap (%d\%), %.2f hrs false alarm (~%d\%)\n", - ($totalActualSpeech/3600.0), - ($totalHypSpeech/3600.0), - ($totalFoundSpeech/3600.0), - ($totalFoundSpeech*100/($totalActualSpeech+0.000001)), - ($totalFalseAlarm/3600.0), - ($totalFalseAlarm*100/($totalNonSpeechEst+0.000001))); -printf STDERR ("\t$numShortSegs segments < 0.2 sec and $numLongSegs segments > 60.0 sec\n"); -printf STDERR ("------------------------------------------------------------------------\n"); diff --git a/egs/babel/s5c/local/resegment/evaluate_segmentation.pl b/egs/babel/s5c/local/resegment/evaluate_segmentation.pl new file mode 120000 index 00000000000..09276466c2b --- /dev/null +++ b/egs/babel/s5c/local/resegment/evaluate_segmentation.pl @@ -0,0 +1 @@ +../../steps/segmentation/evaluate_segmentation.py \ No newline at end of file diff --git a/egs/bn_music_speech/v1/local/run_nnet3_music_id.sh b/egs/bn_music_speech/v1/local/run_nnet3_music_id.sh new file mode 100644 index 00000000000..d96acdabaaa --- /dev/null +++ b/egs/bn_music_speech/v1/local/run_nnet3_music_id.sh @@ -0,0 +1,217 @@ +#!/bin/bash + +set -e +set -o pipefail +set -u + +. path.sh +. cmd.sh + +feat_affix=bp_vh +affix= +reco_nj=32 + +stage=-1 + +# SAD network config +iter=final +extra_left_context=100 # Set to some large value +extra_right_context=20 + + +# Configs +frame_subsampling_factor=1 + +min_silence_duration=3 # minimum number of frames for silence +min_speech_duration=3 # minimum number of frames for speech +min_music_duration=3 # minimum number of frames for music +music_transition_probability=0.1 +sil_transition_probability=0.1 +speech_transition_probability=0.1 +sil_prior=0.3 +speech_prior=0.4 +music_prior=0.3 + +# Decoding options +acwt=1 +beam=10 +max_active=7000 + +mfcc_config=conf/mfcc_hires_bp.conf + +echo $* + +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/bn exp/nnet3_sad_snr/tdnn_j_n4 exp/dnn_music_id" + exit 1 +fi + +# Set to true if the test data has > 8kHz sampling frequency. +do_downsampling=true + +data_dir=$1 +sad_nnet_dir=$2 +dir=$3 + +data_id=`basename $data_dir` + +export PATH="$KALDI_ROOT/tools/sph2pipe_v2.5/:$PATH" +[ ! -z `which sph2pipe` ] + +for f in $sad_nnet_dir/$iter.raw $sad_nnet_dir/post_output-speech.vec $sad_nnet_dir/post_output-music.vec; do + if [ ! -f $f ]; then + echo "$0: Could not find $f. See the local/segmentation/run_train_sad.sh" + exit 1 + fi +done + +mkdir -p $dir + +new_data_dir=$dir/${data_id} +if [ $stage -le 0 ]; then + utils/data/convert_data_dir_to_whole.sh $data_dir ${new_data_dir}_whole + + freq=`cat $mfcc_config | perl -pe 's/\s*#.*//g' | grep "sample-frequency=" | awk -F'=' '{if (NF == 0) print 16000; else print $2}'` + sox=`which sox` + + cat $data_dir/wav.scp | python -c "import sys +for line in sys.stdin.readlines(): + splits = line.strip().split() + if splits[-1] == '|': + out_line = line.strip() + ' $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |' + else: + out_line = 'cat {0} {1} | $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |'.format(splits[0], ' '.join(splits[1:])) + print (out_line)" > ${new_data_dir}_whole/wav.scp + + utils/copy_data_dir.sh ${new_data_dir}_whole ${new_data_dir}_whole_bp_hires +fi + +test_data_dir=${new_data_dir}_whole_bp_hires + +if [ $stage -le 1 ]; then + steps/make_mfcc.sh --mfcc-config $mfcc_config --nj $reco_nj --cmd "$train_cmd" \ + ${new_data_dir}_whole_bp_hires exp/make_hires/${data_id}_whole_bp mfcc_hires + steps/compute_cmvn_stats.sh ${new_data_dir}_whole_bp_hires exp/make_hires/${data_id}_whole_bp mfcc_hires +fi + +if [ $stage -le 2 ]; then + output_name=output-speech + post_vec=$sad_nnet_dir/post_${output_name}.vec + steps/nnet3/compute_output.sh --nj $reco_nj --cmd "$train_cmd" \ + --post-vec "$post_vec" \ + --iter $iter \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --frames-per-chunk 150 \ + --output-name $output_name \ + --frame-subsampling-factor $frame_subsampling_factor \ + --get-raw-nnet-from-am false ${test_data_dir} $sad_nnet_dir $dir/sad_${data_id}_whole_bp +fi + +if [ $stage -le 3 ]; then + output_name=output-music + post_vec=$sad_nnet_dir/post_${output_name}.vec + steps/nnet3/compute_output.sh --nj $reco_nj --cmd "$train_cmd" \ + --post-vec "$post_vec" \ + --iter $iter \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --frames-per-chunk 150 \ + --output-name $output_name \ + --frame-subsampling-factor $frame_subsampling_factor \ + --get-raw-nnet-from-am false ${test_data_dir} $sad_nnet_dir $dir/music_${data_id}_whole_bp +fi + +if [ $stage -le 4 ]; then + $train_cmd JOB=1:$reco_nj $dir/get_average_likes.JOB.log \ + paste-feats \ + "ark:gunzip -c $dir/sad_${data_id}_whole_bp/log_likes.JOB.gz | extract-feature-segments ark:- 'utils/filter_scp.pl -f 2 ${test_data_dir}/split$reco_nj/JOB/utt2spk $data_dir/segments |' ark:- |" \ + "ark:gunzip -c $dir/music_${data_id}_whole_bp/log_likes.JOB.gz | select-feats 1 ark:- ark:- | extract-feature-segments ark:- 'utils/filter_scp.pl -f 2 ${test_data_dir}/split$reco_nj/JOB/utt2spk $data_dir/segments |' ark:- |" \ + ark:- \| \ + matrix-sum-rows --do-average ark:- ark,t:$dir/average_likes.JOB.ark + + for n in `seq $reco_nj`; do + cat $dir/average_likes.$n.ark + done | awk '{print $1" "( exp($3) + exp($5) + 0.01) / (exp($4) + 0.01)}' | \ + local/print_scores.py /dev/stdin | compute-eer - +fi + +lang=$dir/lang + +if [ $stage -le 5 ]; then + mkdir -p $lang + + # Create a lang directory with phones.txt and topo with + # silence, music and speech phones. + steps/segmentation/internal/prepare_sad_lang.py \ + --phone-transition-parameters="--phone-list=1 --min-duration=$min_silence_duration --end-transition-probability=$sil_transition_probability" \ + --phone-transition-parameters="--phone-list=2 --min-duration=$min_speech_duration --end-transition-probability=$speech_transition_probability" \ + --phone-transition-parameters="--phone-list=3 --min-duration=$min_music_duration --end-transition-probability=$music_transition_probability" \ + $lang + + cp $lang/phones.txt $lang/words.txt +fi + +feat_dim=2 # dummy. We don't need this. +if [ $stage -le 6 ]; then + $train_cmd $dir/log/create_transition_model.log gmm-init-mono \ + $lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 +fi + +# Make unigram G.fst +if [ $stage -le 7 ]; then + cat > $lang/word2prior < $lang/G.fst +fi + +graph_dir=$dir/graph_test + +if [ $stage -le 8 ]; then + $train_cmd $dir/log/make_vad_graph.log \ + steps/segmentation/internal/make_sad_graph.sh --iter trans \ + $lang $dir $dir/graph_test || exit 1 +fi + +seg_dir=$dir/segmentation_${data_id}_whole_bp +mkdir -p $seg_dir + +if [ $stage -le 9 ]; then + decoder_opts+=(--acoustic-scale=$acwt --beam=$beam --max-active=$max_active) + $train_cmd JOB=1:$reco_nj $dir/decode.JOB.log \ + paste-feats \ + "ark:gunzip -c $dir/sad_${data_id}_whole_bp/log_likes.JOB.gz | extract-feature-segments ark:- 'utils/filter_scp.pl -f 2 ${test_data_dir}/split$reco_nj/JOB/utt2spk $data_dir/segments |' ark:- |" \ + "ark:gunzip -c $dir/music_${data_id}_whole_bp/log_likes.JOB.gz | select-feats 1 ark:- ark:- | extract-feature-segments ark:- 'utils/filter_scp.pl -f 2 ${test_data_dir}/split$reco_nj/JOB/utt2spk $data_dir/segments |' ark:- |" \ + ark:- \| decode-faster-mapped ${decoder_opts[@]} \ + $dir/trans.mdl $graph_dir/HCLG.fst ark:- \ + ark:/dev/null ark:- \| \ + ali-to-phones --per-frame $dir/trans.mdl ark:- \ + "ark:|gzip -c > $seg_dir/ali.JOB.gz" +fi + +include_silence=true +if [ $stage -le 10 ]; then + $train_cmd JOB=1:$reco_nj $dir/log/get_class_id.JOB.log \ + ali-to-post "ark:gunzip -c $seg_dir/ali.JOB.gz |" ark:- \| \ + post-to-feats --post-dim=4 ark:- ark:- \| \ + matrix-sum-rows --do-average ark:- ark,t:- \| \ + sid/vector_to_music_labels.pl ${include_silence:+--include-silence-in-music} '>' $dir/ratio.JOB + + for n in `seq $reco_nj`; do + cat $dir/ratio.$n + done > $dir/ratio + + cat $dir/ratio | local/print_scores.py /dev/stdin | compute-eer - +fi + +# LOG (compute-eer:main():compute-eer.cc:136) Equal error rate is 0.860585%, at threshold 1.99361 diff --git a/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py b/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py index 1f7253d4891..7f1a5f74fe2 100644 --- a/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py +++ b/egs/wsj/s5/steps/data/data_dir_manipulation_lib.py @@ -1,4 +1,10 @@ -import subprocess +#!/usr/bin/env python +# Copyright 2016 Tom Ko +# 2016 Vimal Manohar +# Apache 2.0 + +from __future__ import print_function +import subprocess, random, argparse, os, shlex, warnings def RunKaldiCommand(command, wait = True): """ Runs commands frequently seen in Kaldi scripts. These are usually a @@ -16,3 +22,415 @@ def RunKaldiCommand(command, wait = True): else: return p +class list_cyclic_iterator: + def __init__(self, list): + self.list_index = 0 + self.list = list + random.shuffle(self.list) + + def next(self): + item = self.list[self.list_index] + self.list_index = (self.list_index + 1) % len(self.list) + return item + +# This functions picks an item from the collection according to the associated probability distribution. +# The probability estimate of each item in the collection is stored in the "probability" field of +# the particular item. x : a collection (list or dictionary) where the values contain a field called probability +def PickItemWithProbability(x): + if isinstance(x, dict): + plist = list(set(x.values())) + else: + plist = x + total_p = sum(item.probability for item in plist) + p = random.uniform(0, total_p) + accumulate_p = 0 + for item in plist: + if accumulate_p + item.probability >= p: + return item + accumulate_p += item.probability + assert False, "Shouldn't get here as the accumulated probability should always equal to 1" + +# This function smooths the probability distribution in the list +def SmoothProbabilityDistribution(list, smoothing_weight=0.0, target_sum=1.0): + if len(list) > 0: + num_unspecified = 0 + accumulated_prob = 0 + for item in list: + if item.probability is None: + num_unspecified += 1 + else: + accumulated_prob += item.probability + + # Compute the probability for the items without specifying their probability + uniform_probability = 0 + if num_unspecified > 0 and accumulated_prob < 1: + uniform_probability = (1 - accumulated_prob) / float(num_unspecified) + elif num_unspecified > 0 and accumulate_prob >= 1: + warnings.warn("The sum of probabilities specified by user is larger than or equal to 1. " + "The items without probabilities specified will be given zero to their probabilities.") + + for item in list: + if item.probability is None: + item.probability = uniform_probability + else: + # smooth the probability + item.probability = (1 - smoothing_weight) * item.probability + smoothing_weight * uniform_probability + + # Normalize the probability + sum_p = sum(item.probability for item in list) + for item in list: + item.probability = item.probability / sum_p * target_sum + + return list + +# This function parses a file and pack the data into a dictionary +# It is useful for parsing file like wav.scp, utt2spk, text...etc +def ParseFileToDict(file, assert2fields = False, value_processor = None): + if value_processor is None: + value_processor = lambda x: x[0] + + dict = {} + for line in open(file, 'r'): + parts = line.split() + if assert2fields: + assert(len(parts) == 2) + + dict[parts[0]] = value_processor(parts[1:]) + return dict + +# This function creates a file and write the content of a dictionary into it +def WriteDictToFile(dict, file_name): + file = open(file_name, 'w') + keys = dict.keys() + keys.sort() + for key in keys: + value = dict[key] + if type(value) in [list, tuple] : + if type(value) is tuple: + value = list(value) + value.sort() + value = ' '.join([ str(x) for x in value ]) + file.write('{0} {1}\n'.format(key, value)) + file.close() + + +# This function creates the utt2uniq file from the utterance id in utt2spk file +def CreateCorruptedUtt2uniq(input_dir, output_dir, num_replicas, include_original, prefix): + corrupted_utt2uniq = {} + # Parse the utt2spk to get the utterance id + utt2spk = ParseFileToDict(input_dir + "/utt2spk", value_processor = lambda x: " ".join(x)) + keys = utt2spk.keys() + keys.sort() + if include_original: + start_index = 0 + else: + start_index = 1 + + for i in range(start_index, num_replicas+1): + for utt_id in keys: + new_utt_id = GetNewId(utt_id, prefix, i) + corrupted_utt2uniq[new_utt_id] = utt_id + + WriteDictToFile(corrupted_utt2uniq, output_dir + "/utt2uniq") + +# This function generates a new id from the input id +# This is needed when we have to create multiple copies of the original data +# E.g. GetNewId("swb0035", prefix="rvb", copy=1) returns a string "rvb1_swb0035" +def GetNewId(id, prefix=None, copy=0): + if prefix is not None: + new_id = prefix + str(copy) + "_" + id + else: + new_id = id + + return new_id + +# This function replicate the entries in files like segments, utt2spk, text +def AddPrefixToFields(input_file, output_file, num_replicas, include_original, prefix, field = [0]): + list = map(lambda x: x.strip(), open(input_file)) + f = open(output_file, "w") + if include_original: + start_index = 0 + else: + start_index = 1 + + for i in range(start_index, num_replicas+1): + for line in list: + if len(line) > 0 and line[0] != ';': + split1 = line.split() + for j in field: + split1[j] = GetNewId(split1[j], prefix, i) + print(" ".join(split1), file=f) + else: + print(line, file=f) + f.close() + +def CopyDataDirFiles(input_dir, output_dir, num_replicas, include_original, prefix): + if not os.path.isfile(output_dir + "/wav.scp"): + raise Exception("CopyDataDirFiles function expects output_dir to contain wav.scp already") + + AddPrefixToFields(input_dir + "/utt2spk", output_dir + "/utt2spk", num_replicas, include_original, prefix, field = [0,1]) + RunKaldiCommand("utils/utt2spk_to_spk2utt.pl <{output_dir}/utt2spk >{output_dir}/spk2utt" + .format(output_dir = output_dir)) + + if os.path.isfile(input_dir + "/utt2uniq"): + AddPrefixToFields(input_dir + "/utt2uniq", output_dir + "/utt2uniq", num_replicas, include_original, prefix, field =[0]) + else: + # Create the utt2uniq file + CreateCorruptedUtt2uniq(input_dir, output_dir, num_replicas, include_original, prefix) + + if os.path.isfile(input_dir + "/text"): + AddPrefixToFields(input_dir + "/text", output_dir + "/text", num_replicas, prefix, include_original, field =[0]) + if os.path.isfile(input_dir + "/segments"): + AddPrefixToFields(input_dir + "/segments", output_dir + "/segments", num_replicas, prefix, include_original, field = [0,1]) + if os.path.isfile(input_dir + "/reco2file_and_channel"): + AddPrefixToFields(input_dir + "/reco2file_and_channel", output_dir + "/reco2file_and_channel", num_replicas, include_original, prefix, field = [0,1]) + + AddPrefixToFields(input_dir + "/reco2dur", output_dir + "/reco2dur", num_replicas, include_original, prefix, field = [0]) + + RunKaldiCommand("utils/validate_data_dir.sh --no-feats {output_dir}" + .format(output_dir = output_dir)) + + +# This function parse the array of rir set parameter strings. +# It will assign probabilities to those rir sets which don't have a probability +# It will also check the existence of the rir list files. +def ParseSetParameterStrings(set_para_array): + set_list = [] + for set_para in set_para_array: + set = lambda: None + setattr(set, "filename", None) + setattr(set, "probability", None) + parts = set_para.split(',') + if len(parts) == 2: + set.probability = float(parts[0]) + set.filename = parts[1].strip() + else: + set.filename = parts[0].strip() + if not os.path.isfile(set.filename): + raise Exception(set.filename + " not found") + set_list.append(set) + + return SmoothProbabilityDistribution(set_list) + + +# This function creates the RIR list +# Each rir object in the list contains the following attributes: +# rir_id, room_id, receiver_position_id, source_position_id, rt60, drr, probability +# Please refer to the help messages in the parser for the meaning of these attributes +def ParseRirList(rir_set_para_array, smoothing_weight, sampling_rate = None): + rir_parser = argparse.ArgumentParser() + rir_parser.add_argument('--rir-id', type=str, required=True, help='This id is unique for each RIR and the noise may associate with a particular RIR by refering to this id') + rir_parser.add_argument('--room-id', type=str, required=True, help='This is the room that where the RIR is generated') + rir_parser.add_argument('--receiver-position-id', type=str, default=None, help='receiver position id') + rir_parser.add_argument('--source-position-id', type=str, default=None, help='source position id') + rir_parser.add_argument('--rt60', type=float, default=None, help='RT60 is the time required for reflections of a direct sound to decay 60 dB.') + rir_parser.add_argument('--drr', type=float, default=None, help='Direct-to-reverberant-ratio of the impulse response.') + rir_parser.add_argument('--cte', type=float, default=None, help='Early-to-late index of the impulse response.') + rir_parser.add_argument('--probability', type=float, default=None, help='probability of the impulse response.') + rir_parser.add_argument('rir_rspecifier', type=str, help="""rir rspecifier, it can be either a filename or a piped command. + E.g. data/impulses/Room001-00001.wav or "sox data/impulses/Room001-00001.wav -t wav - |" """) + + set_list = ParseSetParameterStrings(rir_set_para_array) + + rir_list = [] + for rir_set in set_list: + current_rir_list = map(lambda x: rir_parser.parse_args(shlex.split(x.strip())),open(rir_set.filename)) + for rir in current_rir_list: + if sampling_rate is not None: + # check if the rspecifier is a pipe or not + if len(rir.rir_rspecifier.split()) == 1: + rir.rir_rspecifier = "sox {0} -r {1} -t wav - |".format(rir.rir_rspecifier, sampling_rate) + else: + rir.rir_rspecifier = "{0} sox -t wav - -r {1} -t wav - |".format(rir.rir_rspecifier, sampling_rate) + + rir_list += SmoothProbabilityDistribution(current_rir_list, smoothing_weight, rir_set.probability) + + return rir_list + + +# This dunction checks if the inputs are approximately equal assuming they are floats. +def almost_equal(value_1, value_2, accuracy = 10**-8): + return abs(value_1 - value_2) < accuracy + +# This function converts a list of RIRs into a dictionary of RIRs indexed by the room-id. +# Its values are objects with two attributes: a local RIR list +# and the probability of the corresponding room +# Please look at the comments at ParseRirList() for the attributes that a RIR object contains +def MakeRoomDict(rir_list): + room_dict = {} + for rir in rir_list: + if rir.room_id not in room_dict: + # add new room + room_dict[rir.room_id] = lambda: None + setattr(room_dict[rir.room_id], "rir_list", []) + setattr(room_dict[rir.room_id], "probability", 0) + room_dict[rir.room_id].rir_list.append(rir) + + # the probability of the room is the sum of probabilities of its RIR + for key in room_dict.keys(): + room_dict[key].probability = sum(rir.probability for rir in room_dict[key].rir_list) + + assert almost_equal(sum(room_dict[key].probability for key in room_dict.keys()), 1.0) + + return room_dict + + +# This function creates the point-source noise list +# and the isotropic noise dictionary from the noise information file +# The isotropic noise dictionary is indexed by the room +# and its value is the corrresponding isotropic noise list +# Each noise object in the list contains the following attributes: +# noise_id, noise_type, bg_fg_type, room_linkage, probability, noise_rspecifier +# Please refer to the help messages in the parser for the meaning of these attributes +def ParseNoiseList(noise_set_para_array, smoothing_weight, sampling_rate = None): + noise_parser = argparse.ArgumentParser() + noise_parser.add_argument('--noise-id', type=str, required=True, help='noise id') + noise_parser.add_argument('--noise-type', type=str, required=True, help='the type of noise; i.e. isotropic or point-source', choices = ["isotropic", "point-source"]) + noise_parser.add_argument('--bg-fg-type', type=str, default="background", help='background or foreground noise, for background noises, ' + 'they will be extended before addition to cover the whole speech; for foreground noise, they will be kept ' + 'to their original duration and added at a random point of the speech.', choices = ["background", "foreground"]) + noise_parser.add_argument('--room-linkage', type=str, default=None, help='required if isotropic, should not be specified if point-source.') + noise_parser.add_argument('--probability', type=float, default=None, help='probability of the noise.') + noise_parser.add_argument('noise_rspecifier', type=str, help="""noise rspecifier, it can be either a filename or a piped command. + E.g. type5_noise_cirline_ofc_ambient1.wav or "sox type5_noise_cirline_ofc_ambient1.wav -t wav - |" """) + + set_list = ParseSetParameterStrings(noise_set_para_array) + + pointsource_noise_list = [] + iso_noise_dict = {} + for noise_set in set_list: + current_noise_list = map(lambda x: noise_parser.parse_args(shlex.split(x.strip())),open(noise_set.filename)) + current_pointsource_noise_list = [] + for noise in current_noise_list: + if sampling_rate is not None: + # check if the rspecifier is a pipe or not + if len(noise.noise_rspecifier.split()) == 1: + noise.noise_rspecifier = "sox {0} -r {1} -t wav - |".format(noise.noise_rspecifier, sampling_rate) + else: + noise.noise_rspecifier = "{0} sox -t wav - -r {1} -t wav - |".format(noise.noise_rspecifier, sampling_rate) + + if noise.noise_type == "isotropic": + if noise.room_linkage is None: + raise Exception("--room-linkage must be specified if --noise-type is isotropic") + else: + if noise.room_linkage not in iso_noise_dict: + iso_noise_dict[noise.room_linkage] = [] + iso_noise_dict[noise.room_linkage].append(noise) + else: + current_pointsource_noise_list.append(noise) + + pointsource_noise_list += SmoothProbabilityDistribution(current_pointsource_noise_list, smoothing_weight, noise_set.probability) + + # ensure the point-source noise probabilities sum to 1 + pointsource_noise_list = SmoothProbabilityDistribution(pointsource_noise_list, smoothing_weight, 1.0) + if len(pointsource_noise_list) > 0: + assert almost_equal(sum(noise.probability for noise in pointsource_noise_list), 1.0) + + # ensure the isotropic noise source probabilities for a given room sum to 1 + for key in iso_noise_dict.keys(): + iso_noise_dict[key] = SmoothProbabilityDistribution(iso_noise_dict[key]) + assert almost_equal(sum(noise.probability for noise in iso_noise_dict[key]), 1.0) + + return (pointsource_noise_list, iso_noise_dict) + +def AddPointSourceNoise(room, # the room selected + pointsource_noise_list, # the point source noise list + pointsource_noise_addition_probability, # Probability of adding point-source noises + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_dur, # duration of the recording + max_noises_recording, # Maximum number of point-source noises that can be added + noise_addition_descriptor # descriptor to store the information of the noise added + ): + num_noises_added = 0 + if len(pointsource_noise_list) > 0 and random.random() < pointsource_noise_addition_probability and max_noises_recording >= 1: + for k in range(random.randint(1, max_noises_recording)): + num_noises_added = num_noises_added + 1 + # pick the RIR to reverberate the point-source noise + noise = PickItemWithProbability(pointsource_noise_list) + noise_rir = PickItemWithProbability(room.rir_list) + # If it is a background noise, the noise will be extended and be added to the whole speech + # if it is a foreground noise, the noise will not extended and be added at a random time of the speech + if noise.bg_fg_type == "background": + noise_rvb_command = """wav-reverberate --impulse-response="{0}" --duration={1}""".format(noise_rir.rir_rspecifier, speech_dur) + noise_addition_descriptor['start_times'].append(0) + noise_addition_descriptor['snrs'].append(background_snrs.next()) + noise_addition_descriptor['durations'].append(speech_dur) + noise_addition_descriptor['noise_ids'].append(noise.noise_id) + else: + noise_rvb_command = """wav-reverberate --impulse-response="{0}" """.format(noise_rir.rir_rspecifier) + noise_addition_descriptor['start_times'].append(round(random.random() * speech_dur, 2)) + noise_addition_descriptor['snrs'].append(foreground_snrs.next()) + noise_addition_descriptor['durations'].append(-1) + noise_addition_descriptor['noise_ids'].append(noise.noise_id) + + # check if the rspecifier is a pipe or not + if len(noise.noise_rspecifier.split()) == 1: + noise_addition_descriptor['noise_io'].append("{1} {0} - |".format(noise.noise_rspecifier, noise_rvb_command)) + else: + noise_addition_descriptor['noise_io'].append("{0} {1} - - |".format(noise.noise_rspecifier, noise_rvb_command)) + +# This function randomly decides whether to reverberate, and sample a RIR if it does +# It also decides whether to add the appropriate noises +# This function return the string of options to the binary wav-reverberate +def GenerateReverberationOpts(room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_rvb_probability, # Probability of reverberating a speech signal + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + speech_dur, # duration of the recording + max_noises_recording # Maximum number of point-source noises that can be added + ): + impulse_response_opts = "" + additive_noise_opts = "" + + noise_addition_descriptor = {'noise_io': [], + 'start_times': [], + 'snrs': [], + 'noise_ids': [], + 'durations': [] + } + # Randomly select the room + # Here the room probability is a sum of the probabilities of the RIRs recorded in the room. + room = PickItemWithProbability(room_dict) + # Randomly select the RIR in the room + speech_rir = PickItemWithProbability(room.rir_list) + if random.random() < speech_rvb_probability: + # pick the RIR to reverberate the speech + impulse_response_opts = """--impulse-response="{0}" """.format(speech_rir.rir_rspecifier) + + rir_iso_noise_list = [] + if speech_rir.room_id in iso_noise_dict: + rir_iso_noise_list = iso_noise_dict[speech_rir.room_id] + # Add the corresponding isotropic noise associated with the selected RIR + if len(rir_iso_noise_list) > 0 and random.random() < isotropic_noise_addition_probability: + isotropic_noise = PickItemWithProbability(rir_iso_noise_list) + # extend the isotropic noise to the length of the speech waveform + # check if the rspecifier is really a pipe + if len(isotropic_noise.noise_rspecifier.split()) == 1: + noise_addition_descriptor['noise_io'].append("wav-reverberate --duration={1} {0} - |".format(isotropic_noise.noise_rspecifier, speech_dur)) + else: + noise_addition_descriptor['noise_io'].append("{0} wav-reverberate --duration={1} - - |".format(isotropic_noise.noise_rspecifier, speech_dur)) + noise_addition_descriptor['start_times'].append(0) + noise_addition_descriptor['snrs'].append(background_snrs.next()) + noise_addition_descriptor['noise_ids'].append(isotropic_noise.noise_id) + noise_addition_descriptor['durations'].append(speech_dur) + + AddPointSourceNoise(room, # the room selected + pointsource_noise_list, # the point source noise list + pointsource_noise_addition_probability, # Probability of adding point-source noises + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_dur, # duration of the recording + max_noises_recording, # Maximum number of point-source noises that can be added + noise_addition_descriptor # descriptor to store the information of the noise added + ) + + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['start_times']) + assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['snrs']) + + return [impulse_response_opts, noise_addition_descriptor] + diff --git a/egs/wsj/s5/steps/data/reverberate_data_dir.py b/egs/wsj/s5/steps/data/reverberate_data_dir.py index 0083efa4939..69bc5e08b3b 100755 --- a/egs/wsj/s5/steps/data/reverberate_data_dir.py +++ b/egs/wsj/s5/steps/data/reverberate_data_dir.py @@ -5,7 +5,7 @@ # we're using python 3.x style print but want it to work in python 2.x, from __future__ import print_function -import argparse, shlex, glob, math, os, random, sys, warnings, copy, imp, ast +import argparse, glob, math, os, random, sys, warnings, copy, imp, ast data_lib = imp.load_source('dml', 'steps/data/data_dir_manipulation_lib.py') @@ -20,7 +20,7 @@ def GetArgs(): "--random-seed 1 data/train data/train_rvb", formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("--rir-set-parameters", type=str, action='append', required = True, dest = "rir_set_para_array", + parser.add_argument("--rir-set-parameters", type=str, action='append', required = True, dest = "rir_set_para_array", help="Specifies the parameters of an RIR set. " "Supports the specification of mixture_weight and rir_list_file_name. The mixture weight is optional. " "The default mixture weight is the probability mass remaining after adding the mixture weights " @@ -71,6 +71,9 @@ def GetArgs(): "the RIRs/noises will be resampled to the rate of the source data.") parser.add_argument("--include-original-data", type=str, help="If true, the output data includes one copy of the original data", choices=['true', 'false'], default = "false") + parser.add_argument("--output-additive-noise-dir", type=str, help="Output directory corresponding to the additive noise part of the data corruption") + parser.add_argument("--output-reverb-dir", type=str, help="Output directory corresponding to the reverberated signal part of the data corruption") + parser.add_argument("input_dir", help="Input data directory") parser.add_argument("output_dir", @@ -87,12 +90,29 @@ def CheckArgs(args): if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) - ## Check arguments + ## Check arguments. + if args.prefix is None: if args.num_replicas > 1 or args.include_original_data == "true": args.prefix = "rvb" warnings.warn("--prefix is set to 'rvb' as more than one copy of data is generated") + if args.output_reverb_dir is not None: + if args.output_reverb_dir == "": + args.output_reverb_dir = None + + if args.output_reverb_dir is not None: + if not os.path.exists(args.output_reverb_dir): + os.makedirs(args.output_reverb_dir) + + if args.output_additive_noise_dir is not None: + if args.output_additive_noise_dir == "": + args.output_additive_noise_dir = None + + if args.output_additive_noise_dir is not None: + if not os.path.exists(args.output_additive_noise_dir): + os.makedirs(args.output_additive_noise_dir) + if not args.num_replicas > 0: raise Exception("--num-replications cannot be non-positive") @@ -104,7 +124,7 @@ def CheckArgs(args): if args.isotropic_noise_addition_probability < 0 or args.isotropic_noise_addition_probability > 1: raise Exception("--isotropic-noise-addition-probability must be between 0 and 1") - + if args.rir_smoothing_weight < 0 or args.rir_smoothing_weight > 1: raise Exception("--rir-smoothing-weight must be between 0 and 1") @@ -113,208 +133,20 @@ def CheckArgs(args): if args.max_noises_per_minute < 0: raise Exception("--max-noises-per-minute cannot be negative") - + if args.source_sampling_rate is not None and args.source_sampling_rate <= 0: raise Exception("--source-sampling-rate cannot be non-positive") return args -class list_cyclic_iterator: - def __init__(self, list): - self.list_index = 0 - self.list = list - random.shuffle(self.list) - - def next(self): - item = self.list[self.list_index] - self.list_index = (self.list_index + 1) % len(self.list) - return item - - -# This functions picks an item from the collection according to the associated probability distribution. -# The probability estimate of each item in the collection is stored in the "probability" field of -# the particular item. x : a collection (list or dictionary) where the values contain a field called probability -def PickItemWithProbability(x): - if isinstance(x, dict): - plist = list(set(x.values())) - else: - plist = x - total_p = sum(item.probability for item in plist) - p = random.uniform(0, total_p) - accumulate_p = 0 - for item in plist: - if accumulate_p + item.probability >= p: - return item - accumulate_p += item.probability - assert False, "Shouldn't get here as the accumulated probability should always equal to 1" - - -# This function parses a file and pack the data into a dictionary -# It is useful for parsing file like wav.scp, utt2spk, text...etc -def ParseFileToDict(file, assert2fields = False, value_processor = None): - if value_processor is None: - value_processor = lambda x: x[0] - - dict = {} - for line in open(file, 'r'): - parts = line.split() - if assert2fields: - assert(len(parts) == 2) - - dict[parts[0]] = value_processor(parts[1:]) - return dict - -# This function creates a file and write the content of a dictionary into it -def WriteDictToFile(dict, file_name): - file = open(file_name, 'w') - keys = dict.keys() - keys.sort() - for key in keys: - value = dict[key] - if type(value) in [list, tuple] : - if type(value) is tuple: - value = list(value) - value.sort() - value = ' '.join(str(value)) - file.write('{0} {1}\n'.format(key, value)) - file.close() - - -# This function creates the utt2uniq file from the utterance id in utt2spk file -def CreateCorruptedUtt2uniq(input_dir, output_dir, num_replicas, include_original, prefix): - corrupted_utt2uniq = {} - # Parse the utt2spk to get the utterance id - utt2spk = ParseFileToDict(input_dir + "/utt2spk", value_processor = lambda x: " ".join(x)) - keys = utt2spk.keys() - keys.sort() - if include_original: - start_index = 0 - else: - start_index = 1 - - for i in range(start_index, num_replicas+1): - for utt_id in keys: - new_utt_id = GetNewId(utt_id, prefix, i) - corrupted_utt2uniq[new_utt_id] = utt_id - - WriteDictToFile(corrupted_utt2uniq, output_dir + "/utt2uniq") - - -def AddPointSourceNoise(noise_addition_descriptor, # descriptor to store the information of the noise added - room, # the room selected - pointsource_noise_list, # the point source noise list - pointsource_noise_addition_probability, # Probability of adding point-source noises - foreground_snrs, # the SNR for adding the foreground noises - background_snrs, # the SNR for adding the background noises - speech_dur, # duration of the recording - max_noises_recording # Maximum number of point-source noises that can be added - ): - if len(pointsource_noise_list) > 0 and random.random() < pointsource_noise_addition_probability and max_noises_recording >= 1: - for k in range(random.randint(1, max_noises_recording)): - # pick the RIR to reverberate the point-source noise - noise = PickItemWithProbability(pointsource_noise_list) - noise_rir = PickItemWithProbability(room.rir_list) - # If it is a background noise, the noise will be extended and be added to the whole speech - # if it is a foreground noise, the noise will not extended and be added at a random time of the speech - if noise.bg_fg_type == "background": - noise_rvb_command = """wav-reverberate --impulse-response="{0}" --duration={1}""".format(noise_rir.rir_rspecifier, speech_dur) - noise_addition_descriptor['start_times'].append(0) - noise_addition_descriptor['snrs'].append(background_snrs.next()) - else: - noise_rvb_command = """wav-reverberate --impulse-response="{0}" """.format(noise_rir.rir_rspecifier) - noise_addition_descriptor['start_times'].append(round(random.random() * speech_dur, 2)) - noise_addition_descriptor['snrs'].append(foreground_snrs.next()) - - # check if the rspecifier is a pipe or not - if len(noise.noise_rspecifier.split()) == 1: - noise_addition_descriptor['noise_io'].append("{1} {0} - |".format(noise.noise_rspecifier, noise_rvb_command)) - else: - noise_addition_descriptor['noise_io'].append("{0} {1} - - |".format(noise.noise_rspecifier, noise_rvb_command)) - - return noise_addition_descriptor - - -# This function randomly decides whether to reverberate, and sample a RIR if it does -# It also decides whether to add the appropriate noises -# This function return the string of options to the binary wav-reverberate -def GenerateReverberationOpts(room_dict, # the room dictionary, please refer to MakeRoomDict() for the format - pointsource_noise_list, # the point source noise list - iso_noise_dict, # the isotropic noise dictionary - foreground_snrs, # the SNR for adding the foreground noises - background_snrs, # the SNR for adding the background noises - speech_rvb_probability, # Probability of reverberating a speech signal - isotropic_noise_addition_probability, # Probability of adding isotropic noises - pointsource_noise_addition_probability, # Probability of adding point-source noises - speech_dur, # duration of the recording - max_noises_recording # Maximum number of point-source noises that can be added - ): - reverberate_opts = "" - noise_addition_descriptor = {'noise_io': [], - 'start_times': [], - 'snrs': []} - # Randomly select the room - # Here the room probability is a sum of the probabilities of the RIRs recorded in the room. - room = PickItemWithProbability(room_dict) - # Randomly select the RIR in the room - speech_rir = PickItemWithProbability(room.rir_list) - if random.random() < speech_rvb_probability: - # pick the RIR to reverberate the speech - reverberate_opts += """--impulse-response="{0}" """.format(speech_rir.rir_rspecifier) - - rir_iso_noise_list = [] - if speech_rir.room_id in iso_noise_dict: - rir_iso_noise_list = iso_noise_dict[speech_rir.room_id] - # Add the corresponding isotropic noise associated with the selected RIR - if len(rir_iso_noise_list) > 0 and random.random() < isotropic_noise_addition_probability: - isotropic_noise = PickItemWithProbability(rir_iso_noise_list) - # extend the isotropic noise to the length of the speech waveform - # check if the rspecifier is a pipe or not - if len(isotropic_noise.noise_rspecifier.split()) == 1: - noise_addition_descriptor['noise_io'].append("wav-reverberate --duration={1} {0} - |".format(isotropic_noise.noise_rspecifier, speech_dur)) - else: - noise_addition_descriptor['noise_io'].append("{0} wav-reverberate --duration={1} - - |".format(isotropic_noise.noise_rspecifier, speech_dur)) - noise_addition_descriptor['start_times'].append(0) - noise_addition_descriptor['snrs'].append(background_snrs.next()) - - noise_addition_descriptor = AddPointSourceNoise(noise_addition_descriptor, # descriptor to store the information of the noise added - room, # the room selected - pointsource_noise_list, # the point source noise list - pointsource_noise_addition_probability, # Probability of adding point-source noises - foreground_snrs, # the SNR for adding the foreground noises - background_snrs, # the SNR for adding the background noises - speech_dur, # duration of the recording - max_noises_recording # Maximum number of point-source noises that can be added - ) - - assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['start_times']) - assert len(noise_addition_descriptor['noise_io']) == len(noise_addition_descriptor['snrs']) - if len(noise_addition_descriptor['noise_io']) > 0: - reverberate_opts += "--additive-signals='{0}' ".format(','.join(noise_addition_descriptor['noise_io'])) - reverberate_opts += "--start-times='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['start_times']))) - reverberate_opts += "--snrs='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['snrs']))) - - return reverberate_opts - -# This function generates a new id from the input id -# This is needed when we have to create multiple copies of the original data -# E.g. GetNewId("swb0035", prefix="rvb", copy=1) returns a string "rvb1_swb0035" -def GetNewId(id, prefix=None, copy=0): - if prefix is not None: - new_id = prefix + str(copy) + "_" + id - else: - new_id = id - - return new_id - - # This is the main function to generate pipeline command for the corruption # The generic command of wav-reverberate will be like: -# wav-reverberate --duration=t --impulse-response=rir.wav +# wav-reverberate --duration=t --impulse-response=rir.wav # --additive-signals='noise1.wav,noise2.wav' --snrs='snr1,snr2' --start-times='s1,s2' input.wav output.wav def GenerateReverberatedWavScp(wav_scp, # a dictionary whose values are the Kaldi-IO strings of the speech recordings durations, # a dictionary whose values are the duration (in sec) of the speech recordings - output_dir, # output directory to write the corrupted wav.scp + output_dir, # output directory to write the corrupted wav.scp room_dict, # the room dictionary, please refer to MakeRoomDict() for the format pointsource_noise_list, # the point source noise list iso_noise_dict, # the isotropic noise dictionary @@ -327,13 +159,20 @@ def GenerateReverberatedWavScp(wav_scp, # a dictionary whose values are the Kal shift_output, # option whether to shift the output waveform isotropic_noise_addition_probability, # Probability of adding isotropic noises pointsource_noise_addition_probability, # Probability of adding point-source noises - max_noises_per_minute # maximum number of point-source noises that can be added to a recording according to its duration + max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration + output_reverb_dir = None, + output_additive_noise_dir = None ): - foreground_snrs = list_cyclic_iterator(foreground_snr_array) - background_snrs = list_cyclic_iterator(background_snr_array) + foreground_snrs = data_lib.list_cyclic_iterator(foreground_snr_array) + background_snrs = data_lib.list_cyclic_iterator(background_snr_array) corrupted_wav_scp = {} + reverb_wav_scp = {} + additive_noise_wav_scp = {} keys = wav_scp.keys() keys.sort() + + additive_signals_info = {} + if include_original: start_index = 0 else: @@ -346,51 +185,71 @@ def GenerateReverberatedWavScp(wav_scp, # a dictionary whose values are the Kal if len(wav_original_pipe.split()) == 1: wav_original_pipe = "cat {0} |".format(wav_original_pipe) speech_dur = durations[recording_id] - max_noises_recording = math.floor(max_noises_per_minute * speech_dur / 60) - - reverberate_opts = GenerateReverberationOpts(room_dict, # the room dictionary, please refer to MakeRoomDict() for the format - pointsource_noise_list, # the point source noise list - iso_noise_dict, # the isotropic noise dictionary - foreground_snrs, # the SNR for adding the foreground noises - background_snrs, # the SNR for adding the background noises - speech_rvb_probability, # Probability of reverberating a speech signal - isotropic_noise_addition_probability, # Probability of adding isotropic noises - pointsource_noise_addition_probability, # Probability of adding point-source noises - speech_dur, # duration of the recording - max_noises_recording # Maximum number of point-source noises that can be added - ) + max_noises_recording = math.ceil(max_noises_per_minute * speech_dur / 60) + + [impulse_response_opts, noise_addition_descriptor] = data_lib.GenerateReverberationOpts(room_dict, # the room dictionary, please refer to MakeRoomDict() for the format + pointsource_noise_list, # the point source noise list + iso_noise_dict, # the isotropic noise dictionary + foreground_snrs, # the SNR for adding the foreground noises + background_snrs, # the SNR for adding the background noises + speech_rvb_probability, # Probability of reverberating a speech signal + isotropic_noise_addition_probability, # Probability of adding isotropic noises + pointsource_noise_addition_probability, # Probability of adding point-source noises + speech_dur, # duration of the recording + max_noises_recording # Maximum number of point-source noises that can be added + ) + additive_noise_opts = "" + + if len(noise_addition_descriptor['noise_io']) > 0: + additive_noise_opts += "--additive-signals='{0}' ".format(','.join(noise_addition_descriptor['noise_io'])) + additive_noise_opts += "--start-times='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['start_times']))) + additive_noise_opts += "--snrs='{0}' ".format(','.join(map(lambda x:str(x), noise_addition_descriptor['snrs']))) + + reverberate_opts = impulse_response_opts + additive_noise_opts + + new_recording_id = data_lib.GetNewId(recording_id, prefix, i) # prefix using index 0 is reserved for original data e.g. rvb0_swb0035 corresponds to the swb0035 recording in original data if reverberate_opts == "" or i == 0: - wav_corrupted_pipe = "{0}".format(wav_original_pipe) + wav_corrupted_pipe = "{0}".format(wav_original_pipe) else: wav_corrupted_pipe = "{0} wav-reverberate --shift-output={1} {2} - - |".format(wav_original_pipe, shift_output, reverberate_opts) - new_recording_id = GetNewId(recording_id, prefix, i) corrupted_wav_scp[new_recording_id] = wav_corrupted_pipe - WriteDictToFile(corrupted_wav_scp, output_dir + "/wav.scp") + if output_reverb_dir is not None: + if impulse_response_opts == "": + wav_reverb_pipe = "{0}".format(wav_original_pipe) + else: + wav_reverb_pipe = "{0} wav-reverberate --shift-output={1} --reverb-out-wxfilename=- {2} - /dev/null |".format(wav_original_pipe, shift_output, reverberate_opts) + reverb_wav_scp[new_recording_id] = wav_reverb_pipe + if output_additive_noise_dir is not None: + if additive_noise_opts != "": + wav_additive_noise_pipe = "{0} wav-reverberate --shift-output={1} --additive-noise-out-wxfilename=- {2} - /dev/null |".format(wav_original_pipe, shift_output, reverberate_opts) + additive_noise_wav_scp[new_recording_id] = wav_additive_noise_pipe -# This function replicate the entries in files like segments, utt2spk, text -def AddPrefixToFields(input_file, output_file, num_replicas, include_original, prefix, field = [0]): - list = map(lambda x: x.strip(), open(input_file)) - f = open(output_file, "w") - if include_original: - start_index = 0 - else: - start_index = 1 - - for i in range(start_index, num_replicas+1): - for line in list: - if len(line) > 0 and line[0] != ';': - split1 = line.split() - for j in field: - split1[j] = GetNewId(split1[j], prefix, i) - print(" ".join(split1), file=f) - else: - print(line, file=f) - f.close() + if additive_noise_opts != "": + additive_signals_info[new_recording_id] = [ + ':'.join(x) + for x in zip(noise_addition_descriptor['noise_ids'], + [ str(x) for x in noise_addition_descriptor['start_times'] ], + [ str(x) for x in noise_addition_descriptor['durations'] ]) + ] + + # Write for each new recording, the id, start time and durations + # of the signals. Duration is -1 for the foreground noise and needs to + # be extracted separately if required by determining the durations + # using the wav file + data_lib.WriteDictToFile(additive_signals_info, output_dir + "/additive_signals_info.txt") + + data_lib.WriteDictToFile(corrupted_wav_scp, output_dir + "/wav.scp") + + if output_reverb_dir is not None: + data_lib.WriteDictToFile(reverb_wav_scp, output_reverb_dir + "/wav.scp") + + if output_additive_noise_dir is not None: + data_lib.WriteDictToFile(additive_noise_wav_scp, output_additive_noise_dir + "/wav.scp") # This function creates multiple copies of the necessary files, e.g. utt2spk, wav.scp ... @@ -408,10 +267,12 @@ def CreateReverberatedCopy(input_dir, shift_output, # option whether to shift the output waveform isotropic_noise_addition_probability, # Probability of adding isotropic noises pointsource_noise_addition_probability, # Probability of adding point-source noises - max_noises_per_minute # maximum number of point-source noises that can be added to a recording according to its duration + max_noises_per_minute, # maximum number of point-source noises that can be added to a recording according to its duration + output_reverb_dir = None, + output_additive_noise_dir = None ): - - wav_scp = ParseFileToDict(input_dir + "/wav.scp", value_processor = lambda x: " ".join(x)) + + wav_scp = data_lib.ParseFileToDict(input_dir + "/wav.scp", value_processor = lambda x: " ".join(x)) if not os.path.isfile(input_dir + "/reco2dur"): print("Getting the duration of the recordings..."); read_entire_file="false" @@ -421,225 +282,38 @@ def CreateReverberatedCopy(input_dir, read_entire_file="true" break data_lib.RunKaldiCommand("wav-to-duration --read-entire-file={1} scp:{0}/wav.scp ark,t:{0}/reco2dur".format(input_dir, read_entire_file)) - durations = ParseFileToDict(input_dir + "/reco2dur", value_processor = lambda x: float(x[0])) + durations = data_lib.ParseFileToDict(input_dir + "/reco2dur", value_processor = lambda x: float(x[0])) foreground_snr_array = map(lambda x: float(x), foreground_snr_string.split(':')) background_snr_array = map(lambda x: float(x), background_snr_string.split(':')) GenerateReverberatedWavScp(wav_scp, durations, output_dir, room_dict, pointsource_noise_list, iso_noise_dict, - foreground_snr_array, background_snr_array, num_replicas, include_original, prefix, - speech_rvb_probability, shift_output, isotropic_noise_addition_probability, - pointsource_noise_addition_probability, max_noises_per_minute) + foreground_snr_array, background_snr_array, num_replicas, include_original, prefix, + speech_rvb_probability, shift_output, isotropic_noise_addition_probability, + pointsource_noise_addition_probability, max_noises_per_minute, + output_reverb_dir = output_reverb_dir, + output_additive_noise_dir = output_additive_noise_dir) - AddPrefixToFields(input_dir + "/utt2spk", output_dir + "/utt2spk", num_replicas, include_original, prefix, field = [0,1]) - data_lib.RunKaldiCommand("utils/utt2spk_to_spk2utt.pl <{output_dir}/utt2spk >{output_dir}/spk2utt" - .format(output_dir = output_dir)) + data_lib.CopyDataDirFiles(input_dir, output_dir, num_replicas, include_original, prefix) - if os.path.isfile(input_dir + "/utt2uniq"): - AddPrefixToFields(input_dir + "/utt2uniq", output_dir + "/utt2uniq", num_replicas, include_original, prefix, field =[0]) - else: - # Create the utt2uniq file - CreateCorruptedUtt2uniq(input_dir, output_dir, num_replicas, include_original, prefix) - - if os.path.isfile(input_dir + "/text"): - AddPrefixToFields(input_dir + "/text", output_dir + "/text", num_replicas, include_original, prefix, field =[0]) - if os.path.isfile(input_dir + "/segments"): - AddPrefixToFields(input_dir + "/segments", output_dir + "/segments", num_replicas, include_original, prefix, field = [0,1]) - if os.path.isfile(input_dir + "/reco2file_and_channel"): - AddPrefixToFields(input_dir + "/reco2file_and_channel", output_dir + "/reco2file_and_channel", num_replicas, include_original, prefix, field = [0,1]) - - data_lib.RunKaldiCommand("utils/validate_data_dir.sh --no-feats {output_dir}" - .format(output_dir = output_dir)) - - -# This function smooths the probability distribution in the list -def SmoothProbabilityDistribution(list, smoothing_weight=0.0, target_sum=1.0): - if len(list) > 0: - num_unspecified = 0 - accumulated_prob = 0 - for item in list: - if item.probability is None: - num_unspecified += 1 - else: - accumulated_prob += item.probability - - # Compute the probability for the items without specifying their probability - uniform_probability = 0 - if num_unspecified > 0 and accumulated_prob < 1: - uniform_probability = (1 - accumulated_prob) / float(num_unspecified) - elif num_unspecified > 0 and accumulate_prob >= 1: - warnings.warn("The sum of probabilities specified by user is larger than or equal to 1. " - "The items without probabilities specified will be given zero to their probabilities.") - - for item in list: - if item.probability is None: - item.probability = uniform_probability - else: - # smooth the probability - item.probability = (1 - smoothing_weight) * item.probability + smoothing_weight * uniform_probability - - # Normalize the probability - sum_p = sum(item.probability for item in list) - for item in list: - item.probability = item.probability / sum_p * target_sum - - return list - - -# This function parse the array of rir set parameter strings. -# It will assign probabilities to those rir sets which don't have a probability -# It will also check the existence of the rir list files. -def ParseSetParameterStrings(set_para_array): - set_list = [] - for set_para in set_para_array: - set = lambda: None - setattr(set, "filename", None) - setattr(set, "probability", None) - parts = set_para.split(',') - if len(parts) == 2: - set.probability = float(parts[0]) - set.filename = parts[1].strip() - else: - set.filename = parts[0].strip() - if not os.path.isfile(set.filename): - raise Exception(set.filename + " not found") - set_list.append(set) - - return SmoothProbabilityDistribution(set_list) - - -# This function creates the RIR list -# Each rir object in the list contains the following attributes: -# rir_id, room_id, receiver_position_id, source_position_id, rt60, drr, probability -# Please refer to the help messages in the parser for the meaning of these attributes -def ParseRirList(rir_set_para_array, smoothing_weight, sampling_rate = None): - rir_parser = argparse.ArgumentParser() - rir_parser.add_argument('--rir-id', type=str, required=True, help='This id is unique for each RIR and the noise may associate with a particular RIR by refering to this id') - rir_parser.add_argument('--room-id', type=str, required=True, help='This is the room that where the RIR is generated') - rir_parser.add_argument('--receiver-position-id', type=str, default=None, help='receiver position id') - rir_parser.add_argument('--source-position-id', type=str, default=None, help='source position id') - rir_parser.add_argument('--rt60', type=float, default=None, help='RT60 is the time required for reflections of a direct sound to decay 60 dB.') - rir_parser.add_argument('--drr', type=float, default=None, help='Direct-to-reverberant-ratio of the impulse response.') - rir_parser.add_argument('--cte', type=float, default=None, help='Early-to-late index of the impulse response.') - rir_parser.add_argument('--probability', type=float, default=None, help='probability of the impulse response.') - rir_parser.add_argument('rir_rspecifier', type=str, help="""rir rspecifier, it can be either a filename or a piped command. - E.g. data/impulses/Room001-00001.wav or "sox data/impulses/Room001-00001.wav -t wav - |" """) - - set_list = ParseSetParameterStrings(rir_set_para_array) - - rir_list = [] - for rir_set in set_list: - current_rir_list = map(lambda x: rir_parser.parse_args(shlex.split(x.strip())),open(rir_set.filename)) - for rir in current_rir_list: - if sampling_rate is not None: - # check if the rspecifier is a pipe or not - if len(rir.rir_rspecifier.split()) == 1: - rir.rir_rspecifier = "sox {0} -r {1} -t wav - |".format(rir.rir_rspecifier, sampling_rate) - else: - rir.rir_rspecifier = "{0} sox -t wav - -r {1} -t wav - |".format(rir.rir_rspecifier, sampling_rate) - - rir_list += SmoothProbabilityDistribution(current_rir_list, smoothing_weight, rir_set.probability) - - return rir_list - - -# This dunction checks if the inputs are approximately equal assuming they are floats. -def almost_equal(value_1, value_2, accuracy = 10**-8): - return abs(value_1 - value_2) < accuracy - -# This function converts a list of RIRs into a dictionary of RIRs indexed by the room-id. -# Its values are objects with two attributes: a local RIR list -# and the probability of the corresponding room -# Please look at the comments at ParseRirList() for the attributes that a RIR object contains -def MakeRoomDict(rir_list): - room_dict = {} - for rir in rir_list: - if rir.room_id not in room_dict: - # add new room - room_dict[rir.room_id] = lambda: None - setattr(room_dict[rir.room_id], "rir_list", []) - setattr(room_dict[rir.room_id], "probability", 0) - room_dict[rir.room_id].rir_list.append(rir) - - # the probability of the room is the sum of probabilities of its RIR - for key in room_dict.keys(): - room_dict[key].probability = sum(rir.probability for rir in room_dict[key].rir_list) - - assert almost_equal(sum(room_dict[key].probability for key in room_dict.keys()), 1.0) - - return room_dict - - -# This function creates the point-source noise list -# and the isotropic noise dictionary from the noise information file -# The isotropic noise dictionary is indexed by the room -# and its value is the corrresponding isotropic noise list -# Each noise object in the list contains the following attributes: -# noise_id, noise_type, bg_fg_type, room_linkage, probability, noise_rspecifier -# Please refer to the help messages in the parser for the meaning of these attributes -def ParseNoiseList(noise_set_para_array, smoothing_weight, sampling_rate = None): - noise_parser = argparse.ArgumentParser() - noise_parser.add_argument('--noise-id', type=str, required=True, help='noise id') - noise_parser.add_argument('--noise-type', type=str, required=True, help='the type of noise; i.e. isotropic or point-source', choices = ["isotropic", "point-source"]) - noise_parser.add_argument('--bg-fg-type', type=str, default="background", help='background or foreground noise, for background noises, ' - 'they will be extended before addition to cover the whole speech; for foreground noise, they will be kept ' - 'to their original duration and added at a random point of the speech.', choices = ["background", "foreground"]) - noise_parser.add_argument('--room-linkage', type=str, default=None, help='required if isotropic, should not be specified if point-source.') - noise_parser.add_argument('--probability', type=float, default=None, help='probability of the noise.') - noise_parser.add_argument('noise_rspecifier', type=str, help="""noise rspecifier, it can be either a filename or a piped command. - E.g. type5_noise_cirline_ofc_ambient1.wav or "sox type5_noise_cirline_ofc_ambient1.wav -t wav - |" """) - - set_list = ParseSetParameterStrings(noise_set_para_array) - - pointsource_noise_list = [] - iso_noise_dict = {} - for noise_set in set_list: - current_noise_list = map(lambda x: noise_parser.parse_args(shlex.split(x.strip())),open(noise_set.filename)) - current_pointsource_noise_list = [] - for noise in current_noise_list: - if sampling_rate is not None: - # check if the rspecifier is a pipe or not - if len(noise.noise_rspecifier.split()) == 1: - noise.noise_rspecifier = "sox {0} -r {1} -t wav - |".format(noise.noise_rspecifier, sampling_rate) - else: - noise.noise_rspecifier = "{0} sox -t wav - -r {1} -t wav - |".format(noise.noise_rspecifier, sampling_rate) + if output_reverb_dir is not None: + data_lib.CopyDataDirFiles(input_dir, output_reverb_dir, num_replicas, include_original, prefix) - if noise.noise_type == "isotropic": - if noise.room_linkage is None: - raise Exception("--room-linkage must be specified if --noise-type is isotropic") - else: - if noise.room_linkage not in iso_noise_dict: - iso_noise_dict[noise.room_linkage] = [] - iso_noise_dict[noise.room_linkage].append(noise) - else: - current_pointsource_noise_list.append(noise) - - pointsource_noise_list += SmoothProbabilityDistribution(current_pointsource_noise_list, smoothing_weight, noise_set.probability) - - # ensure the point-source noise probabilities sum to 1 - pointsource_noise_list = SmoothProbabilityDistribution(pointsource_noise_list, smoothing_weight, 1.0) - if len(pointsource_noise_list) > 0: - assert almost_equal(sum(noise.probability for noise in pointsource_noise_list), 1.0) - - # ensure the isotropic noise source probabilities for a given room sum to 1 - for key in iso_noise_dict.keys(): - iso_noise_dict[key] = SmoothProbabilityDistribution(iso_noise_dict[key]) - assert almost_equal(sum(noise.probability for noise in iso_noise_dict[key]), 1.0) - - return (pointsource_noise_list, iso_noise_dict) + if output_additive_noise_dir is not None: + data_lib.CopyDataDirFiles(input_dir, output_additive_noise_dir, num_replicas, include_original, prefix) def Main(): args = GetArgs() random.seed(args.random_seed) - rir_list = ParseRirList(args.rir_set_para_array, args.rir_smoothing_weight, args.source_sampling_rate) + rir_list = data_lib.ParseRirList(args.rir_set_para_array, args.rir_smoothing_weight, args.source_sampling_rate) print("Number of RIRs is {0}".format(len(rir_list))) pointsource_noise_list = [] iso_noise_dict = {} if args.noise_set_para_array is not None: - pointsource_noise_list, iso_noise_dict = ParseNoiseList(args.noise_set_para_array, args.noise_smoothing_weight, args.source_sampling_rate) + pointsource_noise_list, iso_noise_dict = data_lib.ParseNoiseList(args.noise_set_para_array, args.noise_smoothing_weight, args.source_sampling_rate) print("Number of point-source noises is {0}".format(len(pointsource_noise_list))) print("Number of isotropic noises is {0}".format(sum(len(iso_noise_dict[key]) for key in iso_noise_dict.keys()))) - room_dict = MakeRoomDict(rir_list) + room_dict = data_lib.MakeRoomDict(rir_list) if args.include_original_data == "true": include_original = True @@ -660,8 +334,11 @@ def Main(): shift_output = args.shift_output, isotropic_noise_addition_probability = args.isotropic_noise_addition_probability, pointsource_noise_addition_probability = args.pointsource_noise_addition_probability, - max_noises_per_minute = args.max_noises_per_minute) + max_noises_per_minute = args.max_noises_per_minute, + output_reverb_dir = args.output_reverb_dir, + output_additive_noise_dir = args.output_additive_noise_dir) if __name__ == "__main__": Main() + diff --git a/egs/wsj/s5/steps/libs/common.py b/egs/wsj/s5/steps/libs/common.py index 1e0608525ba..f2a336cd640 100644 --- a/egs/wsj/s5/steps/libs/common.py +++ b/egs/wsj/s5/steps/libs/common.py @@ -315,6 +315,7 @@ def split_data(data, num_jobs): run_kaldi_command("utils/split_data.sh {data} {num_jobs}".format( data=data, num_jobs=num_jobs)) + return "{0}/split{1}".format(data, num_jobs) def read_kaldi_matrix(matrix_file): diff --git a/egs/wsj/s5/steps/libs/data.py b/egs/wsj/s5/steps/libs/data.py new file mode 100644 index 00000000000..44895cae1a4 --- /dev/null +++ b/egs/wsj/s5/steps/libs/data.py @@ -0,0 +1,57 @@ +import os + +import libs.common as common_lib + +def get_frame_shift(data_dir): + frame_shift = common_lib.run_kaldi_command("utils/data/get_frame_shift.sh {0}".format(data_dir))[0] + return float(frame_shift.strip()) + +def generate_utt2dur(data_dir): + common_lib.run_kaldi_command("utils/data/get_utt2dur.sh {0}".format(data_dir)) + +def get_utt2dur(data_dir): + generate_utt2dur(data_dir) + utt2dur = {} + for line in open('{0}/utt2dur'.format(data_dir), 'r').readlines(): + parts = line.split() + utt2dur[parts[0]] = float(parts[1]) + return utt2dur + +def get_utt2uniq(data_dir): + utt2uniq_file = '{0}/utt2uniq'.format(data_dir) + if not os.path.exists(utt2uniq_file): + return None, None + utt2uniq = {} + uniq2utt = {} + for line in open(utt2uniq_file, 'r').readlines(): + parts = line.split() + utt2uniq[parts[0]] = parts[1] + if uniq2utt.has_key(parts[1]): + uniq2utt[parts[1]].append(parts[0]) + else: + uniq2utt[parts[1]] = [parts[0]] + return utt2uniq, uniq2utt + +def get_num_frames(data_dir, utts = None): + generate_utt2dur(data_dir) + frame_shift = get_frame_shift(data_dir) + total_duration = 0 + utt2dur = get_utt2dur(data_dir) + if utts is None: + utts = utt2dur.keys() + for utt in utts: + total_duration = total_duration + utt2dur[utt] + return int(float(total_duration)/frame_shift) + +def create_data_links(file_names): + # if file_names already exist create_data_link.pl returns with code 1 + # so we just delete them before calling create_data_link.pl + for file_name in file_names: + try_to_delete(file_name) + common_lib.run_kaldi_command(" utils/create_data_link.pl {0}".format(" ".join(file_names))) + +def try_to_delete(file_name): + try: + os.remove(file_name) + except OSError: + pass diff --git a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py index 87cae801e90..55c13799b0f 100644 --- a/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py +++ b/egs/wsj/s5/steps/libs/nnet3/train/frame_level_objf/common.py @@ -30,7 +30,9 @@ def train_new_models(dir, iter, srand, num_jobs, shuffle_buffer_size, minibatch_size, cache_read_opt, run_opts, frames_per_eg=-1, - min_deriv_time=None, max_deriv_time=None): + min_deriv_time=None, max_deriv_time=None, + min_left_context=None, min_right_context=None, + extra_egs_copy_cmd=""): """ Called from train_one_iteration(), this model does one iteration of training with 'num_jobs' jobs, and writes files like exp/tdnn_a/24.{1,2,3,..}.raw @@ -63,6 +65,17 @@ def train_new_models(dir, iter, srand, num_jobs, deriv_time_opts.append("--optimization.max-deriv-time={0}".format( max_deriv_time)) + this_random = random.Random(srand + iter) + + if min_left_context is not None: + left_context = this_random.randint(min_left_context, left_context) + + if min_right_context is not None: + right_context = this_random.randint(min_right_context, right_context) + + logger.info("On iteration %d, left-context=%d and right-context=%s", + iter, left_context, right_context) + context_opts = "--left-context={0} --right-context={1}".format( left_context, right_context) @@ -92,7 +105,7 @@ def train_new_models(dir, iter, srand, num_jobs, --max-param-change={max_param_change} \ {deriv_time_opts} "{raw_model}" \ "ark,bg:nnet3-copy-egs {frame_opts} {context_opts} """ - """ark:{egs_dir}/egs.{archive_index}.ark ark:- |""" + """ark:{egs_dir}/egs.{archive_index}.ark ark:- |{extra_egs_copy_cmd}""" """nnet3-shuffle-egs --buffer-size={shuffle_buffer_size} """ """--srand={srand} ark:- ark:- | """ """nnet3-merge-egs --minibatch-size={minibatch_size} """ @@ -115,7 +128,9 @@ def train_new_models(dir, iter, srand, num_jobs, raw_model=raw_model_string, context_opts=context_opts, egs_dir=egs_dir, archive_index=archive_index, shuffle_buffer_size=shuffle_buffer_size, - minibatch_size=minibatch_size), wait=False) + minibatch_size=minibatch_size, + extra_egs_copy_cmd=extra_egs_copy_cmd), + wait=False) processes.append(process_handle) @@ -141,9 +156,11 @@ def train_one_iteration(dir, iter, srand, egs_dir, run_opts, cv_minibatch_size=256, frames_per_eg=-1, min_deriv_time=None, max_deriv_time=None, + min_left_context=None, min_right_context=None, shrinkage_value=1.0, get_raw_nnet_from_am=True, - background_process_handler=None): + background_process_handler=None, + extra_egs_copy_cmd=""): """ Called from steps/nnet3/train_*.py scripts for one iteration of neural network training @@ -192,7 +209,8 @@ def train_one_iteration(dir, iter, srand, egs_dir, run_opts=run_opts, mb_size=cv_minibatch_size, get_raw_nnet_from_am=get_raw_nnet_from_am, wait=False, - background_process_handler=background_process_handler) + background_process_handler=background_process_handler, + extra_egs_copy_cmd=extra_egs_copy_cmd) if iter > 0: # Runs in the background @@ -202,7 +220,8 @@ def train_one_iteration(dir, iter, srand, egs_dir, run_opts=run_opts, mb_size=cv_minibatch_size, wait=False, get_raw_nnet_from_am=get_raw_nnet_from_am, - background_process_handler=background_process_handler) + background_process_handler=background_process_handler, + extra_egs_copy_cmd=extra_egs_copy_cmd) # an option for writing cache (storing pairs of nnet-computations # and computation-requests) during training. @@ -276,7 +295,10 @@ def train_one_iteration(dir, iter, srand, egs_dir, cache_read_opt=cache_read_opt, run_opts=run_opts, frames_per_eg=frames_per_eg, min_deriv_time=min_deriv_time, - max_deriv_time=max_deriv_time) + max_deriv_time=max_deriv_time, + min_left_context=min_left_context, + min_right_context=min_right_context, + extra_egs_copy_cmd=extra_egs_copy_cmd) [models_to_average, best_model] = common_train_lib.get_successful_models( num_jobs, '{0}/log/train.{1}.%.log'.format(dir, iter)) @@ -375,7 +397,8 @@ def compute_preconditioning_matrix(dir, egs_dir, num_lda_jobs, run_opts, def compute_train_cv_probabilities(dir, iter, egs_dir, left_context, right_context, run_opts, mb_size=256, wait=False, background_process_handler=None, - get_raw_nnet_from_am=True): + get_raw_nnet_from_am=True, + extra_egs_copy_cmd=""): if get_raw_nnet_from_am: model = "nnet3-am-copy --raw=true {dir}/{iter}.mdl - |".format( dir=dir, iter=iter) @@ -389,7 +412,7 @@ def compute_train_cv_probabilities(dir, iter, egs_dir, left_context, """ {command} {dir}/log/compute_prob_valid.{iter}.log \ nnet3-compute-prob "{model}" \ "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/valid_diagnostic.egs ark:- | \ + ark:{egs_dir}/valid_diagnostic.egs ark:- |{extra_egs_copy_cmd} \ nnet3-merge-egs --minibatch-size={mb_size} ark:- \ ark:- |" """.format(command=run_opts.command, dir=dir, @@ -397,14 +420,15 @@ def compute_train_cv_probabilities(dir, iter, egs_dir, left_context, context_opts=context_opts, mb_size=mb_size, model=model, - egs_dir=egs_dir), + egs_dir=egs_dir, + extra_egs_copy_cmd=extra_egs_copy_cmd), wait=wait, background_process_handler=background_process_handler) common_lib.run_job( """{command} {dir}/log/compute_prob_train.{iter}.log \ nnet3-compute-prob "{model}" \ "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/train_diagnostic.egs ark:- | \ + ark:{egs_dir}/train_diagnostic.egs ark:- |{extra_egs_copy_cmd} \ nnet3-merge-egs --minibatch-size={mb_size} ark:- \ ark:- |" """.format(command=run_opts.command, dir=dir, @@ -412,14 +436,16 @@ def compute_train_cv_probabilities(dir, iter, egs_dir, left_context, context_opts=context_opts, mb_size=mb_size, model=model, - egs_dir=egs_dir), + egs_dir=egs_dir, + extra_egs_copy_cmd=extra_egs_copy_cmd), wait=wait, background_process_handler=background_process_handler) def compute_progress(dir, iter, egs_dir, left_context, right_context, run_opts, mb_size=256, background_process_handler=None, wait=False, - get_raw_nnet_from_am=True): + get_raw_nnet_from_am=True, + extra_egs_copy_cmd=""): if get_raw_nnet_from_am: prev_model = "nnet3-am-copy --raw=true {0}/{1}.mdl - |".format( dir, iter - 1) @@ -436,7 +462,7 @@ def compute_progress(dir, iter, egs_dir, left_context, right_context, nnet3-info "{model}" '&&' \ nnet3-show-progress --use-gpu=no "{prev_model}" "{model}" \ "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/train_diagnostic.egs ark:- | \ + ark:{egs_dir}/train_diagnostic.egs ark:- |{extra_egs_copy_cmd} \ nnet3-merge-egs --minibatch-size={mb_size} ark:- \ ark:- |" """.format(command=run_opts.command, dir=dir, @@ -445,14 +471,16 @@ def compute_progress(dir, iter, egs_dir, left_context, right_context, context_opts=context_opts, mb_size=mb_size, prev_model=prev_model, - egs_dir=egs_dir), + egs_dir=egs_dir, + extra_egs_copy_cmd=extra_egs_copy_cmd), wait=wait, background_process_handler=background_process_handler) def combine_models(dir, num_iters, models_to_combine, egs_dir, left_context, right_context, run_opts, background_process_handler=None, - chunk_width=None, get_raw_nnet_from_am=True): + chunk_width=None, get_raw_nnet_from_am=True, + extra_egs_copy_cmd=""): """ Function to do model combination In the nnet3 setup, the logic @@ -499,7 +527,7 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, --enforce-sum-to-one=true --enforce-positive-weights=true \ --verbose=3 {raw_models} \ "ark,bg:nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/combine.egs ark:- | \ + ark:{egs_dir}/combine.egs ark:- |{extra_egs_copy_cmd} \ nnet3-merge-egs --measure-output-frames=false \ --minibatch-size={mbsize} ark:- ark:- |" \ "{out_model}" @@ -509,7 +537,8 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, context_opts=context_opts, mbsize=mbsize, out_model=out_model, - egs_dir=egs_dir)) + egs_dir=egs_dir, + extra_egs_copy_cmd=extra_egs_copy_cmd)) # Compute the probability of the final, combined model with # the same subset we used for the previous compute_probs, as the @@ -519,14 +548,16 @@ def combine_models(dir, num_iters, models_to_combine, egs_dir, dir=dir, iter='combined', egs_dir=egs_dir, left_context=left_context, right_context=right_context, run_opts=run_opts, wait=False, - background_process_handler=background_process_handler) + background_process_handler=background_process_handler, + extra_egs_copy_cmd=extra_egs_copy_cmd) else: compute_train_cv_probabilities( dir=dir, iter='final', egs_dir=egs_dir, left_context=left_context, right_context=right_context, run_opts=run_opts, wait=False, background_process_handler=background_process_handler, - get_raw_nnet_from_am=False) + get_raw_nnet_from_am=False, + extra_egs_copy_cmd=extra_egs_copy_cmd) def get_realign_iters(realign_times, num_iters, @@ -639,7 +670,8 @@ def adjust_am_priors(dir, input_model, avg_posterior_vector, output_model, def compute_average_posterior(dir, iter, egs_dir, num_archives, prior_subset_size, left_context, right_context, - run_opts, get_raw_nnet_from_am=True): + run_opts, get_raw_nnet_from_am=True, + extra_egs_copy_cmd=""): """ Computes the average posterior of the network Note: this just uses CPUs, using a smallish subset of data. """ @@ -663,7 +695,7 @@ def compute_average_posterior(dir, iter, egs_dir, num_archives, """{command} JOB=1:{num_jobs_compute_prior} {prior_queue_opt} \ {dir}/log/get_post.{iter}.JOB.log \ nnet3-copy-egs {context_opts} \ - ark:{egs_dir}/egs.{egs_part}.ark ark:- \| \ + ark:{egs_dir}/egs.{egs_part}.ark ark:- \| {extra_egs_copy_cmd}\ nnet3-subset-egs --srand=JOB --n={prior_subset_size} \ ark:- ark:- \| \ nnet3-merge-egs --measure-output-frames=true \ @@ -679,7 +711,8 @@ def compute_average_posterior(dir, iter, egs_dir, num_archives, iter=iter, prior_subset_size=prior_subset_size, egs_dir=egs_dir, egs_part=egs_part, context_opts=context_opts, - prior_gpu_opt=run_opts.prior_gpu_opt)) + prior_gpu_opt=run_opts.prior_gpu_opt, + extra_egs_copy_cmd=extra_egs_copy_cmd)) # make sure there is time for $dir/post.{iter}.*.vec to appear. time.sleep(5) diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py index 24eea922968..c612af984b1 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/basic_layers.py @@ -349,7 +349,8 @@ def set_default_configs(self): # note: self.config['input'] is a descriptor, '[-1]' means output # the most recent layer. - self.config = { 'input':'[-1]' } + self.config = {'input': '[-1]', + 'dim': -1} def check_configs(self): diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py index 353b9d3bba4..1092be572b4 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py @@ -6,3 +6,4 @@ from basic_layers import * from lstm import * from tdnn import * +from stats_layer import * diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py index 7ccab2f6c6f..7b34481993b 100644 --- a/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py @@ -29,7 +29,8 @@ 'lstmp-layer' : xlayers.XconfigLstmpLayer, 'lstmpc-layer' : xlayers.XconfigLstmpcLayer, 'fast-lstm-layer' : xlayers.XconfigFastLstmLayer, - 'fast-lstmp-layer' : xlayers.XconfigFastLstmpLayer + 'fast-lstmp-layer' : xlayers.XconfigFastLstmpLayer, + 'stats-layer': xlayers.XconfigStatsLayer } # Converts a line as parsed by ParseConfigLine() into a first diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py new file mode 100644 index 00000000000..beaf7c8923a --- /dev/null +++ b/egs/wsj/s5/steps/libs/nnet3/xconfig/stats_layer.py @@ -0,0 +1,142 @@ +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +""" This module contains the statistics extraction and pooling layer. +""" + +from __future__ import print_function +import re +from libs.nnet3.xconfig.utils import XconfigParserError as xparser_error +from libs.nnet3.xconfig.basic_layers import XconfigLayerBase + + +class XconfigStatsLayer(XconfigLayerBase): + """This class is for parsing lines like + stats-layer name=tdnn1-stats config=mean+stddev(-99:3:9:99) input=tdnn1 + + This adds statistics-pooling and statistics-extraction components. An + example string is 'mean(-99:3:9::99)', which means, compute the mean of + data within a window of -99 to +99, with distinct means computed every 9 + frames (we round to get the appropriate one), and with the input extracted + on multiples of 3 frames (so this will force the input to this layer to be + evaluated every 3 frames). Another example string is + 'mean+stddev(-99:3:9:99)', which will also cause the standard deviation to + be computed. + + The dimension is worked out from the input. mean and stddev add a + dimension of input_dim each to the output dimension. If counts is + specified, an additional dimension is added to the output to store log + counts. + + Parameters of the class, and their defaults: + input='[-1]' [Descriptor giving the input of the layer.] + dim=-1 [Output dimension of layer. If provided, must match the + dimension computed from input] + config='' [Required. Defines what stats must be computed.] + """ + def __init__(self, first_token, key_to_value, prev_names=None): + assert first_token in ['stats-layer'] + XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names) + + def set_default_configs(self): + self.config = {'input': '[-1]', + 'dim': -1, + 'config': ''} + + def set_derived_configs(self): + config_string = self.config['config'] + if config_string == '': + raise xparser_error("config has to be non-empty", + self.str()) + m = re.search("(mean|mean\+stddev|mean\+count|mean\+stddev\+count)" + "\((-?\d+):(-?\d+):(-?\d+):(-?\d+)\)", + config_string) + if m is None: + raise xparser_error("Invalid statistic-config string: {0}".format( + config_string), self) + + self._output_stddev = (m.group(1) in ['mean+stddev', + 'mean+stddev+count']) + self._output_log_counts = (m.group(1) in ['mean+count', + 'mean+stddev+count']) + self._left_context = -int(m.group(2)) + self._input_period = int(m.group(3)) + self._stats_period = int(m.group(4)) + self._right_context = int(m.group(5)) + + output_dim = (self.descriptors['input']['dim'] + * (2 if self._output_stddev else 1) + + 1 if self._output_log_counts else 0) + + if self.config['dim'] > 0 and self.config['dim'] != output_dim: + raise xparser_error( + "Invalid dim supplied {0:d} != " + "actual output dim {1:d}".format( + self.config['dim'], output_dim)) + self.config['dim'] = output_dim + + def check_configs(self): + if not (self._left_context > 0 and self._right_context > 0 + and self._input_period > 0 and self._stats_period > 0 + and self._left_context % self._stats_period == 0 + and self._right_context % self._stats_period == 0 + and self._stats_period % self._input_period == 0): + raise xparser_error( + "Invalid configuration of statistics-extraction: {0}".format( + self.config['config']), self) + super(XconfigStatsLayer, self).check_configs() + + def _generate_config(self): + input_desc = self.descriptors['input']['final-string'] + input_dim = self.descriptors['input']['dim'] + + configs = [] + configs.append( + 'component name={name}-extraction-{lc}-{rc} ' + 'type=StatisticsExtractionComponent input-dim={dim} ' + 'input-period={input_period} output-period={output_period} ' + 'include-variance={var} '.format( + name=self.name, lc=self._left_context, rc=self._right_context, + dim=input_dim, input_period=self._input_period, + output_period=self._stats_period, + var='true' if self._output_stddev else 'false')) + configs.append( + 'component-node name={name}-extraction-{lc}-{rc} ' + 'component={name}-extraction-{lc}-{rc} input={input} '.format( + name=self.name, lc=self._left_context, rc=self._right_context, + input=input_desc)) + + stats_dim = 1 + input_dim * (2 if self._output_stddev else 1) + configs.append( + 'component name={name}-pooling-{lc}-{rc} ' + 'type=StatisticsPoolingComponent input-dim={dim} ' + 'input-period={input_period} left-context={lc} right-context={rc} ' + 'num-log-count-features={count} output-stddevs={var} '.format( + name=self.name, lc=self._left_context, rc=self._right_context, + dim=stats_dim, input_period=self._stats_period, + count=1 if self._output_log_counts else 0, + var='true' if self._output_stddev else 'false')) + configs.append( + 'component-node name={name}-pooling-{lc}-{rc} ' + 'component={name}-pooling-{lc}-{rc} ' + 'input={name}-extraction-{lc}-{rc} '.format( + name=self.name, lc=self._left_context, rc=self._right_context)) + return configs + + def output_name(self, auxiliary_output=None): + return 'Round({name}-pooling-{lc}-{rc}, {period})'.format( + name=self.name, lc=self._left_context, + rc=self._right_context, period=self._stats_period) + + def output_dim(self, auxiliary_outputs=None): + return self.config['dim'] + + def get_full_config(self): + ans = [] + config_lines = self._generate_config() + + for line in config_lines: + for config_name in ['ref', 'final']: + ans.append((config_name, line)) + + return ans diff --git a/egs/wsj/s5/steps/nnet3/components.py b/egs/wsj/s5/steps/nnet3/components.py index 3fb92117d78..c811297cda8 100644 --- a/egs/wsj/s5/steps/nnet3/components.py +++ b/egs/wsj/s5/steps/nnet3/components.py @@ -6,6 +6,7 @@ import sys import warnings import copy +import re from operator import itemgetter def GetSumDescriptor(inputs): @@ -30,17 +31,33 @@ def AddInputLayer(config_lines, feat_dim, splice_indexes=[0], ivector_dim=0): components = config_lines['components'] component_nodes = config_lines['component-nodes'] output_dim = 0 - components.append('input-node name=input dim=' + str(feat_dim)) - list = [('Offset(input, {0})'.format(n) if n != 0 else 'input') for n in splice_indexes] - output_dim += len(splice_indexes) * feat_dim + components.append('input-node name=input dim={0}'.format(feat_dim)) + prev_layer_output = {'descriptor': "input", + 'dimension': feat_dim} + inputs = [] + for n in splice_indexes: + try: + offset = int(n) + if offset == 0: + inputs.append(prev_layer_output['descriptor']) + else: + inputs.append('Offset({0}, {1})'.format( + prev_layer_output['descriptor'], offset)) + output_dim += prev_layer_output['dimension'] + except ValueError: + stats = StatisticsConfig(n, prev_layer_output) + stats_layer = stats.AddLayer(config_lines, "Tdnn_stats_{0}".format(0)) + inputs.append(stats_layer['descriptor']) + output_dim += stats_layer['dimension'] + if ivector_dim > 0: - components.append('input-node name=ivector dim=' + str(ivector_dim)) - list.append('ReplaceIndex(ivector, t, 0)') + components.append('input-node name=ivector dim={0}'.format(ivector_dim)) + inputs.append('ReplaceIndex(ivector, t, 0)') output_dim += ivector_dim - if len(list) > 1: - splice_descriptor = "Append({0})".format(", ".join(list)) + if len(inputs) > 1: + splice_descriptor = "Append({0})".format(", ".join(inputs)) else: - splice_descriptor = list[0] + splice_descriptor = inputs[0] print(splice_descriptor) return {'descriptor': splice_descriptor, 'dimension': output_dim} @@ -55,6 +72,35 @@ def AddNoOpLayer(config_lines, name, input): return {'descriptor': '{0}_noop'.format(name), 'dimension': input['dimension']} +def AddGradientScaleLayer(config_lines, name, input, scale = 1.0, scales_vec = None): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + if scales_vec is None: + components.append('component name={0}_gradient_scale type=ScaleGradientComponent dim={1} scale={2}'.format(name, input['dimension'], scale)) + else: + components.append('component name={0}_gradient_scale type=ScaleGradientComponent scales={2}'.format(name, scales_vec)) + + component_nodes.append('component-node name={0}_gradient_scale component={0}_gradient_scale input={1}'.format(name, input['descriptor'])) + + return {'descriptor': '{0}_gradient_scale'.format(name), + 'dimension': input['dimension']} + +def AddFixedScaleLayer(config_lines, name, input, + scale = 1.0, scales_vec = None): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + if scales_vec is None: + components.append('component name={0}-fixed-scale type=FixedScaleComponent dim={1} scale={2}'.format(name, input['dimension'], scale)) + else: + components.append('component name={0}-fixed-scale type=FixedScaleComponent scales={2}'.format(name, scales_vec)) + + component_nodes.append('component-node name={0}-fixed-scale component={0}-fixed-scale input={1}'.format(name, input['descriptor'])) + + return {'descriptor': '{0}-fixed-scale'.format(name), + 'dimension': input['dimension']} + def AddLdaLayer(config_lines, name, input, lda_file): return AddFixedAffineLayer(config_lines, name, input, lda_file) @@ -257,7 +303,9 @@ def AddFinalLayer(config_lines, input, output_dim, include_log_softmax = True, add_final_sigmoid = False, name_affix = None, - objective_type = "linear"): + objective_type = "linear", + objective_scale = 1.0, + objective_scales_vec = None): components = config_lines['components'] component_nodes = config_lines['component-nodes'] @@ -283,6 +331,9 @@ def AddFinalLayer(config_lines, input, output_dim, prev_layer_output = AddSigmoidLayer(config_lines, final_node_prefix, prev_layer_output) # we use the same name_affix as a prefix in for affine/scale nodes but as a # suffix for output node + if (objective_scale != 1.0 or objective_scales_vec is not None): + prev_layer_output = AddGradientScaleLayer(config_lines, final_node_prefix, prev_layer_output, objective_scale, objective_scales_vec) + AddOutputLayer(config_lines, prev_layer_output, label_delay, suffix = name_affix, objective_type = objective_type) def AddLstmLayer(config_lines, @@ -485,3 +536,82 @@ def AddBLstmLayer(config_lines, 'dimension':output_dim } +# this is a bit like a struct, initialized from a string, which describes how to +# set up the statistics-pooling and statistics-extraction components. +# An example string is 'mean(-99:3:9::99)', which means, compute the mean of +# data within a window of -99 to +99, with distinct means computed every 9 frames +# (we round to get the appropriate one), and with the input extracted on multiples +# of 3 frames (so this will force the input to this layer to be evaluated +# every 3 frames). Another example string is 'mean+stddev(-99:3:9:99)', +# which will also cause the standard deviation to be computed. +class StatisticsConfig: + # e.g. c = StatisticsConfig('mean+stddev(-99:3:9:99)', 400, 'jesus1-forward-output-affine') + def __init__(self, config_string, input): + + self.input_dim = input['dimension'] + self.input_descriptor = input['descriptor'] + + m = re.search("(mean|mean\+stddev|mean\+count|mean\+stddev\+count)\((-?\d+):(-?\d+):(-?\d+):(-?\d+)\)", + config_string) + if m == None: + raise Exception("Invalid splice-index or statistics-config string: " + config_string) + self.output_stddev = (m.group(1) in ['mean+stddev', 'mean+stddev+count']) + self.output_log_counts = (m.group(1) in ['mean+count', 'mean+stddev+count']) + self.left_context = -int(m.group(2)) + self.input_period = int(m.group(3)) + self.stats_period = int(m.group(4)) + self.right_context = int(m.group(5)) + if not (self.left_context > 0 and self.right_context > 0 and + self.input_period > 0 and self.stats_period > 0 and + self.left_context % self.stats_period == 0 and + self.right_context % self.stats_period == 0 and + self.stats_period % self.input_period == 0): + raise Exception("Invalid configuration of statistics-extraction: " + config_string) + + # OutputDim() returns the output dimension of the node that this produces. + def OutputDim(self): + return (self.input_dim * (2 if self.output_stddev else 1) + + 1 if self.output_log_counts else 0) + + # OutputDims() returns an array of output dimensions, consisting of + # [ input-dim ] if just "mean" was specified, otherwise + # [ input-dim input-dim ] + def OutputDims(self): + output_dims = [ self.input_dim ] + if self.output_stddev: + output_dims.append(self.input_dim) + if self.output_log_counts: + output_dims.append(1) + return output_dims + + # Descriptor() returns the textual form of the descriptor by which the + # output of this node is to be accessed. + def Descriptor(self, name): + return 'Round({0}-pooling-{1}-{2}, {3})'.format(name, self.left_context, self.right_context, + self.stats_period) + + def AddLayer(self, config_lines, name): + components = config_lines['components'] + component_nodes = config_lines['component-nodes'] + + components.append('component name={name}-extraction-{lc}-{rc} type=StatisticsExtractionComponent input-dim={dim} ' + 'input-period={input_period} output-period={output_period} include-variance={var} '.format( + name = name, lc = self.left_context, rc = self.right_context, + dim = self.input_dim, input_period = self.input_period, output_period = self.stats_period, + var = ('true' if self.output_stddev else 'false'))) + component_nodes.append('component-node name={name}-extraction-{lc}-{rc} component={name}-extraction-{lc}-{rc} input={input} '.format( + name = name, lc = self.left_context, rc = self.right_context, input = self.input_descriptor)) + stats_dim = 1 + self.input_dim * (2 if self.output_stddev else 1) + components.append('component name={name}-pooling-{lc}-{rc} type=StatisticsPoolingComponent input-dim={dim} ' + 'input-period={input_period} left-context={lc} right-context={rc} num-log-count-features={count} ' + 'output-stddevs={var} '.format(name = name, lc = self.left_context, rc = self.right_context, + dim = stats_dim, input_period = self.stats_period, + count = 1 if self.output_log_counts else 0, + var = ('true' if self.output_stddev else 'false'))) + component_nodes.append('component-node name={name}-pooling-{lc}-{rc} component={name}-pooling-{lc}-{rc} input={name}-extraction-{lc}-{rc} '.format( + name = name, lc = self.left_context, rc = self.right_context)) + + return { 'dimension': self.OutputDim(), + 'descriptor': self.Descriptor(name), + 'dimensions': self.OutputDims() + } diff --git a/egs/wsj/s5/steps/nnet3/compute_output.sh b/egs/wsj/s5/steps/nnet3/compute_output.sh new file mode 100755 index 00000000000..f49790bc578 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/compute_output.sh @@ -0,0 +1,179 @@ +#!/bin/bash + +# Copyright 2012-2015 Johns Hopkins University (Author: Daniel Povey). +# 2016 Vimal Manohar +# Apache 2.0. + +# This script does decoding with a neural-net. If the neural net was built on +# top of fMLLR transforms from a conventional system, you should provide the +# --transform-dir option. + +# Begin configuration section. +stage=1 +transform_dir= # dir to find fMLLR transforms. +nj=4 # number of jobs. If --transform-dir set, must match that number! +cmd=run.pl +use_gpu=false +frames_per_chunk=50 +ivector_scale=1.0 +iter=final +extra_left_context=0 +extra_right_context=0 +extra_left_context_initial=-1 +extra_right_context_final=-1 +frame_subsampling_factor=1 +feat_type= +compress=false +online_ivector_dir= +post_vec= +output_name= +get_raw_nnet_from_am=true +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: $0 [options] " + echo "e.g.: steps/nnet3/compute_output.sh --nj 8 \\" + echo "--online-ivector-dir exp/nnet3/ivectors_test_eval92 \\" + echo " data/test_eval92_hires exp/nnet3/tdnn exp/nnet3/tdnn/output" + echo "main options (for others, see top of script file)" + echo " --transform-dir # directory of previous decoding" + echo " # where we can find transforms for SAT systems." + echo " --config # config containing options" + echo " --nj # number of parallel jobs" + echo " --cmd # Command to run in parallel with" + echo " --iter # Iteration of model to decode; default is final." + exit 1; +fi + +data=$1 +srcdir=$2 +dir=$3 + +if $get_raw_nnet_from_am; then + [ ! -f $srcdir/$iter.mdl ] && echo "$0: no such file $srcdir/$iter.mdl" && exit 1 + model="nnet3-am-copy --raw=true $srcdir/$iter.mdl - |" +else + [ ! -f $srcdir/$iter.raw ] && echo "$0: no such file $srcdir/$iter.raw" && exit 1 + model="nnet3-copy $srcdir/$iter.raw - |" +fi + +mkdir -p $dir/log +echo "rename-node old-name=$output_name new-name=output" > $dir/edits.config + +if [ ! -z "$output_name" ]; then + model="$model nnet3-copy --edits-config=$dir/edits.config - - |" +else + output_name=output +fi + +[ ! -z "$online_ivector_dir" ] && \ + extra_files="$online_ivector_dir/ivector_online.scp $online_ivector_dir/ivector_period" + +for f in $data/feats.scp $extra_files; do + [ ! -f $f ] && echo "$0: no such file $f" && exit 1; +done + +sdata=$data/split$nj; +cmvn_opts=`cat $srcdir/cmvn_opts` || exit 1; + +[[ -d $sdata && $data/feats.scp -ot $sdata ]] || split_data.sh $data $nj || exit 1; +echo $nj > $dir/num_jobs + + +## Set up features. +if [ -z "$feat_type" ]; then + if [ -f $srcdir/final.mat ]; then feat_type=lda; else feat_type=raw; fi + echo "$0: feature type is $feat_type" +fi + +splice_opts=`cat $srcdir/splice_opts 2>/dev/null` + +case $feat_type in + raw) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- |";; + lda) feats="ark,s,cs:apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:$sdata/JOB/feats.scp ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $srcdir/final.mat ark:- ark:- |" + ;; + *) echo "$0: invalid feature type $feat_type" && exit 1; +esac +if [ ! -z "$transform_dir" ]; then + echo "$0: using transforms from $transform_dir" + [ ! -s $transform_dir/num_jobs ] && \ + echo "$0: expected $transform_dir/num_jobs to contain the number of jobs." && exit 1; + nj_orig=$(cat $transform_dir/num_jobs) + + if [ $feat_type == "raw" ]; then trans=raw_trans; + else trans=trans; fi + if [ $feat_type == "lda" ] && \ + ! cmp $transform_dir/../final.mat $srcdir/final.mat && \ + ! cmp $transform_dir/final.mat $srcdir/final.mat; then + echo "$0: LDA transforms differ between $srcdir and $transform_dir" + exit 1; + fi + if [ ! -f $transform_dir/$trans.1 ]; then + echo "$0: expected $transform_dir/$trans.1 to exist (--transform-dir option)" + exit 1; + fi + if [ $nj -ne $nj_orig ]; then + # Copy the transforms into an archive with an index. + for n in $(seq $nj_orig); do cat $transform_dir/$trans.$n; done | \ + copy-feats ark:- ark,scp:$dir/$trans.ark,$dir/$trans.scp || exit 1; + feats="$feats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk scp:$dir/$trans.scp ark:- ark:- |" + else + feats="$feats transform-feats --utt2spk=ark:$sdata/JOB/utt2spk ark:$transform_dir/$trans.JOB ark:- ark:- |" + fi +elif grep 'transform-feats --utt2spk' $srcdir/log/train.1.log >&/dev/null; then + echo "$0: **WARNING**: you seem to be using a neural net system trained with transforms," + echo " but you are not providing the --transform-dir option in test time." +fi +## + +if [ ! -z "$online_ivector_dir" ]; then + ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1; + ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period" +fi + +frame_subsampling_opt= +if [ $frame_subsampling_factor -ne 1 ]; then + # e.g. for 'chain' systems + frame_subsampling_opt="--frame-subsampling-factor=$frame_subsampling_factor" +fi + +output_wspecifier="ark:| copy-feats --compress=$compress ark:- ark:- | gzip -c > $dir/nnet_output.JOB.gz" + +if [ ! -z $post_vec ]; then + if [ $stage -le 1 ]; then + copy-vector --binary=false $post_vec - | \ + awk '{for (i = 2; i < NF; i++) { sum += i; }; + printf ("["); + for (i = 2; i < NF; i++) { printf " "log(i/sum); }; + print (" ]");}' > $dir/log_priors.vec + fi + + output_wspecifier="ark:| matrix-add-offset ark:- 'vector-scale --scale=-1.0 $dir/log_priors.vec - |' ark:- | copy-feats --compress=$compress ark:- ark:- | gzip -c > $dir/log_likes.JOB.gz" +fi + +gpu_opt="--use-gpu=no" +gpu_queue_opt= + +if $use_gpu; then + gpu_queue_opt="--gpu 1" + gpu_opt="--use-gpu=yes" +fi + +if [ $stage -le 2 ]; then + $cmd $gpu_queue_opt JOB=1:$nj $dir/log/compute_output.JOB.log \ + nnet3-compute $gpu_opt $ivector_opts $frame_subsampling_opt \ + --frames-per-chunk=$frames_per_chunk \ + --extra-left-context=$extra_left_context \ + --extra-right-context=$extra_right_context \ + --extra-left-context-initial=$extra_left_context_initial \ + --extra-right-context-final=$extra_right_context_final \ + "$model" "$feats" "$output_wspecifier" || exit 1; +fi + +exit 0; + diff --git a/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py b/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py new file mode 100755 index 00000000000..16e1f98a019 --- /dev/null +++ b/egs/wsj/s5/steps/nnet3/get_egs_multiple_targets.py @@ -0,0 +1,910 @@ +#!/usr/bin/env python + +# Copyright 2016 Vijayaditya Peddinti +# 2016 Vimal Manohar +# Apache 2.0. + +from __future__ import print_function +import os +import argparse +import sys +import logging +import shlex +import random +import math +import glob + +import libs.data as data_lib +import libs.common as common_lib + +logger = logging.getLogger('libs') +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(asctime)s [%(filename)s:%(lineno)s - " + "%(funcName)s - %(levelname)s ] %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) +logger.info('Getting egs for training') + + +def get_args(): + # we add compulsary arguments as named arguments for readability + parser = argparse.ArgumentParser( + description="""Generates training examples used to train the 'nnet3' + network (and also the validation examples used for diagnostics), + and puts them in separate archives.""", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--cmd", type=str, default="run.pl", + help="Specifies the script to launch jobs." + " e.g. queue.pl for launching on SGE cluster run.pl" + " for launching on local machine") + # feat options + parser.add_argument("--feat.dir", type=str, dest='feat_dir', required=True, + help="Directory with features used for training " + "the neural network.") + parser.add_argument("--feat.online-ivector-dir", type=str, + dest='online_ivector_dir', + default=None, action=common_lib.NullstrToNoneAction, + help="directory with the ivectors extracted in an " + "online fashion.") + parser.add_argument("--feat.cmvn-opts", type=str, dest='cmvn_opts', + default=None, action=common_lib.NullstrToNoneAction, + help="A string specifying '--norm-means' and " + "'--norm-vars' values") + parser.add_argument("--feat.apply-cmvn-sliding", type=str, + dest='apply_cmvn_sliding', + default=False, action=common_lib.StrToBoolAction, + help="Apply CMVN sliding, instead of per-utteance " + "or speakers") + + # egs extraction options + parser.add_argument("--frames-per-eg", type=int, default=8, + help="""Number of frames of labels per example. + more->less disk space and less time preparing egs, but + more I/O during training. + note: the script may reduce this if + reduce-frames-per-eg is true.""") + parser.add_argument("--left-context", type=int, default=4, + help="""Amount of left-context per eg (i.e. extra + frames of input features not present in the output + supervision).""") + parser.add_argument("--right-context", type=int, default=4, + help="Amount of right-context per eg") + parser.add_argument("--valid-left-context", type=int, default=None, + help="""Amount of left-context for validation egs, + typically used in recurrent architectures to ensure + matched condition with training egs""") + parser.add_argument("--valid-right-context", type=int, default=None, + help="""Amount of right-context for validation egs, + typically used in recurrent architectures to ensure + matched condition with training egs""") + parser.add_argument("--compress-input", type=str, default=True, + action=common_lib.StrToBoolAction, + choices=["true", "false"], + help="If false, disables compression. Might be " + "necessary to check if results will be affected.") + parser.add_argument("--input-compress-format", type=int, default=0, + help="Format used for compressing the input features") + + parser.add_argument("--reduce-frames-per-eg", type=str, default=True, + action=common_lib.StrToBoolAction, + choices=["true", "false"], + help="""If true, this script may reduce the + frames-per-eg if there is only one archive and even + with the reduced frames-per-eg, the number of + samples-per-iter that would result is less than or + equal to the user-specified value.""") + + parser.add_argument("--num-utts-subset", type=int, default=300, + help="Number of utterances in validation and training" + " subsets used for shrinkage and diagnostics") + parser.add_argument("--num-utts-subset-valid", type=int, + help="Number of utterances in validation" + " subset used for diagnostics") + parser.add_argument("--num-utts-subset-train", type=int, + help="Number of utterances in training" + " subset used for shrinkage and diagnostics") + parser.add_argument("--num-train-egs-combine", type=int, default=10000, + help="Training examples for combination weights at the" + " very end.") + parser.add_argument("--num-valid-egs-combine", type=int, default=0, + help="Validation examples for combination weights at " + "the very end.") + parser.add_argument("--num-egs-diagnostic", type=int, default=4000, + help="Numer of frames for 'compute-probs' jobs") + + parser.add_argument("--samples-per-iter", type=int, default=400000, + help="""This is the target number of egs in each + archive of egs (prior to merging egs). We probably + should have called it egs_per_iter. This is just a + guideline; it will pick a number that divides the + number of samples in the entire data.""") + + parser.add_argument("--stage", type=int, default=0, + help="Stage to start running script from") + parser.add_argument("--num-jobs", type=int, default=6, + help="""This should be set to the maximum number of + jobs you are comfortable to run in parallel; you can + increase it if your disk speed is greater and you have + more machines.""") + parser.add_argument("--srand", type=int, default=0, + help="Rand seed for nnet3-copy-egs and " + "nnet3-shuffle-egs") + + parser.add_argument("--targets-parameters", type=str, action='append', + required=True, dest='targets_para_array', + help="""Parameters for targets. Each set of parameters + corresponds to a separate output node of the neural + network. The targets can be sparse or dense. + The parameters used are: + --targets-rspecifier= + # rspecifier for the targets, can be alignment or + # matrix. + --num-targets= + # targets dimension. required for sparse feats. + --target-type=""") + + parser.add_argument("--dir", type=str, required=True, + help="Directory to store the examples") + + print(' '.join(sys.argv)) + print(sys.argv) + + args = parser.parse_args() + + args = process_args(args) + + return args + + +def process_args(args): + # process the options + if args.num_utts_subset_valid is None: + args.num_utts_subset_valid = args.num_utts_subset + + if args.num_utts_subset_train is None: + args.num_utts_subset_train = args.num_utts_subset + + if args.valid_left_context is None: + args.valid_left_context = args.left_context + if args.valid_right_context is None: + args.valid_right_context = args.right_context + + if (args.left_context < 0 or args.right_context < 0 + or args.valid_left_context < 0 or args.valid_right_context < 0): + raise Exception( + "--{,valid-}{left,right}-context should be non-negative") + + return args + + +def check_for_required_files(feat_dir, targets_scps, online_ivector_dir=None): + required_files = ['{0}/feats.scp'.format(feat_dir), + '{0}/cmvn.scp'.format(feat_dir)] + if online_ivector_dir is not None: + required_files.append('{0}/ivector_online.scp'.format( + online_ivector_dir)) + required_files.append('{0}/ivector_period'.format( + online_ivector_dir)) + + for file in required_files: + if not os.path.isfile(file): + raise Exception('Expected {0} to exist.'.format(file)) + + +def parse_targets_parameters_array(para_array): + targets_parser = argparse.ArgumentParser() + targets_parser.add_argument("--output-name", type=str, required=True, + help="Name of the output. e.g. output-xent") + targets_parser.add_argument("--dim", type=int, default=-1, + help="Target dimension (required for sparse " + "targets") + targets_parser.add_argument("--target-type", type=str, default="dense", + choices=["dense", "sparse"], + help="Dense for matrix format") + targets_parser.add_argument("--targets-scp", type=str, required=True, + help="Scp file of targets; can be posteriors " + "or matrices") + targets_parser.add_argument("--compress", type=str, default=True, + action=common_lib.StrToBoolAction, + help="Specifies whether the output must be " + "compressed") + targets_parser.add_argument("--compress-format", type=int, default=0, + help="Format for compressing target") + targets_parser.add_argument("--deriv-weights-scp", type=str, default="", + help="Per-frame deriv weights for this output") + targets_parser.add_argument("--scp2ark-cmd", type=str, default="", + help="""The command that is used to convert + targets scp to archive. e.g. An scp of + alignments can be converted to posteriors using + ali-to-post""") + + targets_parameters = [targets_parser.parse_args(shlex.split(x)) + for x in para_array] + + for t in targets_parameters: + if not os.path.isfile(t.targets_scp): + raise Exception("Expected {0} to exist.".format(t.targets_scp)) + + if (t.target_type == "dense"): + dim = common_lib.get_feat_dim_from_scp(t.targets_scp) + if (t.dim != -1 and t.dim != dim): + raise Exception('Mismatch in --dim provided and feat dim for ' + 'file {0}; {1} vs {2}'.format(t.targets_scp, + t.dim, dim)) + t.dim = -dim + + return targets_parameters + + +def sample_utts(feat_dir, num_utts_subset, min_duration, exclude_list=None): + utt2durs_dict = data_lib.get_utt2dur(feat_dir) + utt2durs = utt2durs_dict.items() + utt2uniq, uniq2utt = data_lib.get_utt2uniq(feat_dir) + if num_utts_subset is None: + num_utts_subset = len(utt2durs) + if exclude_list is not None: + num_utts_subset = num_utts_subset - len(exclude_list) + + random.shuffle(utt2durs) + sampled_utts = [] + + index = 0 + num_trials = 0 + while (len(sampled_utts) < num_utts_subset + and num_trials <= len(utt2durs)): + if utt2durs[index][-1] >= min_duration: + if utt2uniq is not None: + uniq_id = utt2uniq[utt2durs[index][0]] + utts2add = uniq2utt[uniq_id] + else: + utts2add = [utt2durs[index][0]] + exclude_utt = False + if exclude_list is not None: + for utt in utts2add: + if utt in exclude_list: + exclude_utt = True + break + if not exclude_utt: + for utt in utts2add: + sampled_utts.append(utt) + + index = index + 1 + num_trials = num_trials + 1 + if exclude_list is not None: + assert(len(set(exclude_list).intersection(sampled_utts)) == 0) + if len(sampled_utts) < num_utts_subset: + raise Exception( + """Number of utterances which have duration of at least {md} + seconds is really low (required={rl}, available={al}). Please + check your data.""".format( + md=min_duration, al=len(sampled_utts), rl=num_utts_subset)) + + sampled_utts_durs = [] + for utt in sampled_utts: + sampled_utts_durs.append([utt, utt2durs_dict[utt]]) + return sampled_utts, sampled_utts_durs + + +def write_list(listd, file_name): + file_handle = open(file_name, 'w') + assert(type(listd) == list) + for item in listd: + file_handle.write(str(item)+"\n") + file_handle.close() + + +def get_max_open_files(): + stdout, stderr = common_lib.run_kaldi_command("ulimit -n") + return int(stdout) + + +def get_feat_ivector_strings(dir, feat_dir, split_feat_dir, + cmvn_opt_string, ivector_dir=None, + apply_cmvn_sliding=False): + + if not apply_cmvn_sliding: + train_feats = ("ark,s,cs:utils/filter_scp.pl --exclude " + "{dir}/valid_uttlist {sdir}/JOB/feats.scp | " + "apply-cmvn {cmvn} --utt2spk=ark:{sdir}/JOB/utt2spk " + "scp:{sdir}/JOB/cmvn.scp scp:- ark:- |".format( + dir=dir, sdir=split_feat_dir, + cmvn=cmvn_opt_string)) + valid_feats = ("ark,s,cs:utils/filter_scp.pl {dir}/valid_uttlist " + "{fdir}/feats.scp | " + "apply-cmvn {cmvn} --utt2spk=ark:{fdir}/utt2spk " + "scp:{fdir}/cmvn.scp scp:- ark:- |".format( + dir=dir, fdir=feat_dir, cmvn=cmvn_opt_string)) + train_subset_feats = ("ark,s,cs:utils/filter_scp.pl " + "{dir}/train_subset_uttlist {fdir}/feats.scp | " + "apply-cmvn {cmvn} --utt2spk=ark:{fdir}/utt2spk " + "scp:{fdir}/cmvn.scp scp:- ark:- |".format( + dir=dir, fdir=feat_dir, + cmvn=cmvn_opt_string)) + + def feats_subset_func(subset_list): + return ("ark,s,cs:utils/filter_scp.pl {subset_list} " + "{fdir}/feats.scp | " + "apply-cmvn {cmvn} --utt2spk=ark:{fdir}/utt2spk " + "scp:{fdir}/cmvn.scp scp:- ark:- |".format( + dir=dir, subset_list=subset_list, + fdir=feat_dir, cmvn=cmvn_opt_string)) + + else: + train_feats = ("ark,s,cs:utils/filter_scp.pl --exclude " + "{dir}/valid_uttlist {sdir}/JOB/feats.scp | " + "apply-cmvn-sliding scp:{sdir}/JOB/cmvn.scp scp:- " + "ark:- |".format(dir=dir, sdir=split_feat_dir, + cmvn=cmvn_opt_string)) + + def feats_subset_func(subset_list): + return ("ark,s,cs:utils/filter_scp.pl {subset_list} " + "{fdir}/feats.scp | " + "apply-cmvn-sliding {cmvn} scp:{fdir}/cmvn.scp scp:- " + "ark:- |".format(dir=dir, subset_list=subset_list, + fdir=feat_dir, cmvn=cmvn_opt_string)) + + train_subset_feats = feats_subset_func( + "{0}/train_subset_uttlist".format(dir)) + valid_feats = feats_subset_func("{0}/valid_uttlist".format(dir)) + + if ivector_dir is not None: + ivector_period = common_lib.GetIvectorPeriod(ivector_dir) + ivector_opt = ("--ivectors='ark,s,cs:utils/filter_scp.pl " + "{sdir}/JOB/utt2spk {idir}/ivector_online.scp | " + "subsample-feats --n=-{period} scp:- ark:- |'".format( + sdir=split_feat_dir, idir=ivector_dir, + period=ivector_period)) + valid_ivector_opt = ("--ivectors='ark,s,cs:utils/filter_scp.pl " + "{dir}/valid_uttlist {idir}/ivector_online.scp | " + "subsample-feats --n=-{period} " + "scp:- ark:- |'".format( + dir=dir, idir=ivector_dir, + period=ivector_period)) + train_subset_ivector_opt = ( + "--ivectors='ark,s,cs:utils/filter_scp.pl " + "{dir}/train_subset_uttlist {idir}/ivector_online.scp | " + "subsample-feats --n=-{period} scp:- ark:- |'".format( + dir=dir, idir=ivector_dir, period=ivector_period)) + else: + ivector_opt = '' + valid_ivector_opt = '' + train_subset_ivector_opt = '' + + return {'train_feats': train_feats, + 'valid_feats': valid_feats, + 'train_subset_feats': train_subset_feats, + 'feats_subset_func': feats_subset_func, + 'ivector_opts': ivector_opt, + 'valid_ivector_opts': valid_ivector_opt, + 'train_subset_ivector_opts': train_subset_ivector_opt, + 'feat_dim': common_lib.get_feat_dim(feat_dir), + 'ivector_dim': common_lib.get_ivector_dim(ivector_dir)} + + +def get_egs_options(targets_parameters, frames_per_eg, + left_context, right_context, + valid_left_context, valid_right_context, + compress_input, + input_compress_format=0, length_tolerance=0): + + train_egs_opts = [] + train_egs_opts.append("--left-context={0}".format(left_context)) + train_egs_opts.append("--right-context={0}".format(right_context)) + train_egs_opts.append("--num-frames={0}".format(frames_per_eg)) + train_egs_opts.append("--compress-input={0}".format(compress_input)) + train_egs_opts.append("--input-compress-format={0}".format( + input_compress_format)) + train_egs_opts.append("--compress-targets={0}".format( + ':'.join(["true" if t.compress else "false" + for t in targets_parameters]))) + train_egs_opts.append("--targets-compress-formats={0}".format( + ':'.join([str(t.compress_format) + for t in targets_parameters]))) + train_egs_opts.append("--length-tolerance={0}".format(length_tolerance)) + train_egs_opts.append("--output-names={0}".format( + ':'.join([t.output_name + for t in targets_parameters]))) + train_egs_opts.append("--output-dims={0}".format( + ':'.join([str(t.dim) + for t in targets_parameters]))) + + valid_egs_opts = ( + "--left-context={vlc} --right-context={vrc} " + "--num-frames={n} --compress-input={comp} " + "--input-compress-format={icf} --compress-targets={ct} " + "--targets-compress-formats={tcf} --length-tolerance={tol} " + "--output-names={names} --output-dims={dims}".format( + vlc=valid_left_context, vrc=valid_right_context, n=frames_per_eg, + comp=compress_input, icf=input_compress_format, + ct=':'.join(["true" if t.compress else "false" + for t in targets_parameters]), + tcf=':'.join([str(t.compress_format) + for t in targets_parameters]), + tol=length_tolerance, + names=':'.join([t.output_name + for t in targets_parameters]), + dims=':'.join([str(t.dim) for t in targets_parameters]))) + + return {'train_egs_opts': " ".join(train_egs_opts), + 'valid_egs_opts': valid_egs_opts} + + +def get_targets_list(targets_parameters, subset_list): + targets_list = [] + for t in targets_parameters: + rspecifier = "ark,s,cs:" if t.scp2ark_cmd != "" else "scp,s,cs:" + rspecifier += get_subset_rspecifier(t.targets_scp, subset_list) + rspecifier += t.scp2ark_cmd + deriv_weights_rspecifier = "" + if t.deriv_weights_scp != "": + deriv_weights_rspecifier = "scp,s,cs:{0}".format( + get_subset_rspecifier(t.deriv_weights_scp, subset_list)) + this_targets = '''"{rspecifier}" "{dw}"'''.format( + rspecifier=rspecifier, dw=deriv_weights_rspecifier) + + targets_list.append(this_targets) + return " ".join(targets_list) + + +def get_subset_rspecifier(scp_file, subset_list): + if scp_file == "": + return "" + return "utils/filter_scp.pl {subset} {scp} |".format(subset=subset_list, + scp=scp_file) + + +def split_scp(scp_file, num_jobs): + out_scps = ["{0}.{1}".format(scp_file, n) for n in range(1, num_jobs + 1)] + common_lib.run_kaldi_command("utils/split_scp.pl {scp} {oscps}".format( + scp=scp_file, + oscps=' '.join(out_scps))) + return out_scps + + +def generate_valid_train_subset_egs(dir, targets_parameters, + feat_ivector_strings, egs_opts, + num_train_egs_combine, + num_valid_egs_combine, + num_egs_diagnostic, cmd, + num_jobs=1): + wait_pids = [] + + logger.info("Creating validation and train subset examples.") + + split_scp('{0}/valid_uttlist'.format(dir), num_jobs) + split_scp('{0}/train_subset_uttlist'.format(dir), num_jobs) + + valid_pid = common_lib.run_kaldi_command( + """{cmd} JOB=1:{nj} {dir}/log/create_valid_subset.JOB.log \ + nnet3-get-egs-multiple-targets {v_iv_opt} {v_egs_opt} "{v_feats}" \ + {targets} ark:{dir}/valid_all.JOB.egs""".format( + cmd=cmd, nj=num_jobs, dir=dir, + v_egs_opt=egs_opts['valid_egs_opts'], + v_iv_opt=feat_ivector_strings['valid_ivector_opts'], + v_feats=feat_ivector_strings['feats_subset_func']( + '{dir}/valid_uttlist.JOB'.format(dir=dir)), + targets=get_targets_list( + targets_parameters, + '{dir}/valid_uttlist.JOB'.format(dir=dir))), + wait=False) + + train_pid = common_lib.run_kaldi_command( + """{cmd} JOB=1:{nj} {dir}/log/create_train_subset.JOB.log \ + nnet3-get-egs-multiple-targets {t_iv_opt} {v_egs_opt} "{t_feats}" \ + {targets} ark:{dir}/train_subset_all.JOB.egs""".format( + cmd=cmd, nj=num_jobs, dir=dir, + v_egs_opt=egs_opts['valid_egs_opts'], + t_iv_opt=feat_ivector_strings['train_subset_ivector_opts'], + t_feats=feat_ivector_strings['feats_subset_func']( + '{dir}/train_subset_uttlist.JOB'.format(dir=dir)), + targets=get_targets_list( + targets_parameters, + '{dir}/train_subset_uttlist.JOB'.format(dir=dir))), + wait=False) + + wait_pids.append(valid_pid) + wait_pids.append(train_pid) + + for pid in wait_pids: + stdout, stderr = pid.communicate() + if pid.returncode != 0: + raise Exception(stderr) + + valid_egs_all = ' '.join(['{dir}/valid_all.{n}.egs'.format(dir=dir, n=n) + for n in range(1, num_jobs + 1)]) + train_subset_egs_all = ' '.join(['{dir}/train_subset_all.{n}.egs'.format( + dir=dir, n=n) + for n in range(1, num_jobs + 1)]) + + wait_pids = [] + logger.info("... Getting subsets of validation examples for diagnostics " + " and combination.") + pid = common_lib.run_kaldi_command( + """{cmd} {dir}/log/create_valid_subset_combine.log \ + cat {valid_egs_all} \| nnet3-subset-egs --n={nve_combine} ark:- \ + ark:{dir}/valid_combine.egs""".format( + cmd=cmd, dir=dir, valid_egs_all=valid_egs_all, + nve_combine=num_valid_egs_combine), + wait=False) + wait_pids.append(pid) + + pid = common_lib.run_kaldi_command( + """{cmd} {dir}/log/create_valid_subset_diagnostic.log \ + cat {valid_egs_all} \| nnet3-subset-egs --n={ne_diagnostic} ark:- \ + ark:{dir}/valid_diagnostic.egs""".format( + cmd=cmd, dir=dir, valid_egs_all=valid_egs_all, + ne_diagnostic=num_egs_diagnostic), + wait=False) + wait_pids.append(pid) + + pid = common_lib.run_kaldi_command( + """{cmd} {dir}/log/create_train_subset_combine.log \ + cat {train_subset_egs_all} \| \ + nnet3-subset-egs --n={nte_combine} ark:- \ + ark:{dir}/train_combine.egs""".format( + cmd=cmd, dir=dir, train_subset_egs_all=train_subset_egs_all, + nte_combine=num_train_egs_combine), + wait=False) + wait_pids.append(pid) + + pid = common_lib.run_kaldi_command( + """{cmd} {dir}/log/create_train_subset_diagnostic.log \ + cat {train_subset_egs_all} \| \ + nnet3-subset-egs --n={ne_diagnostic} ark:- \ + ark:{dir}/train_diagnostic.egs""".format( + cmd=cmd, dir=dir, train_subset_egs_all=train_subset_egs_all, + ne_diagnostic=num_egs_diagnostic), wait=False) + wait_pids.append(pid) + + for pid in wait_pids: + stdout, stderr = pid.communicate() + if pid.returncode != 0: + raise Exception(stderr) + + common_lib.run_kaldi_command( + """cat {dir}/valid_combine.egs {dir}/train_combine.egs > \ + {dir}/combine.egs""".format(dir=dir)) + + # perform checks + for file_name in ('{0}/combine.egs {0}/train_diagnostic.egs ' + '{0}/valid_diagnostic.egs'.format(dir).split()): + if os.path.getsize(file_name) == 0: + raise Exception("No examples in {0}".format(file_name)) + + # clean-up + for x in ('{0}/valid_all.*.egs {0}/train_subset_all.*.egs ' + '{0}/train_combine.egs ' + '{0}/valid_combine.egs'.format(dir).split()): + for file_name in glob.glob(x): + os.remove(file_name) + + +def generate_training_examples_internal(dir, targets_parameters, feat_dir, + train_feats_string, + train_egs_opts_string, + ivector_opts, + num_jobs, frames_per_eg, + samples_per_iter, cmd, srand=0, + reduce_frames_per_eg=True, + only_shuffle=False, + dry_run=False): + + # The examples will go round-robin to egs_list. Note: we omit the + # 'normalization.fst' argument while creating temporary egs: the phase of + # egs preparation that involves the normalization FST is quite + # CPU-intensive and it's more convenient to do it later, in the 'shuffle' + # stage. Otherwise to make it efficient we need to use a large 'nj', like + # 40, and in that case there can be too many small files to deal with, + # because the total number of files is the product of 'nj' by + # 'num_archives_intermediate', which might be quite large. + num_frames = data_lib.get_num_frames(feat_dir) + num_archives = (num_frames) / (frames_per_eg * samples_per_iter) + 1 + + reduced = False + while (reduce_frames_per_eg and frames_per_eg > 1 and + num_frames / ((frames_per_eg-1)*samples_per_iter) == 0): + frames_per_eg -= 1 + num_archives = 1 + reduced = True + + if reduced: + logger.info("Reduced frames-per-eg to {0} " + "because amount of data is small".format(frames_per_eg)) + + max_open_files = get_max_open_files() + num_archives_intermediate = num_archives + archives_multiple = 1 + while (num_archives_intermediate+4) > max_open_files: + archives_multiple = archives_multiple + 1 + num_archives_intermediate = int(math.ceil(float(num_archives) + / archives_multiple)) + num_archives = num_archives_intermediate * archives_multiple + egs_per_archive = num_frames/(frames_per_eg * num_archives) + + if egs_per_archive > samples_per_iter: + raise Exception( + """egs_per_archive({epa}) > samples_per_iter({fpi}). + This is an error in the logic for determining + egs_per_archive""".format(epa=egs_per_archive, + fpi=samples_per_iter)) + + if dry_run: + cleanup(dir, archives_multiple) + return {'num_frames': num_frames, + 'num_archives': num_archives, + 'egs_per_archive': egs_per_archive} + + logger.info("Splitting a total of {nf} frames into {na} archives, " + "each with {epa} egs.".format(nf=num_frames, na=num_archives, + epa=egs_per_archive)) + + if os.path.isdir('{0}/storage'.format(dir)): + # this is a striped directory, so create the softlinks + data_lib.create_data_links(["{dir}/egs.{x}.ark".format(dir=dir, x=x) + for x in range(1, num_archives + 1)]) + for x in range(1, num_archives_intermediate + 1): + data_lib.create_data_links( + ["{dir}/egs_orig.{y}.{x}.ark".format(dir=dir, x=x, y=y) + for y in range(1, num_jobs + 1)]) + + split_feat_dir = "{0}/split{1}".format(feat_dir, num_jobs) + egs_list = ' '.join(['ark:{dir}/egs_orig.JOB.{ark_num}.ark'.format( + dir=dir, ark_num=x) + for x in range(1, num_archives_intermediate + 1)]) + + if not only_shuffle: + common_lib.run_kaldi_command( + """{cmd} JOB=1:{nj} {dir}/log/get_egs.JOB.log \ + nnet3-get-egs-multiple-targets {iv_opts} {egs_opts} \ + "{feats}" {targets} ark:- \| \ + nnet3-copy-egs --random=true --srand=$[JOB+{srand}] \ + ark:- {egs_list}""".format( + cmd=cmd, nj=num_jobs, dir=dir, srand=srand, + iv_opts=ivector_opts, egs_opts=train_egs_opts_string, + feats=train_feats_string, + targets=get_targets_list(targets_parameters, + '{sdir}/JOB/utt2spk'.format( + sdir=split_feat_dir)), + egs_list=egs_list)) + + logger.info("Recombining and shuffling order of archives on disk") + egs_list = ' '.join(['{dir}/egs_orig.{n}.JOB.ark'.format(dir=dir, n=x) + for x in range(1, num_jobs + 1)]) + + if archives_multiple == 1: + # there are no intermediate archives so just shuffle egs across + # jobs and dump them into a single output + common_lib.run_kaldi_command( + """{cmd} --max-jobs-run {msjr} JOB=1:{nai} \ + {dir}/log/shuffle.JOB.log \ + nnet3-shuffle-egs --srand=$[JOB+{srand}] \ + "ark:cat {egs_list}|" ark:{dir}/egs.JOB.ark""".format( + cmd=cmd, msjr=num_jobs, + nai=num_archives_intermediate, srand=srand, + dir=dir, egs_list=egs_list)) + else: + # there are intermediate archives so we shuffle egs across jobs + # and split them into archives_multiple output archives + output_archives = ' '.join(["ark:{dir}/egs.JOB.{ark_num}.ark".format( + dir=dir, ark_num=x) + for x in range(1, archives_multiple + 1)]) + # archives were created as egs.x.y.ark + # linking them to egs.i.ark format which is expected by the training + # scripts + for i in range(1, num_archives_intermediate + 1): + for j in range(1, archives_multiple + 1): + archive_index = (i-1) * archives_multiple + j + common_lib.force_sym_link( + "egs.{0}.ark".format(archive_index), + "{dir}/egs.{i}.{j}.ark".format(dir=dir, i=i, j=j)) + + common_lib.run_kaldi_command( + """{cmd} --max-jobs-run {msjr} JOB=1:{nai} \ + {dir}/log/shuffle.JOB.log \ + nnet3-shuffle-egs --srand=$[JOB+{srand}] \ + "ark:cat {egs_list}|" ark:- \| \ + nnet3-copy-egs ark:- {oarks}""".format( + cmd=cmd, msjr=num_jobs, + nai=num_archives_intermediate, srand=srand, + dir=dir, egs_list=egs_list, oarks=output_archives)) + + cleanup(dir, archives_multiple) + return {'num_frames': num_frames, + 'num_archives': num_archives, + 'egs_per_archive': egs_per_archive} + + +def cleanup(dir, archives_multiple): + logger.info("Removing temporary archives in {0}.".format(dir)) + for file_name in glob.glob("{0}/egs_orig*".format(dir)): + real_path = os.path.realpath(file_name) + data_lib.try_to_delete(real_path) + data_lib.try_to_delete(file_name) + + if archives_multiple > 1: + # there will be some extra soft links we want to delete + for file_name in glob.glob('{0}/egs.*.*.ark'.format(dir)): + os.remove(file_name) + + +def create_directory(dir): + import errno + try: + os.makedirs(dir) + except OSError, e: + if e.errno == errno.EEXIST: + pass + + +def generate_training_examples(dir, targets_parameters, feat_dir, + feat_ivector_strings, egs_opts, + frame_shift, frames_per_eg, samples_per_iter, + cmd, num_jobs, srand=0, + only_shuffle=False, dry_run=False): + + # generate the training options string with the given chunk_width + train_egs_opts = egs_opts['train_egs_opts'] + # generate the feature vector string with the utt list for the + # current chunk width + train_feats = feat_ivector_strings['train_feats'] + + if os.path.isdir('{0}/storage'.format(dir)): + real_paths = [os.path.realpath(x).strip("/") + for x in glob.glob('{0}/storage/*'.format(dir))] + common_lib.run_kaldi_command( + """utils/create_split_dir.pl {target_dirs} \ + {dir}/storage""".format( + target_dirs=" ".join(real_paths), dir=dir)) + + info = generate_training_examples_internal( + dir=dir, targets_parameters=targets_parameters, + feat_dir=feat_dir, train_feats_string=train_feats, + train_egs_opts_string=train_egs_opts, + ivector_opts=feat_ivector_strings['ivector_opts'], + num_jobs=num_jobs, frames_per_eg=frames_per_eg, + samples_per_iter=samples_per_iter, cmd=cmd, + srand=srand, + only_shuffle=only_shuffle, + dry_run=dry_run) + + return info + + +def write_egs_info(info, info_dir): + for x in ['num_frames', 'num_archives', 'egs_per_archive', + 'feat_dim', 'ivector_dim', + 'left_context', 'right_context', 'frames_per_eg']: + write_list([info['{0}'.format(x)]], '{0}/{1}'.format(info_dir, x)) + + +def generate_egs(egs_dir, feat_dir, targets_para_array, + online_ivector_dir=None, + frames_per_eg=8, + left_context=4, + right_context=4, + valid_left_context=None, + valid_right_context=None, + cmd="run.pl", stage=0, + cmvn_opts=None, apply_cmvn_sliding=False, + compress_input=True, + input_compress_format=0, + num_utts_subset=300, + num_train_egs_combine=1000, + num_valid_egs_combine=0, + num_egs_diagnostic=4000, + samples_per_iter=400000, + num_jobs=6, + srand=0): + + for directory in '{0}/log {0}/info'.format(egs_dir).split(): + create_directory(directory) + + print (cmvn_opts if cmvn_opts is not None else '', + file=open('{0}/cmvn_opts'.format(egs_dir), 'w')) + print ("true" if apply_cmvn_sliding else "false", + file=open('{0}/apply_cmvn_sliding'.format(egs_dir), 'w')) + + targets_parameters = parse_targets_parameters_array(targets_para_array) + + # Check files + check_for_required_files(feat_dir, + [t.targets_scp for t in targets_parameters], + online_ivector_dir) + + frame_shift = data_lib.get_frame_shift(feat_dir) + min_duration = frames_per_eg * frame_shift + valid_utts = sample_utts(feat_dir, num_utts_subset, min_duration)[0] + train_subset_utts = sample_utts(feat_dir, num_utts_subset, min_duration, + exclude_list=valid_utts)[0] + train_utts, train_utts_durs = sample_utts(feat_dir, None, -1, + exclude_list=valid_utts) + + write_list(valid_utts, '{0}/valid_uttlist'.format(egs_dir)) + write_list(train_subset_utts, '{0}/train_subset_uttlist'.format(egs_dir)) + write_list(train_utts, '{0}/train_uttlist'.format(egs_dir)) + + # split the training data into parts for individual jobs + # we will use the same number of jobs as that used for alignment + split_feat_dir = common_lib.split_data(feat_dir, num_jobs) + feat_ivector_strings = get_feat_ivector_strings( + dir=egs_dir, feat_dir=feat_dir, split_feat_dir=split_feat_dir, + cmvn_opt_string=cmvn_opts, + ivector_dir=online_ivector_dir, + apply_cmvn_sliding=apply_cmvn_sliding) + + egs_opts = get_egs_options(targets_parameters=targets_parameters, + frames_per_eg=frames_per_eg, + left_context=left_context, + right_context=right_context, + valid_left_context=valid_left_context, + valid_right_context=valid_right_context, + compress_input=compress_input, + input_compress_format=input_compress_format) + + if stage <= 2: + logger.info("Generating validation and training subset examples") + + generate_valid_train_subset_egs( + dir=egs_dir, + targets_parameters=targets_parameters, + feat_ivector_strings=feat_ivector_strings, + egs_opts=egs_opts, + num_train_egs_combine=num_train_egs_combine, + num_valid_egs_combine=num_valid_egs_combine, + num_egs_diagnostic=num_egs_diagnostic, + cmd=cmd, + num_jobs=num_jobs) + + logger.info("Generating training examples on disk.") + info = generate_training_examples( + dir=egs_dir, + targets_parameters=targets_parameters, + feat_dir=feat_dir, + feat_ivector_strings=feat_ivector_strings, + egs_opts=egs_opts, + frame_shift=frame_shift, + frames_per_eg=frames_per_eg, + samples_per_iter=samples_per_iter, + cmd=cmd, + num_jobs=num_jobs, + srand=srand, + only_shuffle=True if stage > 3 else False, + dry_run=True if stage > 4 else False) + + info['feat_dim'] = feat_ivector_strings['feat_dim'] + info['ivector_dim'] = feat_ivector_strings['ivector_dim'] + info['left_context'] = left_context + info['right_context'] = right_context + info['frames_per_eg'] = frames_per_eg + + write_egs_info(info, '{dir}/info'.format(dir=egs_dir)) + + +def main(): + args = get_args() + generate_egs(args.dir, args.feat_dir, args.targets_para_array, + online_ivector_dir=args.online_ivector_dir, + frames_per_eg=args.frames_per_eg, + left_context=args.left_context, + right_context=args.right_context, + valid_left_context=args.valid_left_context, + valid_right_context=args.valid_right_context, + cmd=args.cmd, stage=args.stage, + cmvn_opts=args.cmvn_opts, + apply_cmvn_sliding=args.apply_cmvn_sliding, + compress_input=args.compress_input, + input_compress_format=args.input_compress_format, + num_utts_subset=args.num_utts_subset, + num_train_egs_combine=args.num_train_egs_combine, + num_valid_egs_combine=args.num_valid_egs_combine, + num_egs_diagnostic=args.num_egs_diagnostic, + samples_per_iter=args.samples_per_iter, + num_jobs=args.num_jobs, + srand=args.srand) + + +if __name__ == "__main__": + main() diff --git a/egs/wsj/s5/steps/nnet3/get_egs_targets.sh b/egs/wsj/s5/steps/nnet3/get_egs_targets.sh index 7fbc24858b5..cfecf88df38 100755 --- a/egs/wsj/s5/steps/nnet3/get_egs_targets.sh +++ b/egs/wsj/s5/steps/nnet3/get_egs_targets.sh @@ -24,6 +24,8 @@ feat_type=raw # set it to 'lda' to use LDA features. target_type=sparse # dense to have dense targets, # sparse to have posteriors targets num_targets= # required for target-type=sparse with raw nnet +deriv_weights_scp= +l2_regularizer_targets= frames_per_eg=8 # number of frames of labels per example. more->less disk space and # less time preparing egs, but more I/O during training. # note: the script may reduce this if reduce_frames_per_eg is true. @@ -44,6 +46,12 @@ reduce_frames_per_eg=true # If true, this script may reduce the frames_per_eg # equal to the user-specified value. num_utts_subset=300 # number of utterances in validation and training # subsets used for shrinkage and diagnostics. +num_utts_subset_valid= # number of utterances in validation + # subsets used for shrinkage and diagnostics + # if provided, overrides num-utts-subset +num_utts_subset_train= # number of utterances in training + # subsets used for shrinkage and diagnostics. + # if provided, overrides num-utts-subset num_valid_frames_combine=0 # #valid frames for combination weights at the very end. num_train_frames_combine=10000 # # train frames for the above. num_frames_diagnostic=4000 # number of frames for "compute_prob" jobs @@ -59,6 +67,7 @@ stage=0 nj=6 # This should be set to the maximum number of jobs you are # comfortable to run in parallel; you can increase it if your disk # speed is greater and you have more machines. +srand=0 # rand seed for nnet3-copy-egs and nnet3-shuffle-egs online_ivector_dir= # can be used if we are including speaker information as iVectors. cmvn_opts= # can be used for specifying CMVN options, if feature type is not lda (if lda, # it doesn't make sense to use different options than were used as input to the @@ -111,9 +120,18 @@ utils/split_data.sh $data $nj mkdir -p $dir/log $dir/info +[ -z "$num_utts_subset_valid" ] && num_utts_subset_valid=$num_utts_subset +[ -z "$num_utts_subset_train" ] && num_utts_subset_train=$num_utts_subset + +num_utts=$(cat $data/utt2spk | wc -l) +if ! [ $num_utts -gt $[$num_utts_subset_valid*4] ]; then + echo "$0: number of utterances $num_utts in your training data is too small versus --num-utts-subset=$num_utts_subset" + echo "... you probably have so little data that it doesn't make sense to train a neural net." + exit 1 +fi # Get list of validation utterances. -awk '{print $1}' $data/utt2spk | utils/shuffle_list.pl | head -$num_utts_subset | sort \ +awk '{print $1}' $data/utt2spk | utils/shuffle_list.pl | head -$num_utts_subset_valid | sort \ > $dir/valid_uttlist || exit 1; if [ -f $data/utt2uniq ]; then # this matters if you use data augmentation. @@ -128,7 +146,7 @@ if [ -f $data/utt2uniq ]; then # this matters if you use data augmentation. fi awk '{print $1}' $data/utt2spk | utils/filter_scp.pl --exclude $dir/valid_uttlist | \ - utils/shuffle_list.pl | head -$num_utts_subset | sort > $dir/train_subset_uttlist || exit 1; + utils/shuffle_list.pl | head -$num_utts_subset_train > $dir/train_subset_uttlist || exit 1; if [ ! -z "$transform_dir" ] && [ -f $transform_dir/trans.1 ] && [ $feat_type != "raw" ]; then echo "$0: using transforms from $transform_dir" @@ -145,15 +163,33 @@ if [ -f $transform_dir/raw_trans.1 ] && [ $feat_type == "raw" ]; then fi fi +nj_subset=$nj +if [ $nj_subset -gt `cat $dir/train_subset_uttlist | wc -l` ]; then + nj_subset=`cat $dir/train_subset_uttlist | wc -l` +fi + +if [ $nj_subset -gt `cat $dir/valid_uttlist | wc -l` ]; then + nj_subset=`cat $dir/valid_uttlist | wc -l` +fi + +valid_uttlist_all= +train_subset_uttlist_all= +for n in `seq $nj_subset`; do + valid_uttlist_all="$valid_uttlist_all $dir/valid_uttlist.$n" + train_subset_uttlist_all="$train_subset_uttlist_all $dir/train_subset_uttlist.$n" +done + +utils/split_scp.pl $dir/valid_uttlist $valid_uttlist_all +utils/split_scp.pl $dir/train_subset_uttlist $train_subset_uttlist_all ## Set up features. echo "$0: feature type is $feat_type" case $feat_type in raw) feats="ark,s,cs:utils/filter_scp.pl --exclude $dir/valid_uttlist $sdata/JOB/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:- ark:- |" - valid_feats="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |" - train_subset_feats="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |" + valid_feats="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist.JOB $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |" + train_subset_feats="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist.JOB $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- |" echo $cmvn_opts >$dir/cmvn_opts # caution: the top-level nnet training script should copy this to its own dir now. ;; lda) @@ -164,8 +200,8 @@ case $feat_type in echo "You cannot supply --cmvn-opts option if feature type is LDA." && exit 1; cmvn_opts=$(cat $dir/cmvn_opts) feats="ark,s,cs:utils/filter_scp.pl --exclude $dir/valid_uttlist $sdata/JOB/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$sdata/JOB/utt2spk scp:$sdata/JOB/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" - valid_feats="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" - train_subset_feats="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" + valid_feats="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist.JOB $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" + train_subset_feats="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist.JOB $data/feats.scp | apply-cmvn $cmvn_opts --utt2spk=ark:$data/utt2spk scp:$data/cmvn.scp scp:- ark:- | splice-feats $splice_opts ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |" ;; *) echo "$0: invalid feature type --feat-type '$feat_type'" && exit 1; esac @@ -182,8 +218,8 @@ if [ ! -z "$online_ivector_dir" ]; then ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1; ivector_opt="--ivectors='ark,s,cs:utils/filter_scp.pl $sdata/JOB/utt2spk $online_ivector_dir/ivector_online.scp | subsample-feats --n=-$ivector_period scp:- ark:- |'" - valid_ivector_opt="--ivectors='ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist $online_ivector_dir/ivector_online.scp | subsample-feats --n=-$ivector_period scp:- ark:- |'" - train_subset_ivector_opt="--ivectors='ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist $online_ivector_dir/ivector_online.scp | subsample-feats --n=-$ivector_period scp:- ark:- |'" + valid_ivector_opt="--ivectors='ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist.JOB $online_ivector_dir/ivector_online.scp | subsample-feats --n=-$ivector_period scp:- ark:- |'" + train_subset_ivector_opt="--ivectors='ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist.JOB $online_ivector_dir/ivector_online.scp | subsample-feats --n=-$ivector_period scp:- ark:- |'" else echo 0 >$dir/info/ivector_dim fi @@ -255,9 +291,13 @@ fi egs_opts="--left-context=$left_context --right-context=$right_context --compress=$compress" +[ ! -z "$deriv_weights_scp" ] && egs_opts="$egs_opts --deriv-weights-rspecifier=scp:$deriv_weights_scp" +[ ! -z "$l2_regularizer_targets" ] && egs_opts="$egs_opts --l2reg-targets-rspecifier=scp:$l2_regularizer_targets" + [ -z $valid_left_context ] && valid_left_context=$left_context; [ -z $valid_right_context ] && valid_right_context=$right_context; valid_egs_opts="--left-context=$valid_left_context --right-context=$valid_right_context --compress=$compress" +[ ! -z "$deriv_weights_scp" ] && valid_egs_opts="$valid_egs_opts --deriv-weights-rspecifier=scp:$deriv_weights_scp" echo $left_context > $dir/info/left_context echo $right_context > $dir/info/right_context @@ -281,15 +321,15 @@ case $target_type in "dense") get_egs_program="nnet3-get-egs-dense-targets --num-targets=$num_targets" - targets="ark:utils/filter_scp.pl --exclude $dir/valid_uttlist $targets_scp_split | copy-feats scp:- ark:- |" - valid_targets="ark:utils/filter_scp.pl $dir/valid_uttlist $targets_scp | copy-feats scp:- ark:- |" - train_subset_targets="ark:utils/filter_scp.pl $dir/train_subset_uttlist $targets_scp | copy-feats scp:- ark:- |" + targets="ark,s,cs:utils/filter_scp.pl --exclude $dir/valid_uttlist $targets_scp_split | copy-feats scp:- ark:- |" + valid_targets="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist.JOB $targets_scp | copy-feats scp:- ark:- |" + train_subset_targets="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist.JOB $targets_scp | copy-feats scp:- ark:- |" ;; "sparse") get_egs_program="nnet3-get-egs --num-pdfs=$num_targets" - targets="ark:utils/filter_scp.pl --exclude $dir/valid_uttlist $targets_scp_split | ali-to-post scp:- ark:- |" - valid_targets="ark:utils/filter_scp.pl $dir/valid_uttlist $targets_scp | ali-to-post scp:- ark:- |" - train_subset_targets="ark:utils/filter_scp.pl $dir/train_subset_uttlist $targets_scp | ali-to-post scp:- ark:- |" + targets="ark,s,cs:utils/filter_scp.pl --exclude $dir/valid_uttlist $targets_scp_split | ali-to-post scp:- ark:- |" + valid_targets="ark,s,cs:utils/filter_scp.pl $dir/valid_uttlist.JOB $targets_scp | ali-to-post scp:- ark:- |" + train_subset_targets="ark,s,cs:utils/filter_scp.pl $dir/train_subset_uttlist.JOB $targets_scp | ali-to-post scp:- ark:- |" ;; default) echo "$0: Unknown --target-type $target_type. Choices are dense and sparse" @@ -299,31 +339,43 @@ esac if [ $stage -le 3 ]; then echo "$0: Getting validation and training subset examples." rm -f $dir/.error 2>/dev/null - $cmd $dir/log/create_valid_subset.log \ + $cmd JOB=1:$nj_subset $dir/log/create_valid_subset.JOB.log \ $get_egs_program \ $valid_ivector_opt $valid_egs_opts "$valid_feats" \ "$valid_targets" \ - "ark:$dir/valid_all.egs" || touch $dir/.error & - $cmd $dir/log/create_train_subset.log \ + "ark:$dir/valid_all.JOB.egs" || touch $dir/.error & + $cmd JOB=1:$nj_subset $dir/log/create_train_subset.JOB.log \ $get_egs_program \ $train_subset_ivector_opt $valid_egs_opts "$train_subset_feats" \ "$train_subset_targets" \ - "ark:$dir/train_subset_all.egs" || touch $dir/.error & + "ark:$dir/train_subset_all.JOB.egs" || touch $dir/.error & wait; + + valid_egs_all= + train_subset_egs_all= + for n in `seq $nj_subset`; do + valid_egs_all="$valid_egs_all $dir/valid_all.$n.egs" + train_subset_egs_all="$train_subset_egs_all $dir/train_subset_all.$n.egs" + done + [ -f $dir/.error ] && echo "Error detected while creating train/valid egs" && exit 1 echo "... Getting subsets of validation examples for diagnostics and combination." $cmd $dir/log/create_valid_subset_combine.log \ - nnet3-subset-egs --n=$num_valid_frames_combine ark:$dir/valid_all.egs \ + cat $valid_egs_all \| \ + nnet3-subset-egs --n=$num_valid_frames_combine ark:- \ ark:$dir/valid_combine.egs || touch $dir/.error & $cmd $dir/log/create_valid_subset_diagnostic.log \ - nnet3-subset-egs --n=$num_frames_diagnostic ark:$dir/valid_all.egs \ + cat $valid_egs_all \| \ + nnet3-subset-egs --n=$num_frames_diagnostic ark:- \ ark:$dir/valid_diagnostic.egs || touch $dir/.error & $cmd $dir/log/create_train_subset_combine.log \ - nnet3-subset-egs --n=$num_train_frames_combine ark:$dir/train_subset_all.egs \ + cat $train_subset_egs_all \| \ + nnet3-subset-egs --n=$num_train_frames_combine ark:- \ ark:$dir/train_combine.egs || touch $dir/.error & $cmd $dir/log/create_train_subset_diagnostic.log \ - nnet3-subset-egs --n=$num_frames_diagnostic ark:$dir/train_subset_all.egs \ + cat $train_subset_egs_all \| \ + nnet3-subset-egs --n=$num_frames_diagnostic ark:- \ ark:$dir/train_diagnostic.egs || touch $dir/.error & wait sleep 5 # wait for file system to sync. @@ -332,7 +384,7 @@ if [ $stage -le 3 ]; then for f in $dir/{combine,train_diagnostic,valid_diagnostic}.egs; do [ ! -s $f ] && echo "No examples in file $f" && exit 1; done - rm -f $dir/valid_all.egs $dir/train_subset_all.egs $dir/{train,valid}_combine.egs + rm $dir/valid_all.*.egs $dir/train_subset_all.*.egs $dir/{train,valid}_combine.egs fi if [ $stage -le 4 ]; then @@ -349,7 +401,7 @@ if [ $stage -le 4 ]; then $get_egs_program \ $ivector_opt $egs_opts --num-frames=$frames_per_eg "$feats" "$targets" \ ark:- \| \ - nnet3-copy-egs --random=true --srand=JOB ark:- $egs_list || exit 1; + nnet3-copy-egs --random=true --srand=\$[JOB+$srand] ark:- $egs_list || exit 1; fi if [ $stage -le 5 ]; then @@ -365,7 +417,7 @@ if [ $stage -le 5 ]; then if [ $archives_multiple == 1 ]; then # normal case. $cmd --max-jobs-run $nj JOB=1:$num_archives_intermediate $dir/log/shuffle.JOB.log \ - nnet3-shuffle-egs --srand=JOB "ark:cat $egs_list|" ark:$dir/egs.JOB.ark || exit 1; + nnet3-shuffle-egs --srand=\$[JOB+$srand] "ark:cat $egs_list|" ark:$dir/egs.JOB.ark || exit 1; else # we need to shuffle the 'intermediate archives' and then split into the # final archives. we create soft links to manage this splitting, because @@ -381,12 +433,14 @@ if [ $stage -le 5 ]; then done done $cmd --max-jobs-run $nj JOB=1:$num_archives_intermediate $dir/log/shuffle.JOB.log \ - nnet3-shuffle-egs --srand=JOB "ark:cat $egs_list|" ark:- \| \ + nnet3-shuffle-egs --srand=\$[JOB+$srand] "ark:cat $egs_list|" ark:- \| \ nnet3-copy-egs ark:- $output_archives || exit 1; fi fi +wait + if [ $stage -le 6 ]; then echo "$0: removing temporary archives" for x in $(seq $nj); do @@ -400,10 +454,11 @@ if [ $stage -le 6 ]; then # there are some extra soft links that we should delete. for f in $dir/egs.*.*.ark; do rm $f; done fi - echo "$0: removing temporary" + echo "$0: removing temporary stuff" # Ignore errors below because trans.* might not exist. rm -f $dir/trans.{ark,scp} $dir/targets.*.scp 2>/dev/null fi -echo "$0: Finished preparing training examples" +wait +echo "$0: Finished preparing training examples" diff --git a/egs/wsj/s5/steps/nnet3/lstm/make_configs.py b/egs/wsj/s5/steps/nnet3/lstm/make_configs.py index 205b6034fad..9fb9fad1d0c 100755 --- a/egs/wsj/s5/steps/nnet3/lstm/make_configs.py +++ b/egs/wsj/s5/steps/nnet3/lstm/make_configs.py @@ -56,6 +56,18 @@ def GetArgs(): parser.add_argument("--max-change-per-component-final", type=float, help="Enforces per-component max change for the final affine layer. " "if 0 it would not be enforced.", default=1.5) + parser.add_argument("--add-lda", type=str, action=nnet3_train_lib.StrToBoolAction, + help="If \"true\" an LDA matrix computed from the input features " + "(spliced according to the first set of splice-indexes) will be used as " + "the first Affine layer. This affine layer's parameters are fixed during training. " + "This variable needs to be set to \"false\" when using dense-targets.", + default=True, choices = ["false", "true"]) + parser.add_argument("--add-final-sigmoid", type=str, action=nnet3_train_lib.StrToBoolAction, + help="add a sigmoid layer as the final layer. Applicable only if skip-final-softmax is true.", + choices=['true', 'false'], default = False) + parser.add_argument("--objective-type", type=str, default="linear", + choices = ["linear", "quadratic", "xent"], + help = "the type of objective; i.e. quadratic or linear or cross-entropy per dim") # LSTM options parser.add_argument("--num-lstm-layers", type=int, @@ -217,7 +229,9 @@ def ParseLstmDelayString(lstm_delay): raise ValueError("invalid --lstm-delay argument, too-short element: " + lstm_delay) elif len(indexes) == 2 and indexes[0] * indexes[1] >= 0: - raise ValueError('Warning: ' + str(indexes) + ' is not a standard BLSTM mode. There should be a negative delay for the forward, and a postive delay for the backward.') + raise ValueError('Warning: ' + str(indexes) + + ' is not a standard BLSTM mode. ' + + 'There should be a negative delay for the forward, and a postive delay for the backward.') if len(indexes) == 2 and indexes[0] > 0: # always a negative delay followed by a postive delay indexes[0], indexes[1] = indexes[1], indexes[0] lstm_delay_array.append(indexes) @@ -227,29 +241,35 @@ def ParseLstmDelayString(lstm_delay): return lstm_delay_array -def MakeConfigs(config_dir, feat_dim, ivector_dim, num_targets, +def MakeConfigs(config_dir, feat_dim, ivector_dim, num_targets, add_lda, splice_indexes, lstm_delay, cell_dim, hidden_dim, recurrent_projection_dim, non_recurrent_projection_dim, num_lstm_layers, num_hidden_layers, norm_based_clipping, clipping_threshold, zeroing_threshold, zeroing_interval, ng_per_element_scale_options, ng_affine_options, - label_delay, include_log_softmax, xent_regularize, + label_delay, include_log_softmax, add_final_sigmoid, + objective_type, xent_regularize, self_repair_scale_nonlinearity, self_repair_scale_clipgradient, max_change_per_component, max_change_per_component_final): config_lines = {'components':[], 'component-nodes':[]} config_files={} - prev_layer_output = nodes.AddInputLayer(config_lines, feat_dim, splice_indexes[0], ivector_dim) + prev_layer_output = nodes.AddInputLayer(config_lines, feat_dim, splice_indexes[0], + ivector_dim) # Add the init config lines for estimating the preconditioning matrices init_config_lines = copy.deepcopy(config_lines) init_config_lines['components'].insert(0, '# Config file for initializing neural network prior to') init_config_lines['components'].insert(0, '# preconditioning matrix computation') - nodes.AddOutputLayer(init_config_lines, prev_layer_output) + nodes.AddOutputLayer(init_config_lines, prev_layer_output, label_delay = label_delay, objective_type = objective_type) config_files[config_dir + '/init.config'] = init_config_lines - prev_layer_output = nodes.AddLdaLayer(config_lines, "L0", prev_layer_output, config_dir + '/lda.mat') + # add_lda needs to be set "false" when using dense targets, + # or if the task is not a simple classification task + # (e.g. regression, multi-task) + if add_lda: + prev_layer_output = nodes.AddLdaLayer(config_lines, "L0", prev_layer_output, args.config_dir + '/lda.mat') for i in range(num_lstm_layers): if len(lstm_delay[i]) == 2: # add a bi-directional LSTM layer @@ -284,7 +304,7 @@ def MakeConfigs(config_dir, feat_dim, ivector_dim, num_targets, max_change_per_component = max_change_per_component) # make the intermediate config file for layerwise discriminative # training - nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, ng_affine_options, max_change_per_component = max_change_per_component_final, label_delay = label_delay, include_log_softmax = include_log_softmax) + nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, ng_affine_options, max_change_per_component = max_change_per_component_final, label_delay = label_delay, include_log_softmax = include_log_softmax, add_final_sigmoid = add_final_sigmoid, objective_type = objective_type) if xent_regularize != 0.0: @@ -302,7 +322,7 @@ def MakeConfigs(config_dir, feat_dim, ivector_dim, num_targets, ng_affine_options, self_repair_scale = self_repair_scale_nonlinearity, max_change_per_component = max_change_per_component) # make the intermediate config file for layerwise discriminative # training - nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, ng_affine_options, max_change_per_component = max_change_per_component_final, label_delay = label_delay, include_log_softmax = include_log_softmax) + nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, ng_affine_options, max_change_per_component = max_change_per_component_final, label_delay = label_delay, include_log_softmax = include_log_softmax, add_final_sigmoid = add_final_sigmoid, objective_type = objective_type) if xent_regularize != 0.0: nodes.AddFinalLayer(config_lines, prev_layer_output, num_targets, @@ -331,24 +351,30 @@ def ProcessSpliceIndexes(config_dir, splice_indexes, label_delay, num_lstm_layer if (num_hidden_layers < num_lstm_layers): raise Exception("num-lstm-layers : number of lstm layers has to be greater than number of layers, decided based on splice-indexes") - # write the files used by other scripts like steps/nnet3/get_egs.sh - f = open(config_dir + "/vars", "w") - print('model_left_context=' + str(left_context), file=f) - print('model_right_context=' + str(right_context), file=f) - print('num_hidden_layers=' + str(num_hidden_layers), file=f) - # print('initial_right_context=' + str(splice_array[0][-1]), file=f) - f.close() - return [left_context, right_context, num_hidden_layers, splice_indexes] def Main(): args = GetArgs() - [left_context, right_context, num_hidden_layers, splice_indexes] = ProcessSpliceIndexes(args.config_dir, args.splice_indexes, args.label_delay, args.num_lstm_layers) + [left_context, right_context, + num_hidden_layers, splice_indexes] = ProcessSpliceIndexes(args.config_dir, args.splice_indexes, + args.label_delay, args.num_lstm_layers) + + # write the files used by other scripts like steps/nnet3/get_egs.sh + f = open(args.config_dir + "/vars", "w") + print('model_left_context=' + str(left_context), file=f) + print('model_right_context=' + str(right_context), file=f) + print('num_hidden_layers=' + str(num_hidden_layers), file=f) + print('num_targets=' + str(args.num_targets), file=f) + print('objective_type=' + str(args.objective_type), file=f) + print('add_lda=' + ("true" if args.add_lda else "false"), file=f) + print('include_log_softmax=' + ("true" if args.include_log_softmax else "false"), file=f) + f.close() MakeConfigs(config_dir = args.config_dir, feat_dim = args.feat_dim, ivector_dim = args.ivector_dim, num_targets = args.num_targets, + add_lda = args.add_lda, splice_indexes = splice_indexes, lstm_delay = args.lstm_delay, cell_dim = args.cell_dim, hidden_dim = args.hidden_dim, @@ -364,6 +390,8 @@ def Main(): ng_affine_options = args.ng_affine_options, label_delay = args.label_delay, include_log_softmax = args.include_log_softmax, + add_final_sigmoid = args.add_final_sigmoid, + objective_type = args.objective_type, xent_regularize = args.xent_regularize, self_repair_scale_nonlinearity = args.self_repair_scale_nonlinearity, self_repair_scale_clipgradient = args.self_repair_scale_clipgradient, diff --git a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py index b67ba8792a8..d7651889d83 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_dnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_dnn.py @@ -53,6 +53,9 @@ def get_args(): parser.add_argument("--egs.frames-per-eg", type=int, dest='frames_per_eg', default=8, help="Number of output labels per example") + parser.add_argument("--egs.extra-copy-cmd", type=str, + dest='extra_egs_copy_cmd', default = "", + help="""Modify egs before passing it to training"""); # trainer options parser.add_argument("--trainer.prior-subset-size", type=int, @@ -322,7 +325,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): shuffle_buffer_size=args.shuffle_buffer_size, run_opts=run_opts, get_raw_nnet_from_am=False, - background_process_handler=background_process_handler) + background_process_handler=background_process_handler, + extra_egs_copy_cmd=args.extra_egs_copy_cmd) if args.cleanup: # do a clean up everythin but the last 2 models, under certain @@ -353,7 +357,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): left_context=left_context, right_context=right_context, run_opts=run_opts, background_process_handler=background_process_handler, - get_raw_nnet_from_am=False) + get_raw_nnet_from_am=False, + extra_egs_copy_cmd=args.extra_egs_copy_cmd) if include_log_softmax and args.stage <= num_iters + 1: logger.info("Getting average posterior for purposes of " @@ -363,7 +368,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): num_archives=num_archives, left_context=left_context, right_context=right_context, prior_subset_size=args.prior_subset_size, run_opts=run_opts, - get_raw_nnet_from_am=False) + get_raw_nnet_from_am=False, + extra_egs_copy_cmd=args.extra_egs_copy_cmd) if args.cleanup: logger.info("Cleaning up the experiment directory " diff --git a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py index 29df61ab546..e4af318fb57 100755 --- a/egs/wsj/s5/steps/nnet3/train_raw_rnn.py +++ b/egs/wsj/s5/steps/nnet3/train_raw_rnn.py @@ -69,6 +69,21 @@ def get_args(): help="""Number of left steps used in the estimation of LSTM state before prediction of the first label. Overrides the default value in CommonParser""") + parser.add_argument("--egs.extra-copy-cmd", type=str, + dest='extra_egs_copy_cmd', default = "", + help="""Modify egs before passing it to training"""); + parser.add_argument("--trainer.min-chunk-left-context", type=int, + dest='min_chunk_left_context', default=None, + help="""If provided and is less than + --egs.chunk-left-context, then the chunk left context + is randomized between egs.chunk-left-context and + this value.""") + parser.add_argument("--trainer.min-chunk-right-context", type=int, + dest='min_chunk_right_context', default=None, + help="""If provided and is less than + --egs.chunk-right-context, then the chunk right context + is randomized between egs.chunk-right-context and + this value.""") # trainer options parser.add_argument("--trainer.samples-per-iter", type=int, @@ -181,6 +196,12 @@ def process_args(args): "--trainer.deriv-truncate-margin.".format( args.deriv_truncate_margin)) + if args.min_chunk_left_context is None: + args.min_chunk_left_context = args.chunk_left_context + + if args.min_chunk_right_context is None: + args.min_chunk_right_context = args.chunk_right_context + if (not os.path.exists(args.dir) or not os.path.exists(args.dir+"/configs")): raise Exception("This scripts expects {0} to exist and have a configs " @@ -251,12 +272,18 @@ def train(args, run_opts, background_process_handler): # discriminative pretraining num_hidden_layers = variables['num_hidden_layers'] add_lda = common_lib.str_to_bool(variables['add_lda']) - include_log_softmax = common_lib.str_to_bool( - variables['include_log_softmax']) except KeyError as e: raise Exception("KeyError {0}: Variables need to be defined in " "{1}".format(str(e), '{0}/configs'.format(args.dir))) + try: + include_log_softmax = common_lib.str_to_bool( + variables['include_log_softmax']) + except KeyError as e: + logger.warning("KeyError {0}: Using default include-log-softmax value " + "as False.".format(str(e))) + include_log_softmax = False + left_context = args.chunk_left_context + model_left_context right_context = args.chunk_right_context + model_right_context @@ -416,6 +443,10 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): add_layers_period=args.add_layers_period, left_context=left_context, right_context=right_context, + min_left_context=args.min_chunk_left_context + + model_left_context, + min_right_context=args.min_chunk_right_context + + model_right_context, min_deriv_time=min_deriv_time, max_deriv_time=max_deriv_time, momentum=args.momentum, @@ -424,7 +455,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): cv_minibatch_size=args.cv_minibatch_size, run_opts=run_opts, get_raw_nnet_from_am=False, - background_process_handler=background_process_handler) + background_process_handler=background_process_handler, + extra_egs_copy_cmd=args.extra_egs_copy_cmd) if args.cleanup: # do a clean up everythin but the last 2 models, under certain @@ -455,7 +487,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): left_context=left_context, right_context=right_context, run_opts=run_opts, chunk_width=args.chunk_width, background_process_handler=background_process_handler, - get_raw_nnet_from_am=False) + get_raw_nnet_from_am=False, + extra_egs_copy_cmd=args.extra_egs_copy_cmd) if include_log_softmax and args.stage <= num_iters + 1: logger.info("Getting average posterior for purposes of " @@ -465,7 +498,8 @@ def learning_rate(iter, current_num_jobs, num_archives_processed): num_archives=num_archives, left_context=left_context, right_context=right_context, prior_subset_size=args.prior_subset_size, run_opts=run_opts, - get_raw_nnet_from_am=False) + get_raw_nnet_from_am=False, + extra_egs_copy_cmd=args.extra_egs_copy_cmd) if args.cleanup: logger.info("Cleaning up the experiment directory " diff --git a/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py b/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py index c55dae18b19..5edd3303942 100755 --- a/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py +++ b/egs/wsj/s5/steps/nnet3/xconfig_to_configs.py @@ -25,6 +25,9 @@ def get_args(): help='Filename of input xconfig file') parser.add_argument('--config-dir', required=True, help='Directory to write config files and variables') + parser.add_argument('--nnet-edits', type=str, default=None, + action=common_lib.NullstrToNoneAction, + help="Edit network before getting nnet3-info") print(' '.join(sys.argv)) @@ -187,13 +190,19 @@ def write_config_files(config_dir, all_layers): raise -def add_back_compatibility_info(config_dir): +def add_back_compatibility_info(config_dir, nnet_edits=None): """This will be removed when python script refactoring is done.""" common_lib.run_kaldi_command("nnet3-init {0}/ref.config " "{0}/ref.raw".format(config_dir)) - out, err = common_lib.run_kaldi_command("nnet3-info {0}/ref.raw | " - "head -4".format(config_dir)) + model = "{0}/ref.raw".format(config_dir) + if nnet_edits is not None: + model = """nnet3-copy --edits='{0}' {1} - |""".format(nnet_edits, + model) + + print("""nnet3-info "{0}" | head -4""".format(model), file=sys.stderr) + out, err = common_lib.run_kaldi_command("""nnet3-info "{0}" | """ + """head -4""".format(model)) # out looks like this # left-context: 7 # right-context: 0 @@ -226,7 +235,7 @@ def main(): all_layers = xparser.read_xconfig_file(args.xconfig_file) write_expanded_xconfig_files(args.config_dir, all_layers) write_config_files(args.config_dir, all_layers) - add_back_compatibility_info(args.config_dir) + add_back_compatibility_info(args.config_dir, args.nnet_edits) if __name__ == '__main__': diff --git a/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh b/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh index f27baecd673..2f55053efd5 100755 --- a/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh +++ b/egs/wsj/s5/steps/online/nnet2/extract_ivectors.sh @@ -172,8 +172,8 @@ if [ $sub_speaker_frames -gt 0 ]; then feat-to-len scp:$data/feats.scp ark,t:- > $dir/utt_counts || exit 1; fi if ! [ $(wc -l <$dir/utt_counts) -eq $(wc -l <$data/feats.scp) ]; then - echo "$0: error getting per-utterance counts." - exit 0; + echo "$0: error getting per-utterance counts. Number of lines in $dir/utt_counts differs from $data/feats.scp" + exit 1; fi cat $data/spk2utt | python -c " import sys @@ -229,8 +229,8 @@ if [ $stage -le 2 ]; then if [ ! -z "$ali_or_decode_dir" ]; then $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ gmm-global-get-post --n=$num_gselect --min-post=$min_post $srcdir/final.dubm "$gmm_feats" ark:- \| \ - weight-post ark:- "ark,s,cs:gunzip -c $dir/weights.gz|" ark:- \| \ - ivector-extract --acoustic-weight=$posterior_scale --compute-objf-change=true \ + weight-post --length-tolerance=1 ark:- "ark,s,cs:gunzip -c $dir/weights.gz|" ark:- \| \ + ivector-extract --length-tolerance=1 --acoustic-weight=$posterior_scale --compute-objf-change=true \ --max-count=$max_count --spk2utt=ark:$this_sdata/JOB/spk2utt \ $srcdir/final.ie "$feats" ark,s,cs:- ark,t:$dir/ivectors_spk.JOB.ark || exit 1; else diff --git a/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh b/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh index b52de1f516b..f1edd874fa6 100755 --- a/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh +++ b/egs/wsj/s5/steps/online/nnet2/extract_ivectors_online.sh @@ -42,6 +42,9 @@ max_count=0 # The use of this option (e.g. --max-count 100) can make # posterior-scaling, so assuming the posterior-scale is 0.1, # --max-count 100 starts having effect after 1000 frames, or # 10 seconds of data. +weights= +use_most_recent_ivector=true +max_remembered_frames=1000 # End configuration section. @@ -89,6 +92,8 @@ splice_opts=$(cat $srcdir/splice_opts) # involved in online decoding. We need to create a config file for iVector # extraction. +absdir=$(readlink -f $dir) + ieconf=$dir/conf/ivector_extractor.conf echo -n >$ieconf cp $srcdir/online_cmvn.conf $dir/conf/ || exit 1; @@ -103,12 +108,19 @@ echo "--ivector-extractor=$srcdir/final.ie" >>$ieconf echo "--num-gselect=$num_gselect" >>$ieconf echo "--min-post=$min_post" >>$ieconf echo "--posterior-scale=$posterior_scale" >>$ieconf -echo "--max-remembered-frames=1000" >>$ieconf # the default +echo "--max-remembered-frames=$max_remembered_frames" >>$ieconf # the default echo "--max-count=$max_count" >>$ieconf +echo "--use-most-recent-ivector=$use_most_recent_ivector" >>$use_most_recent_ivector +if [ ! -z "$weights" ]; then + if [ -f $weights ] && gunzip -c $weights > /dev/null; then + cp -f $weights $absdir/weights.gz || exit 1 + else + echo "Could not open file $weights" + exit 1 + fi +fi -absdir=$(readlink -f $dir) - for n in $(seq $nj); do # This will do nothing unless the directory $dir/storage exists; # it can be used to distribute the data among multiple machines. @@ -117,10 +129,21 @@ done if [ $stage -le 0 ]; then echo "$0: extracting iVectors" - $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ - ivector-extract-online2 --config=$ieconf ark:$sdata/JOB/spk2utt scp:$sdata/JOB/feats.scp ark:- \| \ - copy-feats --compress=$compress ark:- \ + if [ ! -z "$weights" ]; then + $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ + ivector-extract-online2 --config=$ieconf \ + --frame-weights-rspecifier="ark:gunzip -c $absdir/weights.gz |" \ + --length-tolerance=1 \ + ark:$sdata/JOB/spk2utt scp:$sdata/JOB/feats.scp ark:- \| \ + copy-feats --compress=$compress ark:- \ + ark,scp:$absdir/ivector_online.JOB.ark,$absdir/ivector_online.JOB.scp || exit 1; + else + $cmd JOB=1:$nj $dir/log/extract_ivectors.JOB.log \ + ivector-extract-online2 --config=$ieconf \ + ark:$sdata/JOB/spk2utt scp:$sdata/JOB/feats.scp ark:- \| \ + copy-feats --compress=$compress ark:- \ ark,scp:$absdir/ivector_online.JOB.ark,$absdir/ivector_online.JOB.scp || exit 1; + fi fi if [ $stage -le 1 ]; then diff --git a/egs/wsj/s5/steps/segmentation/convert_ali_to_vec.pl b/egs/wsj/s5/steps/segmentation/convert_ali_to_vec.pl new file mode 100755 index 00000000000..c0d1a9eeae2 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/convert_ali_to_vec.pl @@ -0,0 +1,17 @@ +#! /usr/bin/perl + +# Converts a kaldi integer vector in text format to +# a kaldi vector in text format by adding a pair +# of square brackets around the data. +# Assumes the first column to be the utterance id. + +while (<>) { + chomp; + my @F = split; + + printf ("$F[0] [ "); + for (my $i = 1; $i <= $#F; $i++) { + printf ("$F[$i] "); + } + print ("]"); +} diff --git a/egs/wsj/s5/steps/segmentation/convert_rttm_to_utt2spk_and_segments.py b/egs/wsj/s5/steps/segmentation/convert_rttm_to_utt2spk_and_segments.py new file mode 100755 index 00000000000..23dc5a14f09 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/convert_rttm_to_utt2spk_and_segments.py @@ -0,0 +1,79 @@ +#! /usr/bin/env python + +"""This script converts an RTTM with +speaker info into kaldi utt2spk and segments""" + +import argparse + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script converts an RTTM with + speaker info into kaldi utt2spk and segments""") + parser.add_argument("--use-reco-id-as-spkr", type=str, + choices=["true", "false"], + help="Use the recording ID based on RTTM and " + "reco2file_and_channel as the speaker") + parser.add_argument("rttm_file", type=str, + help="""Input RTTM file. + The format of the RTTM file is + """ + """ """) + parser.add_argument("reco2file_and_channel", type=str, + help="""Input reco2file_and_channel. + The format is .""") + parser.add_argument("utt2spk", type=str, + help="Output utt2spk file") + parser.add_argument("segments", type=str, + help="Output segments file") + + args = parser.parse_args() + + args.use_reco_id_as_spkr = bool(args.use_reco_id_as_spkr == "true") + + return args + +def main(): + args = get_args() + + file_and_channel2reco = {} + for line in open(args.reco2file_and_channel): + parts = line.strip().split() + file_and_channel2reco[(parts[1], parts[2])] = parts[0] + + utt2spk_writer = open(args.utt2spk, 'w') + segments_writer = open(args.segments, 'w') + for line in open(args.rttm_file): + parts = line.strip().split() + if parts[0] != "SPEAKER": + continue + + file_id = parts[1] + channel = parts[2] + + try: + reco = file_and_channel2reco[(file_id, channel)] + except KeyError as e: + raise Exception("Could not find recording with " + "(file_id, channel) " + "= ({0},{1}) in {2}: {3}\n".format( + file_id, channel, + args.reco2file_and_channel, str(e))) + + start_time = float(parts[3]) + end_time = start_time + float(parts[4]) + + if args.use_reco_id_as_spkr: + spkr = reco + else: + spkr = parts[7] + + st = int(start_time * 100) + end = int(end_time * 100) + utt = "{0}-{1:06d}-{2:06d}".format(spkr, st, end) + + utt2spk_writer.write("{0} {1}\n".format(utt, spkr)) + segments_writer.write("{0} {1} {2:7.2f} {3:7.2f}\n".format( + utt, reco, start_time, end_time)) + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/segmentation/convert_utt2spk_and_segments_to_rttm.py b/egs/wsj/s5/steps/segmentation/convert_utt2spk_and_segments_to_rttm.py new file mode 100755 index 00000000000..1443259286b --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/convert_utt2spk_and_segments_to_rttm.py @@ -0,0 +1,65 @@ +#! /usr/bin/env python + +"""This script converts kaldi-style utt2spk and segments to an RTTM""" + +import argparse + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script converts kaldi-style utt2spk and + segments to an RTTM""") + + parser.add_argument("utt2spk", type=str, + help="Input utt2spk file") + parser.add_argument("segments", type=str, + help="Input segments file") + parser.add_argument("reco2file_and_channel", type=str, + help="""Input reco2file_and_channel. + The format is .""") + parser.add_argument("rttm_file", type=str, + help="Output RTTM file") + + args = parser.parse_args() + return args + +def main(): + args = get_args() + + reco2file_and_channel = {} + for line in open(args.reco2file_and_channel): + parts = line.strip().split() + reco2file_and_channel[parts[0]] = (parts[1], parts[2]) + + utt2spk = {} + with open(args.utt2spk, 'r') as utt2spk_reader: + for line in utt2spk_reader: + parts = line.strip().split() + utt2spk[parts[0]] = parts[1] + + with open(args.rttm_file, 'w') as rttm_writer: + for line in open(args.segments, 'r'): + parts = line.strip().split() + + utt = parts[0] + spkr = utt2spk[utt] + + reco = parts[1] + + try: + file_id, channel = reco2file_and_channel[reco] + except KeyError as e: + raise Exception("Could not find recording {0} in {1}: " + "{2}\n".format(reco, + args.reco2file_and_channel, + str(e))) + + start_time = float(parts[2]) + duration = float(parts[3]) - start_time + + rttm_writer.write("SPEAKER {0} {1} {2:7.2f} {3:7.2f} " + " {4} \n".format( + file_id, channel, start_time, + duration, spkr)) + +if __name__ == '__main__': + main() diff --git a/egs/wsj/s5/steps/segmentation/decode_sad.sh b/egs/wsj/s5/steps/segmentation/decode_sad.sh new file mode 100755 index 00000000000..9758d36e24e --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/decode_sad.sh @@ -0,0 +1,42 @@ +#! /bin/bash + +set -e +set -o pipefail + +cmd=run.pl +acwt=0.1 +beam=8 +max_active=1000 + +. path.sh + +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "Usage: $0 " + echo " e.g.: $0 " + exit 1 +fi + +graph_dir=$1 +log_likes_dir=$2 +dir=$3 + +nj=`cat $log_likes_dir/num_jobs` +echo $nj > $dir/num_jobs + +for f in $dir/trans.mdl $log_likes_dir/log_likes.1.gz $graph_dir/HCLG.fst; do + if [ ! -f $f ]; then + echo "$0: Could not find file $f" + fi +done + +decoder_opts+=(--acoustic-scale=$acwt --beam=$beam --max-active=$max_active) + +$cmd JOB=1:$nj $dir/log/decode.JOB.log \ + decode-faster-mapped ${decoder_opts[@]} \ + $dir/trans.mdl \ + $graph_dir/HCLG.fst "ark:gunzip -c $log_likes_dir/log_likes.JOB.gz |" \ + ark:/dev/null ark:- \| \ + ali-to-phones --per-frame $dir/trans.mdl ark:- \ + "ark:|gzip -c > $dir/ali.JOB.gz" diff --git a/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh b/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh new file mode 100755 index 00000000000..8f4ed60dfda --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/decode_sad_to_segments.sh @@ -0,0 +1,97 @@ +#! /bin/bash + +set -e +set -o pipefail +set -u + +stage=-1 +segmentation_config=conf/segmentation.conf +cmd=run.pl + +# Viterbi options +min_silence_duration=30 # minimum number of frames for silence +min_speech_duration=30 # minimum number of frames for speech +frame_subsampling_factor=1 +nonsil_transition_probability=0.1 +sil_transition_probability=0.1 +sil_prior=0.5 +speech_prior=0.5 + +# Decoding options +acwt=1 +beam=10 +max_active=7000 + +. utils/parse_options.sh + +if [ $# -ne 4 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/babel_bengali_dev10h exp/nnet3_sad_snr/tdnn_b_n4/sad_babel_bengali_dev10h exp/nnet3_sad_snr/tdnn_b_n4/segmentation_babel_bengali_dev10h exp/nnet3_sad_snr/tdnn_b_n4/segmentation_babel_bengali_dev10h/babel_bengali_dev10h.seg" + exit 1 +fi + +data=$1 +sad_likes_dir=$2 +dir=$3 +out_data=$4 + +t=sil${sil_prior}_sp${speech_prior} +lang=$dir/lang_test_${t} + +min_silence_duration=`perl -e "print (int($min_silence_duration / $frame_subsampling_factor))"` +min_speech_duration=`perl -e "print (int($min_speech_duration / $frame_subsampling_factor))"` + +if [ $stage -le 1 ]; then + mkdir -p $lang + + steps/segmentation/internal/prepare_sad_lang.py \ + --phone-transition-parameters="--phone-list=1 --min-duration=$min_silence_duration --end-transition-probability=$sil_transition_probability" \ + --phone-transition-parameters="--phone-list=2 --min-duration=$min_speech_duration --end-transition-probability=$nonsil_transition_probability" $lang + + cp $lang/phones.txt $lang/words.txt +fi + +feat_dim=2 # dummy. We don't need this. +if [ $stage -le 2 ]; then + $cmd $dir/log/create_transition_model.log gmm-init-mono \ + $lang/topo $feat_dim - $dir/tree \| \ + copy-transition-model --binary=false - $dir/trans.mdl || exit 1 +fi + +if [ $stage -le 3 ]; then + cat > $lang/word2prior < $lang/G.fst +fi + +graph_dir=$dir/graph_test_${t} + +if [ $stage -le 4 ]; then + $cmd $dir/log/make_vad_graph.log \ + steps/segmentation/internal/make_sad_graph.sh --iter trans \ + $lang $dir $dir/graph_test_${t} || exit 1 +fi + +if [ $stage -le 5 ]; then + steps/segmentation/decode_sad.sh \ + --acwt $acwt --beam $beam --max-active $max_active \ + $graph_dir $sad_likes_dir $dir +fi + +if [ $stage -le 6 ]; then + cat > $lang/phone2sad_map < 8kHz sampling frequency. +do_downsampling=false + +# Segmentation configs +min_silence_duration=30 +min_speech_duration=30 +segmentation_config=conf/segmentation_speech.conf + +echo $* + +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "Usage: $0 " + echo " e.g.: $0 ~/workspace/egs/ami/s5b/data/sdm1/dev data/ami_sdm1_dev exp/nnet3_sad_snr/nnet_tdnn_j_n4" + exit 1 +fi + +src_data_dir=$1 +data_dir=$2 +sad_nnet_dir=$3 + +affix=${affix:+_$affix} +feat_affix=${feat_affix:+_$feat_affix} + +data_id=`basename $data_dir` +sad_dir=${sad_nnet_dir}/${sad_name}${affix}_${data_id}_whole${feat_affix} +seg_dir=${sad_nnet_dir}/${segmentation_name}${affix}_${data_id}_whole${feat_affix} + +export PATH="$KALDI_ROOT/tools/sph2pipe_v2.5/:$PATH" +[ ! -z `which sph2pipe` ] + +if [ $stage -le 0 ]; then + utils/data/convert_data_dir_to_whole.sh $src_data_dir ${data_dir}_whole + + if $do_downsampling; then + freq=`cat $mfcc_config | perl -pe 's/\s*#.*//g' | grep "sample-frequency=" | awk -F'=' '{if (NF == 0) print 16000; else print $2}'` + sox=`which sox` + + cat $src_data_dir/wav.scp | python -c "import sys +for line in sys.stdin.readlines(): + splits = line.strip().split() + if splits[-1] == '|': + out_line = line.strip() + ' $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |' + else: + out_line = 'cat {0} {1} | $sox -t wav - -r $freq -c 1 -b 16 -t wav - downsample |'.format(splits[0], ' '.join(splits[1:])) + print (out_line)" > ${data_dir}_whole/wav.scp + fi + + utils/copy_data_dir.sh ${data_dir}_whole ${data_dir}_whole${feat_affix}_hires +fi + +test_data_dir=${data_dir}_whole${feat_affix}_hires + +if [ $stage -le 1 ]; then + steps/make_mfcc.sh --mfcc-config $mfcc_config --nj $reco_nj --cmd "$train_cmd" \ + ${data_dir}_whole${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} mfcc_hires + steps/compute_cmvn_stats.sh ${data_dir}_whole${feat_affix}_hires exp/make_hires/${data_id}_whole${feat_affix} mfcc_hires +fi + +post_vec=$sad_nnet_dir/post_${output_name}.vec +if [ ! -f $sad_nnet_dir/post_${output_name}.vec ]; then + echo "$0: Could not find $sad_nnet_dir/post_${output_name}.vec. See the last stage of local/segmentation/run_train_sad.sh" + exit 1 +fi + +if [ $stage -le 2 ]; then + steps/nnet3/compute_output.sh --nj $reco_nj --cmd "$train_cmd" \ + --post-vec "$post_vec" \ + --iter $iter \ + --extra-left-context $extra_left_context \ + --extra-right-context $extra_right_context \ + --frames-per-chunk 150 \ + --stage $sad_stage --output-name $output_name \ + --frame-subsampling-factor $frame_subsampling_factor \ + --get-raw-nnet-from-am false ${test_data_dir} $sad_nnet_dir $sad_dir +fi + +if [ $stage -le 3 ]; then + steps/segmentation/decode_sad_to_segments.sh \ + --frame-subsampling-factor $frame_subsampling_factor \ + --min-silence-duration $min_silence_duration \ + --min-speech-duration $min_speech_duration \ + --segmentation-config $segmentation_config --cmd "$train_cmd" \ + ${test_data_dir} $sad_dir $seg_dir $seg_dir/${data_id}_seg +fi + +# Subsegment data directory +if [ $stage -le 4 ]; then + rm $seg_dir/${data_id}_seg/feats.scp || true + utils/data/get_reco2num_frames.sh ${test_data_dir} + awk '{print $1" "$2}' ${seg_dir}/${data_id}_seg/segments | \ + utils/apply_map.pl -f 2 ${test_data_dir}/reco2num_frames > \ + $seg_dir/${data_id}_seg/utt2max_frames + + frame_shift_info=`cat $mfcc_config | steps/segmentation/get_frame_shift_info_from_config.pl` + utils/data/get_subsegment_feats.sh ${test_data_dir}/feats.scp \ + $frame_shift_info $seg_dir/${data_id}_seg/segments | \ + utils/data/fix_subsegmented_feats.pl ${seg_dir}/${data_id}_seg/utt2max_frames > \ + $seg_dir/${data_id}_seg/feats.scp + steps/compute_cmvn_stats.sh --fake $seg_dir/${data_id}_seg +fi diff --git a/egs/wsj/s5/steps/segmentation/evaluate_segmentation.pl b/egs/wsj/s5/steps/segmentation/evaluate_segmentation.pl new file mode 100755 index 00000000000..06a762d7762 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/evaluate_segmentation.pl @@ -0,0 +1,198 @@ +#!/usr/bin/env perl + +# Copyright 2014 Johns Hopkins University (Author: Sanjeev Khudanpur), Vimal Manohar +# Apache 2.0 + +################################################################################ +# +# This script was written to check the goodness of automatic segmentation tools +# It assumes input in the form of two Kaldi segments files, i.e. a file each of +# whose lines contain four space-separated values: +# +# UtteranceID FileID StartTime EndTime +# +# It computes # missed frames, # false positives and # overlapping frames. +# +################################################################################ + +if ($#ARGV == 1) { + $ReferenceSegmentation = $ARGV[0]; + $HypothesizedSegmentation = $ARGV[1]; + printf STDERR ("Comparing reference segmentation\n\t%s\nwith proposed segmentation\n\t%s\n", + $ReferenceSegmentation, + $HypothesizedSegmentation); +} else { + printf STDERR "This program compares the reference segmenation with the proposted segmentation\n"; + printf STDERR "Usage: $0 reference_segments_filename proposed_segments_filename\n"; + printf STDERR "e.g. $0 data/dev10h/segments data/dev10h.seg/segments\n"; + exit (0); +} + +################################################################################ +# First read the reference segmentation, and +# store the start- and end-times of all segments in each file. +################################################################################ + +open (SEGMENTS, "cat $ReferenceSegmentation | sort -k2,2 -k3n,3 -k4n,4 |") + || die "Unable to open $ReferenceSegmentation"; +$numLines = 0; +while ($line=) { + chomp $line; + @field = split("[ \t]+", $line); + unless ($#field == 3) { + exit (1); + printf STDERR "Skipping unparseable line in file $ReferenceSegmentation\n\t$line\n"; + next; + } + $fileID = $field[1]; + unless (exists $firstSeg{$fileID}) { + $firstSeg{$fileID} = $numLines; + $actualSpeech{$fileID} = 0.0; + $hypothesizedSpeech{$fileID} = 0.0; + $foundSpeech{$fileID} = 0.0; + $falseAlarm{$fileID} = 0.0; + $minStartTime{$fileID} = 0.0; + $maxEndTime{$fileID} = 0.0; + } + $refSegName[$numLines] = $field[0]; + $refSegStart[$numLines] = $field[2]; + $refSegEnd[$numLines] = $field[3]; + $actualSpeech{$fileID} += ($field[3]-$field[2]); + $minStartTime{$fileID} = $field[2] if ($minStartTime{$fileID}>$field[2]); + $maxEndTime{$fileID} = $field[3] if ($maxEndTime{$fileID}<$field[3]); + $lastSeg{$fileID} = $numLines; + ++$numLines; +} +close(SEGMENTS); +print STDERR "Read $numLines segments from $ReferenceSegmentation\n"; + +################################################################################ +# Process hypothesized segments sequentially, and gather speech/nonspeech stats +################################################################################ + +open (SEGMENTS, "cat $HypothesizedSegmentation | sort -k2,2 -k1,1 |") + # Kaldi segments files are sorted by UtteranceID, but we re-sort them here + # so that all segments of a file are read together, sorted by start-time. + || die "Unable to open $HypothesizedSegmentation"; +$numLines = 0; +$totalHypSpeech = 0.0; +$totalFoundSpeech = 0.0; +$totalFalseAlarm = 0.0; +$numShortSegs = 0; +$numLongSegs = 0; +while ($line=) { + chomp $line; + @field = split("[ \t]+", $line); + unless ($#field == 3) { + exit (1); + printf STDERR "Skipping unparseable line in file $HypothesizedSegmentation\n\t$line\n"; + next; + } + $fileID = $field[1]; + $segStart = $field[2]; + $segEnd = $field[3]; + if (exists $firstSeg{$fileID}) { + # This FileID exists in the reference segmentation + # So gather statistics for this UtteranceID + $hypothesizedSpeech{$fileID} += ($segEnd-$segStart); + $totalHypSpeech += ($segEnd-$segStart); + if (($segStart>=$maxEndTime{$fileID}) || ($segEnd<=$minStartTime{$fileID})) { + # This entire segment is a false alarm + $falseAlarm{$fileID} += ($segEnd-$segStart); + $totalFalseAlarm += ($segEnd-$segStart); + } else { + # This segment may overlap one or more reference segments + $p = $firstSeg{$fileID}; + while ($refSegEnd[$p]<=$segStart) { + ++$p; + } + # The overlap, if any, begins at the reference segment p + $q = $lastSeg{$fileID}; + while ($refSegStart[$q]>=$segEnd) { + --$q; + } + # The overlap, if any, ends at the reference segment q + if ($q<$p) { + # This segment sits entirely in the nonspeech region + # between the two reference speech segments q and p + $falseAlarm{$fileID} += ($segEnd-$segStart); + $totalFalseAlarm += ($segEnd-$segStart); + } else { + if (($segEnd-$segStart)<0.20) { + # For diagnosing Pascal's VAD segmentation + print STDOUT "Found short speech region $line\n"; + ++$numShortSegs; + } elsif (($segEnd-$segStart)>60.0) { + ++$numLongSegs; + # For diagnosing Pascal's VAD segmentation + print STDOUT "Found long speech region $line\n"; + } + # There is some overlap with segments p through q + for ($s=$p; $s<=$q; ++$s) { + if ($segStart<$refSegStart[$s]) { + # There is a leading false alarm portion before s + $falseAlarm{$fileID} += ($refSegStart[$s]-$segStart); + $totalFalseAlarm += ($refSegStart[$s]-$segStart); + $segStart=$refSegStart[$s]; + } + $speechPortion = ($refSegEnd[$s]<$segEnd) ? + ($refSegEnd[$s]-$segStart) : ($segEnd-$segStart); + $foundSpeech{$fileID} += $speechPortion; + $totalFoundSpeech += $speechPortion; + $segStart=$refSegEnd[$s]; + } + if ($segEnd>$segStart) { + # There is a trailing false alarm portion after q + $falseAlarm{$fileID} += ($segEnd-$segStart); + $totalFalseAlarm += ($segEnd-$segStart); + } + } + } + } else { + # This FileID does not exist in the reference segmentation + # So all this speech counts as a false alarm + exit (1); + printf STDERR ("Unexpected fileID in hypothesized segments: %s", $fileID); + $totalFalseAlarm += ($segEnd-$segStart); + } + ++$numLines; +} +close(SEGMENTS); +print STDERR "Read $numLines segments from $HypothesizedSegmentation\n"; + +################################################################################ +# Now that all hypothesized segments have been processed, compute needed stats +################################################################################ + +$totalActualSpeech = 0.0; +$totalNonSpeechEst = 0.0; # This is just a crude estimate of total nonspeech. +foreach $fileID (sort keys %actualSpeech) { + $totalActualSpeech += $actualSpeech{$fileID}; + $totalNonSpeechEst += $maxEndTime{$fileID} - $actualSpeech{$fileID}; + ####################################################################### + # Print file-wise statistics to STDOUT; can pipe to /dev/null is needed + ####################################################################### + printf STDOUT ("%s: %.2f min actual speech, %.2f min hypothesized: %.2f min overlap (%d\%), %.2f min false alarm (~%d\%)\n", + $fileID, + ($actualSpeech{$fileID}/60.0), + ($hypothesizedSpeech{$fileID}/60.0), + ($foundSpeech{$fileID}/60.0), + ($foundSpeech{$fileID}*100/($actualSpeech{$fileID}+0.01)), + ($falseAlarm{$fileID}/60.0), + ($falseAlarm{$fileID}*100/($maxEndTime{$fileID}-$actualSpeech{$fileID}+0.01))); +} + +################################################################################ +# Finally, we have everything needed to report the segmentation statistics. +################################################################################ + +printf STDERR ("------------------------------------------------------------------------\n"); +printf STDERR ("TOTAL: %.2f hrs actual speech, %.2f hrs hypothesized: %.2f hrs overlap (%d\%), %.2f hrs false alarm (~%d\%)\n", + ($totalActualSpeech/3600.0), + ($totalHypSpeech/3600.0), + ($totalFoundSpeech/3600.0), + ($totalFoundSpeech*100/($totalActualSpeech+0.000001)), + ($totalFalseAlarm/3600.0), + ($totalFalseAlarm*100/($totalNonSpeechEst+0.000001))); +printf STDERR ("\t$numShortSegs segments < 0.2 sec and $numLongSegs segments > 60.0 sec\n"); +printf STDERR ("------------------------------------------------------------------------\n"); diff --git a/egs/wsj/s5/steps/segmentation/get_frame_shift_info_from_config.pl b/egs/wsj/s5/steps/segmentation/get_frame_shift_info_from_config.pl new file mode 100755 index 00000000000..79a42aa9852 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/get_frame_shift_info_from_config.pl @@ -0,0 +1,21 @@ +#! /usr/bin/perl +use strict; +use warnings; + +# This script parses a features config file such as conf/mfcc.conf +# and returns the pair of values frame_shift and frame_overlap in seconds. + +my $frame_shift = 0.01; +my $frame_overlap = 0.015; + +while (<>) { + if (m/--frame-length=(\d+)/) { + $frame_shift = $1 / 1000; + } + + if (m/--window-length=(\d+)/) { + $frame_overlap = $1 / 1000 - $frame_shift; + } +} + +print "$frame_shift $frame_overlap\n"; diff --git a/egs/wsj/s5/steps/segmentation/get_reverb_scp.pl b/egs/wsj/s5/steps/segmentation/get_reverb_scp.pl new file mode 100755 index 00000000000..57f63b517f2 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/get_reverb_scp.pl @@ -0,0 +1,58 @@ +#! /usr/bin/perl +use strict; +use warnings; + +my $field_begin = -1; +my $field_end = -1; + +if ($ARGV[0] eq "-f") { + shift @ARGV; + my $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; + } + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; + } +} + +if (scalar @ARGV != 1 && scalar @ARGV != 2 ) { + print "Usage: get_reverb_scp.pl [-f -] [] < input_scp > output_scp\n"; + exit(1); +} + +my $num_reps = $ARGV[0]; +my $prefix = "rev"; + +if (scalar @ARGV == 2) { + $prefix = $ARGV[1]; +} + +while () { + chomp; + my @A = split; + + for (my $i = 1; $i <= $num_reps; $i++) { + for (my $pos = 0; $pos <= $#A; $pos++) { + my $a = $A[$pos]; + if ( ($field_begin < 0 || $pos >= $field_begin) + && ($field_end < 0 || $pos <= $field_end) ) { + if ($a =~ m/^(sp[0-9.]+-)(.+)$/) { + $a = $1 . "$prefix" . $i . "_" . $2; + } else { + $a = "$prefix" . $i . "_" . $a; + } + } + print $a . " "; + } + print "\n"; + } +} diff --git a/egs/wsj/s5/steps/segmentation/get_sad_map.py b/egs/wsj/s5/steps/segmentation/get_sad_map.py new file mode 100755 index 00000000000..9160503c7ad --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/get_sad_map.py @@ -0,0 +1,156 @@ +#! /usr/bin/env python + +"""This script prints a mapping from phones to speech +activity labels +0 for silence, 1 for speech, 2 for noise and 3 for OOV. +Other labels can be optionally defined. +e.g. If 1, 2 and 3 are silence phones, 4, 5 and 6 are speech phones, +the SAD map would be +1 0 +2 0 +3 0 +4 1 +5 1 +6 1. +The silence and speech are read from the phones/silence.txt and +phones/nonsilence.txt from the lang directory. +An initial SAD map can be provided using --init-sad-map to override +the above default mapping of phones. This is useful to say map + or noise phones to separate SAD labels. +""" + +import argparse + + +class StrToBoolAction(argparse.Action): + """ A custom action to convert bools from shell format i.e., true/false + to python format i.e., True/False """ + def __call__(self, parser, namespace, values, option_string=None): + try: + if values == "true": + setattr(namespace, self.dest, True) + elif values == "true": + setattr(namespace, self.dest, False) + else: + raise ValueError + except ValueError: + raise Exception("Unknown value {0} for --{1}".format(values, + self.dest)) + + +class NullstrToNoneAction(argparse.Action): + """ A custom action to convert empty strings passed by shell + to None in python. This is necessary as shell scripts print null + strings when a variable is not specified. We could use the more apt + None in python. """ + def __call__(self, parser, namespace, values, option_string=None): + if values.strip() == "": + setattr(namespace, self.dest, None) + else: + setattr(namespace, self.dest, values) + + +def get_args(): + parser = argparse.ArgumentParser( + description="""This script prints a mapping from phones to speech + activity labels + 0 for silence, 1 for speech, 2 for noise and 3 for OOV. + Other labels can be optionally defined. + e.g. If 1, 2 and 3 are silence phones, 4, 5 and 6 are speech phones, + the SAD map would be + 1 0 + 2 0 + 3 0 + 4 1 + 5 1 + 6 1. + The silence and speech are read from the phones/silence.txt and + phones/nonsilence.txt from the lang directory. + An initial SAD map can be provided using --init-sad-map to override + the above default mapping of phones. This is useful to say map + or noise phones to separate SAD labels. + """) + + parser.add_argument("--init-sad-map", type=str, action=NullstrToNoneAction, + help="""Initial SAD map that will be used to override + the default mapping using phones/silence.txt and + phones/nonsilence.txt. Does not need to specify labels + for all the phones. + e.g. + 3 + 2""") + + noise_group = parser.add_mutually_exclusive_group() + noise_group.add_argument("--noise-phones-file", type=str, + action=NullstrToNoneAction, + help="Map noise phones from file to label 2") + noise_group.add_argument("--noise-phones-list", type=str, + action=NullstrToNoneAction, + help="A colon-separated list of noise phones to " + "map to label 2") + parser.add_argument("--unk", type=str, action=NullstrToNoneAction, + help="""UNK phone, if provided will be mapped to + label 3""") + + parser.add_argument("--map-noise-to-sil", type=str, + action=StrToBoolAction, + choices=["true", "false"], default=False, + help="""Map noise phones to silence before writing the + map. i.e. anything with label 2 is mapped to + label 0.""") + parser.add_argument("--map-unk-to-speech", type=str, + action=StrToBoolAction, + choices=["true", "false"], default=False, + help="""Map UNK phone to speech before writing the map + i.e. anything with label 3 is mapped to label 1.""") + + parser.add_argument("lang_dir") + + args = parser.parse_args() + + return args + + +def main(): + args = get_args() + + sad_map = {} + + for line in open('{0}/phones/nonsilence.txt'.format(args.lang_dir)): + parts = line.strip().split() + sad_map[parts[0]] = 1 + + for line in open('{0}/phones/silence.txt'.format(args.lang_dir)): + parts = line.strip().split() + sad_map[parts[0]] = 0 + + if args.init_sad_map is not None: + for line in open(args.init_sad_map): + parts = line.strip().split() + try: + sad_map[parts[0]] = int(parts[1]) + except Exception: + raise Exception("Invalid line " + line) + + if args.unk is not None: + sad_map[args.unk] = 3 + + noise_phones = {} + if args.noise_phones_file is not None: + for line in open(args.noise_phones_file): + parts = line.strip().split() + noise_phones[parts[0]] = 1 + + if args.noise_phones_list is not None: + for x in args.noise_phones_list.split(":"): + noise_phones[x] = 1 + + for x, l in sad_map.iteritems(): + if l == 2 and args.map_noise_to_sil: + l = 0 + if l == 3 and args.map_unk_to_speech: + l = 1 + print ("{0} {1}".format(x, l)) + +if __name__ == "__main__": + main() diff --git a/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh b/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh new file mode 100755 index 00000000000..353e6d4664e --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/convert_ali_to_vad.sh @@ -0,0 +1,54 @@ +#! /bin/bash + +set -o pipefail +set -e +set -u + +. path.sh + +cmd=run.pl + +frame_shift=0.01 +frame_subsampling_factor=1 + +. parse_options.sh + +if [ $# -ne 4 ]; then + echo "This script converts the alignment in the alignment directory " + echo "to speech activity segments based on the provided phone-map." + echo "Usage: $0 exp/tri3_ali data/lang/phones/sad.map exp/tri3_ali_vad" + exit 1 +fi + +ali_dir=$1 +phone_map=$2 +dir=$3 + +for f in $phone_map $ali_dir/ali.1.gz; do + [ ! -f $f ] && echo "$0: Could not find $f" && exit 1 +done + +mkdir -p $dir + +nj=`cat $ali_dir/num_jobs` || exit 1 +echo $nj > $dir/num_jobs + +if [ -f $ali_dir/frame_subsampling_factor ]; then + frame_subsampling_factor=`cat $ali_dir/frame_subsampling_factor` +fi + +ali_frame_shift=`perl -e "print ($frame_shift * $frame_subsampling_factor);"` +ali_frame_overlap=`perl -e "print ($ali_frame_shift * 1.5);"` + +dir=`perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' $dir ${PWD}` + +$cmd JOB=1:$nj $dir/log/get_sad.JOB.log \ + segmentation-init-from-ali \ + "ark:gunzip -c ${ali_dir}/ali.JOB.gz | ali-to-phones --per-frame ${ali_dir}/final.mdl ark:- ark:- |" \ + ark:- \| segmentation-copy --label-map=$phone_map ark:- ark:- \| \ + segmentation-post-process --merge-adjacent-segments ark:- \ + ark,scp:$dir/sad_seg.JOB.ark,$dir/sad_seg.JOB.scp + +for n in `seq $nj`; do + cat $dir/sad_seg.$n.scp +done | sort -k1,1 > $dir/sad_seg.scp diff --git a/egs/wsj/s5/steps/segmentation/internal/make_G_fst.py b/egs/wsj/s5/steps/segmentation/internal/make_G_fst.py new file mode 100755 index 00000000000..5ad7e867d10 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/make_G_fst.py @@ -0,0 +1,52 @@ +#! /usr/bin/env python + +from __future__ import print_function +import argparse, math + +def ParseArgs(): + parser = argparse.ArgumentParser("""Make a simple unigram FST for +decoding for segmentation purpose.""") + + parser.add_argument("--word2prior-map", type=str, required=True, + help = "A file with priors for different words") + parser.add_argument("--end-probability", type=float, default=0.01, + help = "Ending probability") + + args = parser.parse_args() + + return args + +def ReadMap(map_file): + out_map = {} + sum_prob = 0 + for line in open(map_file): + parts = line.strip().split() + if len(parts) == 0: + continue + if len(parts) != 2: + raise Exception("Invalid line {0} in {1}".format(line.strip(), map_file)) + + if parts[0] in out_map: + raise Exception("Duplicate entry of {0} in {1}".format(parts[0], map_file)) + + prob = float(parts[1]) + out_map[parts[0]] = prob + + sum_prob += prob + + return (out_map, sum_prob) + +def Main(): + args = ParseArgs() + + word2prior, sum_prob = ReadMap(args.word2prior_map) + sum_prob += args.end_probability + + for w,p in word2prior.iteritems(): + print ("0 0 {word} {word} {log_p}".format(word = w, + log_p = -math.log(p / sum_prob))) + print ("0 {log_p}".format(word = w, + log_p = -math.log(args.end_probability / sum_prob))) + +if __name__ == '__main__': + Main() diff --git a/egs/wsj/s5/steps/segmentation/internal/make_sad_graph.sh b/egs/wsj/s5/steps/segmentation/internal/make_sad_graph.sh new file mode 100755 index 00000000000..5edb3eb2bb6 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/make_sad_graph.sh @@ -0,0 +1,83 @@ +#!/bin/bash + +# Copyright 2016 Vimal Manohar + +# Begin configuration section. +stage=0 +cmd=run.pl +iter=final # use $iter.mdl from $model_dir +tree=tree +tscale=1.0 # transition scale. +loopscale=0.1 # scale for self-loops. +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +[ -f ./path.sh ] && . ./path.sh; # source the path. +. parse_options.sh || exit 1; + +if [ $# -ne 3 ]; then + echo "Usage: $0 [options] " + echo " e.g.: $0 exp/vad_dev/lang exp/vad_dev exp/vad_dev/graph" + echo "Makes the graph in \$dir, corresponding to the model in \$model_dir" + exit 1; +fi + +lang=$1 +model=$2/$iter.mdl +tree=$2/$tree +dir=$3 + +for f in $lang/G.fst $model $tree; do + if [ ! -f $f ]; then + echo "$0: expected $f to exist" + exit 1; + fi +done + +mkdir -p $dir $lang/tmp + +clg=$lang/tmp/CLG.fst + +if [[ ! -s $clg || $clg -ot $lang/G.fst ]]; then + echo "$0: creating CLG." + + fstcomposecontext --context-size=1 --central-position=0 \ + $lang/tmp/ilabels < $lang/G.fst | \ + fstarcsort --sort_type=ilabel > $clg + fstisstochastic $clg || echo "[info]: CLG not stochastic." +fi + +if [[ ! -s $dir/Ha.fst || $dir/Ha.fst -ot $model || $dir/Ha.fst -ot $lang/tmp/ilabels ]]; then + make-h-transducer --disambig-syms-out=$dir/disambig_tid.int \ + --transition-scale=$tscale $lang/tmp/ilabels $tree $model \ + > $dir/Ha.fst || exit 1; +fi + +if [[ ! -s $dir/HCLGa.fst || $dir/HCLGa.fst -ot $dir/Ha.fst || $dir/HCLGa.fst -ot $clg ]]; then + fsttablecompose $dir/Ha.fst $clg | fstdeterminizestar --use-log=true \ + | fstrmsymbols $dir/disambig_tid.int | fstrmepslocal | \ + fstminimizeencoded > $dir/HCLGa.fst || exit 1; + fstisstochastic $dir/HCLGa.fst || echo "HCLGa is not stochastic" +fi + +if [[ ! -s $dir/HCLG.fst || $dir/HCLG.fst -ot $dir/HCLGa.fst ]]; then + add-self-loops --self-loop-scale=$loopscale --reorder=true \ + $model < $dir/HCLGa.fst > $dir/HCLG.fst || exit 1; + + if [ $tscale == 1.0 -a $loopscale == 1.0 ]; then + # No point doing this test if transition-scale not 1, as it is bound to fail. + fstisstochastic $dir/HCLG.fst || echo "[info]: final HCLG is not stochastic." + fi +fi + +# keep a copy of the lexicon and a list of silence phones with HCLG... +# this means we can decode without reference to the $lang directory. + +cp $lang/words.txt $dir/ || exit 1; +cp $lang/phones.txt $dir/ 2> /dev/null # ignore the error if it's not there. + +# to make const fst: +# fstconvert --fst_type=const $dir/HCLG.fst $dir/HCLG_c.fst +am-info --print-args=false $model | grep pdfs | awk '{print $NF}' > $dir/num_pdfs + diff --git a/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh b/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh new file mode 100755 index 00000000000..e37d5dc2f62 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/post_process_segments.sh @@ -0,0 +1,100 @@ +#! /bin/bash + +# Copyright 2015-16 Vimal Manohar +# Apache 2.0. + +set -e +set -o pipefail +set -u + +. path.sh + +cmd=run.pl +stage=-10 + +# General segmentation options +pad_length=50 # Pad speech segments by this many frames on either side +max_blend_length=10 # Maximum duration of speech that will be removed as part + # of smoothing process. This is only if there are no other + # speech segments nearby. +max_intersegment_length=50 # Merge nearby speech segments if the silence + # between them is less than this many frames. +post_pad_length=50 # Pad speech segments by this many frames on either side + # after the merging process using max_intersegment_length +max_segment_length=1000 # Segments that are longer than this are split into + # overlapping frames. +overlap_length=100 # Overlapping frames when segments are split. + # See the above option. +min_silence_length=30 # Min silence length at which to split very long segments + +frame_shift=0.01 + +. utils/parse_options.sh + +if [ $# -ne 3 ]; then + echo "This script post-processes a speech activity segmentation to create " + echo "a kaldi-style data directory." + echo "See the comments for the kind of post-processing options." + echo "Usage: $0 " + echo " e.g.: $0 data/dev_aspire_whole exp/vad_dev_aspire data/dev_aspire_seg" + exit 1 +fi + +data_dir=$1 +dir=$2 +segmented_data_dir=$3 + +for f in $dir/orig_segmentation.1.gz $data_dir/segments; do + if [ ! -f $f ]; then + echo "$0: Could not find $f" + exit 1 + fi +done + +nj=`cat $dir/num_jobs` || exit 1 + +[ $pad_length -eq -1 ] && pad_length= +[ $post_pad_length -eq -1 ] && post_pad_length= +[ $max_blend_length -eq -1 ] && max_blend_length= + +if [ $stage -le 2 ]; then + # Post-process the orignal SAD segmentation using the following steps: + # 1) blend short speech segments of less than $max_blend_length frames + # into silence + # 2) Remove all silence frames and widen speech segments by padding + # $pad_length frames + # 3) Merge adjacent segments that have an intersegment length of less than + # $max_intersegment_length frames + # 4) Widen speech segments again after merging + # 5) Split segments into segments of $max_segment_length at the point where + # the original segmentation had silence + # 6) Split segments into overlapping segments of max length + # $max_segment_length and overlap $overlap_length + # 7) Convert segmentation to kaldi segments and utt2spk + $cmd JOB=1:$nj $dir/log/post_process_segmentation.JOB.log \ + gunzip -c $dir/orig_segmentation.JOB.gz \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=0 ark:- ark:- \| \ + segmentation-post-process ${max_blend_length:+--max-blend-length=$max_blend_length --blend-short-segments-class=1} ark:- ark:- \| \ + segmentation-post-process --remove-labels=0 ${pad_length:+--pad-label=1 --pad-length=$pad_length} ark:- ark:- \| \ + segmentation-post-process --merge-adjacent-segments --max-intersegment-length=$max_intersegment_length ark:- ark:- \| \ + segmentation-post-process ${post_pad_length:+--pad-label=1 --pad-length=$post_pad_length} ark:- ark:- \| \ + segmentation-split-segments --alignments="ark,s,cs:gunzip -c $dir/orig_segmentation.JOB.gz | segmentation-to-ali ark:- ark:- |" \ + --max-segment-length=$max_segment_length --min-alignment-chunk-length=$min_silence_length --ali-label=0 ark:- ark:- \| \ + segmentation-split-segments \ + --max-segment-length=$max_segment_length --overlap-length=$overlap_length ark:- ark:- \| \ + segmentation-to-segments --frame-shift=$frame_shift ark:- \ + ark,t:$dir/utt2spk.JOB $dir/segments.JOB || exit 1 +fi + +for n in `seq $nj`; do + cat $dir/utt2spk.$n +done > $segmented_data_dir/utt2spk + +for n in `seq $nj`; do + cat $dir/segments.$n +done > $segmented_data_dir/segments + +if [ ! -s $segmented_data_dir/utt2spk ] || [ ! -s $segmented_data_dir/segments ]; then + echo "$0: Segmentation failed to generate segments or utt2spk!" + exit 1 +fi diff --git a/egs/wsj/s5/steps/segmentation/internal/prepare_sad_lang.py b/egs/wsj/s5/steps/segmentation/internal/prepare_sad_lang.py new file mode 100755 index 00000000000..17b039015d2 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/internal/prepare_sad_lang.py @@ -0,0 +1,94 @@ +#! /usr/bin/env python + +from __future__ import print_function +import argparse, shlex + +def GetArgs(): + parser = argparse.ArgumentParser(description="""This script generates a lang +directory for purpose of segmentation. It takes as arguments the list of phones, +the corresponding min durations and end transition probability.""") + + parser.add_argument("--phone-transition-parameters", dest='phone_transition_para_array', + type=str, action='append', required = True, + help = "Options to build topology. \n" + "--phone-list= # Colon-separated list of phones\n" + "--min-duration= # Min duration for the phones\n" + "--end-transition-probability= # Probability of the end transition after the minimum duration\n") + parser.add_argument("dir", type=str, + help = "Output lang directory") + args = parser.parse_args() + return args + + +def ParsePhoneTransitionParameters(para_array): + parser = argparse.ArgumentParser() + + parser.add_argument("--phone-list", type=str, required=True, + help="Colon-separated list of phones") + parser.add_argument("--min-duration", type=int, default=3, + help="Minimum number of states for the phone") + parser.add_argument("--end-transition-probability", type=float, default=0.1, + help="Probability of the end transition after the minimum duration") + + phone_transition_parameters = [ parser.parse_args(shlex.split(x)) for x in para_array ] + + for t in phone_transition_parameters: + if (t.end_transition_probability > 1.0 or + t.end_transition_probability < 0.0): + raise ValueError("Expected --end-transition-probability to be " + "between 0 and 1, got {0} for phones {1}".format( + t.end_transition_probability, t.phone_list)) + if t.min_duration > 100 or t.min_duration < 1: + raise ValueError("Expected --min-duration to be " + "between 1 and 100, got {0} for phones {1}".format( + t.min_duration, t.phone_list)) + + t.phone_list = t.phone_list.split(":") + + return phone_transition_parameters + +def GetPhoneMap(phone_transition_parameters): + phone2int = {} + n = 1 + for t in phone_transition_parameters: + for p in t.phone_list: + if p in phone2int: + raise Exception("Phone {0} found in multiple topologies".format(p)) + phone2int[p] = n + n += 1 + + return phone2int + +def Main(): + args = GetArgs() + phone_transition_parameters = ParsePhoneTransitionParameters(args.phone_transition_para_array) + + phone2int = GetPhoneMap(phone_transition_parameters) + + topo = open("{0}/topo".format(args.dir), 'w') + + print ("", file = topo) + + for t in phone_transition_parameters: + print ("", file = topo) + print ("", file = topo) + print ("{0}".format(" ".join([str(phone2int[p]) for p in t.phone_list])), file = topo) + print ("", file = topo) + + for state in range(0, t.min_duration-1): + print(" {0} 0 {1} 1.0 ".format(state, state + 1), file = topo) + print(" {state} 0 {state} {self_prob} {next_state} {next_prob} ".format( + state = t.min_duration - 1, next_state = t.min_duration, + self_prob = 1 - t.end_transition_probability, + next_prob = t.end_transition_probability), file = topo) + print(" {state} ".format(state = t.min_duration), file = topo) # Final state + print ("", file = topo) + print ("", file = topo) + + phones_file = open("{0}/phones.txt".format(args.dir), 'w') + + for p,n in sorted(list(phone2int.items()), key = lambda x:x[1]): + print ("{0} {1}".format(p, n), file = phones_file) + +if __name__ == '__main__': + Main() diff --git a/egs/wsj/s5/steps/segmentation/invert_vector.pl b/egs/wsj/s5/steps/segmentation/invert_vector.pl new file mode 100755 index 00000000000..c16243a0b93 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/invert_vector.pl @@ -0,0 +1,20 @@ +#! /usr/bin/perl +use strict; +use warnings; + +while () { + chomp; + my @F = split; + my $utt = shift @F; + shift @F; + + print "$utt [ "; + for (my $i = 0; $i < $#F; $i++) { + if ($F[$i] == 0) { + print "1 "; + } else { + print 1.0/$F[$i] . " "; + } + } + print "]\n"; +} diff --git a/egs/wsj/s5/steps/segmentation/make_snr_targets.sh b/egs/wsj/s5/steps/segmentation/make_snr_targets.sh new file mode 100755 index 00000000000..71f603a690e --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/make_snr_targets.sh @@ -0,0 +1,104 @@ +#!/bin/bash + +# Copyright 2015-16 Vimal Manohar +# Apache 2.0 +set -e +set -o pipefail + +nj=4 +cmd=run.pl +stage=0 + +data_id= + +compress=true +target_type=Irm +apply_exp=false + +ali_rspecifier= +silence_phones_str=0 + +ignore_noise_dir=false + +ceiling=inf +floor=-inf + +length_tolerance=2 +transform_matrix= + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +. parse_options.sh || exit 1; + +if [ $# != 5 ]; then + echo "Usage: $0 [options] --target-type (Irm|Snr) "; + echo " or : $0 [options] --target-type FbankMask "; + echo "e.g.: $0 data/train_clean_fbank data/train_noise_fbank data/train_corrupted_hires exp/make_snr_targets/train snr_targets" + echo "options: " + echo " --nj # number of parallel jobs" + echo " --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs." + exit 1; +fi + +clean_data=$1 +noise_or_noisy_data=$2 +data=$3 +tmpdir=$4 +targets_dir=$5 + +mkdir -p $targets_dir + +[ -z "$data_id" ] && data_id=`basename $data` + +utils/split_data.sh $clean_data $nj + +for n in `seq $nj`; do + utils/subset_data_dir.sh --utt-list $clean_data/split$nj/$n/utt2spk $noise_or_noisy_data $noise_or_noisy_data/subset${nj}/$n +done + +$ignore_noise_dir && utils/split_data.sh $data $nj + +targets_dir=`perl -e '($data,$pwd)= @ARGV; if($data!~m:^/:) { $data = "$pwd/$data"; } print $data; ' $targets_dir ${PWD}` + +for n in `seq $nj`; do + utils/create_data_link.pl $targets_dir/${data_id}.$n.ark +done + +apply_exp_opts= +if $apply_exp; then + apply_exp_opts=" copy-matrix --apply-exp=true ark:- ark:- |" +fi + +copy_feats_opts="copy-feats" +if [ ! -z "$transform_matrix" ]; then + copy_feats_opts="transform-feats $transform_matrix" +fi + +if [ $stage -le 1 ]; then + if ! $ignore_noise_dir; then + $cmd JOB=1:$nj $tmpdir/make_`basename $targets_dir`_${data_id}.JOB.log \ + compute-snr-targets --length-tolerance=$length_tolerance --target-type=$target_type \ + ${ali_rspecifier:+--ali-rspecifier="$ali_rspecifier" --silence-phones=$silence_phones_str} \ + --floor=$floor --ceiling=$ceiling \ + "ark:$copy_feats_opts scp:$clean_data/split$nj/JOB/feats.scp ark:- |" \ + "ark,s,cs:$copy_feats_opts scp:$noise_or_noisy_data/subset$nj/JOB/feats.scp ark:- |" \ + ark:- \|$apply_exp_opts \ + copy-feats --compress=$compress ark:- \ + ark,scp:$targets_dir/${data_id}.JOB.ark,$targets_dir/${data_id}.JOB.scp || exit 1 + else + feat_dim=$(feat-to-dim scp:$data/feats.scp -) || exit 1 + $cmd JOB=1:$nj $tmpdir/make_`basename $targets_dir`_${data_id}.JOB.log \ + compute-snr-targets --length-tolerance=$length_tolerance --target-type=$target_type \ + ${ali_rspecifier:+--ali-rspecifier="$ali_rspecifier" --silence-phones=$silence_phones_str} \ + --floor=$floor --ceiling=$ceiling --binary-targets --target-dim=$feat_dim \ + scp:$data/split$nj/JOB/feats.scp \ + ark:- \|$apply_exp_opts \ + copy-feats --compress=$compress ark:- \ + ark,scp:$targets_dir/${data_id}.JOB.ark,$targets_dir/${data_id}.JOB.scp || exit 1 + fi +fi + +for n in `seq $nj`; do + cat $targets_dir/${data_id}.$n.scp +done > $data/`basename $targets_dir`.scp diff --git a/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh b/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh new file mode 100755 index 00000000000..c1006d09678 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/post_process_sad_to_segments.sh @@ -0,0 +1,130 @@ +#! /bin/bash + +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e -o pipefail -u +. path.sh + +cmd=run.pl +stage=-10 + +segmentation_config=conf/segmentation.conf +nj=18 + +frame_shift=0.01 +weight_threshold=0.5 +ali_suffix=_acwt0.1 + +frame_subsampling_factor=1 + +phone2sad_map= + +. utils/parse_options.sh + +if [ $# -ne 5 ] && [ $# -ne 4 ]; then + echo "This script converts an alignment directory containing per-frame SAD " + echo "labels or per-frame speech probabilities into kaldi-style " + echo "segmented data directory. " + echo "This script first converts the per-frame labels or weights into " + echo "segmentation and then calls " + echo "steps/segmentation/internal/post_process_sad_to_segments.sh, " + echo "which does the actual post-processing step." + echo "Usage: $0 ( |) " + echo " e.g.: $0 data/dev_aspire_whole exp/vad_dev_aspire data/dev_aspire_seg" + exit 1 +fi + +data_dir=$1 +vad_dir= + +if [ $# -eq 5 ]; then + lang=$2 + vad_dir=$3 + shift; shift; shift +else + weights_scp=$2 + shift; shift +fi + +dir=$1 +segmented_data_dir=$2 + +utils/data/get_reco2utt.sh $data_dir + +mkdir -p $dir + +if [ ! -z "$vad_dir" ]; then + nj=`cat $vad_dir/num_jobs` || exit 1 + + utils/split_data.sh $data_dir $nj + + for n in `seq $nj`; do + cat $data_dir/split$nj/$n/segments | awk '{print $1" "$2}' | \ + utils/utt2spk_to_spk2utt.pl > $data_dir/split$nj/$n/reco2utt + done + + if [ -z "$phone2sad_map" ]; then + phone2sad_map=$dir/phone2sad_map + + { + cat $lang/phones/silence.int | awk '{print $1" 0"}'; + cat $lang/phones/nonsilence.int | awk '{print $1" 1"}'; + } | sort -k1,1 -n > $dir/phone2sad_map + fi + + frame_shift_subsampled=`perl -e "print ($frame_subsampling_factor * $frame_shift)"` + + if [ $stage -le 0 ]; then + # Convert the original SAD into segmentation + $cmd JOB=1:$nj $dir/log/segmentation.JOB.log \ + segmentation-init-from-ali \ + "ark:gunzip -c $vad_dir/ali${ali_suffix}.JOB.gz |" ark:- \| \ + segmentation-combine-segments ark:- \ + "ark:segmentation-init-from-segments --shift-to-zero=false --frame-shift=$frame_shift_subsampled $data_dir/split$nj/JOB/segments ark:- |" \ + "ark,t:$data_dir/split$nj/JOB/reco2utt" ark:- \| \ + segmentation-copy --label-map=$phone2sad_map \ + --frame-subsampling-factor=$frame_subsampling_factor ark:- \ + "ark:| gzip -c > $dir/orig_segmentation.JOB.gz" + fi +else + utils/split_data.sh $data_dir $nj + + for n in `seq $nj`; do + utils/data/get_reco2utt.sh $data_dir/split$nj/$n + utils/filter_scp.pl $data_dir/split$nj/$n/reco2utt $weights_scp > \ + $dir/weights.$n.scp + done + + $cmd JOB=1:$nj $dir/log/weights_to_segments.JOB.log \ + copy-vector scp:$dir/weights.JOB.scp ark,t:- \| \ + awk -v t=$weight_threshold '{printf $1; for (i=3; i < NF; i++) { if ($i >= t) printf (" 1"); else printf (" 0"); }; print "";}' \| \ + segmentation-init-from-ali \ + ark,t:- ark:- \| segmentation-combine-segments ark:- \ + "ark:segmentation-init-from-segments --shift-to-zero=false --frame-shift=$frame_shift_subsampled $data_dir/split$nj/JOB/segments ark:- |" \ + "ark,t:$data_dir/split$nj/JOB/reco2utt" ark:- \| \ + segmentation-copy --frame-subsampling-factor=$frame_subsampling_factor \ + ark:- "ark:| gzip -c > $dir/orig_segmentation.JOB.gz" +fi + +echo $nj > $dir/num_jobs + +if [ $stage -le 1 ]; then + rm -r $segmented_data_dir || true + utils/data/convert_data_dir_to_whole.sh $data_dir $segmented_data_dir || exit 1 + rm $segmented_data_dir/text || true +fi + +steps/segmentation/internal/post_process_segments.sh \ + --stage $stage --cmd "$cmd" \ + --config $segmentation_config --frame-shift $frame_shift \ + $data_dir $dir $segmented_data_dir + +utils/utt2spk_to_spk2utt.pl $segmented_data_dir/utt2spk > $segmented_data_dir/spk2utt || exit 1 +utils/fix_data_dir.sh $segmented_data_dir + +if [ ! -s $segmented_data_dir/utt2spk ] || [ ! -s $segmented_data_dir/segments ]; then + echo "$0: Segmentation failed to generate segments or utt2spk!" + exit 1 +fi + diff --git a/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh b/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh new file mode 100644 index 00000000000..8cfcaa40cda --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/post_process_sad_to_subsegments.sh @@ -0,0 +1,69 @@ +#! /bin/bash + +# Copyright 2015 Vimal Manohar +# Apache 2.0. + +set -e -o pipefail -u +. path.sh + +cmd=run.pl +stage=-10 + +segmentation_config=conf/segmentation.conf +nj=18 + +frame_shift=0.01 + +. utils/parse_options.sh + +if [ $# -ne 5 ]; then + echo "Usage: $0 " + echo " e.g.: $0 data/dev_aspire_whole exp/vad_dev_aspire data/dev_aspire_seg" + exit 1 +fi + +data_dir=$1 +phone2sad_map=$2 +vad_dir=$3 +dir=$4 +segmented_data_dir=$5 + +mkdir -p $dir + +nj=`cat $vad_dir/num_jobs` || exit 1 + +utils/split_data.sh $data_dir $nj + +if [ $stage -le 0 ]; then + # Convert the original SAD into segmentation + $cmd JOB=1:$nj $dir/log/segmentation.JOB.log \ + segmentation-init-from-ali \ + "ark:gunzip -c $vad_dir/ali.JOB.gz |" ark:- \| \ + segmentation-copy --label-map=$phone2sad_map ark:- \ + "ark:| gzip -c > $dir/orig_segmentation.JOB.gz" +fi + +echo $nj > $dir/num_jobs + +if [ $stage -le 1 ]; then + rm -r $segmented_data_dir || true + utils/data/convert_data_dir_to_whole.sh $data_dir $segmented_data_dir || exit 1 + rm $segmented_data_dir/text || true +fi + +steps/segmentation/internal/post_process_segments.sh \ + --stage $stage --cmd "$cmd" \ + --config $segmentation_config --frame-shift $frame_shift \ + $data_dir $dir $segmented_data_dir + +mv $segmented_data_dir/segments $segmented_data_dir/sub_segments +utils/data/subsegment_data_dir.sh $data_dir $segmented_data_dir/sub_segments $segmented_data_dir + +utils/utt2spk_to_spk2utt.pl $segmented_data_dir/utt2spk > $segmented_data_dir/spk2utt || exit 1 +utils/fix_data_dir.sh $segmented_data_dir + +if [ ! -s $segmented_data_dir/utt2spk ] || [ ! -s $segmented_data_dir/segments ]; then + echo "$0: Segmentation failed to generate segments or utt2spk!" + exit 1 +fi + diff --git a/egs/wsj/s5/steps/segmentation/quantize_vector.pl b/egs/wsj/s5/steps/segmentation/quantize_vector.pl new file mode 100755 index 00000000000..0bccebade4c --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/quantize_vector.pl @@ -0,0 +1,28 @@ +#!/usr/bin/perl + +# This script convert per-frame speech probabilities into +# 0-1 labels. + +@ARGV <= 1 or die "Usage: quantize_vector.pl [threshold]"; + +my $t = 0.5; + +if (scalar @ARGV == 1) { + $t = $ARGV[0]; +} + +while () { + chomp; + my @F = split; + + my $str = "$F[0]"; + for (my $i = 2; $i < $#F; $i++) { + if ($F[$i] >= $t) { + $str = "$str 1"; + } else { + $str = "$str 0"; + } + } + + print ("$str\n"); +} diff --git a/egs/wsj/s5/steps/segmentation/split_data_on_reco.sh b/egs/wsj/s5/steps/segmentation/split_data_on_reco.sh new file mode 100755 index 00000000000..4c167d99a1e --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/split_data_on_reco.sh @@ -0,0 +1,29 @@ +#! /bin/bash + +set -e + +if [ $# -ne 3 ]; then + echo "Usage: split_data_on_reco.sh " + exit 1 +fi + +ref_data=$1 +data=$2 +nj=$3 + +utils/data/get_reco2utt.sh $ref_data +utils/data/get_reco2utt.sh $data + +utils/split_data.sh --per-reco $ref_data $nj + +for n in `seq $nj`; do + srn=$ref_data/split${nj}reco/$n + dsn=$data/split${nj}reco/$n + + mkdir -p $dsn + + utils/data/get_reco2utt.sh $srn + utils/filter_scp.pl $srn/reco2utt $data/reco2utt > $dsn/reco2utt + utils/spk2utt_to_utt2spk.pl $dsn/reco2utt > $dsn/utt2reco + utils/subset_data_dir.sh --utt-list $dsn/utt2reco $data $dsn +done diff --git a/egs/wsj/s5/steps/segmentation/vector_get_max.pl b/egs/wsj/s5/steps/segmentation/vector_get_max.pl new file mode 100644 index 00000000000..abb8ea977a2 --- /dev/null +++ b/egs/wsj/s5/steps/segmentation/vector_get_max.pl @@ -0,0 +1,26 @@ +#! /usr/bin/perl + +use warnings; +use strict; + +while (<>) { + chomp; + if (m/^\S+\s+\[.+\]\s*$/) { + my @F = split; + my $utt = shift @F; + shift; + + my $max_id = 0; + my $max = $F[0]; + for (my $i = 1; $i < $#F; $i++) { + if ($F[$i] > $max) { + $max_id = $i; + $max = $F[$i]; + } + } + + print "$utt $max_id\n"; + } else { + die "Invalid line $_\n"; + } +} diff --git a/egs/wsj/s5/utils/copy_data_dir.sh b/egs/wsj/s5/utils/copy_data_dir.sh index 008233daf62..222bc708527 100755 --- a/egs/wsj/s5/utils/copy_data_dir.sh +++ b/egs/wsj/s5/utils/copy_data_dir.sh @@ -83,15 +83,16 @@ fi if [ -f $srcdir/segments ]; then utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments cp $srcdir/wav.scp $destdir - if [ -f $srcdir/reco2file_and_channel ]; then - cp $srcdir/reco2file_and_channel $destdir/ - fi else # no segments->wav indexed by utt. if [ -f $srcdir/wav.scp ]; then utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp fi fi +if [ -f $srcdir/reco2file_and_channel ]; then + cp $srcdir/reco2file_and_channel $destdir/ +fi + if [ -f $srcdir/text ]; then utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text fi diff --git a/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh b/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh new file mode 100755 index 00000000000..f55f60c4774 --- /dev/null +++ b/egs/wsj/s5/utils/data/convert_data_dir_to_whole.sh @@ -0,0 +1,108 @@ +#! /bin/bash + +# This scripts converts a data directory into a "whole" data directory +# by removing the segments and using the recordings themselves as +# utterances + +set -o pipefail + +. path.sh + +cmd=run.pl +stage=-1 + +. parse_options.sh + +if [ $# -ne 2 ]; then + echo "Usage: convert_data_dir_to_whole.sh " + echo " e.g.: convert_data_dir_to_whole.sh data/dev data/dev_whole" + exit 1 +fi + +data=$1 +dir=$2 + +if [ ! -f $data/segments ]; then + # Data directory already does not contain segments. So just copy it. + utils/copy_data_dir.sh $data $dir + exit 0 +fi + +mkdir -p $dir +cp $data/wav.scp $dir +cp $data/reco2file_and_channel $dir +rm -f $dir/{utt2spk,text} || true + +[ -f $data/stm ] && cp $data/stm $dir +[ -f $data/glm ] && cp $data/glm $dir + +text_files= +[ -f $data/text ] && text_files="$data/text $dir/text" + +# Combine utt2spk and text from the segments into utt2spk and text for the whole +# recording. +cat $data/segments | perl -e ' +if (scalar @ARGV == 4) { + ($utt2spk_in, $utt2spk_out, $text_in, $text_out) = @ARGV; +} elsif (scalar @ARGV == 2) { + ($utt2spk_in, $utt2spk_out) = @ARGV; +} else { + die "Unexpected number of arguments"; +} + +if (defined $text_in) { + open(TI, "<$text_in") || die "Error: fail to open $text_in\n"; + open(TO, ">$text_out") || die "Error: fail to open $text_out\n"; +} +open(UI, "<$utt2spk_in") || die "Error: fail to open $utt2spk_in\n"; +open(UO, ">$utt2spk_out") || die "Error: fail to open $utt2spk_out\n"; + +my %file2utt = (); +while () { + chomp; + my @col = split; + @col >= 4 or die "bad line $_\n"; + + if (! defined $file2utt{$col[1]}) { + $file2utt{$col[1]} = []; + } + push @{$file2utt{$col[1]}}, $col[0]; +} + +my %text = (); +my %utt2spk = (); + +while () { + chomp; + my @col = split; + $utt2spk{$col[0]} = $col[1]; +} + +if (defined $text_in) { + while () { + chomp; + my @col = split; + @col >= 1 or die "bad line $_\n"; + + my $utt = shift @col; + $text{$utt} = join(" ", @col); + } +} + +foreach $file (keys %file2utt) { + my @utts = @{$file2utt{$file}}; + #print STDERR $file . " " . join(" ", @utts) . "\n"; + print UO "$file $file\n"; + + if (defined $text_in) { + $text_line = ""; + print TO "$file $text_line\n"; + } +} +' $data/utt2spk $dir/utt2spk $text_files + +sort -u $dir/utt2spk > $dir/utt2spk.tmp +mv $dir/utt2spk.tmp $dir/utt2spk +utils/utt2spk_to_spk2utt.pl $dir/utt2spk > $dir/spk2utt + +utils/fix_data_dir.sh $dir diff --git a/egs/wsj/s5/utils/data/fix_subsegmented_feats.pl b/egs/wsj/s5/utils/data/fix_subsegmented_feats.pl new file mode 100755 index 00000000000..bd8aeb8e409 --- /dev/null +++ b/egs/wsj/s5/utils/data/fix_subsegmented_feats.pl @@ -0,0 +1,79 @@ +#!/usr/bin/env perl + +# Copyright 2016 Vimal Manohar +# Apache 2.0. + +use warnings; + +# This script modifies the feats ranges and ensures that they don't +# exceed the max number of frames supplied in utt2max_frames. +# utt2max_frames can be computed by using +# steps/segmentation/get_reco2num_frames.sh +# cut -d ' ' -f 1,2 /segments | utils/apply_map.pl -f 2 /reco2num_frames > /utt2max_frames + +(scalar @ARGV == 1) or die "Usage: fix_subsegmented_feats.pl "; + +my $utt2max_frames_file = $ARGV[0]; + +open MAX_FRAMES, $utt2max_frames_file or die "fix_subsegmented_feats.pl: Could not open file $utt2max_frames_file"; + +my %utt2max_frames; + +while () { + chomp; + my @F = split; + + (scalar @F == 2) or die "fix_subsegmented_feats.pl: Invalid line $_ in $utt2max_frames_file"; + + $utt2max_frames{$F[0]} = $F[1]; +} + +while () { + my $line = $_; + + if (m/\[([^][]*)\]\[([^][]*)\]\s*$/) { + print ("fix_subsegmented_feats.pl: this script only supports single indices"); + exit(1); + } + + my $before_range = ""; + my $range = ""; + + if (m/^(.*)\[([^][]*)\]\s*$/) { + $before_range = $1; + $range = $2; + } else { + print; + next; + } + + my @F = split(/ /, $before_range); + my $utt = shift @F; + defined $utt2max_frames{$utt} or die "fix_subsegmented_feats.pl: Could not find key $utt in $utt2num_frames_file.\nError with line $line"; + + if ($range !~ m/^(\d*):(\d*)([,]?.*)$/) { + print STDERR "fix_subsegmented_feats.pl: could not make sense of input line $_"; + exit(1); + } + + my $row_start = $1; + my $row_end = $2; + my $col_range = $3; + + if ($row_end >= $utt2max_frames{$utt}) { + print STDERR "Fixed row_end for $utt from $row_end to $utt2max_frames{$utt}-1\n"; + $row_end = $utt2max_frames{$utt} - 1; + } + + if ($row_start ne "") { + $range = "$row_start:$row_end"; + } else { + $range = ""; + } + + if ($col_range ne "") { + $range .= ",$col_range"; + } + + print ("$utt " . join(" ", @F) . "[" . $range . "]\n"); +} diff --git a/egs/wsj/s5/utils/data/get_dct_matrix.py b/egs/wsj/s5/utils/data/get_dct_matrix.py new file mode 100755 index 00000000000..88b28b5dd5c --- /dev/null +++ b/egs/wsj/s5/utils/data/get_dct_matrix.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python + +# we're using python 3.x style print but want it to work in python 2.x, +from __future__ import print_function +import os, argparse, sys, math, warnings + +import numpy as np + +def ComputeLifterCoeffs(Q, dim): + coeffs = np.zeros((dim)) + for i in range(0, dim): + coeffs[i] = 1.0 + 0.5 * Q * math.sin(math.pi * i / Q); + + return coeffs + +def ComputeIDctMatrix(K, N, cepstral_lifter=0): + matrix = np.zeros((K, N)) + # normalizer for X_0 + normalizer = math.sqrt(1.0 / N); + for j in range(0, N): + matrix[0, j] = normalizer; + # normalizer for other elements + normalizer = math.sqrt(2.0 / N); + for k in range(1, K): + for n in range(0, N): + matrix[k, n] = normalizer * math.cos(math.pi/N * (n + 0.5) * k); + + if cepstral_lifter != 0: + lifter_coeffs = ComputeLifterCoeffs(cepstral_lifter, K) + for k in range(0, K): + matrix[k, :] = matrix[k, :] / lifter_coeffs[k]; + + return matrix.T + +def ComputeDctMatrix(K, N, cepstral_lifter=0): + matrix = np.zeros((K, N)) + # normalizer for X_0 + normalizer = math.sqrt(1.0 / N); + for j in range(0, N): + matrix[0, j] = normalizer; + # normalizer for other elements + normalizer = math.sqrt(2.0 / N); + for k in range(1, K): + for n in range(0, N): + matrix[k, n] = normalizer * math.cos(math.pi/N * (n + 0.5) * k); + + if cepstral_lifter != 0: + lifter_coeffs = ComputeLifterCoeffs(cepstral_lifter, K) + for k in range(0, K): + matrix[k, :] = matrix[k, :] * lifter_coeffs[k]; + + return matrix + +def GetArgs(): + parser = argparse.ArgumentParser(description="Write DCT/IDCT matrix") + parser.add_argument("--cepstral-lifter", type=float, + help="Here we need the scaling factor on cepstra in the production of MFCC" + "to cancel out the effect of lifter, e.g. 22.0", default=22.0) + parser.add_argument("--num-ceps", type=int, + default=13, + help="Number of cepstral dimensions") + parser.add_argument("--num-filters", type=int, + default=23, + help="Number of mel filters") + parser.add_argument("--get-idct-matrix", type=str, default="false", + choices=["true","false"], + help="Get IDCT matrix instead of DCT matrix") + parser.add_argument("--add-zero-column", type=str, default="true", + choices=["true","false"], + help="Add a column to convert the matrix from a linear transform to affine transform") + parser.add_argument("out_file", type=str, + help="Output file") + + args = parser.parse_args() + + return args + +def CheckArgs(args): + if args.num_ceps > args.num_filters: + raise Exception("num-ceps must not be larger than num-filters") + + args.out_file_handle = open(args.out_file, 'w') + + return args + +def Main(): + args = GetArgs() + args = CheckArgs(args) + + if args.get_idct_matrix == "false": + matrix = ComputeDctMatrix(args.num_ceps, args.num_filters, + args.cepstral_lifter) + if args.add_zero_column == "true": + matrix = np.append(matrix, np.zeros((args.num_ceps,1)), 1) + else: + matrix = ComputeIDctMatrix(args.num_ceps, args.num_filters, + args.cepstral_lifter) + + if args.add_zero_column == "true": + matrix = np.append(matrix, np.zeros((args.num_filters,1)), 1) + + print('[ ', file=args.out_file_handle) + np.savetxt(args.out_file_handle, matrix, fmt='%.6e') + print(' ]', file=args.out_file_handle) + +if __name__ == "__main__": + Main() + diff --git a/egs/wsj/s5/utils/data/get_frame_shift.sh b/egs/wsj/s5/utils/data/get_frame_shift.sh index d032c9c17fa..f5a3bac9009 100755 --- a/egs/wsj/s5/utils/data/get_frame_shift.sh +++ b/egs/wsj/s5/utils/data/get_frame_shift.sh @@ -38,23 +38,27 @@ if [ ! -s $dir/utt2dur ]; then utils/data/get_utt2dur.sh $dir 1>&2 fi -if [ ! -f $dir/feats.scp ]; then - echo "$0: $dir/feats.scp does not exist" 1>&2 - exit 1 -fi +if [ ! -f $dir/frame_shift ]; then + if [ ! -f $dir/feats.scp ]; then + echo "$0: $dir/feats.scp does not exist" 1>&2 + exit 1 + fi -temp=$(mktemp /tmp/tmp.XXXX) + temp=$(mktemp /tmp/tmp.XXXX) -feat-to-len "scp:head -n 10 $dir/feats.scp|" ark,t:- > $temp + feat-to-len "scp:head -n 10 $dir/feats.scp|" ark,t:- > $temp -if [ -z $temp ]; then - echo "$0: error running feat-to-len" 1>&2 - exit 1 -fi + if [ -z $temp ]; then + echo "$0: error running feat-to-len" 1>&2 + exit 1 + fi -head -n 10 $dir/utt2dur | paste - $temp | \ - awk '{ dur += $2; frames += $4; } END { shift = dur / frames; if (shift > 0.01 && shift < 0.0102) shift = 0.01; print shift; }' || exit 1; + frame_shift=$(head -n 10 $dir/utt2dur | paste - $temp | awk '{ dur += $2; frames += $4; } END { shift = dur / frames; if (shift > 0.01 && shift < 0.0102) shift = 0.01; print shift; }') || exit 1; + + echo $frame_shift > $dir/frame_shift + rm $temp +fi -rm $temp +cat $dir/frame_shift exit 0 diff --git a/egs/wsj/s5/utils/data/get_reco2dur.sh b/egs/wsj/s5/utils/data/get_reco2dur.sh new file mode 100755 index 00000000000..7d2ccb71769 --- /dev/null +++ b/egs/wsj/s5/utils/data/get_reco2dur.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# Copyright 2016 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0 + +# This script operates on a data directory, such as in data/train/, and adds the +# reco2dur file if it does not already exist. The file 'reco2dur' maps from +# utterance to the duration of the utterance in seconds. This script works it +# out from the 'segments' file, or, if not present, from the wav.scp file (it +# first tries interrogating the headers, and if this fails, it reads the wave +# files in entirely.) + +frame_shift=0.01 + +. utils/parse_options.sh +. ./path.sh + +if [ $# != 1 ]; then + echo "Usage: $0 [options] " + echo "e.g.:" + echo " $0 data/train" + echo " Options:" + echo " --frame-shift # frame shift in seconds. Only relevant when we are" + echo " # getting duration from feats.scp (default: 0.01). " + exit 1 +fi + +export LC_ALL=C + +data=$1 + +if [ -s $data/reco2dur ] && \ + [ $(cat $data/wav.scp | wc -l) -eq $(cat $data/reco2dur | wc -l) ]; then + echo "$0: $data/reco2dur already exists with the expected length. We won't recompute it." + exit 0; +fi + +# if the wav.scp contains only lines of the form +# utt1 /foo/bar/sph2pipe -f wav /baz/foo.sph | +if cat $data/wav.scp | perl -e ' + while (<>) { s/\|\s*$/ |/; # make sure final | is preceded by space. + @A = split; if (!($#A == 5 && $A[1] =~ m/sph2pipe$/ && + $A[2] eq "-f" && $A[3] eq "wav" && $A[5] eq "|")) { exit(1); } + $utt = $A[0]; $sphere_file = $A[4]; + + if (!open(F, "<$sphere_file")) { die "Error opening sphere file $sphere_file"; } + $sample_rate = -1; $sample_count = -1; + for ($n = 0; $n <= 30; $n++) { + $line = ; + if ($line =~ m/sample_rate -i (\d+)/) { $sample_rate = $1; } + if ($line =~ m/sample_count -i (\d+)/) { $sample_count = $1; } + if ($line =~ m/end_head/) { break; } + } + close(F); + if ($sample_rate == -1 || $sample_count == -1) { + die "could not parse sphere header from $sphere_file"; + } + $duration = $sample_count * 1.0 / $sample_rate; + print "$utt $duration\n"; + } ' > $data/reco2dur; then + echo "$0: successfully obtained utterance lengths from sphere-file headers" +else + echo "$0: could not get utterance lengths from sphere-file headers, using wav-to-duration" + if ! command -v wav-to-duration >/dev/null; then + echo "$0: wav-to-duration is not on your path" + exit 1; + fi + + read_entire_file=false + if cat $data/wav.scp | grep -q 'sox.*speed'; then + read_entire_file=true + echo "$0: reading from the entire wav file to fix the problem caused by sox commands with speed perturbation. It is going to be slow." + echo "... It is much faster if you call get_reco2dur.sh *before* doing the speed perturbation via e.g. perturb_data_dir_speed.sh or " + echo "... perturb_data_dir_speed_3way.sh." + fi + + if ! wav-to-duration --read-entire-file=$read_entire_file scp:$data/wav.scp ark,t:$data/reco2dur 2>&1 | grep -v 'nonzero return status'; then + echo "$0: there was a problem getting the durations; moving $data/reco2dur to $data/.backup/" + mkdir -p $data/.backup/ + mv $data/reco2dur $data/.backup/ + fi +fi + +echo "$0: computed $data/reco2dur" + +exit 0 + diff --git a/egs/wsj/s5/utils/data/get_reco2num_frames.sh b/egs/wsj/s5/utils/data/get_reco2num_frames.sh new file mode 100755 index 00000000000..03ab7b40616 --- /dev/null +++ b/egs/wsj/s5/utils/data/get_reco2num_frames.sh @@ -0,0 +1,28 @@ +#! /bin/bash + +cmd=run.pl +nj=4 + +frame_shift=0.01 +frame_overlap=0.015 + +. utils/parse_options.sh + +if [ $# -ne 1 ]; then + echo "Usage: $0 " + exit 1 +fi + +data=$1 + +if [ -f $data/reco2num_frames ]; then + echo "$0: $data/reco2num_frames already present!" + exit 0; +fi + +utils/data/get_reco2dur.sh $data +awk -v fs=$frame_shift -v fovlp=$frame_overlap \ + '{print $1" "int( ($2 - fovlp) / fs)}' $data/reco2dur > $data/reco2num_frames + +echo "$0: Computed and wrote $data/reco2num_frames" + diff --git a/egs/wsj/s5/utils/data/get_reco2utt.sh b/egs/wsj/s5/utils/data/get_reco2utt.sh new file mode 100755 index 00000000000..6c30f812cfe --- /dev/null +++ b/egs/wsj/s5/utils/data/get_reco2utt.sh @@ -0,0 +1,21 @@ +#! /bin/bash + +# Copyright 2016 Vimal Manohar +# Apache 2.0 + +if [ $# -ne 1 ]; then + echo "This script creates a reco2utt file in the data directory, " + echo "which is analogous to spk2utt file but with the first column " + echo "as recording instead of speaker." + echo "Usage: get_reco2utt.sh " + echo " e.g.: get_reco2utt.sh data/train" + exit 1 +fi + +data=$1 + +if [ ! -s $data/segments ]; then + utils/data/get_segments_for_data.sh $data > $data/segments +fi + +cut -d ' ' -f 1,2 $data/segments | utils/utt2spk_to_spk2utt.pl > $data/reco2utt diff --git a/egs/wsj/s5/utils/data/get_segments_for_data.sh b/egs/wsj/s5/utils/data/get_segments_for_data.sh index 694acc6a256..7adc4c465d3 100755 --- a/egs/wsj/s5/utils/data/get_segments_for_data.sh +++ b/egs/wsj/s5/utils/data/get_segments_for_data.sh @@ -19,7 +19,7 @@ fi data=$1 -if [ ! -f $data/utt2dur ]; then +if [ ! -s $data/utt2dur ]; then utils/data/get_utt2dur.sh $data 1>&2 || exit 1; fi diff --git a/egs/wsj/s5/utils/data/get_subsegment_feats.sh b/egs/wsj/s5/utils/data/get_subsegment_feats.sh new file mode 100755 index 00000000000..6baba68eedd --- /dev/null +++ b/egs/wsj/s5/utils/data/get_subsegment_feats.sh @@ -0,0 +1,46 @@ +#! /bin/bash + +# Copyright 2016 Johns Hopkins University (Author: Dan Povey) +# 2016 Vimal Manohar +# Apache 2.0. + +if [ $# -ne 4 ]; then + echo "This scripts gets subsegmented_feats (by adding ranges to data/feats.scp) " + echo "for the subsegments file. This is does one part of the " + echo "functionality in subsegment_data_dir.sh, which additionally " + echo "creates a new subsegmented data directory." + echo "Usage: $0 " + echo " e.g.: $0 data/train/feats.scp 0.01 0.015 subsegments" + exit 1 +fi + +feats=$1 +frame_shift=$2 +frame_overlap=$3 +subsegments=$4 + +# The subsegments format is . +# e.g. 'utt_foo-1 utt_foo 7.21 8.93' +# The first awk command replaces this with the format: +# +# e.g. 'utt_foo-1 utt_foo 721 893' +# and the apply_map.pl command replaces 'utt_foo' (the 2nd field) with its corresponding entry +# from the original wav.scp, so we get a line like: +# e.g. 'utt_foo-1 foo-bar.ark:514231 721 892' +# Note: the reason we subtract one from the last time is that it's going to +# represent the 'last' frame, not the 'end' frame [i.e. not one past the last], +# in the matlab-like, but zero-indexed [first:last] notion. For instance, a segment with 1 frame +# would have start-time 0.00 and end-time 0.01, which would become the frame range +# [0:0] +# The second awk command turns this into something like +# utt_foo-1 foo-bar.ark:514231[721:892] +# It has to be a bit careful because the format actually allows for more general things +# like pipes that might contain spaces, so it has to be able to produce output like the +# following: +# utt_foo-1 some command|[721:892] +# Lastly, utils/data/normalize_data_range.pl will only do something nontrivial if +# the original data-dir already had data-ranges in square brackets. +awk -v s=$frame_shift -v fovlp=$frame_overlap '{print $1, $2, int(($3/s)+0.5), int(($4-fovlp)/s+0.5);}' <$subsegments| \ + utils/apply_map.pl -f 2 $feats | \ + awk '{p=NF-1; for (n=1;n $data/utt2dur elif [ -f $data/wav.scp ]; then diff --git a/egs/wsj/s5/utils/data/get_utt2num_frames.sh b/egs/wsj/s5/utils/data/get_utt2num_frames.sh new file mode 100755 index 00000000000..e2921601ec9 --- /dev/null +++ b/egs/wsj/s5/utils/data/get_utt2num_frames.sh @@ -0,0 +1,42 @@ +#! /bin/bash + +cmd=run.pl +nj=4 + +frame_shift=0.01 +frame_overlap=0.015 + +. utils/parse_options.sh + +if [ $# -ne 1 ]; then + echo "This script writes a file utt2num_frames with the " + echo "number of frames in each utterance as measured based on the " + echo "duration of the utterances (in utt2dur) and the specified " + echo "frame_shift and frame_overlap." + echo "Usage: $0 " + exit 1 +fi + +data=$1 + +if [ -f $data/utt2num_frames ]; then + echo "$0: $data/utt2num_frames already present!" + exit 0; +fi + +if [ ! -f $data/feats.scp ]; then + utils/data/get_utt2dur.sh $data + awk -v fs=$frame_shift -v fovlp=$frame_overlap \ + '{print $1" "int( ($2 - fovlp) / fs)}' $data/utt2dur > $data/utt2num_frames + exit 0 +fi + +utils/split_data.sh $data $nj || exit 1 +$cmd JOB=1:$nj $data/log/get_utt2num_frames.JOB.log \ + feat-to-len scp:$data/split${nj}/JOB/feats.scp ark,t:$data/split$nj/JOB/utt2num_frames || exit 1 + +for n in `seq $nj`; do + cat $data/split$nj/$n/utt2num_frames +done > $data/utt2num_frames + +echo "$0: Computed and wrote $data/utt2num_frames" diff --git a/egs/wsj/s5/utils/data/modify_speaker_info.sh b/egs/wsj/s5/utils/data/modify_speaker_info.sh index f75e9be5f67..e42f0df551d 100755 --- a/egs/wsj/s5/utils/data/modify_speaker_info.sh +++ b/egs/wsj/s5/utils/data/modify_speaker_info.sh @@ -37,6 +37,7 @@ utts_per_spk_max=-1 seconds_per_spk_max=-1 respect_speaker_info=true +respect_recording_info=true # end configuration section . utils/parse_options.sh @@ -93,10 +94,26 @@ else utt2dur_opt= fi -utils/data/internal/modify_speaker_info.py \ - $utt2dur_opt --respect-speaker-info=$respect_speaker_info \ - --utts-per-spk-max=$utts_per_spk_max --seconds-per-spk-max=$seconds_per_spk_max \ - <$srcdir/utt2spk >$destdir/utt2spk +if ! $respect_speaker_info && $respect_recording_info; then + if [ -f $srcdir/segments ]; then + cat $srcdir/segments | awk '{print $1" "$2}' | \ + utils/data/internal/modify_speaker_info.py \ + $utt2dur_opt --respect-speaker-info=true \ + --utts-per-spk-max=$utts_per_spk_max --seconds-per-spk-max=$seconds_per_spk_max \ + >$destdir/utt2spk + else + cat $srcdir/wav.scp | awk '{print $1" "$2}' | \ + utils/data/internal/modify_speaker_info.py \ + $utt2dur_opt --respect-speaker-info=true \ + --utts-per-spk-max=$utts_per_spk_max --seconds-per-spk-max=$seconds_per_spk_max \ + >$destdir/utt2spk + fi +else + utils/data/internal/modify_speaker_info.py \ + $utt2dur_opt --respect-speaker-info=$respect_speaker_info \ + --utts-per-spk-max=$utts_per_spk_max --seconds-per-spk-max=$seconds_per_spk_max \ + <$srcdir/utt2spk >$destdir/utt2spk +fi utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt diff --git a/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh b/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh index c575166534e..4b12a94eee9 100755 --- a/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh +++ b/egs/wsj/s5/utils/data/perturb_data_dir_speed_3way.sh @@ -43,5 +43,9 @@ utils/data/combine_data.sh $destdir ${srcdir} ${destdir}_speed0.9 ${destdir}_spe rm -r ${destdir}_speed0.9 ${destdir}_speed1.1 echo "$0: generated 3-way speed-perturbed version of data in $srcdir, in $destdir" -utils/validate_data_dir.sh --no-feats $destdir +if [ -f $srcdir/text ]; then + utils/validate_data_dir.sh --no-feats $destdir +else + utils/validate_data_dir.sh --no-feats --no-text $destdir +fi diff --git a/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh b/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh index bc76939643c..185c7abf426 100755 --- a/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh +++ b/egs/wsj/s5/utils/data/perturb_data_dir_volume.sh @@ -7,6 +7,11 @@ # the wav.scp to perturb the volume (typically useful for training data when # using systems that don't have cepstral mean normalization). +reco2vol= +force=false +scale_low=0.125 +scale_high=2 + . utils/parse_options.sh if [ $# != 1 ]; then @@ -25,30 +30,67 @@ if [ ! -f $data/wav.scp ]; then exit 1 fi -if grep -q "sox --vol" $data/wav.scp; then +if ! $force && grep -q "sox --vol" $data/wav.scp; then echo "$0: It looks like the data was already volume perturbed. Not doing anything." exit 0 fi -cat $data/wav.scp | python -c " +if [ -z "$reco2vol" ]; then + cat $data/wav.scp | python -c " import sys, os, subprocess, re, random random.seed(0) -scale_low = 1.0/8 -scale_high = 2.0 +scale_low = $scale_low +scale_high = $scale_high +volume_writer = open('$data/reco2vol', 'w') +for line in sys.stdin.readlines(): + if len(line.strip()) == 0: + continue + # Handle three cases of rxfilenames appropriately; 'input piped command', 'file offset' and 'filename' + vol = random.uniform(scale_low, scale_high) + + parts = line.strip().split() + if line.strip()[-1] == '|': + print '{0} sox --vol {1} -t wav - -t wav - |'.format(line.strip(), vol) + elif re.search(':[0-9]+$', line.strip()) is not None: + print '{id} wav-copy {wav} - | sox --vol {vol} -t wav - -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = vol) + else: + print '{id} sox --vol {vol} -t wav {wav} -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = vol) + volume_writer.write('{id} {vol}\n'.format(id = parts[0], vol = vol)) +" > $data/wav.scp_scaled || exit 1; +else + cat $data/wav.scp | python -c " +import sys, os, subprocess, re +volumes = {} +for line in open('$reco2vol'): + if len(line.strip()) == 0: + continue + parts = line.strip().split() + volumes[parts[0]] = float(parts[1]) + for line in sys.stdin.readlines(): if len(line.strip()) == 0: continue # Handle three cases of rxfilenames appropriately; 'input piped command', 'file offset' and 'filename' + + parts = line.strip().split() + id = parts[0] + + if id not in volumes: + raise Exception('Could not find volume for id {id}'.format(id = id)) + + vol = volumes[id] + if line.strip()[-1] == '|': - print '{0} sox --vol {1} -t wav - -t wav - |'.format(line.strip(), random.uniform(scale_low, scale_high)) + print '{0} sox --vol {1} -t wav - -t wav - |'.format(line.strip(), vol) elif re.search(':[0-9]+$', line.strip()) is not None: - parts = line.split() - print '{id} wav-copy {wav} - | sox --vol {vol} -t wav - -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = random.uniform(scale_low, scale_high)) + print '{id} wav-copy {wav} - | sox --vol {vol} -t wav - -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = vol) else: - parts = line.split() - print '{id} sox --vol {vol} -t wav {wav} -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = random.uniform(scale_low, scale_high)) + print '{id} sox --vol {vol} -t wav {wav} -t wav - |'.format(id = parts[0], wav=' '.join(parts[1:]), vol = vol) " > $data/wav.scp_scaled || exit 1; + cp $reco2vol $data/reco2vol +fi + len1=$(cat $data/wav.scp | wc -l) len2=$(cat $data/wav.scp_scaled | wc -l) if [ "$len1" != "$len2" ]; then diff --git a/egs/wsj/s5/utils/data/subsegment_data_dir.sh b/egs/wsj/s5/utils/data/subsegment_data_dir.sh index 18a00c3df7d..b018d5ec94a 100755 --- a/egs/wsj/s5/utils/data/subsegment_data_dir.sh +++ b/egs/wsj/s5/utils/data/subsegment_data_dir.sh @@ -24,14 +24,15 @@ segment_end_padding=0.0 . utils/parse_options.sh -if [ $# != 4 ]; then +if [ $# != 4 ] && [ $# != 3 ]; then echo "Usage: " - echo " $0 [options] " + echo " $0 [options] [] " echo "This script sub-segments a data directory. is to" echo "have lines of the form " echo "and is of the form ... ." echo "This script appropriately combines the with the original" echo "segments file, if necessary, and if not, creates a segments file." + echo " is an optional argument." echo "e.g.:" echo " $0 data/train [options] exp/tri3b_resegment/segments exp/tri3b_resegment/text data/train_resegmented" echo " Options:" @@ -50,11 +51,23 @@ export LC_ALL=C srcdir=$1 subsegments=$2 -new_text=$3 -dir=$4 +no_text=true +if [ $# -eq 4 ]; then + new_text=$3 + dir=$4 + no_text=false -for f in "$subsegments" "$new_text" "$srcdir/utt2spk"; do + if [ ! -f "$new_text" ]; then + echo "$0: no such file $new_text" + exit 1 + fi + +else + dir=$3 +fi + +for f in "$subsegments" "$srcdir/utt2spk"; do if [ ! -f "$f" ]; then echo "$0: no such file $f" exit 1; @@ -65,9 +78,11 @@ if ! mkdir -p $dir; then echo "$0: failed to create directory $dir" fi -if ! cmp <(awk '{print $1}' <$subsegments) <(awk '{print $1}' <$new_text); then - echo "$0: expected the first fields of the files $subsegments and $new_text to be identical" - exit 1 +if ! $no_text; then + if ! cmp <(awk '{print $1}' <$subsegments) <(awk '{print $1}' <$new_text); then + echo "$0: expected the first fields of the files $subsegments and $new_text to be identical" + exit 1 + fi fi # create the utt2spk in $dir @@ -86,8 +101,11 @@ awk '{print $1, $2}' < $subsegments > $dir/new2old_utt utils/apply_map.pl -f 2 $srcdir/utt2spk < $dir/new2old_utt >$dir/utt2spk # .. and the new spk2utt file. utils/utt2spk_to_spk2utt.pl <$dir/utt2spk >$dir/spk2utt -# the new text file is just what the user provides. -cp $new_text $dir/text + +if ! $no_text; then + # the new text file is just what the user provides. + cp $new_text $dir/text +fi # copy the source wav.scp cp $srcdir/wav.scp $dir diff --git a/egs/wsj/s5/utils/fix_data_dir.sh b/egs/wsj/s5/utils/fix_data_dir.sh index 0333d628544..33e710a605f 100755 --- a/egs/wsj/s5/utils/fix_data_dir.sh +++ b/egs/wsj/s5/utils/fix_data_dir.sh @@ -6,6 +6,11 @@ # It puts the original contents of data-dir into # data-dir/.backup +utt_extra_files= +spk_extra_files= + +. utils/parse_options.sh + if [ $# != 1 ]; then echo "Usage: utils/data/fix_data_dir.sh " echo "e.g.: utils/data/fix_data_dir.sh data/train" @@ -110,7 +115,7 @@ function filter_speakers { filter_file $tmpdir/speakers $data/spk2utt utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk - for s in cmvn.scp spk2gender; do + for s in cmvn.scp spk2gender $spk_extra_files; do f=$data/$s if [ -f $f ]; then filter_file $tmpdir/speakers $f @@ -158,7 +163,7 @@ function filter_utts { fi fi - for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav; do + for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $utt_extra_files; do if [ -f $data/$x ]; then cp $data/$x $data/.backup/$x if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then diff --git a/egs/wsj/s5/utils/perturb_data_dir_speed.sh b/egs/wsj/s5/utils/perturb_data_dir_speed.sh index 20ff86755eb..e3d56d58b9c 100755 --- a/egs/wsj/s5/utils/perturb_data_dir_speed.sh +++ b/egs/wsj/s5/utils/perturb_data_dir_speed.sh @@ -112,4 +112,9 @@ cat $srcdir/utt2dur | utils/apply_map.pl -f 1 $destdir/utt_map | \ rm $destdir/spk_map $destdir/utt_map 2>/dev/null echo "$0: generated speed-perturbed version of data in $srcdir, in $destdir" -utils/validate_data_dir.sh --no-feats $destdir + +if [ -f $srcdir/text ]; then + utils/validate_data_dir.sh --no-feats $destdir +else + utils/validate_data_dir.sh --no-feats --no-text $destdir +fi diff --git a/egs/wsj/s5/utils/split_data.sh b/egs/wsj/s5/utils/split_data.sh index e44a4ab6359..646830481db 100755 --- a/egs/wsj/s5/utils/split_data.sh +++ b/egs/wsj/s5/utils/split_data.sh @@ -16,9 +16,14 @@ # limitations under the License. split_per_spk=true +split_per_reco=false if [ "$1" == "--per-utt" ]; then split_per_spk=false shift +elif [ "$1" == "--per-reco" ]; then + split_per_spk=false + split_per_reco=true + shift fi if [ $# != 2 ]; then @@ -59,10 +64,14 @@ if [ -f $data/text ] && [ $nu -ne $nt ]; then echo "** use utils/fix_data_dir.sh to fix this." fi - if $split_per_spk; then utt2spk_opt="--utt2spk=$data/utt2spk" utt="" +elif $split_per_reco; then + utils/data/get_reco2utt.sh $data + utils/spk2utt_to_utt2spk.pl $data/reco2utt > $data/utt2reco + utt2spk_opt="--utt2spk=$data/utt2reco" + utt="reco" else utt2spk_opt= utt="utt" @@ -86,6 +95,7 @@ if ! $need_to_split; then fi utt2spks=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n/utt2spk; done) +utt2recos=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n/utt2reco; done) directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}${utt}/$n; done) @@ -100,11 +110,20 @@ fi which lockfile >&/dev/null && lockfile -l 60 $data/.split_lock trap 'rm -f $data/.split_lock' EXIT HUP INT PIPE TERM -utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1 +if $split_per_reco; then + utils/split_scp.pl $utt2spk_opt $data/utt2reco $utt2recos || exit 1 +else + utils/split_scp.pl $utt2spk_opt $data/utt2spk $utt2spks || exit 1 +fi for n in `seq $numsplit`; do dsn=$data/split${numsplit}${utt}/$n - utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1; + + if $split_per_reco; then + utils/filter_scp.pl $dsn/utt2reco $data/utt2spk > $dsn/utt2spk + fi + + utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1 done maybe_wav_scp= diff --git a/egs/wsj/s5/utils/subset_data_dir.sh b/egs/wsj/s5/utils/subset_data_dir.sh index 5fe3217ddad..9533d0216c9 100755 --- a/egs/wsj/s5/utils/subset_data_dir.sh +++ b/egs/wsj/s5/utils/subset_data_dir.sh @@ -108,6 +108,7 @@ function do_filtering { [ -f $srcdir/vad.scp ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/vad.scp >$destdir/vad.scp [ -f $srcdir/utt2lang ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/utt2lang >$destdir/utt2lang [ -f $srcdir/utt2dur ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/utt2dur >$destdir/utt2dur + [ -f $srcdir/utt2uniq ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/utt2uniq >$destdir/utt2uniq [ -f $srcdir/wav.scp ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/wav.scp >$destdir/wav.scp [ -f $srcdir/spk2warp ] && utils/filter_scp.pl $destdir/spk2utt <$srcdir/spk2warp >$destdir/spk2warp [ -f $srcdir/utt2warp ] && utils/filter_scp.pl $destdir/utt2spk <$srcdir/utt2warp >$destdir/utt2warp @@ -126,6 +127,10 @@ function do_filtering { [ -f $srcdir/stm ] && utils/filter_scp.pl $destdir/reco < $srcdir/stm > $destdir/stm rm $destdir/reco + else + awk '{print $1;}' $destdir/wav.scp | sort | uniq > $destdir/reco + [ -f $srcdir/reco2file_and_channel ] && \ + utils/filter_scp.pl $destdir/reco <$srcdir/reco2file_and_channel >$destdir/reco2file_and_channel fi srcutts=`cat $srcdir/utt2spk | wc -l` destutts=`cat $destdir/utt2spk | wc -l` diff --git a/src/Makefile b/src/Makefile index 9905be869a0..a42f78f4742 100644 --- a/src/Makefile +++ b/src/Makefile @@ -6,16 +6,16 @@ SHELL := /bin/bash SUBDIRS = base matrix util feat tree thread gmm transform sgmm \ - fstext hmm lm decoder lat kws cudamatrix nnet \ + fstext hmm lm decoder lat kws cudamatrix nnet segmenter \ bin fstbin gmmbin fgmmbin sgmmbin featbin \ nnetbin latbin sgmm2 sgmm2bin nnet2 nnet3 chain nnet3bin nnet2bin kwsbin \ - ivector ivectorbin online2 online2bin lmbin chainbin + ivector ivectorbin online2 online2bin lmbin chainbin segmenterbin MEMTESTDIRS = base matrix util feat tree thread gmm transform sgmm \ - fstext hmm lm decoder lat nnet kws chain \ + fstext hmm lm decoder lat nnet kws chain segmenter \ bin fstbin gmmbin fgmmbin sgmmbin featbin \ nnetbin latbin sgmm2 nnet2 nnet3 nnet2bin nnet3bin sgmm2bin kwsbin \ - ivector ivectorbin online2 online2bin lmbin + ivector ivectorbin online2 online2bin lmbin segmenterbin CUDAMEMTESTDIR = cudamatrix @@ -155,7 +155,7 @@ $(EXT_SUBDIRS) : mklibdir bin fstbin gmmbin fgmmbin sgmmbin sgmm2bin featbin nnetbin nnet2bin nnet3bin chainbin latbin ivectorbin lmbin kwsbin online2bin: \ base matrix util feat tree thread gmm transform sgmm sgmm2 fstext hmm \ - lm decoder lat cudamatrix nnet nnet2 nnet3 ivector chain kws online2 + lm decoder lat cudamatrix nnet nnet2 nnet3 ivector chain kws online2 segmenter #2)The libraries have inter-dependencies base: base/.depend.mk @@ -179,6 +179,7 @@ nnet2: base util matrix thread lat gmm hmm tree transform cudamatrix nnet3: base util matrix thread lat gmm hmm tree transform cudamatrix chain fstext chain: lat hmm tree fstext matrix cudamatrix util thread base ivector: base util matrix thread transform tree gmm +segmenter: base matrix util gmm thread #3)Dependencies for optional parts of Kaldi onlinebin: base matrix util feat tree gmm transform sgmm sgmm2 fstext hmm lm decoder lat cudamatrix nnet nnet2 online thread # python-kaldi-decoding: base matrix util feat tree thread gmm transform sgmm sgmm2 fstext hmm decoder lat online diff --git a/src/bin/Makefile b/src/bin/Makefile index 687040889b3..3dc59fe8112 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -24,7 +24,8 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ matrix-logprob matrix-sum \ build-pfile-from-ali get-post-on-ali tree-info am-info \ vector-sum matrix-sum-rows est-pca sum-lda-accs sum-mllt-accs \ - transform-vec align-text matrix-dim + transform-vec align-text matrix-dim weight-pdf-post weight-matrix \ + matrix-add-offset matrix-dot-product OBJFILES = diff --git a/src/bin/copy-matrix.cc b/src/bin/copy-matrix.cc index d7b8181c64c..56f2e51d90f 100644 --- a/src/bin/copy-matrix.cc +++ b/src/bin/copy-matrix.cc @@ -36,16 +36,30 @@ int main(int argc, char *argv[]) { " e.g.: copy-matrix --binary=false 1.mat -\n" " copy-matrix ark:2.trans ark,t:-\n" "See also: copy-feats\n"; - + bool binary = true; + bool apply_log = false; + bool apply_exp = false; + bool apply_softmax_per_row = false; + BaseFloat apply_power = 1.0; BaseFloat scale = 1.0; + ParseOptions po(usage); po.Register("binary", &binary, "Write in binary mode (only relevant if output is a wxfilename)"); po.Register("scale", &scale, "This option can be used to scale the matrices being copied."); - + po.Register("apply-log", &apply_log, + "This option can be used to apply log on the matrices. " + "Must be avoided if matrix has negative quantities."); + po.Register("apply-exp", &apply_exp, + "This option can be used to apply exp on the matrices"); + po.Register("apply-power", &apply_power, + "This option can be used to apply a power on the matrices"); + po.Register("apply-softmax-per-row", &apply_softmax_per_row, + "This option can be used to apply softmax per row of the matrices"); + po.Read(argc, argv); if (po.NumArgs() != 2) { @@ -53,6 +67,10 @@ int main(int argc, char *argv[]) { exit(1); } + if ( (apply_log && apply_exp) || (apply_softmax_per_row && apply_exp) || + (apply_softmax_per_row && apply_log) ) + KALDI_ERR << "Only one of apply-log, apply-exp and " + << "apply-softmax-per-row can be given"; std::string matrix_in_fn = po.GetArg(1), matrix_out_fn = po.GetArg(2); @@ -68,11 +86,15 @@ int main(int argc, char *argv[]) { if (in_is_rspecifier != out_is_wspecifier) KALDI_ERR << "Cannot mix archives with regular files (copying matrices)"; - + if (!in_is_rspecifier) { Matrix mat; ReadKaldiObject(matrix_in_fn, &mat); if (scale != 1.0) mat.Scale(scale); + if (apply_log) mat.ApplyLog(); + if (apply_exp) mat.ApplyExp(); + if (apply_softmax_per_row) mat.ApplySoftMaxPerRow(); + if (apply_power != 1.0) mat.ApplyPow(apply_power); Output ko(matrix_out_fn, binary); mat.Write(ko.Stream(), binary); KALDI_LOG << "Copied matrix to " << matrix_out_fn; @@ -82,9 +104,14 @@ int main(int argc, char *argv[]) { BaseFloatMatrixWriter writer(matrix_out_fn); SequentialBaseFloatMatrixReader reader(matrix_in_fn); for (; !reader.Done(); reader.Next(), num_done++) { - if (scale != 1.0) { + if (scale != 1.0 || apply_log || apply_exp || + apply_power != 1.0 || apply_softmax_per_row) { Matrix mat(reader.Value()); - mat.Scale(scale); + if (scale != 1.0) mat.Scale(scale); + if (apply_log) mat.ApplyLog(); + if (apply_exp) mat.ApplyExp(); + if (apply_softmax_per_row) mat.ApplySoftMaxPerRow(); + if (apply_power != 1.0) mat.ApplyPow(apply_power); writer.Write(reader.Key(), mat); } else { writer.Write(reader.Key(), reader.Value()); diff --git a/src/bin/matrix-add-offset.cc b/src/bin/matrix-add-offset.cc new file mode 100644 index 00000000000..90f72ba3254 --- /dev/null +++ b/src/bin/matrix-add-offset.cc @@ -0,0 +1,84 @@ +// bin/matrix-add-offset.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Add an offset vector to the rows of matrices in a table.\n" + "\n" + "Usage: matrix-add-offset [options] " + " \n" + "e.g.: matrix-add-offset log_post.mat neg_priors.vec log_like.mat\n" + "See also: matrix-sum-rows, matrix-sum, vector-sum\n"; + + + ParseOptions po(usage); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + std::string rspecifier = po.GetArg(1); + std::string vector_rxfilename = po.GetArg(2); + std::string wspecifier = po.GetArg(3); + + SequentialBaseFloatMatrixReader mat_reader(rspecifier); + BaseFloatMatrixWriter mat_writer(wspecifier); + + int32 num_done = 0; + + Vector vec; + { + bool binary_in; + Input ki(vector_rxfilename, &binary_in); + vec.Read(ki.Stream(), binary_in); + } + + for (; !mat_reader.Done(); mat_reader.Next()) { + std::string key = mat_reader.Key(); + Matrix mat(mat_reader.Value()); + if (vec.Dim() != mat.NumCols()) { + KALDI_ERR << "Mismatch in vector dimension and " + << "number of columns in matrix; " + << vec.Dim() << " vs " << mat.NumCols(); + } + mat.AddVecToRows(1.0, vec); + mat_writer.Write(key, mat); + num_done++; + } + + KALDI_LOG << "Added offset to " << num_done << " matrices."; + + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/bin/matrix-dot-product.cc b/src/bin/matrix-dot-product.cc new file mode 100644 index 00000000000..a292cab9a40 --- /dev/null +++ b/src/bin/matrix-dot-product.cc @@ -0,0 +1,183 @@ +// bin/matrix-dot-product.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Get element-wise dot product of matrices. Always returns a matrix " + "that is the same size as the first matrix.\n" + "If there is a mismatch in number of rows, the utterance is skipped, " + "unless the mismatch is within a tolerance. If the second matrix has " + "number of rows that is larger than the first matrix by less than the " + "specified tolerance, then a submatrix of the second matrix is " + "multiplied element-wise with the first matrix.\n" + "\n" + "Usage: matrix-dot-product [options] " + "[ ...] " + "\n" + " e.g.: matrix-dot-product ark:1.weights ark:2.weights " + "ark:combine.weights\n" + "or \n" + "Usage: matrix-dot-product [options] " + "[ ...] " + "\n" + " e.g.: matrix-sum --binary=false 1.mat 2.mat product.mat\n" + "See also: matrix-sum, matrix-sum-rows\n"; + + bool binary = true; + int32 length_tolerance = 0; + + ParseOptions po(usage); + + po.Register("binary", &binary, "If true, write output as binary (only " + "relevant for usage types two or three"); + po.Register("length-tolerance", &length_tolerance, + "Tolerance length mismatch of this many frames"); + + po.Read(argc, argv); + + if (po.NumArgs() < 2) { + po.PrintUsage(); + exit(1); + } + + int32 N = po.NumArgs(); + std::string matrix_in_fn1 = po.GetArg(1), + matrix_out_fn = po.GetArg(N); + + if (ClassifyWspecifier(matrix_out_fn, NULL, NULL, NULL) != kNoWspecifier) { + // output to table. + + // Output matrix + BaseFloatMatrixWriter matrix_writer(matrix_out_fn); + + // Input matrices + SequentialBaseFloatMatrixReader matrix_reader1(matrix_in_fn1); + std::vector + matrix_readers(N-2, + static_cast(NULL)); + std::vector matrix_in_fns(N-2); + for (int32 i = 2; i < N; ++i) { + matrix_readers[i-2] = new RandomAccessBaseFloatMatrixReader( + po.GetArg(i)); + matrix_in_fns[i-2] = po.GetArg(i); + } + int32 n_utts = 0, n_total_matrices = 0, + n_success = 0, n_missing = 0, n_other_errors = 0; + + for (; !matrix_reader1.Done(); matrix_reader1.Next()) { + std::string key = matrix_reader1.Key(); + Matrix matrix1 = matrix_reader1.Value(); + matrix_reader1.FreeCurrent(); + n_utts++; + n_total_matrices++; + + Matrix matrix_out(matrix1); + + int32 i = 0; + for (i = 0; i < N-2; ++i) { + bool failed = false; // Indicates failure for this key. + if (matrix_readers[i]->HasKey(key)) { + const Matrix &matrix2 = matrix_readers[i]->Value(key); + n_total_matrices++; + if (SameDim(matrix2, matrix_out)) { + matrix_out.MulElements(matrix2); + } else { + KALDI_WARN << "Dimension mismatch for utterance " << key + << " : " << matrix2.NumRows() << " by " + << matrix2.NumCols() << " for " + << "system " << (i + 2) << ", rspecifier: " + << matrix_in_fns[i] << " vs " << matrix_out.NumRows() + << " by " << matrix_out.NumCols() + << " primary matrix, rspecifier:" << matrix_in_fn1; + if (matrix2.NumRows() - matrix_out.NumRows() <= + length_tolerance) { + KALDI_WARN << "Tolerated length mismatch for key " << key; + matrix_out.MulElements(matrix2.Range(0, matrix_out.NumRows(), + 0, matrix2.NumCols())); + } else { + KALDI_WARN << "Skipping key " << key; + failed = true; + n_other_errors++; + } + } + } else { + KALDI_WARN << "No matrix found for utterance " << key << " for " + << "system " << (i + 2) << ", rspecifier: " + << matrix_in_fns[i]; + failed = true; + n_missing++; + } + + if (failed) break; + } + + if (i != N-2) // Skipping utterance + continue; + + matrix_writer.Write(key, matrix_out); + n_success++; + } + + KALDI_LOG << "Processed " << n_utts << " utterances: with a total of " + << n_total_matrices << " matrices across " << (N-1) + << " different systems."; + KALDI_LOG << "Produced output for " << n_success << " utterances; " + << n_missing << " total missing matrices and skipped " + << n_other_errors << "matrices."; + + DeletePointers(&matrix_readers); + + return (n_success != 0 && n_missing < (n_success - n_missing)) ? 0 : 1; + } else { + for (int32 i = 1; i < N; i++) { + if (ClassifyRspecifier(po.GetArg(i), NULL, NULL) != kNoRspecifier) { + KALDI_ERR << "Wrong usage: if last argument is not " + << "table, the other arguments must not be tables."; + } + } + + Matrix mat1; + ReadKaldiObject(po.GetArg(1), &mat1); + + for (int32 i = 2; i < N; i++) { + Matrix mat; + ReadKaldiObject(po.GetArg(i), &mat); + + mat1.MulElements(mat); + } + + WriteKaldiObject(mat1, po.GetArg(N), binary); + KALDI_LOG << "Multiplied " << (po.NumArgs() - 1) << " matrices; " + << "wrote product to " << PrintableWxfilename(po.GetArg(N)); + + return 0; + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/bin/matrix-sum-rows.cc b/src/bin/matrix-sum-rows.cc index 7e60483eef2..ee6504ba2b1 100644 --- a/src/bin/matrix-sum-rows.cc +++ b/src/bin/matrix-sum-rows.cc @@ -34,9 +34,13 @@ int main(int argc, char *argv[]) { "e.g.: matrix-sum-rows ark:- ark:- | vector-sum ark:- sum.vec\n" "See also: matrix-sum, vector-sum\n"; + bool do_average = false; ParseOptions po(usage); + po.Register("do-average", &do_average, + "Do average instead of sum"); + po.Read(argc, argv); if (po.NumArgs() != 2) { @@ -45,28 +49,28 @@ int main(int argc, char *argv[]) { } std::string rspecifier = po.GetArg(1); std::string wspecifier = po.GetArg(2); - + SequentialBaseFloatMatrixReader mat_reader(rspecifier); BaseFloatVectorWriter vec_writer(wspecifier); - + int32 num_done = 0; int64 num_rows_done = 0; - + for (; !mat_reader.Done(); mat_reader.Next()) { std::string key = mat_reader.Key(); Matrix mat(mat_reader.Value()); Vector vec(mat.NumCols()); - vec.AddRowSumMat(1.0, mat, 0.0); + vec.AddRowSumMat(!do_average ? 1.0 : 1.0 / mat.NumRows(), mat, 0.0); // Do the summation in double, to minimize roundoff. Vector float_vec(vec); vec_writer.Write(key, float_vec); num_done++; num_rows_done += mat.NumRows(); } - + KALDI_LOG << "Summed rows " << num_done << " matrices, " << num_rows_done << " rows in total."; - + return (num_done != 0 ? 0 : 1); } catch(const std::exception &e) { std::cerr << e.what(); diff --git a/src/bin/vector-scale.cc b/src/bin/vector-scale.cc index 60d4d3121d2..ea68ae31ad0 100644 --- a/src/bin/vector-scale.cc +++ b/src/bin/vector-scale.cc @@ -30,11 +30,14 @@ int main(int argc, char *argv[]) { const char *usage = "Scale a set of vectors in a Table (useful for speaker vectors and " "per-frame weights)\n" - "Usage: vector-scale [options] \n"; + "Usage: vector-scale [options] \n"; ParseOptions po(usage); BaseFloat scale = 1.0; + bool binary = false; + po.Register("binary", &binary, "If true, write output as binary " + "not relevant for archives"); po.Register("scale", &scale, "Scaling factor for vectors"); po.Read(argc, argv); @@ -43,17 +46,33 @@ int main(int argc, char *argv[]) { exit(1); } - std::string rspecifier = po.GetArg(1); - std::string wspecifier = po.GetArg(2); + std::string vector_in_fn = po.GetArg(1); + std::string vector_out_fn = po.GetArg(2); - BaseFloatVectorWriter vec_writer(wspecifier); - - SequentialBaseFloatVectorReader vec_reader(rspecifier); - for (; !vec_reader.Done(); vec_reader.Next()) { - Vector vec(vec_reader.Value()); + if (ClassifyWspecifier(vector_in_fn, NULL, NULL, NULL) != kNoWspecifier) { + if (ClassifyRspecifier(vector_in_fn, NULL, NULL) == kNoRspecifier) { + KALDI_ERR << "Cannot mix archives and regular files"; + } + BaseFloatVectorWriter vec_writer(vector_out_fn); + SequentialBaseFloatVectorReader vec_reader(vector_in_fn); + for (; !vec_reader.Done(); vec_reader.Next()) { + Vector vec(vec_reader.Value()); + vec.Scale(scale); + vec_writer.Write(vec_reader.Key(), vec); + } + } else { + if (ClassifyRspecifier(vector_in_fn, NULL, NULL) != kNoRspecifier) { + KALDI_ERR << "Cannot mix archives and regular files"; + } + bool binary_in; + Input ki(vector_in_fn, &binary_in); + Vector vec; + vec.Read(ki.Stream(), binary_in); vec.Scale(scale); - vec_writer.Write(vec_reader.Key(), vec); + Output ko(vector_out_fn, binary); + vec.Write(ko.Stream(), binary); } + return 0; } catch(const std::exception &e) { std::cerr << e.what(); diff --git a/src/bin/weight-matrix.cc b/src/bin/weight-matrix.cc new file mode 100644 index 00000000000..c6823b8da29 --- /dev/null +++ b/src/bin/weight-matrix.cc @@ -0,0 +1,84 @@ +// bin/weight-matrix.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + typedef kaldi::int32 int32; + + const char *usage = + "Takes archives (typically per-utterance) of features and " + "per-frame weights,\n" + "and weights the features by the per-frame weights\n" + "\n" + "Usage: weight-matrix " + "\n"; + + ParseOptions po(usage); + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string matrix_rspecifier = po.GetArg(1), + weights_rspecifier = po.GetArg(2), + matrix_wspecifier = po.GetArg(3); + + SequentialBaseFloatMatrixReader matrix_reader(matrix_rspecifier); + RandomAccessBaseFloatVectorReader weights_reader(weights_rspecifier); + BaseFloatMatrixWriter matrix_writer(matrix_wspecifier); + + int32 num_done = 0, num_err = 0; + + for (; !matrix_reader.Done(); matrix_reader.Next()) { + std::string key = matrix_reader.Key(); + Matrix mat = matrix_reader.Value(); + if (!weights_reader.HasKey(key)) { + KALDI_WARN << "No weight vectors for utterance " << key; + num_err++; + continue; + } + const Vector &weights = weights_reader.Value(key); + if (weights.Dim() != mat.NumRows()) { + KALDI_WARN << "Weights for utterance " << key + << " have wrong size, " << weights.Dim() + << " vs. " << mat.NumRows(); + num_err++; + continue; + } + mat.MulRowsVec(weights); + matrix_writer.Write(key, mat); + num_done++; + } + KALDI_LOG << "Applied per-frame weights for " << num_done + << " matrices; errors on " << num_err; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/bin/weight-pdf-post.cc b/src/bin/weight-pdf-post.cc new file mode 100644 index 00000000000..c7477a046c8 --- /dev/null +++ b/src/bin/weight-pdf-post.cc @@ -0,0 +1,154 @@ +// bin/weight-pdf-post.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "gmm/am-diag-gmm.h" +#include "hmm/transition-model.h" +#include "hmm/hmm-utils.h" +#include "hmm/posterior.h" + +namespace kaldi { + +void WeightPdfPost(const ConstIntegerSet &pdf_set, + BaseFloat pdf_scale, + Posterior *post) { + for (size_t i = 0; i < post->size(); i++) { + std::vector > this_post; + this_post.reserve((*post)[i].size()); + for (size_t j = 0; j < (*post)[i].size(); j++) { + int32 pdf_id = (*post)[i][j].first; + BaseFloat weight = (*post)[i][j].second; + if (pdf_set.count(pdf_id) != 0) { // is a silence. + if (pdf_scale != 0.0) + this_post.push_back(std::make_pair(pdf_id, weight*pdf_scale)); + } else { + this_post.push_back(std::make_pair(pdf_id, weight)); + } + } + (*post)[i].swap(this_post); + } +} + +void WeightPdfPostDistributed(const ConstIntegerSet &pdf_set, + BaseFloat pdf_scale, + Posterior *post) { + for (size_t i = 0; i < post->size(); i++) { + std::vector > this_post; + this_post.reserve((*post)[i].size()); + BaseFloat sil_weight = 0.0, nonsil_weight = 0.0; + for (size_t j = 0; j < (*post)[i].size(); j++) { + int32 pdf_id = (*post)[i][j].first; + BaseFloat weight = (*post)[i][j].second; + if (pdf_set.count(pdf_id) != 0) + sil_weight += weight; + else + nonsil_weight += weight; + } + // This "distributed" weighting approach doesn't make sense if we have + // negative weights. + KALDI_ASSERT(sil_weight >= 0.0 && nonsil_weight >= 0.0); + if (sil_weight + nonsil_weight == 0.0) continue; + BaseFloat frame_scale = (sil_weight * pdf_scale + nonsil_weight) / + (sil_weight + nonsil_weight); + if (frame_scale != 0.0) { + for (size_t j = 0; j < (*post)[i].size(); j++) { + int32 pdf_id = (*post)[i][j].first; + BaseFloat weight = (*post)[i][j].second; + this_post.push_back(std::make_pair(pdf_id, weight * frame_scale)); + } + } + (*post)[i].swap(this_post); + } +} + +} // namespace kaldi + +int main(int argc, char *argv[]) { + using namespace kaldi; + typedef kaldi::int32 int32; + try { + const char *usage = + "Apply weight to specific pdfs or tids in posts\n" + "Usage: weight-pdf-post [options] " + " \n" + "e.g.:\n" + " weight-pdf-post 0.00001 0:2 ark:1.post ark:nosil.post\n"; + + ParseOptions po(usage); + + bool distribute = false; + + po.Register("distribute", &distribute, "If true, rather than weighting the " + "individual posteriors, apply the weighting to the " + "whole frame: " + "i.e. on time t, scale all posterior entries by " + "p(sil)*silence-weight + p(non-sil)*1.0"); + + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + + std::string pdf_weight_str = po.GetArg(1), + pdfs_str = po.GetArg(2), + posteriors_rspecifier = po.GetArg(3), + posteriors_wspecifier = po.GetArg(4); + + BaseFloat pdf_weight = 0.0; + if (!ConvertStringToReal(pdf_weight_str, &pdf_weight)) + KALDI_ERR << "Invalid pdf-weight parameter: expected float, got \"" + << pdf_weight << '"'; + std::vector pdfs; + if (!SplitStringToIntegers(pdfs_str, ":", false, &pdfs)) + KALDI_ERR << "Invalid pdf string string " << pdfs_str; + if (pdfs.empty()) + KALDI_WARN <<"No pdf specified, this will have no effect"; + ConstIntegerSet pdf_set(pdfs); // faster lookup. + + int32 num_posteriors = 0; + SequentialPosteriorReader posterior_reader(posteriors_rspecifier); + PosteriorWriter posterior_writer(posteriors_wspecifier); + + for (; !posterior_reader.Done(); posterior_reader.Next()) { + num_posteriors++; + // Posterior is vector > > + Posterior post = posterior_reader.Value(); + // Posterior is vector > > + if (distribute) + WeightPdfPostDistributed(pdf_set, + pdf_weight, &post); + else + WeightPdfPost(pdf_set, + pdf_weight, &post); + + posterior_writer.Write(posterior_reader.Key(), post); + } + KALDI_LOG << "Done " << num_posteriors << " posteriors."; + return (num_posteriors != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/bin/weight-post.cc b/src/bin/weight-post.cc index d536896eaaa..bbaad465195 100644 --- a/src/bin/weight-post.cc +++ b/src/bin/weight-post.cc @@ -26,32 +26,38 @@ int main(int argc, char *argv[]) { try { using namespace kaldi; - typedef kaldi::int32 int32; + typedef kaldi::int32 int32; + + int32 length_tolerance = 2; const char *usage = "Takes archives (typically per-utterance) of posteriors and per-frame weights,\n" "and weights the posteriors by the per-frame weights\n" "\n" "Usage: weight-post \n"; - + ParseOptions po(usage); + + po.Register("length-tolerance", &length_tolerance, + "Tolerate this many frames of length mismatch"); + po.Read(argc, argv); if (po.NumArgs() != 3) { po.PrintUsage(); exit(1); } - + std::string post_rspecifier = po.GetArg(1), weights_rspecifier = po.GetArg(2), post_wspecifier = po.GetArg(3); SequentialPosteriorReader posterior_reader(post_rspecifier); RandomAccessBaseFloatVectorReader weights_reader(weights_rspecifier); - PosteriorWriter post_writer(post_wspecifier); - + PosteriorWriter post_writer(post_wspecifier); + int32 num_done = 0, num_err = 0; - + for (; !posterior_reader.Done(); posterior_reader.Next()) { std::string key = posterior_reader.Key(); Posterior post = posterior_reader.Value(); @@ -61,7 +67,8 @@ int main(int argc, char *argv[]) { continue; } const Vector &weights = weights_reader.Value(key); - if (weights.Dim() != static_cast(post.size())) { + if (std::abs(weights.Dim() - static_cast(post.size())) > + length_tolerance) { KALDI_WARN << "Weights for utterance " << key << " have wrong size, " << weights.Dim() << " vs. " << post.size(); @@ -71,7 +78,7 @@ int main(int argc, char *argv[]) { for (size_t i = 0; i < post.size(); i++) { if (weights(i) == 0.0) post[i].clear(); for (size_t j = 0; j < post[i].size(); j++) - post[i][j].second *= weights(i); + post[i][j].second *= i < weights.Dim() ? weights(i) : 0.0; } post_writer.Write(key, post); num_done++; diff --git a/src/featbin/Makefile b/src/featbin/Makefile index dc2bea215d8..aaa4abca24c 100644 --- a/src/featbin/Makefile +++ b/src/featbin/Makefile @@ -15,7 +15,8 @@ BINFILES = compute-mfcc-feats compute-plp-feats compute-fbank-feats \ process-kaldi-pitch-feats compare-feats wav-to-duration add-deltas-sdc \ compute-and-process-kaldi-pitch-feats modify-cmvn-stats wav-copy \ wav-reverberate append-vector-to-feats detect-sinusoids shift-feats \ - concat-feats append-post-to-feats post-to-feats + concat-feats append-post-to-feats post-to-feats vector-to-feat \ + extract-column compute-snr-targets OBJFILES = diff --git a/src/featbin/apply-cmvn-sliding.cc b/src/featbin/apply-cmvn-sliding.cc index 4a6d02d16cd..105319761b5 100644 --- a/src/featbin/apply-cmvn-sliding.cc +++ b/src/featbin/apply-cmvn-sliding.cc @@ -35,10 +35,13 @@ int main(int argc, char *argv[]) { "Useful for speaker-id; see also apply-cmvn-online\n" "\n" "Usage: apply-cmvn-sliding [options] \n"; - + + std::string skip_dims_str; ParseOptions po(usage); SlidingWindowCmnOptions opts; opts.Register(&po); + po.Register("skip-dims", &skip_dims_str, "Dimensions for which to skip " + "normalization: colon-separated list of integers, e.g. 13:14:15)"); po.Read(argc, argv); @@ -47,15 +50,24 @@ int main(int argc, char *argv[]) { exit(1); } + std::vector skip_dims; // optionally use "fake" + // (zero-mean/unit-variance) stats for some + // dims to disable normalization. + if (!SplitStringToIntegers(skip_dims_str, ":", false, &skip_dims)) { + KALDI_ERR << "Bad --skip-dims option (should be colon-separated list of " + << "integers)"; + } + + int32 num_done = 0, num_err = 0; - + std::string feat_rspecifier = po.GetArg(1); std::string feat_wspecifier = po.GetArg(2); SequentialBaseFloatMatrixReader feat_reader(feat_rspecifier); BaseFloatMatrixWriter feat_writer(feat_wspecifier); - - for (;!feat_reader.Done(); feat_reader.Next()) { + + for (; !feat_reader.Done(); feat_reader.Next()) { std::string utt = feat_reader.Key(); Matrix feat(feat_reader.Value()); if (feat.NumRows() == 0) { @@ -67,7 +79,7 @@ int main(int argc, char *argv[]) { feat.NumCols(), kUndefined); SlidingWindowCmn(opts, feat, &cmvn_feat); - + feat_writer.Write(utt, cmvn_feat); num_done++; } diff --git a/src/featbin/compute-snr-targets.cc b/src/featbin/compute-snr-targets.cc new file mode 100644 index 00000000000..cdb7ef66c2a --- /dev/null +++ b/src/featbin/compute-snr-targets.cc @@ -0,0 +1,273 @@ +// featbin/compute-snr-targets.cc + +// Copyright 2015-2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Compute snr targets using clean and noisy speech features.\n" + "The targets can be of 3 types -- \n" + "Irm (Ideal Ratio Mask) = Clean fbank / (Clean fbank + Noise fbank)\n" + "FbankMask = Clean fbank / Noisy fbank\n" + "Snr (Signal To Noise Ratio) = Clean fbank / Noise fbank\n" + "Both input and output features are assumed to be in log domain.\n" + "ali-rspecifier and silence-phones are used to identify whether " + "a particular frame is \"clean\" or not. Silence frames in " + "\"clean\" fbank are treated as \"noise\" and hence the SNR for those " + "frames are -inf in log scale.\n" + "Usage: compute-snr-targets [options] \n" + " or compute-snr-targets [options] --binary-targets \n" + "e.g.: compute-snr-targets scp:clean.scp scp:noisy.scp ark:targets.ark\n"; + + std::string target_type = "Irm"; + std::string ali_rspecifier; + std::string silence_phones_str; + std::string floor_str = "-inf", ceiling_str = "inf"; + int32 length_tolerance = 0; + bool binary_targets = false; + int32 target_dim = -1; + + ParseOptions po(usage); + po.Register("target_type", &target_type, "Target type can be FbankMask or IRM"); + po.Register("ali-rspecifier", &ali_rspecifier, "If provided, all the " + "energy in the silence region of clean file is considered noise"); + po.Register("silence-phones", &silence_phones_str, "Comma-separated list of " + "silence phones"); + po.Register("floor", &floor_str, "If specified, the target is floored at " + "this value. You may want to do this if you are using targets " + "in original log form as is usual in the case of Snr, but may " + "not if you are applying Exp() as is usual in the case of Irm"); + po.Register("ceiling", &ceiling_str, "If specified, the target is ceiled " + "at this value. You may want to do this if you expect " + "infinities or very large values, particularly for Snr targets."); + po.Register("length-tolerance", &length_tolerance, "Tolerate differences " + "in utterance lengths of these many frames"); + po.Register("binary-targets", &binary_targets, "If specified, then the " + "targets are created considering each frame to be either " + "completely signal or completely noise as decided by the " + "ali-rspecifier option. When ali-rspecifier is not specified, " + "then the entire utterance is considered to be just signal." + "If this option is specified, then only a single argument " + "-- the clean features -- is must be specified."); + po.Register("target-dim", &target_dim, "Overrides the target dimension. " + "Applicable only with --binary-targets is specified"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3 && po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::vector silence_phones; + if (!silence_phones_str.empty()) { + if (!SplitStringToIntegers(silence_phones_str, ":", false, &silence_phones)) { + KALDI_ERR << "Invalid silence-phones string " << silence_phones_str; + } + std::sort(silence_phones.begin(), silence_phones.end()); + } + + double floor = kLogZeroDouble, ceiling = -kLogZeroDouble; + + if (floor_str != "-inf") + if (!ConvertStringToReal(floor_str, &floor)) { + KALDI_ERR << "Invalid --floor value " << floor_str; + } + + if (ceiling_str != "inf") + if (!ConvertStringToReal(ceiling_str, &ceiling)) { + KALDI_ERR << "Invalid --ceiling value " << ceiling_str; + } + + int32 num_done = 0, num_err = 0, num_success = 0; + int64 num_sil_frames = 0; + int64 num_speech_frames = 0; + + if (!binary_targets) { + // This is the 'normal' case, where we have both clean and + // noise/corrupted input features. + // The word 'noisy' in the variable names is used to mean 'corrupted'. + std::string clean_rspecifier = po.GetArg(1), + noisy_rspecifier = po.GetArg(2), + targets_wspecifier = po.GetArg(3); + + SequentialBaseFloatMatrixReader noisy_reader(noisy_rspecifier); + RandomAccessBaseFloatMatrixReader clean_reader(clean_rspecifier); + BaseFloatMatrixWriter kaldi_writer(targets_wspecifier); + + RandomAccessInt32VectorReader alignment_reader(ali_rspecifier); + + for (; !noisy_reader.Done(); noisy_reader.Next(), num_done++) { + const std::string &key = noisy_reader.Key(); + Matrix total_energy(noisy_reader.Value()); + // Although this is called 'energy', it is actually log filterbank + // features of noise or corrupted files + // Actually noise feats in the case of Irm and Snr + + // TODO: Support multiple corrupted version for a particular clean file + std::string uniq_key = key; + if (!clean_reader.HasKey(uniq_key)) { + KALDI_WARN << "Could not find uniq key " << uniq_key << " " + << "in clean feats " << clean_rspecifier; + num_err++; + continue; + } + + Matrix clean_energy(clean_reader.Value(uniq_key)); + + if (target_type == "Irm") { + total_energy.LogAddExpMat(1.0, clean_energy, kNoTrans); + } + + if (!ali_rspecifier.empty()) { + if (!alignment_reader.HasKey(uniq_key)) { + KALDI_WARN << "Could not find uniq key " << uniq_key + << "in alignment " << ali_rspecifier; + num_err++; + continue; + } + const std::vector &ali = alignment_reader.Value(key); + + if (std::abs(static_cast (ali.size()) - clean_energy.NumRows()) > length_tolerance) { + KALDI_WARN << "Mismatch in number of frames in alignment " + << "and feats; " << static_cast(ali.size()) + << " vs " << clean_energy.NumRows(); + num_err++; + continue; + } + + int32 length = std::min(static_cast(ali.size()), clean_energy.NumRows()); + if (ali.size() < length) + // TODO: Support this case + KALDI_ERR << "This code currently does not support the case " + << "where alignment smaller than features because " + << "it is not expected to happen"; + + KALDI_ASSERT(clean_energy.NumRows() == length); + KALDI_ASSERT(total_energy.NumRows() == length); + + if (clean_energy.NumRows() < length) clean_energy.Resize(length, clean_energy.NumCols(), kCopyData); + if (total_energy.NumRows() < length) total_energy.Resize(length, total_energy.NumCols(), kCopyData); + + for (int32 i = 0; i < clean_energy.NumRows(); i++) { + if (std::binary_search(silence_phones.begin(), silence_phones.end(), ali[i])) { + clean_energy.Row(i).Set(kLogZeroDouble); + num_sil_frames++; + } else num_speech_frames++; + } + } + + clean_energy.AddMat(-1.0, total_energy); + if (ceiling_str != "inf") { + clean_energy.ApplyCeiling(ceiling); + } + + if (floor_str != "-inf") { + clean_energy.ApplyFloor(floor); + } + + kaldi_writer.Write(key, Matrix(clean_energy)); + num_success++; + } + } else { + // Copying tables of features. + std::string feats_rspecifier = po.GetArg(1), + targets_wspecifier = po.GetArg(2); + + SequentialBaseFloatMatrixReader feats_reader(feats_rspecifier); + BaseFloatMatrixWriter kaldi_writer(targets_wspecifier); + + RandomAccessInt32VectorReader alignment_reader(ali_rspecifier); + + int64 num_sil_frames = 0; + int64 num_speech_frames = 0; + + for (; !feats_reader.Done(); feats_reader.Next(), num_done++) { + const std::string &key = feats_reader.Key(); + const Matrix &feats = feats_reader.Value(); + + Matrix targets; + + if (target_dim < 0) + targets.Resize(feats.NumRows(), feats.NumCols()); + else + targets.Resize(feats.NumRows(), target_dim); + + if (target_type == "Snr") + targets.Set(-kLogZeroDouble); + + if (!ali_rspecifier.empty()) { + if (!alignment_reader.HasKey(key)) { + KALDI_WARN << "Could not find uniq key " << key + << " in alignment " << ali_rspecifier; + num_err++; + continue; + } + + const std::vector &ali = alignment_reader.Value(key); + + if (std::abs(static_cast (ali.size()) - feats.NumRows()) > length_tolerance) { + KALDI_WARN << "Mismatch in number of frames in alignment " + << "and feats; " << static_cast(ali.size()) + << " vs " << feats.NumRows(); + num_err++; + continue; + } + + int32 length = std::min(static_cast(ali.size()), feats.NumRows()); + KALDI_ASSERT(ali.size() >= length); + + for (int32 i = 0; i < feats.NumRows(); i++) { + if (std::binary_search(silence_phones.begin(), silence_phones.end(), ali[i])) { + targets.Row(i).Set(kLogZeroDouble); + num_sil_frames++; + } else { + num_speech_frames++; + } + } + + if (ceiling_str != "inf") { + targets.ApplyCeiling(ceiling); + } + + if (floor_str != "-inf") { + targets.ApplyFloor(floor); + } + + kaldi_writer.Write(key, targets); + } + } + } + + KALDI_LOG << "Computed SNR targets for " << num_success + << " out of " << num_done << " utterances; failed for " + << num_err; + KALDI_LOG << "Got [ " << num_speech_frames << "," + << num_sil_frames << "] frames of silence and speech"; + return (num_success > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} diff --git a/src/featbin/extract-column.cc b/src/featbin/extract-column.cc new file mode 100644 index 00000000000..7fa6644af03 --- /dev/null +++ b/src/featbin/extract-column.cc @@ -0,0 +1,84 @@ +// featbin/extract-column.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace std; + + const char *usage = + "Extract a column out of a matrix. \n" + "This is most useful to extract log-energies \n" + "from feature files\n" + "\n" + "Usage: extract-column [options] --column-index= " + " \n" + " e.g. extract-column ark:feats-in.ark ark:energies.ark\n" + "See also: select-feats, subset-feats, subsample-feats, extract-rows\n"; + + ParseOptions po(usage); + + int32 column_index = 0; + + po.Register("column-index", &column_index, + "Index of column to extract"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + string feat_rspecifier = po.GetArg(1); + string vector_wspecifier = po.GetArg(2); + + SequentialBaseFloatMatrixReader reader(feat_rspecifier); + BaseFloatVectorWriter writer(vector_wspecifier); + + int32 num_done = 0, num_err = 0; + + string line; + + for (; !reader.Done(); reader.Next(), num_done++) { + const Matrix& feats(reader.Value()); + Vector col(feats.NumRows()); + if (column_index >= feats.NumCols()) { + KALDI_ERR << "Column index " << column_index << " is " + << "not less than number of columns " << feats.NumCols(); + } + col.CopyColFromMat(feats, column_index); + writer.Write(reader.Key(), col); + } + + KALDI_LOG << "Processed " << num_done << " matrices successfully; " + << "errors on " << num_err; + + return (num_done > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + diff --git a/src/featbin/vector-to-feat.cc b/src/featbin/vector-to-feat.cc new file mode 100644 index 00000000000..1fe521db864 --- /dev/null +++ b/src/featbin/vector-to-feat.cc @@ -0,0 +1,100 @@ +// featbin/vector-to-feat.cc + +// Copyright 2015 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/kaldi-matrix.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + + const char *usage = + "Convert a vector into a single feature so that it can be appended \n" + "to other feature matrices\n" + "Usage: vector-to-feats \n" + "or: vector-to-feats \n" + "e.g.: vector-to-feats scp:weights.scp ark:weight_feats.ark\n" + " or: vector-to-feats weight_vec feat_mat\n" + "See also: copy-feats, copy-matrix, paste-feats, \n" + "subsample-feats, splice-feats\n"; + + ParseOptions po(usage); + bool compress = false, binary = true; + + po.Register("binary", &binary, "Binary-mode output (not relevant if writing " + "to archive)"); + po.Register("compress", &compress, "If true, write output in compressed form" + "(only currently supported for wxfilename, i.e. archive/script," + "output)"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + int32 num_done = 0; + + if (ClassifyRspecifier(po.GetArg(1), NULL, NULL) != kNoRspecifier) { + std::string vector_rspecifier = po.GetArg(1); + std::string feature_wspecifier = po.GetArg(2); + + SequentialBaseFloatVectorReader vector_reader(vector_rspecifier); + BaseFloatMatrixWriter feat_writer(feature_wspecifier); + CompressedMatrixWriter compressed_feat_writer(feature_wspecifier); + + for (; !vector_reader.Done(); vector_reader.Next(), ++num_done) { + const Vector &vec = vector_reader.Value(); + Matrix feat(vec.Dim(), 1); + feat.CopyColFromVec(vec, 0); + + if (!compress) + feat_writer.Write(vector_reader.Key(), feat); + else + compressed_feat_writer.Write(vector_reader.Key(), + CompressedMatrix(feat)); + } + KALDI_LOG << "Converted " << num_done << " vectors into features"; + return (num_done != 0 ? 0 : 1); + } + + KALDI_ASSERT(!compress && "Compression not yet supported for single files"); + + std::string vector_rxfilename = po.GetArg(1), + feature_wxfilename = po.GetArg(2); + + Vector vec; + ReadKaldiObject(vector_rxfilename, &vec); + + Matrix feat(vec.Dim(), 1); + feat.CopyColFromVec(vec, 0); + + WriteKaldiObject(feat, feature_wxfilename, binary); + + KALDI_LOG << "Converted vector " << PrintableRxfilename(vector_rxfilename) + << " to " << PrintableWxfilename(feature_wxfilename); + return 0; + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/featbin/wav-reverberate.cc b/src/featbin/wav-reverberate.cc index a9e6d3509c1..3b92f6e0b3e 100644 --- a/src/featbin/wav-reverberate.cc +++ b/src/featbin/wav-reverberate.cc @@ -156,6 +156,8 @@ int main(int argc, char *argv[]) { bool normalize_output = true; BaseFloat volume = 0; BaseFloat duration = 0; + std::string reverb_wxfilename; + std::string additive_noise_wxfilename; po.Register("multi-channel-output", &multi_channel_output, "Specifies if the output should be multi-channel or not"); @@ -212,6 +214,14 @@ int main(int argc, char *argv[]) { "after reverberating and possibly adding noise. " "If you set this option to a nonzero value, it will be as " "if you had also specified --normalize-output=false."); + po.Register("reverb-out-wxfilename", &reverb_wxfilename, + "Output the reverberated wave file, i.e. before adding the " + "additive noise. " + "Useful for computing SNR features or for debugging"); + po.Register("additive-noise-out-wxfilename", + &additive_noise_wxfilename, + "Output the additive noise file used to corrupt the input wave." + "Useful for computing SNR features or for debugging"); po.Read(argc, argv); if (po.NumArgs() != 2) { @@ -314,10 +324,23 @@ int main(int argc, char *argv[]) { int32 num_samp_output = (duration > 0 ? samp_freq_input * duration : (shift_output ? num_samp_input : num_samp_input + num_samp_rir - 1)); + Matrix out_matrix(num_output_channels, num_samp_output); + Matrix out_reverb_matrix; + if (!reverb_wxfilename.empty()) + out_reverb_matrix.Resize(num_output_channels, num_samp_output); + + Matrix out_noise_matrix; + if (!additive_noise_wxfilename.empty()) + out_noise_matrix.Resize(num_output_channels, num_samp_output); + for (int32 output_channel = 0; output_channel < num_output_channels; output_channel++) { Vector input(num_samp_input); + + Vector out_reverb(0); + Vector out_noise(0); + input.CopyRowFromMat(input_matrix, input_channel); float power_before_reverb = VecVec(input, input) / input.Dim(); @@ -337,6 +360,16 @@ int main(int argc, char *argv[]) { } } + if (!reverb_wxfilename.empty()) { + out_reverb.Resize(input.Dim()); + out_reverb.CopyFromVec(input); + } + + if (!additive_noise_wxfilename.empty()) { + out_noise.Resize(input.Dim()); + out_noise.SetZero(); + } + if (additive_signal_matrices.size() > 0) { Vector noise(0); int32 this_noise_channel = (multi_channel_output ? output_channel : noise_channel); @@ -345,33 +378,86 @@ int main(int argc, char *argv[]) { for (int32 i = 0; i < additive_signal_matrices.size(); i++) { noise.Resize(additive_signal_matrices[i].NumCols()); noise.CopyRowFromMat(additive_signal_matrices[i], this_noise_channel); - AddNoise(&noise, snr_vector[i], start_time_vector[i], - samp_freq_input, early_energy, &input); + + if (!additive_noise_wxfilename.empty()) { + AddNoise(&noise, snr_vector[i], start_time_vector[i], + samp_freq_input, early_energy, &out_noise); + } else { + AddNoise(&noise, snr_vector[i], start_time_vector[i], + samp_freq_input, early_energy, &input); + } + } + + if (!additive_noise_wxfilename.empty()) { + input.AddVec(1.0, out_noise); } } float power_after_reverb = VecVec(input, input) / input.Dim(); - if (volume > 0) + if (volume > 0) { input.Scale(volume); - else if (normalize_output) + out_reverb.Scale(volume); + out_noise.Scale(volume); + } else if (normalize_output) { input.Scale(sqrt(power_before_reverb / power_after_reverb)); + out_reverb.Scale(sqrt(power_before_reverb / power_after_reverb)); + out_noise.Scale(sqrt(power_before_reverb / power_after_reverb)); + } if (num_samp_output <= num_samp_input) { // trim the signal from the start out_matrix.CopyRowFromVec(input.Range(shift_index, num_samp_output), output_channel); + + if (!reverb_wxfilename.empty()) { + out_reverb_matrix.CopyRowFromVec(out_reverb.Range(shift_index, num_samp_output), output_channel); + } + + if (!additive_noise_wxfilename.empty()) { + out_noise_matrix.CopyRowFromVec(out_noise.Range(shift_index, num_samp_output), output_channel); + } } else { - // repeat the signal to fill up the duration - Vector extended_input(num_samp_output); - extended_input.SetZero(); - AddVectorsOfUnequalLength(input.Range(shift_index, num_samp_input), &extended_input); - out_matrix.CopyRowFromVec(extended_input, output_channel); + { + // repeat the signal to fill up the duration + Vector extended_input(num_samp_output); + extended_input.SetZero(); + AddVectorsOfUnequalLength(input.Range(shift_index, num_samp_input), &extended_input); + out_matrix.CopyRowFromVec(extended_input, output_channel); + } + if (!reverb_wxfilename.empty()) { + // repeat the signal to fill up the duration + Vector extended_input(num_samp_output); + extended_input.SetZero(); + AddVectorsOfUnequalLength(out_reverb.Range(shift_index, num_samp_input), &extended_input); + out_reverb_matrix.CopyRowFromVec(extended_input, output_channel); + } + if (!additive_noise_wxfilename.empty()) { + // repeat the signal to fill up the duration + Vector extended_input(num_samp_output); + extended_input.SetZero(); + AddVectorsOfUnequalLength(out_noise.Range(shift_index, num_samp_input), &extended_input); + out_noise_matrix.CopyRowFromVec(extended_input, output_channel); + } } } + + { + WaveData out_wave(samp_freq_input, out_matrix); + Output ko(output_wave_file, false); + out_wave.Write(ko.Stream()); + } + + if (!reverb_wxfilename.empty()) { + WaveData out_wave(samp_freq_input, out_reverb_matrix); + Output ko(reverb_wxfilename, false); + out_wave.Write(ko.Stream()); + } - WaveData out_wave(samp_freq_input, out_matrix); - Output ko(output_wave_file, false); - out_wave.Write(ko.Stream()); + if (!additive_noise_wxfilename.empty()) { + WaveData out_wave(samp_freq_input, out_noise_matrix); + Output ko(additive_noise_wxfilename, false); + out_wave.Write(ko.Stream()); + } return 0; } catch(const std::exception &e) { diff --git a/src/matrix/compressed-matrix.cc b/src/matrix/compressed-matrix.cc index 2ac2c544bc8..6fc365c8f03 100644 --- a/src/matrix/compressed-matrix.cc +++ b/src/matrix/compressed-matrix.cc @@ -24,14 +24,14 @@ namespace kaldi { -//static +//static MatrixIndexT CompressedMatrix::DataSize(const GlobalHeader &header) { // Returns size in bytes of the data. if (header.format == 1) { return sizeof(GlobalHeader) + header.num_cols * (sizeof(PerColHeader) + header.num_rows); } else { - KALDI_ASSERT(header.format == 2) ; + KALDI_ASSERT(header.format == 2); return sizeof(GlobalHeader) + 2 * header.num_rows * header.num_cols; } @@ -40,7 +40,7 @@ MatrixIndexT CompressedMatrix::DataSize(const GlobalHeader &header) { template void CompressedMatrix::CopyFromMat( - const MatrixBase &mat) { + const MatrixBase &mat, int32 format) { if (data_ != NULL) { delete [] static_cast(data_); // call delete [] because was allocated with new float[] data_ = NULL; @@ -52,7 +52,7 @@ void CompressedMatrix::CopyFromMat( KALDI_COMPILE_TIME_ASSERT(sizeof(global_header) == 20); // otherwise // something weird is happening and our code probably won't work or // won't be robust across platforms. - + // Below, the point of the "safety_margin" is that the minimum // and maximum values in the matrix shouldn't coincide with // the minimum and maximum ranges of the 16-bit range, because @@ -80,16 +80,22 @@ void CompressedMatrix::CopyFromMat( global_header.num_rows = mat.NumRows(); global_header.num_cols = mat.NumCols(); - if (mat.NumRows() > 8) { - global_header.format = 1; // format where each row has a PerColHeader. + if (format <= 0) { + if (mat.NumRows() > 8) { + global_header.format = 1; // format where each row has a PerColHeader. + } else { + global_header.format = 2; // format where all data is uint16. + } + } else if (format == 1 || format == 2) { + global_header.format = format; } else { - global_header.format = 2; // format where all data is uint16. + KALDI_ERR << "Error format for compression:format should be <=2."; } - + int32 data_size = DataSize(global_header); data_ = AllocateData(data_size); - + *(reinterpret_cast(data_)) = global_header; if (global_header.format == 1) { @@ -124,10 +130,12 @@ void CompressedMatrix::CopyFromMat( // Instantiate the template for float and double. template -void CompressedMatrix::CopyFromMat(const MatrixBase &mat); +void CompressedMatrix::CopyFromMat(const MatrixBase &mat, + int32 format); template -void CompressedMatrix::CopyFromMat(const MatrixBase &mat); +void CompressedMatrix::CopyFromMat(const MatrixBase &mat, + int32 format); CompressedMatrix::CompressedMatrix( @@ -146,10 +154,10 @@ CompressedMatrix::CompressedMatrix( if (old_num_rows == 0) { return; } // Zero-size matrix stored as zero pointer. if (num_rows == 0 || num_cols == 0) { return; } - + GlobalHeader new_global_header; KALDI_COMPILE_TIME_ASSERT(sizeof(new_global_header) == 20); - + GlobalHeader *old_global_header = reinterpret_cast(cmat.Data()); new_global_header = *old_global_header; @@ -159,10 +167,10 @@ CompressedMatrix::CompressedMatrix( // We don't switch format from 1 -> 2 (in case of size reduction) yet; if this // is needed, we will do this below by creating a temporary Matrix. new_global_header.format = old_global_header->format; - + data_ = AllocateData(DataSize(new_global_header)); // allocate memory *(reinterpret_cast(data_)) = new_global_header; - + if (old_global_header->format == 1) { // Both have the format where we have a PerColHeader and then compress as // chars... @@ -196,7 +204,7 @@ CompressedMatrix::CompressedMatrix( reinterpret_cast(old_global_header + 1); uint16 *new_data = reinterpret_cast(reinterpret_cast(data_) + 1); - + old_data += col_offset + (old_num_cols * row_offset); for (int32 row = 0; row < num_rows; row++) { @@ -281,7 +289,7 @@ void CompressedMatrix::ComputeColHeader( // Now, sdata.begin(), sdata.begin() + quarter_nr, and sdata.begin() + // 3*quarter_nr, and sdata.end() - 1, contain the elements that would appear // at those positions in sorted order. - + header->percentile_0 = std::min(FloatToUint16(global_header, sdata[0]), 65532); header->percentile_25 = @@ -297,7 +305,7 @@ void CompressedMatrix::ComputeColHeader( header->percentile_100 = std::max( FloatToUint16(global_header, sdata[num_rows-1]), header->percentile_75 + static_cast(1)); - + } else { // handle this pathological case. std::sort(sdata.begin(), sdata.end()); // Note: we know num_rows is at least 1. @@ -382,7 +390,7 @@ void CompressedMatrix::CompressColumn( unsigned char *byte_data) { ComputeColHeader(global_header, data, stride, num_rows, header); - + float p0 = Uint16ToFloat(global_header, header->percentile_0), p25 = Uint16ToFloat(global_header, header->percentile_25), p75 = Uint16ToFloat(global_header, header->percentile_75), @@ -491,7 +499,7 @@ void CompressedMatrix::CopyToMat(MatrixBase *mat, mat->CopyFromMat(temp, kTrans); return; } - + if (data_ == NULL) { KALDI_ASSERT(mat->NumRows() == 0); KALDI_ASSERT(mat->NumCols() == 0); @@ -501,7 +509,7 @@ void CompressedMatrix::CopyToMat(MatrixBase *mat, int32 num_cols = h->num_cols, num_rows = h->num_rows; KALDI_ASSERT(mat->NumRows() == num_rows); KALDI_ASSERT(mat->NumCols() == num_cols); - + if (h->format == 1) { PerColHeader *per_col_header = reinterpret_cast(h+1); unsigned char *byte_data = reinterpret_cast(per_col_header + @@ -625,7 +633,7 @@ void CompressedMatrix::CopyToMat(int32 row_offset, GlobalHeader *h = reinterpret_cast(data_); int32 num_rows = h->num_rows, num_cols = h->num_cols, tgt_cols = dest->NumCols(), tgt_rows = dest->NumRows(); - + if (h->format == 1) { // format where we have a per-column header and use one byte per // element. diff --git a/src/matrix/compressed-matrix.h b/src/matrix/compressed-matrix.h index 603134ab800..a9dd1e4fcd2 100644 --- a/src/matrix/compressed-matrix.h +++ b/src/matrix/compressed-matrix.h @@ -35,12 +35,12 @@ namespace kaldi { /// column). /// The basic idea is for each column (in the normal configuration) -/// we work out the values at the 0th, 25th, 50th and 100th percentiles +/// we work out the values at the 0th, 25th, 75th and 100th percentiles /// and store them as 16-bit integers; we then encode each value in /// the column as a single byte, in 3 separate ranges with different -/// linear encodings (0-25th, 25-50th, 50th-100th). -/// If the matrix has 8 rows or fewer, we simply store all values as -/// uint16. +/// linear encodings (0-25th, 25-75th, 75th-100th). +/// If the matrix has 8 rows or fewer or format=2, we simply store all values +/// as uint16. class CompressedMatrix { public: @@ -49,7 +49,9 @@ class CompressedMatrix { ~CompressedMatrix() { Clear(); } template - CompressedMatrix(const MatrixBase &mat): data_(NULL) { CopyFromMat(mat); } + CompressedMatrix(const MatrixBase &mat, int32 format = 0): data_(NULL) { + CopyFromMat(mat, format); + } /// Initializer that can be used to select part of an existing /// CompressedMatrix without un-compressing and re-compressing (note: unlike @@ -65,7 +67,7 @@ class CompressedMatrix { /// This will resize *this and copy the contents of mat to *this. template - void CopyFromMat(const MatrixBase &mat); + void CopyFromMat(const MatrixBase &mat, int32 format = 0); CompressedMatrix(const CompressedMatrix &mat); diff --git a/src/matrix/kaldi-matrix.cc b/src/matrix/kaldi-matrix.cc index 34003e8a550..0b5191e1e7a 100644 --- a/src/matrix/kaldi-matrix.cc +++ b/src/matrix/kaldi-matrix.cc @@ -396,6 +396,87 @@ void MatrixBase::AddMat(const Real alpha, const MatrixBase& A, } } +template +void MatrixBase::LogAddExpMat(const Real alpha, const MatrixBase& A, + MatrixTransposeType transA) { + if (alpha == 0) return; + + if (&A == this) { + if (transA == kNoTrans) { + Add(alpha + 1.0); + } else { + KALDI_ASSERT(num_rows_ == num_cols_ && "AddMat: adding to self (transposed): not symmetric."); + Real *data = data_; + if (alpha == 1.0) { // common case-- handle separately. + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < row; col++) { + Real *lower = data + (row * stride_) + col, + *upper = data + (col * stride_) + row; + Real sum = LogAdd(*lower, *upper); + *lower = *upper = sum; + } + *(data + (row * stride_) + row) += Log(2.0); // diagonal. + } + } else { + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < row; col++) { + Real *lower = data + (row * stride_) + col, + *upper = data + (col * stride_) + row; + Real lower_tmp = *lower; + if (alpha > 0) { + *lower = LogAdd(*lower, Log(alpha) + *upper); + *upper = LogAdd(*upper, Log(alpha) + lower_tmp); + } else { + KALDI_ASSERT(alpha < 0); + *lower = LogSub(*lower, Log(-alpha) + *upper); + *upper = LogSub(*upper, Log(-alpha) + lower_tmp); + } + } + if (alpha > -1.0) + *(data + (row * stride_) + row) += Log(1.0 + alpha); // diagonal. + else + KALDI_ERR << "Cannot subtract log-matrices if the difference is " + << "negative"; + } + } + } + } else { + int aStride = (int) A.stride_; + Real *adata = A.data_, *data = data_; + if (transA == kNoTrans) { + KALDI_ASSERT(A.num_rows_ == num_rows_ && A.num_cols_ == num_cols_); + if (num_rows_ == 0) return; + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < num_cols_; col++) { + Real *value = data + (row * stride_) + col, + *aValue = adata + (row * aStride) + col; + if (alpha > 0) + *value = LogAdd(*value, Log(alpha) + *aValue); + else { + KALDI_ASSERT(alpha < 0); + *value = LogSub(*value, Log(-alpha) + *aValue); + } + } + } + } else { + KALDI_ASSERT(A.num_cols_ == num_rows_ && A.num_rows_ == num_cols_); + if (num_rows_ == 0) return; + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < num_cols_; col++) { + Real *value = data + (row * stride_) + col, + *aValue = adata + (col * aStride) + row; + if (alpha > 0) + *value = LogAdd(*value, Log(alpha) + *aValue); + else { + KALDI_ASSERT(alpha < 0); + *value = LogSub(*value, Log(-alpha) + *aValue); + } + } + } + } + } +} + template template void MatrixBase::AddSp(const Real alpha, const SpMatrix &S) { @@ -2533,6 +2614,15 @@ Real MatrixBase::ApplySoftMax() { return max + Log(sum); } +template +void MatrixBase::ApplySoftMaxPerRow() { + for (MatrixIndexT i = 0; i < num_rows_; i++) { + Row(i).ApplySoftMax(); + kaldi::ApproxEqual(Row(i).Sum(), 1.0); + } + KALDI_ASSERT(Max() <= 1.0 && Min() >= 0.0); +} + template void MatrixBase::Tanh(const MatrixBase &src) { KALDI_ASSERT(SameDim(*this, src)); diff --git a/src/matrix/kaldi-matrix.h b/src/matrix/kaldi-matrix.h index e254fcad118..b5a6bc7521d 100644 --- a/src/matrix/kaldi-matrix.h +++ b/src/matrix/kaldi-matrix.h @@ -453,6 +453,11 @@ class MatrixBase { /// Apply soft-max to the collection of all elements of the /// matrix and return normalizer (log sum of exponentials). Real ApplySoftMax(); + + /// Softmax nonlinearity + /// Y = Softmax(X) : Yij = e^Xij / sum_k(e^Xik), done to each row + /// for each row, the max value is first subtracted for good numerical stability + void ApplySoftMaxPerRow(); /// Set each element to the sigmoid of the corresponding element of "src". void Sigmoid(const MatrixBase &src); @@ -543,6 +548,10 @@ class MatrixBase { /// *this += alpha * M [or M^T] void AddMat(const Real alpha, const MatrixBase &M, MatrixTransposeType transA = kNoTrans); + + /// *this += alpha * M [or M^T] when the matrices are stored as log + void LogAddExpMat(const Real alpha, const MatrixBase &M, + MatrixTransposeType transA = kNoTrans); /// *this = beta * *this + alpha * M M^T, for symmetric matrices. It only /// updates the lower triangle of *this. It will leave the matrix asymmetric; diff --git a/src/matrix/sparse-matrix.cc b/src/matrix/sparse-matrix.cc index 477d36f190a..777819ed677 100644 --- a/src/matrix/sparse-matrix.cc +++ b/src/matrix/sparse-matrix.cc @@ -705,15 +705,16 @@ MatrixIndexT GeneralMatrix::NumCols() const { } -void GeneralMatrix::Compress() { +void GeneralMatrix::Compress(int32 format) { if (mat_.NumRows() != 0) { - cmat_.CopyFromMat(mat_); + cmat_.CopyFromMat(mat_, format); mat_.Resize(0, 0); } } void GeneralMatrix::Uncompress() { if (cmat_.NumRows() != 0) { + mat_.Resize(cmat_.NumRows(), cmat_.NumCols(), kUndefined); cmat_.CopyToMat(&mat_); cmat_.Clear(); } diff --git a/src/matrix/sparse-matrix.h b/src/matrix/sparse-matrix.h index 9f9362542e1..88619da3034 100644 --- a/src/matrix/sparse-matrix.h +++ b/src/matrix/sparse-matrix.h @@ -228,8 +228,10 @@ class GeneralMatrix { public: GeneralMatrixType Type() const; - void Compress(); // If it was a full matrix, compresses, changing Type() to - // kCompressedMatrix; otherwise does nothing. + /// If it was a full matrix, compresses, changing Type() to + /// kCompressedMatrix; otherwise does nothing. + /// format shows the compression format. + void Compress(int32 format = 0); void Uncompress(); // If it was a compressed matrix, uncompresses, changing // Type() to kFullMatrix; otherwise does nothing. diff --git a/src/nnet3/nnet-combine.cc b/src/nnet3/nnet-combine.cc index 45c1f74477b..d40c63bd3e7 100644 --- a/src/nnet3/nnet-combine.cc +++ b/src/nnet3/nnet-combine.cc @@ -424,15 +424,28 @@ double NnetCombiner::ComputeObjfAndDerivFromNnet( end = egs_.end(); for (; iter != end; ++iter) prob_computer_->Compute(*iter); - const SimpleObjectiveInfo *objf_info = prob_computer_->GetObjective("output"); - if (objf_info == NULL) - KALDI_ERR << "Error getting objective info (unsuitable egs?)"; - KALDI_ASSERT(objf_info->tot_weight > 0.0); + + double tot_weight = 0.0; + double tot_objf = 0.0; + + { + const unordered_map &objf_info = prob_computer_->GetAllObjectiveInfo(); + unordered_map::const_iterator objf_it = objf_info.begin(), + objf_end = objf_info.end(); + + for (; objf_it != objf_end; ++objf_it) { + tot_objf += objf_it->second.tot_objective; + tot_weight += objf_it->second.tot_weight; + } + } + + KALDI_ASSERT(tot_weight > 0.0); + const Nnet &deriv = prob_computer_->GetDeriv(); VectorizeNnet(deriv, nnet_params_deriv); // we prefer to deal with normalized objective functions. - nnet_params_deriv->Scale(1.0 / objf_info->tot_weight); - return objf_info->tot_objective / objf_info->tot_weight; + nnet_params_deriv->Scale(1.0 / tot_weight); + return tot_objf / tot_weight; } diff --git a/src/nnet3/nnet-component-itf.cc b/src/nnet3/nnet-component-itf.cc index 00dd802e091..389b9876b3c 100644 --- a/src/nnet3/nnet-component-itf.cc +++ b/src/nnet3/nnet-component-itf.cc @@ -89,6 +89,10 @@ Component* Component::NewComponentOfType(const std::string &component_type) { ans = new SoftmaxComponent(); } else if (component_type == "LogSoftmaxComponent") { ans = new LogSoftmaxComponent(); + } else if (component_type == "LogComponent") { + ans = new LogComponent(); + } else if (component_type == "ExpComponent") { + ans = new ExpComponent(); } else if (component_type == "RectifiedLinearComponent") { ans = new RectifiedLinearComponent(); } else if (component_type == "NormalizeComponent") { @@ -119,6 +123,8 @@ Component* Component::NewComponentOfType(const std::string &component_type) { ans = new NoOpComponent(); } else if (component_type == "ClipGradientComponent") { ans = new ClipGradientComponent(); + } else if (component_type == "ScaleGradientComponent") { + ans = new ScaleGradientComponent(); } else if (component_type == "ElementwiseProductComponent") { ans = new ElementwiseProductComponent(); } else if (component_type == "ConvolutionComponent") { @@ -310,11 +316,14 @@ std::string NonlinearComponent::Info() const { std::stringstream stream; if (InputDim() == OutputDim()) { stream << Type() << ", dim=" << InputDim(); - } else { + } else if (OutputDim() - InputDim() == 1) { // Note: this is a very special case tailored for class NormalizeComponent. stream << Type() << ", input-dim=" << InputDim() << ", output-dim=" << OutputDim() << ", add-log-stddev=true"; + } else { + stream << Type() << ", input-dim=" << InputDim() + << ", output-dim=" << OutputDim(); } if (self_repair_lower_threshold_ != BaseFloat(kUnsetThreshold)) @@ -323,7 +332,7 @@ std::string NonlinearComponent::Info() const { stream << ", self-repair-upper-threshold=" << self_repair_upper_threshold_; if (self_repair_scale_ != 0.0) stream << ", self-repair-scale=" << self_repair_scale_; - if (count_ > 0 && value_sum_.Dim() == dim_ && deriv_sum_.Dim() == dim_) { + if (count_ > 0 && value_sum_.Dim() == dim_) { stream << ", count=" << std::setprecision(3) << count_ << std::setprecision(6); stream << ", self-repaired-proportion=" @@ -333,10 +342,12 @@ std::string NonlinearComponent::Info() const { Vector value_avg(value_avg_dbl); value_avg.Scale(1.0 / count_); stream << ", value-avg=" << SummarizeVector(value_avg); - Vector deriv_avg_dbl(deriv_sum_); - Vector deriv_avg(deriv_avg_dbl); - deriv_avg.Scale(1.0 / count_); - stream << ", deriv-avg=" << SummarizeVector(deriv_avg); + if (deriv_sum_.Dim() == dim_) { + Vector deriv_avg_dbl(deriv_sum_); + Vector deriv_avg(deriv_avg_dbl); + deriv_avg.Scale(1.0 / count_); + stream << ", deriv-avg=" << SummarizeVector(deriv_avg); + } } return stream.str(); } diff --git a/src/nnet3/nnet-component-itf.h b/src/nnet3/nnet-component-itf.h index e5974b46f46..3013c485ea4 100644 --- a/src/nnet3/nnet-component-itf.h +++ b/src/nnet3/nnet-component-itf.h @@ -403,6 +403,11 @@ class UpdatableComponent: public Component { /// Sets the learning rate directly, bypassing learning_rate_factor_. virtual void SetActualLearningRate(BaseFloat lrate) { learning_rate_ = lrate; } + /// Sets the learning rate factor + virtual void SetLearningRateFactor(BaseFloat lrate_factor) { + learning_rate_factor_ = lrate_factor; + } + /// Gets the learning rate of gradient descent. Note: if you call /// SetLearningRate(x), and learning_rate_factor_ != 1.0, /// a different value than x will returned. @@ -413,6 +418,9 @@ class UpdatableComponent: public Component { /// NnetTrainer by querying the max-changes for each component. /// See NnetTrainer::UpdateParamsWithMaxChange() in nnet3/nnet-training.cc. BaseFloat MaxChange() const { return max_change_; } + + /// Gets the learning rate factor + BaseFloat LearningRateFactor() const { return learning_rate_factor_; } virtual std::string Info() const; diff --git a/src/nnet3/nnet-component-test.cc b/src/nnet3/nnet-component-test.cc index 3cc6af1c70d..a2e5e23436c 100644 --- a/src/nnet3/nnet-component-test.cc +++ b/src/nnet3/nnet-component-test.cc @@ -379,6 +379,11 @@ bool TestSimpleComponentDataDerivative(const Component &c, KALDI_LOG << "Accepting deriv differences since " << "it is ClipGradientComponent."; return true; + } + else if (c.Type() == "ScaleGradientComponent") { + KALDI_LOG << "Accepting deriv differences since " + << "it is ScaleGradientComponent."; + return true; } return ans; } diff --git a/src/nnet3/nnet-diagnostics.cc b/src/nnet3/nnet-diagnostics.cc index 7f7d485ffe0..64abe8a0578 100644 --- a/src/nnet3/nnet-diagnostics.cc +++ b/src/nnet3/nnet-diagnostics.cc @@ -92,20 +92,24 @@ void NnetComputeProb::ProcessOutputs(const NnetExample &eg, << "mismatch for '" << io.name << "': " << output.NumCols() << " (nnet) vs. " << io.features.NumCols() << " (egs)\n"; } + + const Vector *deriv_weights = NULL; + if (config_.apply_deriv_weights && io.deriv_weights.Dim() > 0) + deriv_weights = &(io.deriv_weights); { BaseFloat tot_weight, tot_objf; bool supply_deriv = config_.compute_deriv; ComputeObjectiveFunction(io.features, obj_type, io.name, supply_deriv, computer, - &tot_weight, &tot_objf); + &tot_weight, &tot_objf, deriv_weights); SimpleObjectiveInfo &totals = objf_info_[io.name]; totals.tot_weight += tot_weight; totals.tot_objective += tot_objf; } - if (obj_type == kLinear && config_.compute_accuracy) { + if (config_.compute_accuracy) { BaseFloat tot_weight, tot_accuracy; ComputeAccuracy(io.features, output, - &tot_weight, &tot_accuracy); + &tot_weight, &tot_accuracy, deriv_weights); SimpleObjectiveInfo &totals = accuracy_info_[io.name]; totals.tot_weight += tot_weight; totals.tot_objective += tot_accuracy; @@ -156,7 +160,8 @@ bool NnetComputeProb::PrintTotalStats() const { void ComputeAccuracy(const GeneralMatrix &supervision, const CuMatrixBase &nnet_output, BaseFloat *tot_weight_out, - BaseFloat *tot_accuracy_out) { + BaseFloat *tot_accuracy_out, + const Vector *deriv_weights) { int32 num_rows = nnet_output.NumRows(), num_cols = nnet_output.NumCols(); KALDI_ASSERT(supervision.NumRows() == num_rows && @@ -181,24 +186,27 @@ void ComputeAccuracy(const GeneralMatrix &supervision, for (int32 r = 0; r < num_rows; r++) { SubVector vec(mat, r); BaseFloat row_sum = vec.Sum(); - KALDI_ASSERT(row_sum >= 0.0); + // KALDI_ASSERT(row_sum >= 0.0); // For conventional ASR systems int32 best_index; vec.Max(&best_index); // discard max value. + if (deriv_weights) + row_sum *= (*deriv_weights)(r); tot_weight += row_sum; if (best_index == best_index_cpu[r]) tot_accuracy += row_sum; } break; - } case kFullMatrix: { const Matrix &mat = supervision.GetFullMatrix(); for (int32 r = 0; r < num_rows; r++) { SubVector vec(mat, r); BaseFloat row_sum = vec.Sum(); - KALDI_ASSERT(row_sum >= 0.0); + // KALDI_ASSERT(row_sum >= 0.0); // For conventional ASR systems int32 best_index; vec.Max(&best_index); // discard max value. + if (deriv_weights) + row_sum *= (*deriv_weights)(r); tot_weight += row_sum; if (best_index == best_index_cpu[r]) tot_accuracy += row_sum; @@ -212,6 +220,8 @@ void ComputeAccuracy(const GeneralMatrix &supervision, BaseFloat row_sum = row.Sum(); int32 best_index; row.Max(&best_index); + if (deriv_weights) + row_sum *= (*deriv_weights)(r); KALDI_ASSERT(best_index < num_cols); tot_weight += row_sum; if (best_index == best_index_cpu[r]) diff --git a/src/nnet3/nnet-diagnostics.h b/src/nnet3/nnet-diagnostics.h index 298548857dd..59f0cd16f47 100644 --- a/src/nnet3/nnet-diagnostics.h +++ b/src/nnet3/nnet-diagnostics.h @@ -36,7 +36,6 @@ struct SimpleObjectiveInfo { double tot_objective; SimpleObjectiveInfo(): tot_weight(0.0), tot_objective(0.0) { } - }; @@ -44,12 +43,15 @@ struct NnetComputeProbOptions { bool debug_computation; bool compute_deriv; bool compute_accuracy; + bool apply_deriv_weights; + NnetOptimizeOptions optimize_config; NnetComputeOptions compute_config; NnetComputeProbOptions(): debug_computation(false), compute_deriv(false), - compute_accuracy(true) { } + compute_accuracy(true), + apply_deriv_weights(true) { } void Register(OptionsItf *opts) { // compute_deriv is not included in the command line options // because it's not relevant for nnet3-compute-prob. @@ -57,6 +59,9 @@ struct NnetComputeProbOptions { "debug for the actual computation (very verbose!)"); opts->Register("compute-accuracy", &compute_accuracy, "If true, compute " "accuracy values as well as objective functions"); + opts->Register("apply-deriv-weights", &apply_deriv_weights, + "Apply per-frame deriv weights"); + // register the optimization options with the prefix "optimization". ParseOptions optimization_opts("optimization", opts); optimize_config.Register(&optimization_opts); @@ -97,11 +102,17 @@ class NnetComputeProb { // or NULL if there is no such info. const SimpleObjectiveInfo *GetObjective(const std::string &output_name) const; + // return objective info for all outputs + const unordered_map & GetAllObjectiveInfo() const { + return objf_info_; + } + // if config.compute_deriv == true, returns a reference to the // computed derivative. Otherwise crashes. const Nnet &GetDeriv() const; ~NnetComputeProb(); + private: void ProcessOutputs(const NnetExample &eg, NnetComputer *computer); @@ -152,7 +163,8 @@ class NnetComputeProb { void ComputeAccuracy(const GeneralMatrix &supervision, const CuMatrixBase &nnet_output, BaseFloat *tot_weight, - BaseFloat *tot_accuracy); + BaseFloat *tot_accuracy, + const Vector *deriv_weights = NULL); } // namespace nnet3 diff --git a/src/nnet3/nnet-example-utils.cc b/src/nnet3/nnet-example-utils.cc index 30f7840f6f8..548fb842385 100644 --- a/src/nnet3/nnet-example-utils.cc +++ b/src/nnet3/nnet-example-utils.cc @@ -63,9 +63,9 @@ static void GetIoSizes(const std::vector &src, KALDI_ASSERT(*names_iter == io.name); int32 i = names_iter - names_begin; int32 this_dim = io.features.NumCols(); - if (dims[i] == -1) + if (dims[i] == -1) { dims[i] = this_dim; - else if(dims[i] != this_dim) { + } else if (dims[i] != this_dim) { KALDI_ERR << "Merging examples with inconsistent feature dims: " << dims[i] << " vs. " << this_dim << " for '" << io.name << "'."; @@ -87,9 +87,20 @@ static void MergeIo(const std::vector &src, const std::vector &sizes, bool compress, NnetExample *merged_eg) { + // The total number of Indexes we have across all examples. int32 num_feats = names.size(); + std::vector cur_size(num_feats, 0); + + // The features in the different NnetIo in the Indexes across all examples std::vector > output_lists(num_feats); + + // The deriv weights in the different NnetIo in the Indexes across all + // examples + std::vector const*> > + output_deriv_weights(num_feats); + + // Initialize the merged_eg merged_eg->io.clear(); merged_eg->io.resize(num_feats); for (int32 f = 0; f < num_feats; f++) { @@ -102,20 +113,27 @@ static void MergeIo(const std::vector &src, std::vector::const_iterator names_begin = names.begin(), names_end = names.end(); - std::vector::const_iterator iter = src.begin(), end = src.end(); - for (int32 n = 0; iter != end; ++iter,++n) { - std::vector::const_iterator iter2 = iter->io.begin(), - end2 = iter->io.end(); - for (; iter2 != end2; ++iter2) { - const NnetIo &io = *iter2; + std::vector::const_iterator eg_iter = src.begin(), + eg_end = src.end(); + for (int32 n = 0; eg_iter != eg_end; ++eg_iter, ++n) { + std::vector::const_iterator io_iter = eg_iter->io.begin(), + io_end = eg_iter->io.end(); + for (; io_iter != io_end; ++io_iter) { + const NnetIo &io = *io_iter; std::vector::const_iterator names_iter = std::lower_bound(names_begin, names_end, io.name); KALDI_ASSERT(*names_iter == io.name); + int32 f = names_iter - names_begin; - int32 this_size = io.indexes.size(), - &this_offset = cur_size[f]; + int32 this_size = io.indexes.size(); + int32 &this_offset = cur_size[f]; KALDI_ASSERT(this_size + this_offset <= sizes[f]); + + // Add f^th Io's features and deriv_weights output_lists[f].push_back(&(io.features)); + output_deriv_weights[f].push_back(&(io.deriv_weights)); + + // Work on the Indexes for the f^th Io in merged_eg NnetIo &output_io = merged_eg->io[f]; std::copy(io.indexes.begin(), io.indexes.end(), output_io.indexes.begin() + this_offset); @@ -139,10 +157,26 @@ static void MergeIo(const std::vector &src, // the following won't do anything if the features were sparse. merged_eg->io[f].features.Compress(); } - } -} + Vector &this_deriv_weights = merged_eg->io[f].deriv_weights; + if (output_deriv_weights[f][0]->Dim() > 0) { + this_deriv_weights.Resize( + merged_eg->io[f].indexes.size(), kUndefined); + KALDI_ASSERT(this_deriv_weights.Dim() == + merged_eg->io[f].features.NumRows()); + std::vector const*>::const_iterator + it = output_deriv_weights[f].begin(), + end = output_deriv_weights[f].end(); + + for (int32 i = 0, cur_offset = 0; it != end; ++it, i++) { + KALDI_ASSERT((*it)->Dim() == output_lists[f][i]->NumRows()); + this_deriv_weights.Range(cur_offset, (*it)->Dim()).CopyFromVec(**it); + cur_offset += (*it)->Dim(); + } + } + } +} void MergeExamples(const std::vector &src, bool compress, @@ -282,9 +316,15 @@ void RoundUpNumFrames(int32 frame_subsampling_factor, KALDI_ERR << "--num-frames-overlap=" << (*num_frames_overlap) << " < " << "--num-frames=" << (*num_frames); } - } +int32 NumOutputs(const NnetExample &eg) { + int32 num_outputs = 0; + for (size_t i = 0; i < eg.io.size(); i++) + if (eg.io[i].name.find("output") != std::string::npos) + num_outputs++; + return num_outputs; +} -} // namespace nnet3 -} // namespace kaldi +} // namespace nnet3 +} // namespace kaldi diff --git a/src/nnet3/nnet-example-utils.h b/src/nnet3/nnet-example-utils.h index 3e309e18915..d223c5eb5d1 100644 --- a/src/nnet3/nnet-example-utils.h +++ b/src/nnet3/nnet-example-utils.h @@ -80,6 +80,8 @@ void RoundUpNumFrames(int32 frame_subsampling_factor, int32 *num_frames, int32 *num_frames_overlap); +// Returns the number of outputs in an eg +int32 NumOutputs(const NnetExample &eg); } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-example.cc b/src/nnet3/nnet-example.cc index 9a34258e0ee..2ad90c0f11d 100644 --- a/src/nnet3/nnet-example.cc +++ b/src/nnet3/nnet-example.cc @@ -19,6 +19,7 @@ // limitations under the License. #include "nnet3/nnet-example.h" +#include "nnet3/nnet-example-utils.h" #include "lat/lattice-functions.h" #include "hmm/posterior.h" @@ -31,6 +32,8 @@ void NnetIo::Write(std::ostream &os, bool binary) const { WriteToken(os, binary, name); WriteIndexVector(os, binary, indexes); features.Write(os, binary); + WriteToken(os, binary, ""); // for DerivWeights. Want to save space. + WriteVectorAsChar(os, binary, deriv_weights); WriteToken(os, binary, ""); KALDI_ASSERT(static_cast(features.NumRows()) == indexes.size()); } @@ -40,7 +43,14 @@ void NnetIo::Read(std::istream &is, bool binary) { ReadToken(is, binary, &name); ReadIndexVector(is, binary, &indexes); features.Read(is, binary); - ExpectToken(is, binary, ""); + std::string token; + ReadToken(is, binary, &token); + // in the future this back-compatibility code can be reworked. + if (token != "") { + KALDI_ASSERT(token == ""); + ReadVectorAsChar(is, binary, &deriv_weights); + ExpectToken(is, binary, ""); + } } bool NnetIo::operator == (const NnetIo &other) const { @@ -52,42 +62,75 @@ bool NnetIo::operator == (const NnetIo &other) const { Matrix this_mat, other_mat; features.GetMatrix(&this_mat); other.features.GetMatrix(&other_mat); - return ApproxEqual(this_mat, other_mat); + return (ApproxEqual(this_mat, other_mat) && + deriv_weights.ApproxEqual(other.deriv_weights)); } NnetIo::NnetIo(const std::string &name, - int32 t_begin, const MatrixBase &feats): + int32 t_begin, const MatrixBase &feats, + int32 skip_frame): name(name), features(feats) { - int32 num_rows = feats.NumRows(); - KALDI_ASSERT(num_rows > 0); - indexes.resize(num_rows); // sets all n,t,x to zeros. - for (int32 i = 0; i < num_rows; i++) - indexes[i].t = t_begin + i; + int32 num_skipped_rows = feats.NumRows(); + KALDI_ASSERT(num_skipped_rows > 0); + indexes.resize(num_skipped_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_skipped_rows; i++) + indexes[i].t = t_begin + i * skip_frame; +} + +NnetIo::NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 t_begin, const MatrixBase &feats, + int32 skip_frame): + name(name), features(feats), deriv_weights(deriv_weights) { + int32 num_skipped_rows = feats.NumRows(); + KALDI_ASSERT(num_skipped_rows > 0); + indexes.resize(num_skipped_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_skipped_rows; i++) + indexes[i].t = t_begin + i * skip_frame; } void NnetIo::Swap(NnetIo *other) { name.swap(other->name); indexes.swap(other->indexes); features.Swap(&(other->features)); + deriv_weights.Swap(&(other->deriv_weights)); } NnetIo::NnetIo(const std::string &name, int32 dim, int32 t_begin, - const Posterior &labels): + const Posterior &labels, + int32 skip_frame): name(name) { - int32 num_rows = labels.size(); - KALDI_ASSERT(num_rows > 0); + int32 num_skipped_rows = labels.size(); + KALDI_ASSERT(num_skipped_rows > 0); SparseMatrix sparse_feats(dim, labels); features = sparse_feats; - indexes.resize(num_rows); // sets all n,t,x to zeros. - for (int32 i = 0; i < num_rows; i++) - indexes[i].t = t_begin + i; + indexes.resize(num_skipped_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_skipped_rows; i++) + indexes[i].t = t_begin + i * skip_frame; } - +NnetIo::NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 dim, + int32 t_begin, + const Posterior &labels, + int32 skip_frame): + name(name), deriv_weights(deriv_weights) { + int32 num_skipped_rows = labels.size(); + KALDI_ASSERT(num_skipped_rows > 0); + SparseMatrix sparse_feats(dim, labels); + features = sparse_feats; + indexes.resize(num_skipped_rows); // sets all n,t,x to zeros. + for (int32 i = 0; i < num_skipped_rows; i++) + indexes[i].t = t_begin + i * skip_frame; +} void NnetExample::Write(std::ostream &os, bool binary) const { +#ifdef KALDI_PARANOID + KALDI_ASSERT(NumOutputs(eg) > 0); +#endif // Note: weight, label, input_frames and spk_info are members. This is a // struct. WriteToken(os, binary, ""); @@ -114,12 +157,12 @@ void NnetExample::Read(std::istream &is, bool binary) { } -void NnetExample::Compress() { +void NnetExample::Compress(int32 format) { std::vector::iterator iter = io.begin(), end = io.end(); // calling features.Compress() will do nothing if they are sparse or already // compressed. for (; iter != end; ++iter) - iter->features.Compress(); + iter->features.Compress(format); } } // namespace nnet3 diff --git a/src/nnet3/nnet-example.h b/src/nnet3/nnet-example.h index 1df7cd1e78e..f097369443a 100644 --- a/src/nnet3/nnet-example.h +++ b/src/nnet3/nnet-example.h @@ -45,12 +45,32 @@ struct NnetIo { /// a Matrix, or SparseMatrix (a SparseMatrix would be the natural format for posteriors). GeneralMatrix features; + /// This is a vector of per-frame weights, required to be between 0 and 1, + /// that is applied to the derivative during training (but not during model + /// combination, where the derivatives need to agree with the computed objf + /// values for the optimization code to work). + /// If this vector is empty it means we're not applying per-frame weights, + /// so it's equivalent to a vector of all ones. This vector is written + /// to disk compactly as unsigned char. + Vector deriv_weights; + /// This constructor creates NnetIo with name "name", indexes with n=0, x=0, /// and t values ranging from t_begin to t_begin + feats.NumRows() - 1, and /// the provided features. t_begin should be the frame that feats.Row(0) /// represents. NnetIo(const std::string &name, - int32 t_begin, const MatrixBase &feats); + int32 t_begin, + const MatrixBase &feats, + int32 skip_frame = 1); + + /// This is similar to the above constructor but also takes in a + /// a deriv weights argument. + NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 t_begin, + const MatrixBase &feats, + int32 skip_frame = 1); + /// This constructor sets "name" to the provided string, sets "indexes" with /// n=0, x=0, and t from t_begin to t_begin + labels.size() - 1, and the labels @@ -58,12 +78,30 @@ struct NnetIo { NnetIo(const std::string &name, int32 dim, int32 t_begin, - const Posterior &labels); + const Posterior &labels, + int32 skip_frame = 1); + + /// This is similar to the above constructor but also takes in a + /// a deriv weights argument. + NnetIo(const std::string &name, + const VectorBase &deriv_weights, + int32 dim, + int32 t_begin, + const Posterior &labels, + int32 skip_frame = 1); void Swap(NnetIo *other); NnetIo() { } + // Compress the features in this NnetIo structure with specified format. + // the "format" will be 1 for the original format where each column has a + // PerColHeader, and 2 for the format, where everything is represented as + // 16-bit integers. + // If format <= 0, then format 1 will be used, unless the matrix has 8 or + // fewer rows (in which case format 2 will be used). + void Compress(int32 format = 0) { features.Compress(format); } + // Use default copy constructor and assignment operators. void Write(std::ostream &os, bool binary) const; @@ -80,7 +118,6 @@ struct NnetIo { /// more frames of input, used for standard cross-entropy training of neural /// nets (and possibly for other objective functions). struct NnetExample { - /// "io" contains the input and output. In principle there can be multiple /// types of both input and output, with different names. The order is /// irrelevant. @@ -95,8 +132,13 @@ struct NnetExample { void Swap(NnetExample *other) { io.swap(other->io); } - /// Compresses any (input) features that are not sparse. - void Compress(); + // Compresses any features that are not sparse and not compressed. + // The "format" is 1 for the original format where each column has a + // PerColHeader, and 2 for the format, where everything is represented as + // 16-bit integers. + // If format <= 0, then format 1 will be used, unless the matrix has 8 or + // fewer rows (in which case format 2 will be used). + void Compress(int32 format = 0); /// Caution: this operator == is not very efficient. It's only used in /// testing code. diff --git a/src/nnet3/nnet-nnet.cc b/src/nnet3/nnet-nnet.cc index ad5f715a294..4fcbbc70a1f 100644 --- a/src/nnet3/nnet-nnet.cc +++ b/src/nnet3/nnet-nnet.cc @@ -84,8 +84,14 @@ std::string Nnet::GetAsConfigLine(int32 node_index, bool include_dim) const { node.descriptor.WriteConfig(ans, node_names_); if (include_dim) ans << " dim=" << node.Dim(*this); - ans << " objective=" << (node.u.objective_type == kLinear ? "linear" : - "quadratic"); + + if (node.u.objective_type == kLinear) + ans << " objective=linear"; + else if (node.u.objective_type == kQuadratic) + ans << " objective=quadratic"; + else if (node.u.objective_type == kXentPerDim) + ans << " objective=xent-per-dim"; + break; case kComponent: ans << "component-node name=" << name << " component=" @@ -390,6 +396,8 @@ void Nnet::ProcessOutputNodeConfigLine( nodes_[node_index].u.objective_type = kLinear; } else if (objective_type == "quadratic") { nodes_[node_index].u.objective_type = kQuadratic; + } else if (objective_type == "xent-per-dim") { + nodes_[node_index].u.objective_type = kXentPerDim; } else { KALDI_ERR << "Invalid objective type: " << objective_type; } diff --git a/src/nnet3/nnet-nnet.h b/src/nnet3/nnet-nnet.h index 16e8333d5b1..b9ed3c1052b 100644 --- a/src/nnet3/nnet-nnet.h +++ b/src/nnet3/nnet-nnet.h @@ -49,7 +49,12 @@ namespace nnet3 { /// - Objective type kQuadratic is used to mean the objective function /// f(x, y) = -0.5 (x-y).(x-y), which is to be maximized, as in the kLinear /// case. -enum ObjectiveType { kLinear, kQuadratic }; +/// - Objective type kXentPerDim is the objective function that is used +/// to learn a set of bernoulli random variables. +/// f(x, y) = x * y + (1-x) * Log(1-Exp(y)), where +/// x is the true probability of class 1 and +/// y is the predicted log probability of class 1 +enum ObjectiveType { kLinear, kQuadratic, kXentPerDim }; enum NodeType { kInput, kDescriptor, kComponent, kDimRange, kNone }; @@ -249,7 +254,7 @@ class Nnet { void ResetGenerators(); // resets random-number generators for all // random components. You must also set srand() for this to be // effective. - + private: void Destroy(); diff --git a/src/nnet3/nnet-simple-component.cc b/src/nnet3/nnet-simple-component.cc index 58908a0fe09..fcfd4b9affa 100644 --- a/src/nnet3/nnet-simple-component.cc +++ b/src/nnet3/nnet-simple-component.cc @@ -922,6 +922,87 @@ void ClipGradientComponent::Add(BaseFloat alpha, const Component &other_in) { num_clipped_ += alpha * other->num_clipped_; } + +void ScaleGradientComponent::Init(const CuVectorBase &scales) { + KALDI_ASSERT(scales.Dim() != 0); + scales_ = scales; +} + + +void ScaleGradientComponent::InitFromConfig(ConfigLine *cfl) { + std::string filename; + // Accepts "scales" config (for filename) or "dim" -> random init, for testing. + if (cfl->GetValue("scales", &filename)) { + if (cfl->HasUnusedValues()) + KALDI_ERR << "Invalid initializer for layer of type " + << Type() << ": \"" << cfl->WholeLine() << "\""; + CuVector vec; + ReadKaldiObject(filename, &vec); + Init(vec); + } else { + int32 dim; + BaseFloat scale = 1.0; + bool scale_ok = cfl->GetValue("scale", &scale); + if (!cfl->GetValue("dim", &dim) || cfl->HasUnusedValues()) + KALDI_ERR << "Invalid initializer for layer of type " + << Type() << ": \"" << cfl->WholeLine() << "\""; + KALDI_ASSERT(dim > 0); + CuVector vec(dim); + if (scale_ok) { + vec.Set(scale); + } else { + vec.SetRandn(); + } + Init(vec); + } +} + + +std::string ScaleGradientComponent::Info() const { + std::ostringstream stream; + stream << Component::Info(); + PrintParameterStats(stream, "scales", scales_, true); + return stream.str(); +} + +void ScaleGradientComponent::Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const { + out->CopyFromMat(in); // does nothing if same matrix. +} + +void ScaleGradientComponent::Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &, // in_value + const CuMatrixBase &, // out_value + const CuMatrixBase &out_deriv, + Component *, // to_update + CuMatrixBase *in_deriv) const { + in_deriv->CopyFromMat(out_deriv); // does nothing if same memory. + in_deriv->MulColsVec(scales_); +} + +Component* ScaleGradientComponent::Copy() const { + ScaleGradientComponent *ans = new ScaleGradientComponent(); + ans->scales_ = scales_; + return ans; +} + + +void ScaleGradientComponent::Write(std::ostream &os, bool binary) const { + WriteToken(os, binary, ""); + WriteToken(os, binary, ""); + scales_.Write(os, binary); + WriteToken(os, binary, ""); +} + +void ScaleGradientComponent::Read(std::istream &is, bool binary) { + ExpectOneOrTwoTokens(is, binary, "", ""); + scales_.Read(is, binary); + ExpectToken(is, binary, ""); +} + + void TanhComponent::Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase &in, CuMatrixBase *out) const { @@ -2517,6 +2598,26 @@ void ConstantFunctionComponent::UnVectorize(const VectorBase ¶ms) output_.CopyFromVec(params); } +void ExpComponent::Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const { + // Applied exp function + out->CopyFromMat(in); + out->ApplyExp(); +} + +void ExpComponent::Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &,//in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + Component *to_update, + CuMatrixBase *in_deriv) const { + if (in_deriv != NULL) { + in_deriv->CopyFromMat(out_value); + in_deriv->MulElements(out_deriv); + } +} NaturalGradientAffineComponent::NaturalGradientAffineComponent(): max_change_per_sample_(0.0), @@ -2568,10 +2669,15 @@ void NaturalGradientAffineComponent::Read(std::istream &is, bool binary) { ReadBasicType(is, binary, &max_change_scale_stats_); ReadToken(is, binary, &token); } - if (token != "" && - token != "") - KALDI_ERR << "Expected or " - << ", got " << token; + + std::ostringstream ostr_beg, ostr_end; + ostr_beg << "<" << Type() << ">"; // e.g. "" + ostr_end << ""; // e.g. "" + + if (token != ostr_end.str() && + token != ostr_beg.str()) + KALDI_ERR << "Expected " << ostr_beg.str() << " or " + << ostr_end.str() << ", got " << token; SetNaturalGradientConfigs(); } @@ -2720,7 +2826,10 @@ void NaturalGradientAffineComponent::Write(std::ostream &os, WriteBasicType(os, binary, active_scaling_count_); WriteToken(os, binary, ""); WriteBasicType(os, binary, max_change_scale_stats_); - WriteToken(os, binary, ""); + + std::ostringstream ostr_end; + ostr_end << ""; // e.g. "" + WriteToken(os, binary, ostr_end.str()); } std::string NaturalGradientAffineComponent::Info() const { @@ -3095,6 +3204,126 @@ void SoftmaxComponent::StoreStats(const CuMatrixBase &out_value) { StoreStatsInternal(out_value, NULL); } +std::string LogComponent::Info() const { + std::stringstream stream; + stream << NonlinearComponent::Info() + << ", log-floor=" << log_floor_; + return stream.str(); +} + +void LogComponent::InitFromConfig(ConfigLine *cfl) { + cfl->GetValue("log-floor", &log_floor_); + NonlinearComponent::InitFromConfig(cfl); +} + +void LogComponent::Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const { + // Apllies log function (x >= epsi ? log(x) : log(epsi)). + out->CopyFromMat(in); + out->ApplyFloor(log_floor_); + out->ApplyLog(); +} + +void LogComponent::Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + Component *to_update, + CuMatrixBase *in_deriv) const { + if (in_deriv != NULL) { + CuMatrix divided_in_value(in_value), floored_in_value(in_value); + divided_in_value.Set(1.0); + floored_in_value.CopyFromMat(in_value); + floored_in_value.ApplyFloor(log_floor_); // (x > epsi ? x : epsi) + + divided_in_value.DivElements(floored_in_value); // (x > epsi ? 1/x : 1/epsi) + in_deriv->CopyFromMat(in_value); + in_deriv->Add(-1.0 * log_floor_); // (x - epsi) + in_deriv->ApplyHeaviside(); // (x > epsi ? 1 : 0) + in_deriv->MulElements(divided_in_value); // (dy/dx: x > epsi ? 1/x : 0) + in_deriv->MulElements(out_deriv); // dF/dx = dF/dy * dy/dx + } +} + +void LogComponent::Read(std::istream &is, bool binary) { + std::ostringstream ostr_beg, ostr_end; + ostr_beg << "<" << Type() << ">"; // e.g. "" + ostr_end << ""; // e.g. "" + ExpectOneOrTwoTokens(is, binary, ostr_beg.str(), ""); + ReadBasicType(is, binary, &dim_); // Read dimension. + ExpectToken(is, binary, ""); + value_sum_.Read(is, binary); + ExpectToken(is, binary, ""); + deriv_sum_.Read(is, binary); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &count_); + value_sum_.Scale(count_); + deriv_sum_.Scale(count_); + + std::string token; + ReadToken(is, binary, &token); + if (token == "") { + ReadBasicType(is, binary, &self_repair_lower_threshold_); + ReadToken(is, binary, &token); + } + if (token == "") { + ReadBasicType(is, binary, &self_repair_upper_threshold_); + ReadToken(is, binary, &token); + } + if (token == "") { + ReadBasicType(is, binary, &self_repair_scale_); + ReadToken(is, binary, &token); + } + if (token == "") { + ReadBasicType(is, binary, &log_floor_); + ReadToken(is, binary, &token); + } + if (token != ostr_end.str()) { + KALDI_ERR << "Expected token " << ostr_end.str() + << ", got " << token; + } +} + +void LogComponent::Write(std::ostream &os, bool binary) const { + std::ostringstream ostr_beg, ostr_end; + ostr_beg << "<" << Type() << ">"; // e.g. "" + ostr_end << ""; // e.g. "" + WriteToken(os, binary, ostr_beg.str()); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, dim_); + // Write the values and derivatives in a count-normalized way, for + // greater readability in text form. + WriteToken(os, binary, ""); + Vector temp(value_sum_); + if (count_ != 0.0) temp.Scale(1.0 / count_); + temp.Write(os, binary); + WriteToken(os, binary, ""); + + temp.Resize(deriv_sum_.Dim(), kUndefined); + temp.CopyFromVec(deriv_sum_); + if (count_ != 0.0) temp.Scale(1.0 / count_); + temp.Write(os, binary); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, count_); + if (self_repair_lower_threshold_ != kUnsetThreshold) { + WriteToken(os, binary, ""); + WriteBasicType(os, binary, self_repair_lower_threshold_); + } + if (self_repair_upper_threshold_ != kUnsetThreshold) { + WriteToken(os, binary, ""); + WriteBasicType(os, binary, self_repair_upper_threshold_); + } + if (self_repair_scale_ != 0.0) { + WriteToken(os, binary, ""); + WriteBasicType(os, binary, self_repair_scale_); + } + WriteToken(os, binary, ""); + WriteBasicType(os, binary, log_floor_); + WriteToken(os, binary, ostr_end.str()); +} + void LogSoftmaxComponent::Propagate(const ComponentPrecomputedIndexes *indexes, const CuMatrixBase &in, @@ -3135,12 +3364,18 @@ void FixedScaleComponent::InitFromConfig(ConfigLine *cfl) { Init(vec); } else { int32 dim; + BaseFloat scale = 1.0; + bool scale_ok = cfl->GetValue("scale", &scale); if (!cfl->GetValue("dim", &dim) || cfl->HasUnusedValues()) KALDI_ERR << "Invalid initializer for layer of type " << Type() << ": \"" << cfl->WholeLine() << "\""; KALDI_ASSERT(dim > 0); CuVector vec(dim); - vec.SetRandn(); + if (scale_ok) { + vec.Set(scale); + } else { + vec.SetRandn(); + } Init(vec); } } diff --git a/src/nnet3/nnet-simple-component.h b/src/nnet3/nnet-simple-component.h index f09a989759a..ff9ec5fd26b 100644 --- a/src/nnet3/nnet-simple-component.h +++ b/src/nnet3/nnet-simple-component.h @@ -697,6 +697,71 @@ class LogSoftmaxComponent: public NonlinearComponent { LogSoftmaxComponent &operator = (const LogSoftmaxComponent &other); // Disallow. }; +// The LogComponent outputs the log of input values as y = Log(max(x, epsi)) +class LogComponent: public NonlinearComponent { + public: + explicit LogComponent(const LogComponent &other): + NonlinearComponent(other), log_floor_(other.log_floor_) { } + LogComponent(): log_floor_(1e-20) { } + virtual std::string Type() const { return "LogComponent"; } + virtual int32 Properties() const { + return kSimpleComponent|kBackpropNeedsInput|kStoresStats; + } + + virtual std::string Info() const; + + virtual void InitFromConfig(ConfigLine *cfl); + + virtual void Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const; + virtual void Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + Component *to_update, + CuMatrixBase *in_deriv) const; + + virtual Component* Copy() const { return new LogComponent(*this); } + + virtual void Read(std::istream &is, bool binary); + + virtual void Write(std::ostream &os, bool binary) const; + + private: + LogComponent &operator = (const LogComponent &other); // Disallow. + BaseFloat log_floor_; +}; + + +// The ExpComponent outputs the exp of input values as y = Exp(x) +class ExpComponent: public NonlinearComponent { + public: + explicit ExpComponent(const ExpComponent &other): + NonlinearComponent(other) { } + ExpComponent() { } + virtual std::string Type() const { return "ExpComponent"; } + virtual int32 Properties() const { + return kSimpleComponent|kBackpropNeedsOutput|kStoresStats; + } + virtual void Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const; + virtual void Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &, + const CuMatrixBase &out_value, + const CuMatrixBase &, + Component *to_update, + CuMatrixBase *in_deriv) const; + + virtual Component* Copy() const { return new ExpComponent(*this); } + private: + ExpComponent &operator = (const ExpComponent &other); // Disallow. +}; + + /// Keywords: natural gradient descent, NG-SGD, naturalgradient. For /// the top-level of the natural gradient code look here, and also in /// nnet-precondition-online.h. @@ -826,6 +891,8 @@ class FixedAffineComponent: public Component { // Function to provide access to linear_params_. const CuMatrix &LinearParams() const { return linear_params_; } + const CuVector &BiasParams() const { return bias_params_; } + protected: friend class AffineComponent; CuMatrix linear_params_; @@ -1129,6 +1196,46 @@ class ClipGradientComponent: public Component { }; +// Applied a per-element scale only on the gradient during back propagation +// Duplicates the input during forward propagation +class ScaleGradientComponent : public Component { + public: + ScaleGradientComponent() { } + virtual std::string Type() const { return "ScaleGradientComponent"; } + virtual std::string Info() const; + virtual int32 Properties() const { + return kSimpleComponent|kLinearInInput|kPropagateInPlace|kBackpropInPlace; + } + + void Init(const CuVectorBase &scales); + + // The ConfigLine cfl contains only the option scales=, + // where the string is the filename of a Kaldi-format matrix to read. + virtual void InitFromConfig(ConfigLine *cfl); + + virtual int32 InputDim() const { return scales_.Dim(); } + virtual int32 OutputDim() const { return scales_.Dim(); } + + virtual void Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const; + virtual void Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &, // in_value + const CuMatrixBase &, // out_value + const CuMatrixBase &out_deriv, + Component *, // to_update + CuMatrixBase *in_deriv) const; + virtual Component* Copy() const; + virtual void Read(std::istream &is, bool binary); + virtual void Write(std::ostream &os, bool binary) const; + + protected: + CuVector scales_; + KALDI_DISALLOW_COPY_AND_ASSIGN(ScaleGradientComponent); +}; + + /** PermuteComponent changes the order of the columns (i.e. the feature or activation dimensions). Output dimension i is mapped to input dimension column_map_[i], so it's like doing: diff --git a/src/nnet3/nnet-test-utils.cc b/src/nnet3/nnet-test-utils.cc index 170ea51ca8f..da519fa1cd3 100644 --- a/src/nnet3/nnet-test-utils.cc +++ b/src/nnet3/nnet-test-utils.cc @@ -1104,7 +1104,7 @@ void ComputeExampleComputationRequestSimple( static void GenerateRandomComponentConfig(std::string *component_type, std::string *config) { - int32 n = RandInt(0, 30); + int32 n = RandInt(0, 33); BaseFloat learning_rate = 0.001 * RandInt(1, 3); std::ostringstream os; @@ -1401,8 +1401,7 @@ static void GenerateRandomComponentConfig(std::string *component_type, *component_type = "DropoutComponent"; os << "dim=" << RandInt(1, 200) << " dropout-proportion=" << RandUniform(); - break; - } + } case 30: { *component_type = "LstmNonlinearityComponent"; // set self-repair scale to zero so the derivative tests will pass. @@ -1410,6 +1409,21 @@ static void GenerateRandomComponentConfig(std::string *component_type, << " self-repair-scale=0.0"; break; } + case 31: { + *component_type = "LogComponent"; + os << "dim=" << RandInt(1, 50); + break; + } + case 32: { + *component_type = "ExpComponent"; + os << "dim=" << RandInt(1, 50); + break; + } + case 33: { + *component_type = "ScaleGradientComponent"; + os << "dim=" << RandInt(1, 100); + break; + } default: KALDI_ERR << "Error generating random component"; } diff --git a/src/nnet3/nnet-training.cc b/src/nnet3/nnet-training.cc index 87d64e27871..bdbe244a648 100644 --- a/src/nnet3/nnet-training.cc +++ b/src/nnet3/nnet-training.cc @@ -39,7 +39,7 @@ NnetTrainer::NnetTrainer(const NnetTrainerOptions &config, // natural-gradient updates. SetZero(is_gradient, delta_nnet_); const int32 num_updatable = NumUpdatableComponents(*delta_nnet_); - num_max_change_per_component_applied_.resize(num_updatable, 0); + num_max_change_per_component_applied_.resize(num_updatable, 0); num_max_change_global_applied_ = 0; if (config_.read_cache != "") { @@ -52,7 +52,7 @@ NnetTrainer::NnetTrainer(const NnetTrainerOptions &config, KALDI_WARN << "Could not open cached computation. " "Probably this is the first training iteration."; } - } + } } @@ -88,9 +88,12 @@ void NnetTrainer::ProcessOutputs(const NnetExample &eg, ObjectiveType obj_type = nnet_->GetNode(node_index).u.objective_type; BaseFloat tot_weight, tot_objf; bool supply_deriv = true; + const Vector *deriv_weights = NULL; + if (config_.apply_deriv_weights && io.deriv_weights.Dim() > 0) + deriv_weights = &(io.deriv_weights); ComputeObjectiveFunction(io.features, obj_type, io.name, supply_deriv, computer, - &tot_weight, &tot_objf); + &tot_weight, &tot_objf, deriv_weights); objf_info_[io.name].UpdateStats(io.name, config_.print_interval, num_minibatches_processed_++, tot_weight, tot_objf); @@ -167,7 +170,7 @@ void NnetTrainer::UpdateParamsWithMaxChange() { << " / " << num_updatable << " Updatable Components." << "(smallest factor=" << min_scale << " on " << component_name_with_min_scale - << " with max-change=" << max_change_with_min_scale <<"). "; + << " with max-change=" << max_change_with_min_scale <<"). "; if (param_delta > config_.max_param_change) ostr << "Global max-change factor was " << config_.max_param_change / param_delta @@ -185,11 +188,12 @@ bool NnetTrainer::PrintTotalStats() const { unordered_map::const_iterator iter = objf_info_.begin(), end = objf_info_.end(); - bool ans = false; + bool ans = true; for (; iter != end; ++iter) { const std::string &name = iter->first; const ObjectiveFunctionInfo &info = iter->second; - ans = ans || info.PrintTotalStats(name); + if (!info.PrintTotalStats(name)) + ans = false; } PrintMaxChangeStats(); return ans; @@ -276,7 +280,7 @@ bool ObjectiveFunctionInfo::PrintTotalStats(const std::string &name) const { << (tot_objf / tot_weight) << " over " << tot_weight << " frames."; } else { KALDI_LOG << "Overall average objective function for '" << name << "' is " - << objf << " + " << aux_objf << " = " << sum_objf + << objf << " + " << aux_objf << " = " << sum_objf << " over " << tot_weight << " frames."; } KALDI_LOG << "[this line is to be parsed by a script:] " @@ -290,7 +294,7 @@ NnetTrainer::~NnetTrainer() { Output ko(config_.write_cache, config_.binary_write_cache); compiler_.WriteCache(ko.Stream(), config_.binary_write_cache); KALDI_LOG << "Wrote computation cache to " << config_.write_cache; - } + } delete delta_nnet_; } @@ -300,7 +304,8 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, bool supply_deriv, NnetComputer *computer, BaseFloat *tot_weight, - BaseFloat *tot_objf) { + BaseFloat *tot_objf, + const VectorBase *deriv_weights) { const CuMatrixBase &output = computer->GetOutput(output_name); if (output.NumCols() != supervision.NumCols()) @@ -309,6 +314,51 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, << " (nnet) vs. " << supervision.NumCols() << " (egs)\n"; switch (objective_type) { + case kXentPerDim: { + // objective is x * log(y) + (1-x) * log(1-y) + CuMatrix cu_post(supervision.NumRows(), supervision.NumCols(), + kUndefined); // x + cu_post.CopyFromGeneralMat(supervision); + + CuMatrix n_cu_post(cu_post.NumRows(), cu_post.NumCols()); + n_cu_post.Set(1.0); + n_cu_post.AddMat(-1.0, cu_post); // 1-x + + CuMatrix log_prob(output); // y + log_prob.ApplyLog(); // log(y) + + CuMatrix n_output(output.NumRows(), + output.NumCols(), kSetZero); + n_output.Set(1.0); + n_output.AddMat(-1.0, output); // 1-y + n_output.ApplyLog(); // log(1-y) + + BaseFloat num_elements = static_cast(cu_post.NumRows()); + if (deriv_weights) { + CuVector cu_deriv_weights(*deriv_weights); + num_elements = cu_deriv_weights.Sum(); + cu_post.MulRowsVec(cu_deriv_weights); + n_cu_post.MulRowsVec(cu_deriv_weights); + } + + *tot_weight = num_elements * cu_post.NumCols(); + *tot_objf = TraceMatMat(log_prob, cu_post, kTrans) + + TraceMatMat(n_output, n_cu_post, kTrans); + + if (supply_deriv) { + // deriv is x / y - (1-x) / (1-y) + n_output.ApplyExp(); // 1-y + n_cu_post.DivElements(n_output); // 1-x / (1-y) + + log_prob.ApplyExp(); // y + cu_post.DivElements(log_prob); // x / y + + cu_post.AddMat(-1.0, n_cu_post); // x / y - (1-x) / (1-y) + computer->AcceptOutputDeriv(output_name, &cu_post); + } + + break; + } case kLinear: { // objective is x * y. switch (supervision.Type()) { @@ -318,20 +368,38 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, // The cross-entropy objective is computed by a simple dot product, // because after the LogSoftmaxLayer, the output is already in the form // of log-likelihoods that are normalized to sum to one. - *tot_weight = cu_post.Sum(); - *tot_objf = TraceMatSmat(output, cu_post, kTrans); - if (supply_deriv) { + if (deriv_weights) { CuMatrix output_deriv(output.NumRows(), output.NumCols(), kUndefined); cu_post.CopyToMat(&output_deriv); - computer->AcceptOutputDeriv(output_name, &output_deriv); + CuVector cu_deriv_weights(*deriv_weights); + output_deriv.MulRowsVec(cu_deriv_weights); + *tot_weight = cu_deriv_weights.Sum(); + *tot_objf = TraceMatMat(output, output_deriv, kTrans); + if (supply_deriv) { + computer->AcceptOutputDeriv(output_name, &output_deriv); + } + } else { + *tot_weight = cu_post.Sum(); + *tot_objf = TraceMatSmat(output, cu_post, kTrans); + if (supply_deriv) { + CuMatrix output_deriv(output.NumRows(), output.NumCols(), + kUndefined); + cu_post.CopyToMat(&output_deriv); + computer->AcceptOutputDeriv(output_name, &output_deriv); + } } + break; } case kFullMatrix: { // there is a redundant matrix copy in here if we're not using a GPU // but we don't anticipate this code branch being used in many cases. CuMatrix cu_post(supervision.GetFullMatrix()); + if (deriv_weights) { + CuVector cu_deriv_weights(*deriv_weights); + cu_post.MulRowsVec(cu_deriv_weights); + } *tot_weight = cu_post.Sum(); *tot_objf = TraceMatMat(output, cu_post, kTrans); if (supply_deriv) @@ -343,6 +411,10 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, supervision.GetMatrix(&post); CuMatrix cu_post; cu_post.Swap(&post); + if (deriv_weights) { + CuVector cu_deriv_weights(*deriv_weights); + cu_post.MulRowsVec(cu_deriv_weights); + } *tot_weight = cu_post.Sum(); *tot_objf = TraceMatMat(output, cu_post, kTrans); if (supply_deriv) @@ -360,6 +432,11 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, diff.CopyFromGeneralMat(supervision); diff.AddMat(-1.0, output); *tot_weight = diff.NumRows(); + if (deriv_weights) { + CuVector cu_deriv_weights(*deriv_weights); + diff.MulRowsVec(cu_deriv_weights); + *tot_weight = deriv_weights->Sum(); + } *tot_objf = -0.5 * TraceMatMat(diff, diff, kTrans); if (supply_deriv) computer->AcceptOutputDeriv(output_name, &diff); diff --git a/src/nnet3/nnet-training.h b/src/nnet3/nnet-training.h index 70c90267c66..7b22bc75211 100644 --- a/src/nnet3/nnet-training.h +++ b/src/nnet3/nnet-training.h @@ -42,6 +42,8 @@ struct NnetTrainerOptions { BaseFloat max_param_change; NnetOptimizeOptions optimize_config; NnetComputeOptions compute_config; + bool apply_deriv_weights; + NnetTrainerOptions(): zero_component_stats(true), store_component_stats(true), @@ -49,7 +51,8 @@ struct NnetTrainerOptions { debug_computation(false), momentum(0.0), binary_write_cache(true), - max_param_change(2.0) { } + max_param_change(2.0), + apply_deriv_weights(true) { } void Register(OptionsItf *opts) { opts->Register("store-component-stats", &store_component_stats, "If true, store activations and derivatives for nonlinear " @@ -69,6 +72,9 @@ struct NnetTrainerOptions { "so that the 'effective' learning rate is the same as " "before (because momentum would normally increase the " "effective learning rate by 1/(1-momentum))"); + opts->Register("apply-deriv-weights", &apply_deriv_weights, + "If true, apply the per-frame derivative weights stored with " + "the example"); opts->Register("read-cache", &read_cache, "the location where we can read " "the cached computation from"); opts->Register("write-cache", &write_cache, "the location where we want to " @@ -226,7 +232,8 @@ void ComputeObjectiveFunction(const GeneralMatrix &supervision, bool supply_deriv, NnetComputer *computer, BaseFloat *tot_weight, - BaseFloat *tot_objf); + BaseFloat *tot_objf, + const VectorBase* deriv_weights = NULL); diff --git a/src/nnet3bin/Makefile b/src/nnet3bin/Makefile index d46c56a1044..aeb3dc1dc03 100644 --- a/src/nnet3bin/Makefile +++ b/src/nnet3bin/Makefile @@ -17,7 +17,7 @@ BINFILES = nnet3-init nnet3-info nnet3-get-egs nnet3-copy-egs nnet3-subset-egs \ nnet3-discriminative-merge-egs nnet3-discriminative-shuffle-egs \ nnet3-discriminative-compute-objf nnet3-discriminative-train \ discriminative-get-supervision nnet3-discriminative-subset-egs \ - nnet3-discriminative-compute-from-egs + nnet3-discriminative-compute-from-egs nnet3-get-egs-multiple-targets OBJFILES = diff --git a/src/nnet3bin/nnet3-acc-lda-stats.cc b/src/nnet3bin/nnet3-acc-lda-stats.cc index 0b3b537855e..b41c4a6704d 100644 --- a/src/nnet3bin/nnet3-acc-lda-stats.cc +++ b/src/nnet3bin/nnet3-acc-lda-stats.cc @@ -87,13 +87,18 @@ class NnetLdaStatsAccumulator { // but we're about to do an outer product, so this doesn't dominate. Vector row(cu_row); + BaseFloat deriv_weight = 1.0; + if (output_supervision->deriv_weights.Dim() > 0 && r < output_supervision->deriv_weights.Dim()) { + deriv_weight = output_supervision->deriv_weights(r); + } + const SparseVector &post(smat.Row(r)); const std::pair *post_data = post.Data(), *post_end = post_data + post.NumElements(); for (; post_data != post_end; ++post_data) { MatrixIndexT pdf = post_data->first; BaseFloat weight = post_data->second; - BaseFloat pruned_weight = RandPrune(weight, rand_prune); + BaseFloat pruned_weight = RandPrune(weight, rand_prune) * deriv_weight; if (pruned_weight != 0.0) lda_stats_.Accumulate(row, pdf, pruned_weight); } @@ -110,11 +115,16 @@ class NnetLdaStatsAccumulator { // but we're about to do an outer product, so this doesn't dominate. Vector row(cu_row); + BaseFloat deriv_weight = 1.0; + if (output_supervision->deriv_weights.Dim() > 0 && r < output_supervision->deriv_weights.Dim()) { + deriv_weight = output_supervision->deriv_weights(r); + } + SubVector post(output_mat, r); int32 num_pdfs = post.Dim(); for (int32 pdf = 0; pdf < num_pdfs; pdf++) { BaseFloat weight = post(pdf); - BaseFloat pruned_weight = RandPrune(weight, rand_prune); + BaseFloat pruned_weight = RandPrune(weight, rand_prune) * deriv_weight; if (pruned_weight != 0.0) lda_stats_.Accumulate(row, pdf, pruned_weight); } diff --git a/src/nnet3bin/nnet3-compute-from-egs.cc b/src/nnet3bin/nnet3-compute-from-egs.cc index 66eace0dab5..e35e67bbeb5 100644 --- a/src/nnet3bin/nnet3-compute-from-egs.cc +++ b/src/nnet3bin/nnet3-compute-from-egs.cc @@ -36,7 +36,8 @@ class NnetComputerFromEg { // Compute the output (which will have the same number of rows as the number // of Indexes in the output of the eg), and put it in "output". - void Compute(const NnetExample &eg, Matrix *output) { + void Compute(const NnetExample &eg, const std::string &output_name, + Matrix *output) { ComputationRequest request; bool need_backprop = false, store_stats = false; GetComputationRequest(nnet_, eg, need_backprop, store_stats, &request); @@ -47,7 +48,7 @@ class NnetComputerFromEg { NnetComputer computer(options, computation, nnet_, NULL); computer.AcceptInputs(nnet_, eg.io); computer.Forward(); - const CuMatrixBase &nnet_output = computer.GetOutput("output"); + const CuMatrixBase &nnet_output = computer.GetOutput(output_name); output->Resize(nnet_output.NumRows(), nnet_output.NumCols()); nnet_output.CopyToMat(output); } @@ -80,11 +81,14 @@ int main(int argc, char *argv[]) { bool binary_write = true, apply_exp = false; std::string use_gpu = "yes"; + std::string output_name = "output"; ParseOptions po(usage); po.Register("binary", &binary_write, "Write output in binary mode"); po.Register("apply-exp", &apply_exp, "If true, apply exp function to " "output"); + po.Register("output-name", &output_name, "Do computation for " + "specified output"); po.Register("use-gpu", &use_gpu, "yes|no|optional|wait, only has effect if compiled with CUDA"); @@ -115,7 +119,7 @@ int main(int argc, char *argv[]) { for (; !example_reader.Done(); example_reader.Next(), num_egs++) { Matrix output; - computer.Compute(example_reader.Value(), &output); + computer.Compute(example_reader.Value(), output_name, &output); KALDI_ASSERT(output.NumRows() != 0); if (apply_exp) output.ApplyExp(); diff --git a/src/nnet3bin/nnet3-compute.cc b/src/nnet3bin/nnet3-compute.cc index 9305ef7e6b6..d46220c7ffd 100644 --- a/src/nnet3bin/nnet3-compute.cc +++ b/src/nnet3bin/nnet3-compute.cc @@ -159,6 +159,9 @@ int main(int argc, char *argv[]) { num_success++; } +#if HAVE_CUDA==1 + CuDevice::Instantiate().PrintProfile(); +#endif double elapsed = timer.Elapsed(); KALDI_LOG << "Time taken "<< elapsed << "s: real-time factor assuming 100 frames/sec is " diff --git a/src/nnet3bin/nnet3-copy-egs.cc b/src/nnet3bin/nnet3-copy-egs.cc index efb51f51910..2702ae5fae9 100644 --- a/src/nnet3bin/nnet3-copy-egs.cc +++ b/src/nnet3bin/nnet3-copy-egs.cc @@ -23,10 +23,29 @@ #include "hmm/transition-model.h" #include "nnet3/nnet-example.h" #include "nnet3/nnet-example-utils.h" +#include namespace kaldi { namespace nnet3 { +bool KeepOutputs(const std::vector &keep_outputs, + NnetExample *eg) { + std::vector io_new; + int32 num_outputs = 0; + for (std::vector::iterator it = eg->io.begin(); + it != eg->io.end(); ++it) { + if (it->name.find("output") != std::string::npos) { + if (!std::binary_search(keep_outputs.begin(), keep_outputs.end(), it->name)) + continue; + num_outputs++; + } + io_new.push_back(*it); + } + eg->io.swap(io_new); + + return num_outputs; +} + // returns an integer randomly drawn with expected value "expected_count" // (will be either floor(expected_count) or ceil(expected_count)). int32 GetCount(double expected_count) { @@ -58,7 +77,7 @@ bool ContainsSingleExample(const NnetExample &eg, end = io.indexes.end(); // Should not have an empty input/output type. KALDI_ASSERT(!io.indexes.empty()); - if (io.name == "input" || io.name == "output") { + if (io.name == "input" || io.name.find("output") != std::string::npos) { int32 min_t = iter->t, max_t = iter->t; for (; iter != end; ++iter) { int32 this_t = iter->t; @@ -75,7 +94,7 @@ bool ContainsSingleExample(const NnetExample &eg, *min_input_t = min_t; *max_input_t = max_t; } else { - KALDI_ASSERT(io.name == "output"); + KALDI_ASSERT(io.name.find("output") != std::string::npos); done_output = true; *min_output_t = min_t; *max_output_t = max_t; @@ -127,7 +146,7 @@ void FilterExample(const NnetExample &eg, min_t = min_input_t; max_t = max_input_t; is_input_or_output = true; - } else if (name == "output") { + } else if (name.find("output") != std::string::npos) { min_t = min_output_t; max_t = max_output_t; is_input_or_output = true; @@ -137,6 +156,7 @@ void FilterExample(const NnetExample &eg, if (!is_input_or_output) { // Just copy everything. io_out.indexes = io_in.indexes; io_out.features = io_in.features; + io_out.deriv_weights = io_in.deriv_weights; } else { const std::vector &indexes_in = io_in.indexes; std::vector &indexes_out = io_out.indexes; @@ -157,6 +177,19 @@ void FilterExample(const NnetExample &eg, } } KALDI_ASSERT(iter_out == keep.end()); + + if (io_in.deriv_weights.Dim() > 0) { + io_out.deriv_weights.Resize(num_kept, kUndefined); + int32 in_dim = 0, out_dim = 0; + iter_out = keep.begin(); + for (; iter_out != keep.end(); ++iter_out, in_dim++) { + if (*iter_out) + io_out.deriv_weights(out_dim++) = io_in.deriv_weights(in_dim); + } + KALDI_ASSERT(out_dim == num_kept); + KALDI_ASSERT(iter_out == keep.end()); + } + if (num_kept == 0) KALDI_ERR << "FilterExample removed all indexes for '" << name << "'"; @@ -243,6 +276,22 @@ bool SelectFromExample(const NnetExample &eg, return true; } +bool RemoveZeroDerivOutputs(NnetExample *eg) { + std::vector io_new; + int32 num_outputs = 0; + for (std::vector::iterator it = eg->io.begin(); + it != eg->io.end(); ++it) { + if (it->name.find("output") != std::string::npos) { + if (it->deriv_weights.Dim() > 0 && it->deriv_weights.Sum() == 0) + continue; + num_outputs++; + } + io_new.push_back(*it); + } + eg->io.swap(io_new); + + return (num_outputs > 0); +} } // namespace nnet3 } // namespace kaldi @@ -270,6 +319,8 @@ int main(int argc, char *argv[]) { int32 srand_seed = 0; int32 frame_shift = 0; BaseFloat keep_proportion = 1.0; + std::string keep_outputs_str; + bool remove_zero_deriv_outputs = false; // The following config variables, if set, can be used to extract a single // frame of labels from a multi-frame example, and/or to reduce the amount @@ -301,7 +352,11 @@ int main(int argc, char *argv[]) { "feature left-context that we output."); po.Register("right-context", &right_context, "Can be used to truncate the " "feature right-context that we output."); - + po.Register("keep-outputs", &keep_outputs_str, "Comma separated list of " + "output nodes to keep"); + po.Register("remove-zero-deriv-outputs", &remove_zero_deriv_outputs, + "Remove outputs that do not contribute to the objective " + "because of zero deriv-weights"); po.Read(argc, argv); @@ -321,17 +376,29 @@ int main(int argc, char *argv[]) { for (int32 i = 0; i < num_outputs; i++) example_writers[i] = new NnetExampleWriter(po.GetArg(i+2)); + std::vector keep_outputs; + if (!keep_outputs_str.empty()) { + SplitStringToVector(keep_outputs_str, ",:", true, &keep_outputs); + std::sort(keep_outputs.begin(), keep_outputs.end()); + } int64 num_read = 0, num_written = 0; for (; !example_reader.Done(); example_reader.Next(), num_read++) { // count is normally 1; could be 0, or possibly >1. int32 count = GetCount(keep_proportion); std::string key = example_reader.Key(); - const NnetExample &eg = example_reader.Value(); + NnetExample eg(example_reader.Value()); + + if (!keep_outputs_str.empty()) { + if (!KeepOutputs(keep_outputs, &eg)) continue; + } + for (int32 c = 0; c < count; c++) { int32 index = (random ? Rand() : num_written) % num_outputs; if (frame_str == "" && left_context == -1 && right_context == -1 && frame_shift == 0) { + if (remove_zero_deriv_outputs) + if (!RemoveZeroDerivOutputs(&eg)) continue; example_writers[index]->Write(key, eg); num_written++; } else { // the --frame option or context options were set. @@ -340,6 +407,8 @@ int main(int argc, char *argv[]) { frame_shift, &eg_modified)) { // this branch of the if statement will almost always be taken (should only // not be taken for shorter-than-normal egs from the end of a file. + if (remove_zero_deriv_outputs) + if (!RemoveZeroDerivOutputs(&eg_modified)) continue; example_writers[index]->Write(key, eg_modified); num_written++; } diff --git a/src/nnet3bin/nnet3-get-egs-dense-targets.cc b/src/nnet3bin/nnet3-get-egs-dense-targets.cc index 23bf8922a5b..502e0700f27 100644 --- a/src/nnet3bin/nnet3-get-egs-dense-targets.cc +++ b/src/nnet3bin/nnet3-get-egs-dense-targets.cc @@ -32,9 +32,13 @@ namespace nnet3 { static void ProcessFile(const MatrixBase &feats, const MatrixBase *ivector_feats, + const VectorBase *deriv_weights, + const MatrixBase *l2reg_targets, const MatrixBase &targets, const std::string &utt_id, bool compress, + int32 input_compress_format, + int32 feats_compress_format, int32 num_targets, int32 left_context, int32 right_context, @@ -42,9 +46,9 @@ static void ProcessFile(const MatrixBase &feats, int64 *num_frames_written, int64 *num_egs_written, NnetExampleWriter *example_writer) { - KALDI_ASSERT(feats.NumRows() == static_cast(targets.NumRows())); - - for (int32 t = 0; t < feats.NumRows(); t += frames_per_eg) { + //KALDI_ASSERT(feats.NumRows() == static_cast(targets.NumRows())); + int min_size = std::min(feats.NumRows(), targets.NumRows()); + for (int32 t = 0; t < min_size; t += frames_per_eg) { // actual_frames_per_eg is the number of frames with actual targets. // At the end of the file, we pad with the last frame repeated @@ -52,18 +56,18 @@ static void ProcessFile(const MatrixBase &feats, // for recompilations). // TODO: We might need to ignore the end of the file. int32 actual_frames_per_eg = std::min(frames_per_eg, - feats.NumRows() - t); + min_size - t); int32 tot_frames = left_context + frames_per_eg + right_context; - Matrix input_frames(tot_frames, feats.NumCols()); + Matrix input_frames(tot_frames, feats.NumCols(), kUndefined); // Set up "input_frames". for (int32 j = -left_context; j < frames_per_eg + right_context; j++) { int32 t2 = j + t; if (t2 < 0) t2 = 0; - if (t2 >= feats.NumRows()) t2 = feats.NumRows() - 1; + if (t2 >= min_size) t2 = min_size - 1; SubVector src(feats, t2), dest(input_frames, j + left_context); dest.CopyFromVec(src); @@ -75,8 +79,11 @@ static void ProcessFile(const MatrixBase &feats, eg.io.push_back(NnetIo("input", - left_context, input_frames)); + if (compress) + eg.io.back().Compress(input_compress_format); + // if applicable, add the iVector feature. - if (ivector_feats != NULL) { + if (ivector_feats) { // try to get closest frame to middle of window to get // a representative iVector. int32 closest_frame = t + (actual_frames_per_eg / 2); @@ -102,17 +109,57 @@ static void ProcessFile(const MatrixBase &feats, for (int32 i = actual_frames_per_eg; i < frames_per_eg; i++) { // Copy the i^th row of the target matrix from the last row of the // input targets matrix - KALDI_ASSERT(t + actual_frames_per_eg - 1 == feats.NumRows() - 1); + KALDI_ASSERT(t + actual_frames_per_eg - 1 == min_size - 1); SubVector this_target_dest(targets_dest, i); SubVector this_target_src(targets, t+actual_frames_per_eg-1); this_target_dest.CopyFromVec(this_target_src); } - // push this created targets matrix into the eg - eg.io.push_back(NnetIo("output", 0, targets_dest)); + if (!deriv_weights) { + // push this created targets matrix into the eg + eg.io.push_back(NnetIo("output", 0, targets_dest)); + } else { + Vector this_deriv_weights(targets_dest.NumRows()); + int32 frames_to_copy = std::min(t + actual_frames_per_eg, deriv_weights->Dim()) - t; + this_deriv_weights.Range(0, frames_to_copy).CopyFromVec(deriv_weights->Range(t, frames_to_copy)); + if (this_deriv_weights.Sum() == 0) continue; // Ignore frames that have frame weights 0 + eg.io.push_back(NnetIo("output", this_deriv_weights, 0, targets_dest)); + } + + if (l2reg_targets) { + // add the labels. + Matrix l2reg_targets_dest(frames_per_eg, l2reg_targets->NumCols()); + for (int32 i = 0; i < actual_frames_per_eg; i++) { + // Copy the i^th row of the target matrix from the (t+i)^th row of the + // input targets matrix + SubVector this_target_dest(l2reg_targets_dest, i); + SubVector this_target_src(*l2reg_targets, t+i); + this_target_dest.CopyFromVec(this_target_src); + } + + // Copy the last frame's target to the padded frames + for (int32 i = actual_frames_per_eg; i < frames_per_eg; i++) { + // Copy the i^th row of the target matrix from the last row of the + // input targets matrix + KALDI_ASSERT(t + actual_frames_per_eg - 1 == feats.NumRows() - 1); + SubVector this_target_dest(l2reg_targets_dest, i); + SubVector this_target_src(*l2reg_targets, t+actual_frames_per_eg-1); + this_target_dest.CopyFromVec(this_target_src); + } + + if (!deriv_weights) { + eg.io.push_back(NnetIo("output-l2reg", 0, l2reg_targets_dest)); + } else { + Vector this_deriv_weights(l2reg_targets_dest.NumRows()); + int32 frames_to_copy = std::min(t + actual_frames_per_eg, deriv_weights->Dim()) - t; + this_deriv_weights.Range(0, frames_to_copy).CopyFromVec(deriv_weights->Range(t, frames_to_copy)); + if (this_deriv_weights.Sum() == 0) continue; // Ignore frames that have frame weights 0 + eg.io.push_back(NnetIo("output-l2reg", this_deriv_weights, 0, l2reg_targets_dest)); + } + } if (compress) - eg.Compress(); + eg.Compress(feats_compress_format); std::ostringstream os; os << utt_id << "-" << t; @@ -155,14 +202,20 @@ int main(int argc, char *argv[]) { bool compress = true; + int32 input_compress_format = 0, feats_compress_format = 0; int32 num_targets = -1, left_context = 0, right_context = 0, - num_frames = 1, length_tolerance = 100; + num_frames = 1, length_tolerance = 2; - std::string ivector_rspecifier; + std::string ivector_rspecifier, deriv_weights_rspecifier, + l2reg_targets_rspecifier; ParseOptions po(usage); po.Register("compress", &compress, "If true, write egs in " "compressed format."); + po.Register("compress-format", &feats_compress_format, "Format for " + "compressing all feats in general"); + po.Register("input-compress-format", &input_compress_format, "Format for " + "compressing input feats e.g. Use 2 for compressing wave"); po.Register("num-targets", &num_targets, "Number of targets for the neural network"); po.Register("left-context", &left_context, "Number of frames of left " "context the neural net requires."); @@ -174,6 +227,13 @@ int main(int argc, char *argv[]) { "features, as matrix."); po.Register("length-tolerance", &length_tolerance, "Tolerance for " "difference in num-frames between feat and ivector matrices"); + po.Register("deriv-weights-rspecifier", &deriv_weights_rspecifier, + "Per-frame weights (only binary - 0 or 1) that specifies " + "whether a frame's gradient must be backpropagated or not. " + "Not specifying this is equivalent to specifying a vector of " + "all 1s."); + po.Register("l2reg-targets-rspecifier", &l2reg_targets_rspecifier, + "Add l2 regularizer targets"); po.Read(argc, argv); @@ -194,6 +254,8 @@ int main(int argc, char *argv[]) { RandomAccessBaseFloatMatrixReader matrix_reader(matrix_rspecifier); NnetExampleWriter example_writer(examples_wspecifier); RandomAccessBaseFloatMatrixReader ivector_reader(ivector_rspecifier); + RandomAccessBaseFloatVectorReader deriv_weights_reader(deriv_weights_rspecifier); + RandomAccessBaseFloatMatrixReader l2reg_targets_reader(l2reg_targets_rspecifier); int32 num_done = 0, num_err = 0; int64 num_frames_written = 0, num_egs_written = 0; @@ -206,10 +268,10 @@ int main(int argc, char *argv[]) { num_err++; } else { const Matrix &target_matrix = matrix_reader.Value(key); - if (target_matrix.NumRows() != feats.NumRows()) { - KALDI_WARN << "Target matrix has wrong size " - << target_matrix.NumRows() - << " versus " << feats.NumRows(); + if ((target_matrix.NumRows() - feats.NumRows()) > length_tolerance) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and target matrix " << target_matrix.NumRows() + << "exceeds tolerance " << length_tolerance; num_err++; continue; } @@ -226,7 +288,7 @@ int main(int argc, char *argv[]) { } } - if (ivector_feats != NULL && + if (ivector_feats && (abs(feats.NumRows() - ivector_feats->NumRows()) > length_tolerance || ivector_feats->NumRows() == 0)) { KALDI_WARN << "Length difference between feats " << feats.NumRows() @@ -235,8 +297,56 @@ int main(int argc, char *argv[]) { num_err++; continue; } - - ProcessFile(feats, ivector_feats, target_matrix, key, compress, + + const Vector *deriv_weights = NULL; + if (!deriv_weights_rspecifier.empty()) { + if (!deriv_weights_reader.HasKey(key)) { + KALDI_WARN << "No deriv weights for utterance " << key; + num_err++; + continue; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + deriv_weights = &(deriv_weights_reader.Value(key)); + } + } + + if (deriv_weights && + (abs(feats.NumRows() - deriv_weights->Dim()) > length_tolerance + || deriv_weights->Dim() == 0)) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and deriv weights " << deriv_weights->Dim() + << " exceeds tolerance " << length_tolerance; + num_err++; + continue; + } + + const Matrix *l2reg_target_matrix = NULL; + if (!l2reg_targets_rspecifier.empty()) { + if (!l2reg_targets_reader.HasKey(key)) { + KALDI_WARN << "No l2 regularizer targets for utterance " << key; + num_err++; + continue; + } + { + // this address will be valid until we call HasKey() or Value() + // again. + l2reg_target_matrix = &(l2reg_targets_reader.Value(key)); + + if (l2reg_target_matrix->NumRows() != feats.NumRows()) { + KALDI_WARN << "l2 regularizer target matrix has wrong size " + << l2reg_target_matrix->NumRows() + << " versus " << feats.NumRows(); + num_err++; + continue; + } + } + } + + + ProcessFile(feats, ivector_feats, deriv_weights, + l2reg_target_matrix, target_matrix, + key, compress, input_compress_format, feats_compress_format, num_targets, left_context, right_context, num_frames, &num_frames_written, &num_egs_written, &example_writer); diff --git a/src/nnet3bin/nnet3-get-egs-multiple-targets.cc b/src/nnet3bin/nnet3-get-egs-multiple-targets.cc new file mode 100644 index 00000000000..49f0dde4af7 --- /dev/null +++ b/src/nnet3bin/nnet3-get-egs-multiple-targets.cc @@ -0,0 +1,538 @@ +// nnet3bin/nnet3-get-egs-multiple-targets.cc + +// Copyright 2012-2015 Johns Hopkins University (author: Daniel Povey) +// 2014-2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "hmm/transition-model.h" +#include "hmm/posterior.h" +#include "nnet3/nnet-example.h" +#include "nnet3/nnet-example-utils.h" + +namespace kaldi { +namespace nnet3 { + +bool ToBool(std::string str) { + std::transform(str.begin(), str.end(), str.begin(), ::tolower); + + if ((str.compare("true") == 0) || (str.compare("t") == 0) + || (str.compare("1") == 0)) + return true; + if ((str.compare("false") == 0) || (str.compare("f") == 0) + || (str.compare("0") == 0)) + return false; + KALDI_ERR << "Invalid format for boolean argument [expected true or false]: " + << str; + return false; // never reached +} + +static void ProcessFile(const MatrixBase &feats, + const MatrixBase *ivector_feats, + const std::vector &output_names, + const std::vector &output_dims, + const std::vector* > &dense_target_matrices, + const std::vector &posteriors, + const std::vector* > &deriv_weights, + const std::string &utt_id, + bool compress_input, + int32 input_compress_format, + const std::vector &compress_targets, + const std::vector &targets_compress_formats, + int32 left_context, + int32 right_context, + int32 frames_per_eg, + std::vector *num_frames_written, + std::vector *num_egs_written, + NnetExampleWriter *example_writer) { + KALDI_ASSERT(output_names.size() > 0); + //KALDI_ASSERT(feats.NumRows() == static_cast(targets.NumRows())); + for (int32 t = 0; t < feats.NumRows(); t += frames_per_eg) { + + int32 tot_frames = left_context + frames_per_eg + right_context; + + Matrix input_frames(tot_frames, feats.NumCols(), kUndefined); + + // Set up "input_frames". + for (int32 j = -left_context; j < frames_per_eg + right_context; j++) { + int32 t2 = j + t; + if (t2 < 0) t2 = 0; + if (t2 >= feats.NumRows()) t2 = feats.NumRows() - 1; + SubVector src(feats, t2), + dest(input_frames, j + left_context); + dest.CopyFromVec(src); + } + + NnetExample eg; + + // call the regular input "input". + eg.io.push_back(NnetIo("input", - left_context, + input_frames)); + + if (compress_input) + eg.io.back().Compress(input_compress_format); + + // if applicable, add the iVector feature. + if (ivector_feats) { + int32 actual_frames_per_eg = std::min(frames_per_eg, + feats.NumRows() - t); + // try to get closest frame to middle of window to get + // a representative iVector. + int32 closest_frame = t + (actual_frames_per_eg / 2); + KALDI_ASSERT(ivector_feats->NumRows() > 0); + if (closest_frame >= ivector_feats->NumRows()) + closest_frame = ivector_feats->NumRows() - 1; + Matrix ivector(1, ivector_feats->NumCols()); + ivector.Row(0).CopyFromVec(ivector_feats->Row(closest_frame)); + eg.io.push_back(NnetIo("ivector", 0, ivector)); + } + + int32 num_outputs_added = 0; + + for (int32 n = 0; n < output_names.size(); n++) { + Vector this_deriv_weights(0); + if (deriv_weights[n]) { + // actual_frames_per_eg is the number of frames with actual targets. + // At the end of the file, we pad with the last frame repeated + // so that all examples have the same structure (prevents the need + // for recompilations). + int32 actual_frames_per_eg = std::min(std::min(frames_per_eg, + feats.NumRows() - t), deriv_weights[n]->Dim() - t); + + this_deriv_weights.Resize(frames_per_eg); + int32 frames_to_copy = std::min(t + actual_frames_per_eg, + deriv_weights[n]->Dim()) - t; + this_deriv_weights.Range(0, frames_to_copy).CopyFromVec(deriv_weights[n]->Range(t, frames_to_copy)); + if (this_deriv_weights.Sum() == 0) { + continue; // Ignore frames that have frame weights 0 + } + } + + if (dense_target_matrices[n]) { + const MatrixBase &targets = *dense_target_matrices[n]; + Matrix targets_dest(frames_per_eg, targets.NumCols()); + + // actual_frames_per_eg is the number of frames with actual targets. + // At the end of the file, we pad with the last frame repeated + // so that all examples have the same structure (prevents the need + // for recompilations). + int32 actual_frames_per_eg = std::min(std::min(frames_per_eg, + feats.NumRows() - t), targets.NumRows() - t); + + for (int32 i = 0; i < actual_frames_per_eg; i++) { + // Copy the i^th row of the target matrix from the (t+i)^th row of the + // input targets matrix + SubVector this_target_dest(targets_dest, i); + SubVector this_target_src(targets, t+i); + this_target_dest.CopyFromVec(this_target_src); + } + + // Copy the last frame's target to the padded frames + for (int32 i = actual_frames_per_eg; i < frames_per_eg; i++) { + // Copy the i^th row of the target matrix from the last row of the + // input targets matrix + KALDI_ASSERT(t + actual_frames_per_eg - 1 == targets.NumRows() - 1); + SubVector this_target_dest(targets_dest, i); + SubVector this_target_src(targets, t+actual_frames_per_eg-1); + this_target_dest.CopyFromVec(this_target_src); + } + + if (deriv_weights[n]) { + eg.io.push_back(NnetIo(output_names[n], this_deriv_weights, 0, targets_dest)); + } else { + eg.io.push_back(NnetIo(output_names[n], 0, targets_dest)); + } + } else if (posteriors[n]) { + const Posterior &pdf_post = *(posteriors[n]); + + // actual_frames_per_eg is the number of frames with actual targets. + // At the end of the file, we pad with the last frame repeated + // so that all examples have the same structure (prevents the need + // for recompilations). + int32 actual_frames_per_eg = std::min(std::min(frames_per_eg, + feats.NumRows() - t), static_cast(pdf_post.size()) - t); + + Posterior labels(frames_per_eg); + for (int32 i = 0; i < actual_frames_per_eg; i++) + labels[i] = pdf_post[t + i]; + // remaining posteriors for frames are empty. + + if (deriv_weights[n]) { + eg.io.push_back(NnetIo(output_names[n], this_deriv_weights, output_dims[n], 0, labels)); + } else { + eg.io.push_back(NnetIo(output_names[n], output_dims[n], 0, labels)); + } + } else + continue; + if (compress_targets[n]) + eg.io.back().Compress(targets_compress_formats[n]); + + num_outputs_added++; + (*num_frames_written)[n] += frames_per_eg; // Actually actual_frames_per_eg, but that depends on the different output. For simplification, frames_per_eg is used. + (*num_egs_written)[n] += 1; + } + + if (num_outputs_added == 0) continue; + + std::ostringstream os; + os << utt_id << "-" << t; + + std::string key = os.str(); // key is - + + KALDI_ASSERT(NumOutputs(eg) == num_outputs_added); + + example_writer->Write(key, eg); + } +} + + +} // namespace nnet2 +} // namespace kaldi + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace kaldi::nnet3; + typedef kaldi::int32 int32; + typedef kaldi::int64 int64; + + const char *usage = + "Get frame-by-frame examples of data for nnet3 neural network training.\n" + "This program is similar to nnet3-get-egs, but the targets here are " + "dense matrices instead of posteriors (sparse matrices).\n" + "This is useful when you want the targets to be continuous real-valued " + "with the neural network possibly trained with a quadratic objective\n" + "\n" + "Usage: nnet3-get-egs-multiple-targets [options] " + " ::[:] " + "[ :: ... ] \n" + "\n" + "Here is any random string for output node name, \n" + " is the rspecifier for either dense targets in matrix format or sparse targets in posterior format,\n" + "and is the target dimension of output node for sparse targets or -1 for dense targets\n" + "\n" + "An example [where $feats expands to the actual features]:\n" + "nnet-get-egs-multiple-targets --left-context=12 \\\n" + "--right-context=9 --num-frames=8 \"$feats\" \\\n" + "output-snr:\"ark:copy-matrix ark:exp/snrs/snr.1.ark ark:- |\":-1 \n" + " ark:- \n"; + + + bool compress_input = true; + int32 input_compress_format = 0; + int32 left_context = 0, right_context = 0, + num_frames = 1, length_tolerance = 2; + + std::string ivector_rspecifier, + targets_compress_formats_str, + compress_targets_str; + std::string output_dims_str; + std::string output_names_str; + + ParseOptions po(usage); + po.Register("compress-input", &compress_input, "If true, write egs in " + "compressed format."); + po.Register("input-compress-format", &input_compress_format, "Format for " + "compressing input feats e.g. Use 2 for compressing wave"); + po.Register("compress-targets", &compress_targets_str, "CSL of whether " + "targets must be compressed for each of the outputs"); + po.Register("targets-compress-formats", &targets_compress_formats_str, "Format for " + "compressing all feats in general"); + po.Register("left-context", &left_context, "Number of frames of left " + "context the neural net requires."); + po.Register("right-context", &right_context, "Number of frames of right " + "context the neural net requires."); + po.Register("num-frames", &num_frames, "Number of frames with labels " + "that each example contains."); + po.Register("ivectors", &ivector_rspecifier, "Rspecifier of ivector " + "features, as matrix."); + po.Register("length-tolerance", &length_tolerance, "Tolerance for " + "difference in num-frames between feat and ivector matrices"); + po.Register("output-dims", &output_dims_str, "CSL of output node dims"); + po.Register("output-names", &output_names_str, "CSL of output node names"); + //po.Register("deriv-weights-rspecifiers", &deriv_weights_rspecifiers_str, + // "CSL of per-frame weights (only binary - 0 or 1) that specifies " + // "whether a frame's gradient must be backpropagated or not. " + // "Not specifying this is equivalent to specifying a vector of " + // "all 1s."); + + po.Read(argc, argv); + + if (po.NumArgs() < 3) { + po.PrintUsage(); + exit(1); + } + + std::string feature_rspecifier = po.GetArg(1), + examples_wspecifier = po.GetArg(po.NumArgs()); + + // Read in all the training files. + SequentialBaseFloatMatrixReader feat_reader(feature_rspecifier); + RandomAccessBaseFloatMatrixReader ivector_reader(ivector_rspecifier); + NnetExampleWriter example_writer(examples_wspecifier); + + int32 num_outputs = (po.NumArgs() - 2) / 2; + KALDI_ASSERT(num_outputs > 0); + + std::vector deriv_weights_readers(num_outputs, + static_cast(NULL)); + std::vector dense_targets_readers(num_outputs, + static_cast(NULL)); + std::vector sparse_targets_readers(num_outputs, + static_cast(NULL)); + + + std::vector compress_targets(1, true); + std::vector compress_targets_vector; + + if (!compress_targets_str.empty()) { + SplitStringToVector(compress_targets_str, ":,", + true, &compress_targets_vector); + } + + if (compress_targets_vector.size() == 1 && num_outputs != 1) { + KALDI_WARN << "compress-targets is of size 1. " + << "Extending it to size num-outputs=" << num_outputs; + compress_targets[0] = ToBool(compress_targets_vector[0]); + compress_targets.resize(num_outputs, ToBool(compress_targets_vector[0])); + } else { + if (compress_targets_vector.size() != num_outputs) { + KALDI_ERR << "Mismatch in length of compress-targets and num-outputs; " + << compress_targets_vector.size() << " vs " << num_outputs; + } + for (int32 n = 0; n < num_outputs; n++) { + compress_targets[n] = ToBool(compress_targets_vector[n]); + } + } + + std::vector targets_compress_formats(1, 1); + if (!targets_compress_formats_str.empty()) { + SplitStringToIntegers(targets_compress_formats_str, ":,", + true, &targets_compress_formats); + } + + if (targets_compress_formats.size() == 1 && num_outputs != 1) { + KALDI_WARN << "targets-compress-formats is of size 1. " + << "Extending it to size num-outputs=" << num_outputs; + targets_compress_formats.resize(num_outputs, targets_compress_formats[0]); + } + + if (targets_compress_formats.size() != num_outputs) { + KALDI_ERR << "Mismatch in length of targets-compress-formats and num-outputs; " + << targets_compress_formats.size() << " vs " << num_outputs; + } + + std::vector output_dims(num_outputs); + SplitStringToIntegers(output_dims_str, ":,", + true, &output_dims); + + std::vector output_names(num_outputs); + SplitStringToVector(output_names_str, ":,", true, &output_names); + + //std::vector deriv_weights_rspecifiers; + //if (!deriv_weights_rspecifiers_str.empty()) { + // std::vector parts; + // SplitStringToVector(deriv_weights_rspecifiers_str, ":,", + // false, &deriv_weights_rspecifiers); + + // if (deriv_weights_rspecifiers.size() != num_outputs) { + // KALDI_ERR << "Expecting the number of deriv-weights-rspecifiers to " + // << "be equal to the number of outputs"; + // } + //} + + std::vector targets_rspecifiers(num_outputs); + std::vector deriv_weights_rspecifiers(num_outputs); + + for (int32 n = 0; n < num_outputs; n++) { + const std::string &targets_rspecifier = po.GetArg(2*n + 2); + const std::string &deriv_weights_rspecifier = po.GetArg(2*n + 3); + + targets_rspecifiers[n] = targets_rspecifier; + deriv_weights_rspecifiers[n] = deriv_weights_rspecifier; + + if (output_dims[n] >= 0) { + sparse_targets_readers[n] = new RandomAccessPosteriorReader(targets_rspecifier); + } else { + dense_targets_readers[n] = new RandomAccessBaseFloatMatrixReader(targets_rspecifier); + } + + if (!deriv_weights_rspecifier.empty()) + deriv_weights_readers[n] = new RandomAccessBaseFloatVectorReader(deriv_weights_rspecifier); + + KALDI_LOG << "output-name=" << output_names[n] + << " target-dim=" << output_dims[n] + << " targets-rspecifier=\"" << targets_rspecifiers[n] << "\"" + << " deriv-weights-rspecifier=\"" << deriv_weights_rspecifiers[n] << "\"" + << " compress-target=" << (compress_targets[n] ? "true" : "false") + << " target-compress-format=" << targets_compress_formats[n]; + } + + int32 num_done = 0, num_err = 0; + + std::vector num_frames_written(num_outputs, 0); + std::vector num_egs_written(num_outputs, 0); + + for (; !feat_reader.Done(); feat_reader.Next()) { + std::string key = feat_reader.Key(); + const Matrix &feats = feat_reader.Value(); + + const Matrix *ivector_feats = NULL; + if (!ivector_rspecifier.empty()) { + if (!ivector_reader.HasKey(key)) { + KALDI_WARN << "No iVectors for utterance " << key; + num_err++; + continue; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + ivector_feats = &(ivector_reader.Value(key)); + } + } + + if (ivector_feats && + (abs(feats.NumRows() - ivector_feats->NumRows()) > length_tolerance + || ivector_feats->NumRows() == 0)) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and iVectors " << ivector_feats->NumRows() + << "exceeds tolerance " << length_tolerance; + num_err++; + continue; + } + + std::vector* > dense_targets(num_outputs, static_cast* >(NULL)); + std::vector sparse_targets(num_outputs, static_cast(NULL)); + std::vector* > deriv_weights(num_outputs, static_cast* >(NULL)); + + int32 num_outputs_found = 0; + for (int32 n = 0; n < num_outputs; n++) { + if (dense_targets_readers[n]) { + if (!dense_targets_readers[n]->HasKey(key)) { + KALDI_WARN << "No dense targets matrix for key " << key << " in " + << "rspecifier " << targets_rspecifiers[n] + << " for output " << output_names[n]; + continue; + } + const MatrixBase *target_matrix = &(dense_targets_readers[n]->Value(key)); + + if ((target_matrix->NumRows() - feats.NumRows()) > length_tolerance) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and target matrix " << target_matrix->NumRows() + << "exceeds tolerance " << length_tolerance; + num_err++; + continue; + } + + dense_targets[n] = target_matrix; + } else { + if (!sparse_targets_readers[n]->HasKey(key)) { + KALDI_WARN << "No sparse target matrix for key " << key << " in " + << "rspecifier " << targets_rspecifiers[n] + << " for output " << output_names[n]; + continue; + } + const Posterior *posterior = &(sparse_targets_readers[n]->Value(key)); + + if (abs(static_cast(posterior->size()) - feats.NumRows()) > length_tolerance + || posterior->size() < feats.NumRows()) { + KALDI_WARN << "Posterior has wrong size " << posterior->size() + << " versus " << feats.NumRows(); + num_err++; + continue; + } + + sparse_targets[n] = posterior; + } + + if (deriv_weights_readers[n]) { + if (!deriv_weights_readers[n]->HasKey(key)) { + KALDI_WARN << "No deriv weights for key " << key << " in " + << "rspecifier " << deriv_weights_rspecifiers[n] + << " for output " << output_names[n]; + num_err++; + sparse_targets[n] = NULL; + dense_targets[n] = NULL; + continue; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + deriv_weights[n] = &(deriv_weights_readers[n]->Value(key)); + } + } + + if (deriv_weights[n] && + (abs(feats.NumRows() - deriv_weights[n]->Dim()) > length_tolerance + || deriv_weights[n]->Dim() == 0)) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and deriv weights " << deriv_weights[n]->Dim() + << " exceeds tolerance " << length_tolerance; + num_err++; + sparse_targets[n] = NULL; + dense_targets[n] = NULL; + deriv_weights[n] = NULL; + continue; + } + + num_outputs_found++; + } + + if (num_outputs_found == 0) { + KALDI_WARN << "No output found for key " << key; + num_err++; + continue; + } + + ProcessFile(feats, ivector_feats, output_names, output_dims, + dense_targets, sparse_targets, + deriv_weights, key, + compress_input, input_compress_format, + compress_targets, targets_compress_formats, + left_context, right_context, num_frames, + &num_frames_written, &num_egs_written, + &example_writer); + num_done++; + } + + int64 max_num_egs_written = 0, max_num_frames_written = 0; + for (int32 n = 0; n < num_outputs; n++) { + delete dense_targets_readers[n]; + delete sparse_targets_readers[n]; + delete deriv_weights_readers[n]; + if (num_egs_written[n] == 0) return false; + if (num_egs_written[n] > max_num_egs_written) { + max_num_egs_written = num_egs_written[n]; + max_num_frames_written = num_frames_written[n]; + } + } + + KALDI_LOG << "Finished generating examples, " + << "successfully processed " << num_done + << " feature files, wrote at most " << max_num_egs_written << " examples, " + << " with at most " << max_num_frames_written << " egs in total; " + << num_err << " files had errors."; + + return (num_err > num_done ? 1 : 0); + } catch(const std::exception &e) { + std::cerr << e.what() << '\n'; + return -1; + } +} + + diff --git a/src/nnet3bin/nnet3-get-egs.cc b/src/nnet3bin/nnet3-get-egs.cc index 75f264f1ceb..dbf8b636305 100644 --- a/src/nnet3bin/nnet3-get-egs.cc +++ b/src/nnet3bin/nnet3-get-egs.cc @@ -32,9 +32,12 @@ namespace nnet3 { static void ProcessFile(const MatrixBase &feats, const MatrixBase *ivector_feats, + const VectorBase *deriv_weights, const Posterior &pdf_post, const std::string &utt_id, bool compress, + int32 input_compress_format, + int32 feats_compress_format, int32 num_pdfs, int32 left_context, int32 right_context, @@ -42,16 +45,16 @@ static void ProcessFile(const MatrixBase &feats, int64 *num_frames_written, int64 *num_egs_written, NnetExampleWriter *example_writer) { - KALDI_ASSERT(feats.NumRows() == static_cast(pdf_post.size())); - - for (int32 t = 0; t < feats.NumRows(); t += frames_per_eg) { + //KALDI_ASSERT(feats.NumRows() == static_cast(pdf_post.size())); + int32 min_size = std::min(feats.NumRows(), static_cast(pdf_post.size())); + for (int32 t = 0; t < min_size; t += frames_per_eg) { // actual_frames_per_eg is the number of frames with nonzero // posteriors. At the end of the file we pad with zero posteriors // so that all examples have the same structure (prevents the need // for recompilations). int32 actual_frames_per_eg = std::min(frames_per_eg, - feats.NumRows() - t); + min_size - t); int32 tot_frames = left_context + frames_per_eg + right_context; @@ -62,7 +65,7 @@ static void ProcessFile(const MatrixBase &feats, for (int32 j = -left_context; j < frames_per_eg + right_context; j++) { int32 t2 = j + t; if (t2 < 0) t2 = 0; - if (t2 >= feats.NumRows()) t2 = feats.NumRows() - 1; + if (t2 >= min_size) t2 = min_size - 1; SubVector src(feats, t2), dest(input_frames, j + left_context); dest.CopyFromVec(src); @@ -73,9 +76,12 @@ static void ProcessFile(const MatrixBase &feats, // call the regular input "input". eg.io.push_back(NnetIo("input", - left_context, input_frames)); + + if (compress) + eg.io.back().Compress(input_compress_format); // if applicable, add the iVector feature. - if (ivector_feats != NULL) { + if (ivector_feats) { // try to get closest frame to middle of window to get // a representative iVector. int32 closest_frame = t + (actual_frames_per_eg / 2); @@ -92,10 +98,20 @@ static void ProcessFile(const MatrixBase &feats, for (int32 i = 0; i < actual_frames_per_eg; i++) labels[i] = pdf_post[t + i]; // remaining posteriors for frames are empty. - eg.io.push_back(NnetIo("output", num_pdfs, 0, labels)); + + if (!deriv_weights) { + eg.io.push_back(NnetIo("output", num_pdfs, 0, labels)); + } else { + Vector this_deriv_weights(frames_per_eg); + int32 frames_to_copy = std::min(t + actual_frames_per_eg, deriv_weights->Dim()) - t; + this_deriv_weights.Range(0, frames_to_copy).CopyFromVec(deriv_weights->Range(t, frames_to_copy)); + if (this_deriv_weights.Sum() == 0) continue; // Ignore frames that have frame weights 0 + eg.io.push_back(NnetIo("output", this_deriv_weights, num_pdfs, 0, labels)); + } + if (compress) - eg.Compress(); + eg.Compress(feats_compress_format); std::ostringstream os; os << utt_id << "-" << t; @@ -140,14 +156,19 @@ int main(int argc, char *argv[]) { bool compress = true; + int32 input_compress_format = 0, feats_compress_format = 0; int32 num_pdfs = -1, left_context = 0, right_context = 0, num_frames = 1, length_tolerance = 100; - std::string ivector_rspecifier; + std::string ivector_rspecifier, deriv_weights_rspecifier; ParseOptions po(usage); po.Register("compress", &compress, "If true, write egs in " "compressed format."); + po.Register("compress-format", &feats_compress_format, "Format for " + "compressing all feats in general"); + po.Register("input-compress-format", &input_compress_format, "Format for " + "compressing input feats e.g. Use 2 for compressing wave"); po.Register("num-pdfs", &num_pdfs, "Number of pdfs in the acoustic " "model"); po.Register("left-context", &left_context, "Number of frames of left " @@ -160,6 +181,11 @@ int main(int argc, char *argv[]) { "features, as a matrix."); po.Register("length-tolerance", &length_tolerance, "Tolerance for " "difference in num-frames between feat and ivector matrices"); + po.Register("deriv-weights-rspecifier", &deriv_weights_rspecifier, + "Per-frame weights (only binary - 0 or 1) that specifies " + "whether a frame's gradient must be backpropagated or not. " + "Not specifying this is equivalent to specifying a vector of " + "all 1s."); po.Read(argc, argv); @@ -181,6 +207,7 @@ int main(int argc, char *argv[]) { RandomAccessPosteriorReader pdf_post_reader(pdf_post_rspecifier); NnetExampleWriter example_writer(examples_wspecifier); RandomAccessBaseFloatMatrixReader ivector_reader(ivector_rspecifier); + RandomAccessBaseFloatVectorReader deriv_weights_reader(deriv_weights_rspecifier); int32 num_done = 0, num_err = 0; int64 num_frames_written = 0, num_egs_written = 0; @@ -192,13 +219,17 @@ int main(int argc, char *argv[]) { KALDI_WARN << "No pdf-level posterior for key " << key; num_err++; } else { - const Posterior &pdf_post = pdf_post_reader.Value(key); - if (pdf_post.size() != feats.NumRows()) { + Posterior pdf_post = pdf_post_reader.Value(key); + if (abs(static_cast(pdf_post.size()) - feats.NumRows()) > length_tolerance + || pdf_post.size() < feats.NumRows()) { KALDI_WARN << "Posterior has wrong size " << pdf_post.size() << " versus " << feats.NumRows(); num_err++; continue; } + while (static_cast(pdf_post.size()) > feats.NumRows()) { + pdf_post.pop_back(); + } const Matrix *ivector_feats = NULL; if (!ivector_rspecifier.empty()) { if (!ivector_reader.HasKey(key)) { @@ -212,7 +243,7 @@ int main(int argc, char *argv[]) { } } - if (ivector_feats != NULL && + if (ivector_feats && (abs(feats.NumRows() - ivector_feats->NumRows()) > length_tolerance || ivector_feats->NumRows() == 0)) { KALDI_WARN << "Length difference between feats " << feats.NumRows() @@ -221,8 +252,33 @@ int main(int argc, char *argv[]) { num_err++; continue; } + + const Vector *deriv_weights = NULL; + if (!deriv_weights_rspecifier.empty()) { + if (!deriv_weights_reader.HasKey(key)) { + KALDI_WARN << "No deriv weights for utterance " << key; + num_err++; + continue; + } else { + // this address will be valid until we call HasKey() or Value() + // again. + deriv_weights = &(deriv_weights_reader.Value(key)); + } + } + + if (deriv_weights && + (abs(feats.NumRows() - deriv_weights->Dim()) > length_tolerance + || deriv_weights->Dim() == 0)) { + KALDI_WARN << "Length difference between feats " << feats.NumRows() + << " and deriv weights " << deriv_weights->Dim() + << " exceeds tolerance " << length_tolerance; + num_err++; + continue; + } + - ProcessFile(feats, ivector_feats, pdf_post, key, compress, + ProcessFile(feats, ivector_feats, deriv_weights, pdf_post, + key, compress, input_compress_format, feats_compress_format, num_pdfs, left_context, right_context, num_frames, &num_frames_written, &num_egs_written, &example_writer); diff --git a/src/nnet3bin/nnet3-latgen-faster.cc b/src/nnet3bin/nnet3-latgen-faster.cc index 5a090acb5b5..e0f21e723e7 100644 --- a/src/nnet3bin/nnet3-latgen-faster.cc +++ b/src/nnet3bin/nnet3-latgen-faster.cc @@ -65,6 +65,8 @@ int main(int argc, char *argv[]) { po.Register("ivectors", &ivector_rspecifier, "Rspecifier for " "iVectors as vectors (i.e. not estimated online); per utterance " "by default, or per speaker if you provide the --utt2spk option."); + po.Register("utt2spk", &utt2spk_rspecifier, "Rspecifier for " + "utt2spk option used to get ivectors per speaker"); po.Register("online-ivectors", &online_ivector_rspecifier, "Rspecifier for " "iVectors estimated online, as matrices. If you supply this," " you must set the --online-ivector-period option."); diff --git a/src/nnet3bin/nnet3-merge-egs.cc b/src/nnet3bin/nnet3-merge-egs.cc index 8627671f53a..30096ab9988 100644 --- a/src/nnet3bin/nnet3-merge-egs.cc +++ b/src/nnet3bin/nnet3-merge-egs.cc @@ -26,11 +26,13 @@ namespace kaldi { namespace nnet3 { -// returns the number of indexes/frames in the NnetIo named "output" in the eg, -// or crashes if it is not there. +// returns the number of indexes/frames in the output NnetIo +// assumes the output name starts with "output" and only looks at the +// first such output to get the indexes size. +// crashes if it there is no such output int32 NumOutputIndexes(const NnetExample &eg) { for (size_t i = 0; i < eg.io.size(); i++) - if (eg.io[i].name == "output") + if (eg.io[i].name.find("output") != std::string::npos) return eg.io[i].indexes.size(); KALDI_ERR << "No output named 'output' in the eg."; return 0; // Suppress compiler warning. diff --git a/src/nnet3bin/nnet3-show-progress.cc b/src/nnet3bin/nnet3-show-progress.cc index 10898dc0ca6..785d3d0aa88 100644 --- a/src/nnet3bin/nnet3-show-progress.cc +++ b/src/nnet3bin/nnet3-show-progress.cc @@ -107,17 +107,39 @@ int main(int argc, char *argv[]) { eg_end = examples.end(); for (; eg_iter != eg_end; ++eg_iter) prob_computer.Compute(*eg_iter); - const SimpleObjectiveInfo *objf_info = prob_computer.GetObjective("output"); - double objf_per_frame = objf_info->tot_objective / objf_info->tot_weight; + + double tot_weight = 0.0; + + { + const unordered_map &objf_info = prob_computer.GetAllObjectiveInfo(); + + unordered_map::const_iterator objf_it = objf_info.begin(), + objf_end = objf_info.end(); + + + for (; objf_it != objf_end; ++objf_it) { + double objf_per_frame = objf_it->second.tot_objective / objf_it->second.tot_weight; + + if (objf_it->first == "output") { + KALDI_LOG << "At position " << middle + << ", objf per frame is " << objf_per_frame; + } else { + KALDI_LOG << "At position " << middle + << ", objf per frame for '" << objf_it->first + << "' is " << objf_per_frame; + } + + tot_weight += objf_it->second.tot_weight; + } + } + const Nnet &nnet_gradient = prob_computer.GetDeriv(); - KALDI_LOG << "At position " << middle - << ", objf per frame is " << objf_per_frame; Vector old_dotprod(num_updatable), new_dotprod(num_updatable); ComponentDotProducts(nnet_gradient, nnet1, &old_dotprod); ComponentDotProducts(nnet_gradient, nnet2, &new_dotprod); - old_dotprod.Scale(1.0 / objf_info->tot_weight); - new_dotprod.Scale(1.0 / objf_info->tot_weight); + old_dotprod.Scale(1.0 / tot_weight); + new_dotprod.Scale(1.0 / tot_weight); diff.AddVec(1.0/ num_segments, new_dotprod); diff.AddVec(-1.0 / num_segments, old_dotprod); KALDI_VLOG(1) << "By segment " << s << ", objf change is " diff --git a/src/online2bin/ivector-extract-online2.cc b/src/online2bin/ivector-extract-online2.cc index 3251d93b5dd..f597f66763b 100644 --- a/src/online2bin/ivector-extract-online2.cc +++ b/src/online2bin/ivector-extract-online2.cc @@ -55,6 +55,8 @@ int main(int argc, char *argv[]) { g_num_threads = 8; bool repeat = false; + int32 length_tolerance = 0; + std::string frame_weights_rspecifier; po.Register("num-threads", &g_num_threads, "Number of threads to use for computing derived variables " @@ -62,6 +64,12 @@ int main(int argc, char *argv[]) { po.Register("repeat", &repeat, "If true, output the same number of iVectors as input frames " "(including repeated data)."); + po.Register("frame-weights-rspecifier", &frame_weights_rspecifier, + "Archive of frame weights to scale stats"); + po.Register("length-tolerance", &length_tolerance, + "Tolerance on the difference in number of frames " + "for feats and weights"); + po.Read(argc, argv); if (po.NumArgs() != 3) { @@ -82,9 +90,9 @@ int main(int argc, char *argv[]) { SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier); + RandomAccessBaseFloatVectorReader frame_weights_reader(frame_weights_rspecifier); BaseFloatMatrixWriter ivector_writer(ivectors_wspecifier); - for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { std::string spk = spk2utt_reader.Key(); const std::vector &uttlist = spk2utt_reader.Value(); @@ -105,6 +113,31 @@ int main(int argc, char *argv[]) { &matrix_feature); ivector_feature.SetAdaptationState(adaptation_state); + + if (!frame_weights_rspecifier.empty()) { + if (!frame_weights_reader.HasKey(utt)) { + KALDI_WARN << "Did not find weights for utterance " << utt; + num_err++; + continue; + } + const Vector &weights = frame_weights_reader.Value(utt); + + if (std::abs(weights.Dim() - feats.NumRows()) > length_tolerance) { + num_err++; + continue; + } + + std::vector > frame_weights; + for (int32 i = 0; i < feats.NumRows(); i++) { + if (i < weights.Dim()) + frame_weights.push_back(std::make_pair(i, weights(i))); + else + frame_weights.push_back(std::make_pair(i, 0.0)); + } + + + ivector_feature.UpdateFrameWeights(frame_weights); + } int32 T = feats.NumRows(), n = (repeat ? 1 : ivector_config.ivector_period), diff --git a/src/segmenter/Makefile b/src/segmenter/Makefile new file mode 100644 index 00000000000..03df6132050 --- /dev/null +++ b/src/segmenter/Makefile @@ -0,0 +1,16 @@ +all: + +include ../kaldi.mk + +TESTFILES = segmentation-io-test + +OBJFILES = segment.o segmentation.o segmentation-utils.o \ + segmentation-post-processor.o + +LIBNAME = kaldi-segmenter + +ADDLIBS = ../gmm/kaldi-gmm.a \ + ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a ../thread/kaldi-thread.a + +include ../makefiles/default_rules.mk + diff --git a/src/segmenter/segment.cc b/src/segmenter/segment.cc new file mode 100644 index 00000000000..b4f485c26bc --- /dev/null +++ b/src/segmenter/segment.cc @@ -0,0 +1,35 @@ +#include "segmenter/segment.h" + +namespace kaldi { +namespace segmenter { + +void Segment::Write(std::ostream &os, bool binary) const { + if (binary) { + os.write(reinterpret_cast(&start_frame), sizeof(start_frame)); + os.write(reinterpret_cast(&end_frame), sizeof(start_frame)); + os.write(reinterpret_cast(&class_id), sizeof(class_id)); + } else { + WriteBasicType(os, binary, start_frame); + WriteBasicType(os, binary, end_frame); + WriteBasicType(os, binary, Label()); + } +} + +void Segment::Read(std::istream &is, bool binary) { + if (binary) { + is.read(reinterpret_cast(&start_frame), sizeof(start_frame)); + is.read(reinterpret_cast(&end_frame), sizeof(end_frame)); + is.read(reinterpret_cast(&class_id), sizeof(class_id)); + } else { + ReadBasicType(is, binary, &start_frame); + ReadBasicType(is, binary, &end_frame); + int32 label; + ReadBasicType(is, binary, &label); + SetLabel(label); + } + + KALDI_ASSERT(end_frame >= start_frame && start_frame >= 0); +} + +} // end namespace segmenter +} // end namespace kaldi diff --git a/src/segmenter/segment.h b/src/segmenter/segment.h new file mode 100644 index 00000000000..1657affc875 --- /dev/null +++ b/src/segmenter/segment.h @@ -0,0 +1,78 @@ +#ifndef KALDI_SEGMENTER_SEGMENT_H_ +#define KALDI_SEGMENTER_SEGMENT_H_ + +#include "base/kaldi-common.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { +namespace segmenter { + +/** + * This structure defines a single segment. It consists of the following basic + * properties: + * 1) start_frame : This is the frame index of the first frame in the + * segment. + * 2) end_frame : This is the frame index of the last frame in the segment. + * Note that the end_frame is included in the segment. + * 3) class_id : This is the class corresponding to the segments. For e.g., + * could be 0, 1 or 2 depending on whether the segment is + * silence, speech or noise. In general, it can be any + * integer class label. +**/ + +struct Segment { + int32 start_frame; + int32 end_frame; + int32 class_id; + + // Accessors for labels or class id. This is useful in the future when + // we might change the type of label. + inline int32 Label() const { return class_id; } + inline void SetLabel(int32 label) { class_id = label; } + inline int32 Length() const { return end_frame - start_frame + 1; } + + // This is the default constructor that sets everything to undefined values. + Segment() : start_frame(-1), end_frame(-1), class_id(-1) { } + + // This constructor initializes the segmented with the provided start and end + // frames and the segment label. This is the main constructor. + Segment(int32 start, int32 end, int32 label) : + start_frame(start), end_frame(end), class_id(label) { } + + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); + + // This is a function that returns the size of the elements in the structure. + // It is used during I/O in binary mode, which checks for the total size + // required to store the segment. + static size_t SizeInBytes() { + return (sizeof(int32) + sizeof(int32) + sizeof(int32)); + } +}; + +/** + * Comparator to order segments based on start frame +**/ + +class SegmentComparator { + public: + bool operator() (const Segment &lhs, const Segment &rhs) const { + return lhs.start_frame < rhs.start_frame; + } +}; + +/** + * Comparator to order segments based on length +**/ + +class SegmentLengthComparator { + public: + bool operator() (const Segment &lhs, const Segment &rhs) const { + return lhs.Length() < rhs.Length(); + } +}; + +} // end namespace segmenter +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_SEGMENT_H_ diff --git a/src/segmenter/segmentation-io-test.cc b/src/segmenter/segmentation-io-test.cc new file mode 100644 index 00000000000..f019a653a4a --- /dev/null +++ b/src/segmenter/segmentation-io-test.cc @@ -0,0 +1,63 @@ +// segmenter/segmentation-io-test.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmentation.h" + +namespace kaldi { +namespace segmenter { + +void UnitTestSegmentationIo() { + Segmentation seg; + int32 max_length = RandInt(0, 1000), + max_segment_length = max_length / 10, + num_classes = RandInt(0, 3); + + if (max_segment_length == 0) + max_segment_length = 1; + + seg.GenRandomSegmentation(max_length, max_segment_length, num_classes); + + bool binary = ( RandInt(0,1) == 0 ); + std::ostringstream os; + + seg.Write(os, binary); + + Segmentation seg2; + std::istringstream is(os.str()); + seg2.Read(is, binary); + + std::ostringstream os2; + seg2.Write(os2, binary); + + KALDI_ASSERT(os2.str() == os.str()); +} + +} // namespace segmenter +} // namespace kaldi + +int main() { + using namespace kaldi; + using namespace kaldi::segmenter; + + for (int32 i = 0; i < 100; i++) + UnitTestSegmentationIo(); + return 0; +} + + diff --git a/src/segmenter/segmentation-post-processor.cc b/src/segmenter/segmentation-post-processor.cc new file mode 100644 index 00000000000..2c97e31db56 --- /dev/null +++ b/src/segmenter/segmentation-post-processor.cc @@ -0,0 +1,198 @@ +// segmenter/segmentation-post-processor.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmentation-utils.h" +#include "segmenter/segmentation-post-processor.h" + +namespace kaldi { +namespace segmenter { + +static inline bool IsMergingLabelsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (!opts.merge_labels_csl.empty() || opts.merge_dst_label != -1); +} + +static inline bool IsPaddingSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (opts.pad_label != -1 || opts.pad_length != -1); +} + +static inline bool IsShrinkingSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (opts.shrink_label != -1 || opts.shrink_length != -1); +} + +static inline bool IsBlendingShortSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (opts.blend_short_segments_class != -1 || opts.max_blend_length != -1); +} + +static inline bool IsRemovingSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (!opts.remove_labels_csl.empty()); +} + +static inline bool IsMergingAdjacentSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (opts.merge_adjacent_segments); +} + +static inline bool IsSplittingSegmentsToBeDone( + const SegmentationPostProcessingOptions &opts) { + return (opts.max_segment_length != -1); +} + + +SegmentationPostProcessor::SegmentationPostProcessor( + const SegmentationPostProcessingOptions &opts) : opts_(opts) { + if (!opts_.remove_labels_csl.empty()) { + if (!SplitStringToIntegers(opts_.remove_labels_csl, ":", + false, &remove_labels_)) { + KALDI_ERR << "Bad value for --remove-labels option: " + << opts_.remove_labels_csl; + } + std::sort(remove_labels_.begin(), remove_labels_.end()); + } + + if (!opts_.merge_labels_csl.empty()) { + if (!SplitStringToIntegers(opts_.merge_labels_csl, ":", + false, &merge_labels_)) { + KALDI_ERR << "Bad value for --merge-labels option: " + << opts_.merge_labels_csl; + } + std::sort(merge_labels_.begin(), merge_labels_.end()); + } + + Check(); +} + +void SegmentationPostProcessor::Check() const { + if (IsPaddingSegmentsToBeDone(opts_) && opts_.pad_label < 0) { + KALDI_ERR << "Invalid value " << opts_.pad_label << " for option " + << "--pad-label. It must be non-negative."; + } + + if (IsPaddingSegmentsToBeDone(opts_) && opts_.pad_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.pad_length << " for option " + << "--pad-length. It must be positive."; + } + + if (IsShrinkingSegmentsToBeDone(opts_) && opts_.shrink_label < 0) { + KALDI_ERR << "Invalid value " << opts_.shrink_label << " for option " + << "--shrink-label. It must be non-negative."; + } + + if (IsShrinkingSegmentsToBeDone(opts_) && opts_.shrink_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.shrink_length << " for option " + << "--shrink-length. It must be positive."; + } + + if (IsBlendingShortSegmentsToBeDone(opts_) && + opts_.blend_short_segments_class < 0) { + KALDI_ERR << "Invalid value " << opts_.blend_short_segments_class + << " for option " << "--blend-short-segments-class. " + << "It must be non-negative."; + } + + if (IsBlendingShortSegmentsToBeDone(opts_) && opts_.max_blend_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.max_blend_length << " for option " + << "--max-blend-length. It must be positive."; + } + + if (IsRemovingSegmentsToBeDone(opts_) && remove_labels_[0] < 0) { + KALDI_ERR << "Invalid value " << opts_.remove_labels_csl + << " for option " << "--remove-labels. " + << "The labels must be non-negative."; + } + + if (IsMergingAdjacentSegmentsToBeDone(opts_) && + opts_.max_intersegment_length < 0) { + KALDI_ERR << "Invalid value " << opts_.max_intersegment_length + << " for option " + << "--max-intersegment-length. It must be non-negative."; + } + + if (IsSplittingSegmentsToBeDone(opts_) && opts_.max_segment_length <= 0) { + KALDI_ERR << "Invalid value " << opts_.max_segment_length + << " for option " + << "--max-segment-length. It must be positive."; + } + + if (opts_.post_process_label != -1 && opts_.post_process_label < 0) { + KALDI_ERR << "Invalid value " << opts_.post_process_label << " for option " + << "--post-process-label. It must be non-negative."; + } +} + +bool SegmentationPostProcessor::PostProcess(Segmentation *seg) const { + DoMergingLabels(seg); + DoPaddingSegments(seg); + DoShrinkingSegments(seg); + DoBlendingShortSegments(seg); + DoRemovingSegments(seg); + DoMergingAdjacentSegments(seg); + DoSplittingSegments(seg); + + return true; +} + +void SegmentationPostProcessor::DoMergingLabels(Segmentation *seg) const { + if (!IsMergingLabelsToBeDone(opts_)) return; + MergeLabels(merge_labels_, opts_.merge_dst_label, seg); +} + +void SegmentationPostProcessor::DoPaddingSegments(Segmentation *seg) const { + if (!IsPaddingSegmentsToBeDone(opts_)) return; + PadSegments(opts_.pad_label, opts_.pad_length, seg); +} + +void SegmentationPostProcessor::DoShrinkingSegments(Segmentation *seg) const { + if (!IsShrinkingSegmentsToBeDone(opts_)) return; + ShrinkSegments(opts_.shrink_label, opts_.shrink_length, seg); +} + +void SegmentationPostProcessor::DoBlendingShortSegments( + Segmentation *seg) const { + if (!IsBlendingShortSegmentsToBeDone(opts_)) return; + BlendShortSegmentsWithNeighbors(opts_.blend_short_segments_class, + opts_.max_blend_length, + opts_.max_intersegment_length, seg); +} + +void SegmentationPostProcessor::DoRemovingSegments(Segmentation *seg) const { + if (!IsRemovingSegmentsToBeDone(opts_)) return; + RemoveSegments(remove_labels_, seg); +} + +void SegmentationPostProcessor::DoMergingAdjacentSegments( + Segmentation *seg) const { + if (!IsMergingAdjacentSegmentsToBeDone(opts_)) return; + MergeAdjacentSegments(opts_.max_intersegment_length, seg); +} + +void SegmentationPostProcessor::DoSplittingSegments(Segmentation *seg) const { + if (!IsSplittingSegmentsToBeDone(opts_)) return; + SplitSegments(opts_.max_segment_length, + opts_.max_segment_length / 2, + opts_.overlap_length, + opts_.post_process_label, seg); +} + +} // end namespace segmenter +} // end namespace kaldi diff --git a/src/segmenter/segmentation-post-processor.h b/src/segmenter/segmentation-post-processor.h new file mode 100644 index 00000000000..01a23b93b1b --- /dev/null +++ b/src/segmenter/segmentation-post-processor.h @@ -0,0 +1,168 @@ +// segmenter/segmentation-post-processor.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SEGMENTER_SEGMENTATION_POST_PROCESSOR_H_ +#define KALDI_SEGMENTER_SEGMENTATION_POST_PROCESSOR_H_ + +#include "base/kaldi-common.h" +#include "itf/options-itf.h" +#include "segmenter/segmentation.h" + +namespace kaldi { +namespace segmenter { + +/** + * Structure for some common options related to segmentation that would be used + * in multiple segmentation programs. Some of the operations include merging, + * filtering etc. +**/ + +struct SegmentationPostProcessingOptions { + std::string merge_labels_csl; + int32 merge_dst_label; + + int32 pad_label; + int32 pad_length; + + int32 shrink_label; + int32 shrink_length; + + int32 blend_short_segments_class; + int32 max_blend_length; + + std::string remove_labels_csl; + + bool merge_adjacent_segments; + int32 max_intersegment_length; + + int32 max_segment_length; + int32 overlap_length; + + int32 post_process_label; + + SegmentationPostProcessingOptions() : + merge_dst_label(-1), + pad_label(-1), pad_length(-1), + shrink_label(-1), shrink_length(-1), + blend_short_segments_class(-1), max_blend_length(-1), + merge_adjacent_segments(false), max_intersegment_length(0), + max_segment_length(-1), overlap_length(0), + post_process_label(-1) { } + + void Register(OptionsItf *opts) { + opts->Register("merge-labels", &merge_labels_csl, "Merge labels into a " + "single label defined by merge-dst-label. " + "The labels are specified as a colon-separated list. " + "Refer to the MergeLabels() code for details. " + "Used in conjunction with the option --merge-dst-label"); + opts->Register("merge-dst-label", &merge_dst_label, + "Merge labels specified by merge-labels into this label. " + "Refer to the MergeLabels() code for details. " + "Used in conjunction with the option --merge-labels."); + opts->Register("pad-label", &pad_label, + "Pad segments of this label by pad_length frames." + "Refer to the PadSegments() code for details. " + "Used in conjunction with the option --pad-length."); + opts->Register("pad-length", &pad_length, "Pad segments by this many " + "frames on either side. " + "Refer to the PadSegments() code for details. " + "Used in conjunction with the option --pad-label."); + opts->Register("shrink-label", &shrink_label, + "Shrink segments of this label by shrink_length frames. " + "Refer to the ShrinkSegments() code for details. " + "Used in conjunction with the option --shrink-length."); + opts->Register("shrink-length", &shrink_length, "Shrink segments by this " + "many frames on either side. " + "Refer to the ShrinkSegments() code for details. " + "Used in conjunction with the option --shrink-label."); + opts->Register("blend-short-segments-class", &blend_short_segments_class, + "The label for which the short segments are to be " + "blended with the neighboring segments that are less than " + "max_intersegment_length frames away. " + "Refer to BlendShortSegments() code for details. " + "Used in conjunction with the option --max-blend-length " + "and --max-intersegment-length."); + opts->Register("max-blend-length", &max_blend_length, + "The maximum length of segment in number of frames that " + "will be blended with the neighboring segments provided " + "they both have the same label. " + "Refer to BlendShortSegments() code for details. " + "Used in conjunction with the option " + "--blend-short-segments-class"); + opts->Register("remove-labels", &remove_labels_csl, + "Remove any segment whose label is contained in " + "remove_labels_csl. " + "Refer to the RemoveLabels() code for details."); + opts->Register("merge-adjacent-segments", &merge_adjacent_segments, + "Merge adjacent segments of the same label if they are " + "within max-intersegment-length distance. " + "Refer to the MergeAdjacentSegments() code for details. " + "Used in conjunction with the option " + "--max-intersegment-length\n"); + opts->Register("max-intersegment-length", &max_intersegment_length, + "The maximum intersegment length that is allowed for " + "two adjacent segments to be merged. " + "Refer to the MergeAdjacentSegments() code for details. " + "Used in conjunction with the option " + "--merge-adjacent-segments or " + "--blend-short-segments-class\n"); + opts->Register("max-segment-length", &max_segment_length, + "If segment is longer than this length, split it into " + "pieces with less than these many frames. " + "Refer to the SplitSegments() code for details. " + "Used in conjunction with the option --overlap-length."); + opts->Register("overlap-length", &overlap_length, + "When splitting segments longer than max-segment-length, " + "have the pieces overlap by these many frames. " + "Refer to the SplitSegments() code for details. " + "Used in conjunction with the option --max-segment-length."); + opts->Register("post-process-label", &post_process_label, + "Do post processing only on this label. This option is " + "applicable to only a few operations including " + "SplitSegments"); + } +}; + +class SegmentationPostProcessor { + public: + explicit SegmentationPostProcessor( + const SegmentationPostProcessingOptions &opts); + + bool PostProcess(Segmentation *seg) const; + + void DoMergingLabels(Segmentation *seg) const; + void DoPaddingSegments(Segmentation *seg) const; + void DoShrinkingSegments(Segmentation *seg) const; + void DoBlendingShortSegments(Segmentation *seg) const; + void DoRemovingSegments(Segmentation *seg) const; + void DoMergingAdjacentSegments(Segmentation *seg) const; + void DoSplittingSegments(Segmentation *seg) const; + + private: + const SegmentationPostProcessingOptions &opts_; + std::vector merge_labels_; + std::vector remove_labels_; + + void Check() const; +}; + +} // end namespace segmenter +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_SEGMENTATION_POST_PROCESSOR_H_ diff --git a/src/segmenter/segmentation-test.cc b/src/segmenter/segmentation-test.cc new file mode 100644 index 00000000000..7654b23b119 --- /dev/null +++ b/src/segmenter/segmentation-test.cc @@ -0,0 +1,226 @@ +// segmenter/segmentation-test.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmentation.h" + +namespace kaldi { +namespace segmenter { + +void GenerateRandomSegmentation(int32 max_length, int32 num_classes, + Segmentation *segmentation) { + Clear(); + int32 s = max_length; + int32 e = max_length; + + while (s >= 0) { + int32 chunk_size = rand() % (max_length / 10); + s = e - chunk_size + 1; + int32 k = rand() % num_classes; + + if (k != 0) { + segmentation.Emplace(s, e, k); + } + e = s - 1; + } + Check(); +} + + +int32 GenerateRandomAlignment(int32 max_length, int32 num_classes, + std::vector *ali) { + int32 N = RandInt(1, max_length); + int32 C = RandInt(1, num_classes); + + ali->clear(); + + int32 len = 0; + while (len < N) { + int32 c = RandInt(0, C-1); + int32 n = std::min(RandInt(1, N), N - len); + ali->insert(ali->begin() + len, n, c); + len += n; + } + KALDI_ASSERT(ali->size() == N && len == N); + + int32 state = -1, num_segments = 0; + for (std::vector::const_iterator it = ali->begin(); + it != ali->end(); ++it) { + if (*it != state) num_segments++; + state = *it; + } + + return num_segments; +} + +void TestConversionToAlignment() { + std::vector ali; + int32 max_length = 1000, num_classes = 3; + int32 num_segments = GenerateRandomAlignment(max_length, num_classes, &ali); + + Segmentation seg; + KALDI_ASSERT(num_segments == seg.InsertFromAlignment(ali, 0)); + + std::vector out_ali; + { + seg.ConvertToAlignment(&out_ali); + KALDI_ASSERT(ali == out_ali); + } + + { + seg.ConvertToAlignment(&out_ali, num_classes, max_length * 2); + std::vector tmp_ali(out_ali.begin(), out_ali.begin() + ali.size()); + KALDI_ASSERT(ali == tmp_ali); + for (std::vector::const_iterator it = out_ali.begin() + ali.size(); + it != out_ali.end(); ++it) { + KALDI_ASSERT(*it == num_classes); + } + } + + seg.Clear(); + KALDI_ASSERT(num_segments == seg.InsertFromAlignment(ali, max_length)); + { + seg.ConvertToAlignment(&out_ali, num_classes, max_length * 2); + + for (std::vector::const_iterator it = out_ali.begin(); + it != out_ali.begin() + max_length; ++it) { + KALDI_ASSERT(*it == num_classes); + } + std::vector tmp_ali(out_ali.begin() + max_length, out_ali.begin() + max_length + ali.size()); + KALDI_ASSERT(tmp_ali == ali); + + for (std::vector::const_iterator it = out_ali.begin() + max_length + ali.size(); + it != out_ali.end(); ++it) { + KALDI_ASSERT(*it == num_classes); + } + } +} + +void TestRemoveSegments() { + std::vector ali; + int32 max_length = 1000, num_classes = 10; + int32 num_segments = GenerateRandomAlignment(max_length, num_classes, &ali); + + Segmentation seg; + KALDI_ASSERT(num_segments == seg.InsertFromAlignment(ali, 0)); + + for (int32 i = 0; i < num_classes; i++) { + Segmentation out_seg(seg); + out_seg.RemoveSegments(i); + std::vector out_ali; + out_seg.ConvertToAlignment(&out_ali, i, ali.size()); + KALDI_ASSERT(ali == out_ali); + } + + { + std::vector classes; + for (int32 i = 0; i < 3; i++) + classes.push_back(RandInt(0, num_classes - 1)); + std::sort(classes.begin(), classes.end()); + + Segmentation out_seg1(seg); + out_seg1.RemoveSegments(classes); + + Segmentation out_seg2(seg); + for (std::vector::const_iterator it = classes.begin(); + it != classes.end(); ++it) + out_seg2.RemoveSegments(*it); + + std::vector out_ali1, out_ali2; + out_seg1.ConvertToAlignment(&out_ali1); + out_seg2.ConvertToAlignment(&out_ali2); + + KALDI_ASSERT(out_ali1 == out_ali2); + } +} + +void TestIntersectSegments() { + int32 max_length = 100, num_classes = 3; + + std::vector primary_ali; + GenerateRandomAlignment(max_length, num_classes, &primary_ali); + + std::vector secondary_ali; + GenerateRandomAlignment(max_length, num_classes, &secondary_ali); + + Segmentation primary_seg; + primary_seg.InsertFromAlignment(primary_ali); + + Segmentation secondary_seg; + secondary_seg.InsertFromAlignment(secondary_ali); + + { + Segmentation out_seg; + primary_seg.IntersectSegments(secondary_seg, &out_seg, num_classes); + + std::vector out_ali; + out_seg.ConvertToAlignment(&out_ali); + + std::vector oracle_ali(primary_ali.size()); + + for (size_t i = 0; i < oracle_ali.size(); i++) { + int32 p = (i < primary_ali.size()) ? primary_ali[i] : -1; + int32 s = (i < secondary_ali.size()) ? secondary_ali[i] : -2; + + oracle_ali[i] = (p == s) ? p : num_classes; + } + + KALDI_ASSERT(oracle_ali == out_ali); + } + + { + Segmentation out_seg; + primary_seg.IntersectSegments(secondary_seg, &out_seg); + + std::vector out_ali; + out_seg.ConvertToAlignment(&out_ali, num_classes); + + std::vector oracle_ali(out_ali.size()); + + for (size_t i = 0; i < oracle_ali.size(); i++) { + int32 p = (i < primary_ali.size()) ? primary_ali[i] : -1; + int32 s = (i < secondary_ali.size()) ? secondary_ali[i] : -2; + + oracle_ali[i] = (p == s) ? p : num_classes; + } + + KALDI_ASSERT(oracle_ali == out_ali); + } + +} + +void UnitTestSegmentation() { + TestConversionToAlignment(); + TestRemoveSegments(); + TestIntersectSegments(); +} + +} // namespace segmenter +} // namespace kaldi + +int main() { + using namespace kaldi; + using namespace kaldi::segmenter; + + for (int32 i = 0; i < 10; i++) + UnitTestSegmentation(); + return 0; +} + + + diff --git a/src/segmenter/segmentation-utils.cc b/src/segmenter/segmentation-utils.cc new file mode 100644 index 00000000000..3adc178d66d --- /dev/null +++ b/src/segmenter/segmentation-utils.cc @@ -0,0 +1,743 @@ +// segmenter/segmentation-utils.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmentation-utils.h" + +namespace kaldi { +namespace segmenter { + +void MergeLabels(const std::vector &merge_labels, + int32 dest_label, + Segmentation *segmentation) { + KALDI_ASSERT(segmentation); + + // Check if sorted and unique + KALDI_ASSERT(std::adjacent_find(merge_labels.begin(), + merge_labels.end(), std::greater()) + == merge_labels.end()); + + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + if (std::binary_search(merge_labels.begin(), merge_labels.end(), + it->Label())) { + it->SetLabel(dest_label); + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +void RelabelSegmentsUsingMap(const unordered_map &label_map, + Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + unordered_map::const_iterator map_it = label_map.find( + it->Label()); + if (map_it == label_map.end()) + KALDI_ERR << "Could not find label " << it->Label() << " in label map."; + + it->SetLabel(map_it->second); + } +} + +void RelabelAllSegments(int32 label, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) + it->SetLabel(label); +} + +void ScaleFrameShift(BaseFloat factor, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + it->start_frame *= factor; + it->end_frame *= factor; + } +} + +void RemoveSegments(int32 label, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (it->Label() == label) { + it = segmentation->Erase(it); + } else { + ++it; + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +void RemoveSegments(const std::vector &labels, + Segmentation *segmentation) { + // Check if sorted and unique + KALDI_ASSERT(std::adjacent_find(labels.begin(), + labels.end(), std::greater()) == labels.end()); + + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (std::binary_search(labels.begin(), labels.end(), it->Label())) { + it = segmentation->Erase(it); + } else { + ++it; + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +// Opposite of RemoveSegments() +void KeepSegments(int32 label, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (it->Label() != label) { + it = segmentation->Erase(it); + } else { + ++it; + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +// TODO(Vimal): Write test function for this. +void SplitInputSegmentation(const Segmentation &in_segmentation, + int32 segment_length, + Segmentation *out_segmentation) { + out_segmentation->Clear(); + for (SegmentList::const_iterator it = in_segmentation.Begin(); + it != in_segmentation.End(); ++it) { + int32 length = it->Length(); + + // Since ceil is used, this results in all pieces to be smaller than + // segment_length rather than being larger. + int32 num_chunks = std::ceil(static_cast(length) + / segment_length); + int32 actual_segment_length = static_cast(length) / num_chunks; + + int32 start_frame = it->start_frame; + for (int32 j = 0; j < num_chunks; j++) { + int32 end_frame = std::min(start_frame + actual_segment_length - 1, + it->end_frame); + out_segmentation->EmplaceBack(start_frame, end_frame, it->Label()); + start_frame = end_frame + 1; + } + } +#ifdef KALDI_PARANOID + out_segmentation->Check(); +#endif +} + +// TODO(Vimal): Write test function for this. +void SplitSegments(int32 segment_length, int32 min_remainder, + int32 overlap_length, int32 segment_label, + Segmentation *segmentation) { + KALDI_ASSERT(segmentation); + KALDI_ASSERT(segment_length > 0 && min_remainder > 0); + KALDI_ASSERT(overlap_length >= 0); + + KALDI_ASSERT(overlap_length < segment_length); + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + if (segment_label != -1 && it->Label() != segment_label) continue; + + int32 start_frame = it->start_frame; + int32 length = it->Length(); + + if (length > segment_length + min_remainder) { + // Split segment + // To show what this is doing, consider the following example, where it is + // currently pointing to B. + // A <--> B <--> C + + // Modify the start_frame of the current frame. This prepares the current + // segment to be used as the "next segment" when we move the iterator in + // the next statement. + // In the example, the start_frame for B has just been modified. + it->start_frame = start_frame + segment_length - overlap_length; + + // Create a new segment and add it to the where the current iterator is. + // The statement below results in this: + // A <--> B1 <--> B <--> C + // with the iterator it pointing at B1. So when the iterator is + // incremented in the for loop, it will point to B again, but whose + // start_frame had been modified. + it = segmentation->Emplace(it, start_frame, + start_frame + segment_length - 1, + it->Label()); + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +// TODO(Vimal): Write test code for this +void SplitSegmentsUsingAlignment(int32 segment_length, + int32 segment_label, + const std::vector &ali, + int32 ali_label, + int32 min_silence_length, + Segmentation *segmentation) { + KALDI_ASSERT(segmentation); + KALDI_ASSERT(segment_length > 0); + + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End();) { + // Safety check. In practice, should never fail. + KALDI_ASSERT(segmentation->Dim() <= ali.size()); + + if (segment_label != -1 && it->Label() != segment_label) { + ++it; + continue; + } + + int32 start_frame = it->start_frame; + int32 length = it->Length(); + int32 label = it->Label(); + + if (length <= segment_length) { + ++it; + continue; + } + + // Split segment + // To show what this is doing, consider the following example, where it is + // currently pointing to B. + // A <--> B <--> C + + Segmentation ali_segmentation; + InsertFromAlignment(ali, start_frame, + start_frame + length, + 0, &ali_segmentation, NULL); + KeepSegments(ali_label, &ali_segmentation); + MergeAdjacentSegments(0, &ali_segmentation); + + // Get largest alignment chunk where label == ali_label + SegmentList::iterator s_it = ali_segmentation.MaxElement(); + + if (s_it == ali_segmentation.End() || s_it->Length() < min_silence_length) { + ++it; + continue; + } + + KALDI_ASSERT(s_it->start_frame >= start_frame); + KALDI_ASSERT(s_it->end_frame <= start_frame + length); + + // Modify the start_frame of the current frame. This prepares the current + // segment to be used as the "next segment" when we move the iterator in + // the next statement. + // In the example, the start_frame for B has just been modified. + int32 end_frame; + if (s_it->Length() > 1) { + end_frame = s_it->start_frame + s_it->Length() / 2 - 2; + it->start_frame = end_frame + 2; + } else { + end_frame = s_it->start_frame - 1; + it->start_frame = s_it->end_frame + 1; + } + + // end_frame is within this current segment + KALDI_ASSERT(end_frame < start_frame + length); + // The first new segment length is smaller than the old segment length + KALDI_ASSERT(end_frame - start_frame + 1 < length); + + // The second new segment length is smaller than the old segment length + KALDI_ASSERT(it->end_frame - end_frame - 1 < length); + + if (it->Length() < 0) { + // This is possible when the beginning of the segment is silence + it = segmentation->Erase(it); + } + + // Create a new segment and add it to the where the current iterator is. + // The statement below results in this: + // A <--> B1 <--> B <--> C + // with the iterator it pointing at B1. + if (end_frame >= start_frame) { + it = segmentation->Emplace(it, start_frame, end_frame, label); + } + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +// TODO(Vimal): Write test code for this +void IntersectSegmentationAndAlignment(const Segmentation &in_segmentation, + const std::vector &alignment, + int32 ali_label, + int32 min_align_chunk_length, + Segmentation *out_segmentation) { + KALDI_ASSERT(out_segmentation); + + for (SegmentList::const_iterator it = in_segmentation.Begin(); + it != in_segmentation.End(); ++it) { + Segmentation filter_segmentation; + InsertFromAlignment(alignment, it->start_frame, + std::min(it->end_frame + 1, + static_cast(alignment.size())), + 0, &filter_segmentation, NULL); + + for (SegmentList::const_iterator f_it = filter_segmentation.Begin(); + f_it != filter_segmentation.End(); ++f_it) { + if (f_it->Length() < min_align_chunk_length) continue; + if (ali_label != -1 && f_it->Label() != ali_label) continue; + out_segmentation->EmplaceBack(f_it->start_frame, f_it->end_frame, + it->Label()); + } + } +} + +void SubSegmentUsingNonOverlappingSegments( + const Segmentation &primary_segmentation, + const Segmentation &secondary_segmentation, int32 secondary_label, + int32 subsegment_label, int32 unmatched_label, + Segmentation *out_segmentation) { + KALDI_ASSERT(out_segmentation); + KALDI_ASSERT(secondary_segmentation.Dim() > 0); + + std::vector alignment; + ConvertToAlignment(secondary_segmentation, -1, -1, 0, &alignment); + + for (SegmentList::const_iterator it = primary_segmentation.Begin(); + it != primary_segmentation.End(); ++it) { + if (it->end_frame >= alignment.size()) { + alignment.resize(it->end_frame + 1, -1); + } + Segmentation filter_segmentation; + InsertFromAlignment(alignment, it->start_frame, it->end_frame + 1, + 0, &filter_segmentation, NULL); + + for (SegmentList::const_iterator f_it = filter_segmentation.Begin(); + f_it != filter_segmentation.End(); ++f_it) { + int32 label = (unmatched_label > 0 ? unmatched_label : it->Label()); + if (f_it->Label() == secondary_label) { + if (subsegment_label >= 0) { + label = subsegment_label; + } else { + label = f_it->Label(); + } + } + out_segmentation->EmplaceBack(f_it->start_frame, f_it->end_frame, + label); + } + } +} + +// TODO(Vimal): Write test code for this +void MergeAdjacentSegments(int32 max_intersegment_length, + Segmentation *segmentation) { + SegmentList::iterator it = segmentation->Begin(), + prev_it = segmentation->Begin(); + + while (it != segmentation->End()) { + KALDI_ASSERT(it->start_frame >= prev_it->start_frame); + + if (it != segmentation->Begin() && + it->Label() == prev_it->Label() && + prev_it->end_frame + max_intersegment_length + 1 >= it->start_frame) { + // merge segments + if (prev_it->end_frame < it->end_frame) { + // If the previous segment end before the current segment, then + // extend the previous segment to the end_frame of the current + // segment and remove the current segment. + prev_it->end_frame = it->end_frame; + } // else simply remove the current segment. + it = segmentation->Erase(it); + } else { + // no merging of segments + prev_it = it; + ++it; + } + } + +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +void PadSegments(int32 label, int32 length, Segmentation *segmentation) { + KALDI_ASSERT(segmentation); + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + if (it->Label() != label) continue; + + it->start_frame -= length; + it->end_frame += length; + + if (it->start_frame < 0) it->start_frame = 0; + } +} + +void WidenSegments(int32 label, int32 length, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ++it) { + if (it->Label() == label) { + if (it != segmentation->Begin()) { + // it is not the beginning of the segmentation, so we can widen it on + // the start_frame side + SegmentList::iterator prev_it = it; + --prev_it; + it->start_frame -= length; + if (prev_it->Label() == label && it->start_frame < prev_it->end_frame) { + // After widening this segment, it overlaps the previous segment that + // also has the same class_id. Then turn this segment into a composite + // one + it->start_frame = prev_it->start_frame; + // and remove the previous segment from the list. + segmentation->Erase(prev_it); + } else if (prev_it->Label() != label && + it->start_frame < prev_it->end_frame) { + // Previous segment is not the same class_id, so we cannot turn this + // into a composite segment. + if (it->start_frame <= prev_it->start_frame) { + // The extended segment absorbs the previous segment into it + // So remove the previous segment + segmentation->Erase(prev_it); + } else { + // The extended segment reduces the length of the previous + // segment. But does not completely overlap it. + prev_it->end_frame -= length; + if (prev_it->end_frame < prev_it->start_frame) + segmentation->Erase(prev_it); + } + } + if (it->start_frame < 0) it->start_frame = 0; + } else { + it->start_frame -= length; + if (it->start_frame < 0) it->start_frame = 0; + } + + SegmentList::iterator next_it = it; + ++next_it; + + if (next_it != segmentation->End()) + // We do not know the length of the file. + // So we don't want to extend the last one. + it->end_frame += length; // Line (1) + } else { // if (it->Label() != label) + if (it != segmentation->Begin()) { + SegmentList::iterator prev_it = it; + --prev_it; + if (prev_it->end_frame >= it->end_frame) { + // The extended previous segment in Line (1) completely + // overlaps the current segment. So remove the current segment. + it = segmentation->Erase(it); + // So that we can increment in the for loop + --it; // TODO(Vimal): This is buggy. + } else if (prev_it->end_frame >= it->start_frame) { + // The extended previous segment in Line (1) reduces the length of + // this segment. + it->start_frame = prev_it->end_frame + 1; + } + } + } + } +} + +void ShrinkSegments(int32 label, int32 length, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (it->Label() == label) { + if (it->Length() <= 2 * length) { + it = segmentation->Erase(it); + } else { + it->start_frame += length; + it->end_frame -= length; + ++it; + } + } else { + ++it; + } + } + +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +void BlendShortSegmentsWithNeighbors(int32 label, int32 max_length, + int32 max_intersegment_length, + Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (it == segmentation->Begin()) { + // Can't blend the first segment + ++it; + continue; + } + + SegmentList::iterator next_it = it; + ++next_it; + + if (next_it == segmentation->End()) // End of segmentation + break; + + SegmentList::iterator prev_it = it; + --prev_it; + + // If the previous and current segments have different labels, + // then ensure that they are not overlapping + KALDI_ASSERT(it->start_frame >= prev_it->start_frame && + (prev_it->Label() == it->Label() || + prev_it->end_frame < it->start_frame)); + + KALDI_ASSERT(next_it->start_frame >= it->start_frame && + (it->Label() == next_it->Label() || + it->end_frame < next_it->start_frame)); + + if (next_it->Label() != prev_it->Label() || it->Label() != label || + it->Length() >= max_length || + next_it->start_frame - it->end_frame - 1 > max_intersegment_length || + it->start_frame - prev_it->end_frame - 1 > max_intersegment_length) { + ++it; + continue; + } + + prev_it->end_frame = next_it->end_frame; + segmentation->Erase(it); + it = segmentation->Erase(next_it); + } +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif +} + +bool ConvertToAlignment(const Segmentation &segmentation, + int32 default_label, int32 length, + int32 tolerance, + std::vector *alignment) { + KALDI_ASSERT(alignment); + alignment->clear(); + + if (length != -1) { + KALDI_ASSERT(length >= 0); + alignment->resize(length, default_label); + } + + SegmentList::const_iterator it = segmentation.Begin(); + for (; it != segmentation.End(); ++it) { + if (length != -1 && it->end_frame >= length + tolerance) { + KALDI_WARN << "End frame (" << it->end_frame << ") " + << ">= length (" << length + << ") + tolerance (" << tolerance << ")." + << "Conversion failed."; + return false; + } + + int32 end_frame = it->end_frame; + if (length == -1) { + alignment->resize(it->end_frame + 1, default_label); + } else { + if (it->end_frame >= length) + end_frame = length - 1; + } + + KALDI_ASSERT(end_frame < alignment->size()); + for (int32 i = it->start_frame; i <= end_frame; i++) { + (*alignment)[i] = it->Label(); + } + } + return true; +} + +int32 InsertFromAlignment(const std::vector &alignment, + int32 start, int32 end, + int32 start_time_offset, + Segmentation *segmentation, + std::vector *frame_counts_per_class) { + KALDI_ASSERT(segmentation); + + if (end <= start) return 0; // nothing to insert + + // Correct boundaries + if (end > alignment.size()) end = alignment.size(); + if (start < 0) start = 0; + + KALDI_ASSERT(end > start); // This is possible if end was originally + // greater than alignment.size(). + // The user must resize alignment appropriately + // before passing to this function. + + int32 num_segments = 0; + int32 state = -100, start_frame = -1; + for (int32 i = start; i < end; i++) { + KALDI_ASSERT(alignment[i] >= -1); + if (alignment[i] != state) { + // Change of state i.e. a different class id. + // So the previous segment has ended. + if (start_frame != -1) { + // start_frame == -1 in the beginning of the alignment. That is just + // initialization step and hence no creation of segment. + segmentation->EmplaceBack(start_frame + start_time_offset, + i-1 + start_time_offset, state); + num_segments++; + + if (frame_counts_per_class && state > 0) { + if (frame_counts_per_class->size() <= state) { + frame_counts_per_class->resize(state + 1, 0); + } + (*frame_counts_per_class)[state] += i - start_frame; + } + } + start_frame = i; + state = alignment[i]; + } + } + + KALDI_ASSERT(state >= -1 && start_frame >= 0 && start_frame < end); + segmentation->EmplaceBack(start_frame + start_time_offset, + end-1 + start_time_offset, state); + num_segments++; + if (frame_counts_per_class && state > 0) { + if (frame_counts_per_class->size() <= state) { + frame_counts_per_class->resize(state + 1, 0); + } + (*frame_counts_per_class)[state] += end - start_frame; + } + +#ifdef KALDI_PARANOID + segmentation->Check(); +#endif + + return num_segments; +} + +int32 InsertFromSegmentation( + const Segmentation &in_segmentation, int32 start_time_offset, + bool sort, + Segmentation *out_segmentation, + std::vector *frame_counts_per_class) { + KALDI_ASSERT(out_segmentation); + + if (in_segmentation.Dim() == 0) return 0; // nothing to insert + + int32 num_segments = 0; + + for (SegmentList::const_iterator it = in_segmentation.Begin(); + it != in_segmentation.End(); ++it) { + out_segmentation->EmplaceBack(it->start_frame + start_time_offset, + it->end_frame + start_time_offset, + it->Label()); + num_segments++; + if (frame_counts_per_class) { + if (frame_counts_per_class->size() <= it->Label()) { + frame_counts_per_class->resize(it->Label() + 1, 0); + } + (*frame_counts_per_class)[it->Label()] += it->Length(); + } + } + + if (sort) out_segmentation->Sort(); + +#ifdef KALDI_PARANOID + out_segmentation->Check(); +#endif + + return num_segments; +} + +void ExtendSegmentation(const Segmentation &in_segmentation, + bool sort, + Segmentation *segmentation) { + InsertFromSegmentation(in_segmentation, 0, sort, segmentation, NULL); +} + +bool GetClassCountsPerFrame( + const Segmentation &segmentation, + int32 length, int32 tolerance, + std::vector > *class_counts_per_frame) { + KALDI_ASSERT(class_counts_per_frame); + + if (length != -1) { + KALDI_ASSERT(length >= 0); + class_counts_per_frame->resize(length, std::map()); + } + + SegmentList::const_iterator it = segmentation.Begin(); + for (; it != segmentation.End(); ++it) { + if (length != -1 && it->end_frame >= length + tolerance) { + KALDI_WARN << "End frame (" << it->end_frame << ") " + << ">= length + tolerance (" << length + tolerance << ")." + << "Conversion failed."; + return false; + } + + int32 end_frame = it->end_frame; + if (length == -1) { + class_counts_per_frame->resize(it->end_frame + 1, + std::map()); + } else { + if (it->end_frame >= length) + end_frame = length - 1; + } + + KALDI_ASSERT(end_frame < class_counts_per_frame->size()); + for (int32 i = it->start_frame; i <= end_frame; i++) { + std::map &this_class_counts = (*class_counts_per_frame)[i]; + std::map::iterator c_it = this_class_counts.lower_bound( + it->Label()); + if (c_it == this_class_counts.end() || it->Label() < c_it->first) { + this_class_counts.insert(c_it, std::make_pair(it->Label(), 1)); + } else { + c_it->second++; + } + } + } + + return true; +} + +bool IsNonOverlapping(const Segmentation &segmentation) { + std::vector vec; + for (SegmentList::const_iterator it = segmentation.Begin(); + it != segmentation.End(); ++it) { + vec.resize(it->end_frame + 1, false); + for (int32 i = it->start_frame; i <= it->end_frame; i++) { + if (vec[i]) return false; + vec[i] = true; + } + } + return true; +} + +void Sort(Segmentation *segmentation) { + segmentation->Sort(); +} + +void TruncateToLength(int32 length, Segmentation *segmentation) { + for (SegmentList::iterator it = segmentation->Begin(); + it != segmentation->End(); ) { + if (it->start_frame >= length) { + it = segmentation->Erase(it); + continue; + } + + if (it->end_frame >= length) + it->end_frame = length - 1; + ++it; + } +} + +} // end namespace segmenter +} // end namespace kaldi diff --git a/src/segmenter/segmentation-utils.h b/src/segmenter/segmentation-utils.h new file mode 100644 index 00000000000..9401722ccb7 --- /dev/null +++ b/src/segmenter/segmentation-utils.h @@ -0,0 +1,337 @@ +// segmenter/segmentation-utils.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SEGMENTER_SEGMENTATION_UTILS_H_ +#define KALDI_SEGMENTER_SEGMENTATION_UTILS_H_ + +#include "segmenter/segmentation.h" + +namespace kaldi { +namespace segmenter { + +/** + * This function is very straight forward. It just merges the labels in + * merge_labels to the class-id dest_label. This means any segment that + * originally had the class-id as any of the labels in merge_labels would end + * up having the class-id dest_label. + **/ +void MergeLabels(const std::vector &merge_labels, + int32 dest_label, Segmentation *segmentation); + +// Relabel segments using a map from old to new label. +// If segment label is not found in the map, the function exits with +// an error. +void RelabelSegmentsUsingMap(const unordered_map &label_map, + Segmentation *segmentation); + +// Relabel all segments to class-id label +void RelabelAllSegments(int32 label, Segmentation *segmentation); + +// Scale frame shift by this factor. +// Usually frame length is 0.01 and frame shift 0.015. But sometimes +// the alignments are obtained using a subsampling factor of 3. This +// function can be used to maintain consistency among different +// alignments and segmentations. +void ScaleFrameShift(BaseFloat factor, Segmentation *segmentation); + +/** + * This is very straight forward. It removes all segments of label "label" +**/ +void RemoveSegments(int32 label, Segmentation *segmentation); + +/** + * This is very straight forward. It removes any segment whose label is + * contained in the vector "labels" +**/ +void RemoveSegments(const std::vector &labels, + Segmentation *segmentation); + +// Keep only segments of label "label" +void KeepSegments(int32 label, Segmentation *segmentation); + +/** + * This function splits an input segmentation in_segmentation into pieces of + * approximately segment_length. Each piece is given the same class id as the + * original segment. + * + * The way this function is written is that it first figures out the number of + * pieces that the segment must be broken into. Then it creates that many pieces + * of equal size (actual_segment_length). This mimics some of the approaches + * used at script level +**/ +void SplitInputSegmentation(const Segmentation &in_segmentation, + int32 segment_length, + Segmentation *out_segmentation); + +/** + * This function splits the segments in the the segmentation + * into pieces of segment_length. + * But if the last remaining piece is smaller than min_remainder, then the last + * piece is merged to the piece before it, resulting in a piece that is of + * length < segment_length + min_remainder. + * If overlap_length > 0, then the created pieces overlap by these many frames. + * If segment_label == -1, then all segments are split. + * Otherwise, only the segments with this label are split. + * + * The way this function works it is it looks at the current segment length and + * checks if it is larger than segment_length + min_remainder. If it is larger, + * then it must be split. To do this, it first modifies the start_frame of + * the current frame to start_frame + segment_length - overlap. + * It then creates a new segment of length segment_length from the original + * start_frame to start_frame + segment_length - 1 and adds it just before the + * current segment. So in the next iteration, we would actually be back to the + * same segment, but whose start_frame had just been modified. +**/ +void SplitSegments(int32 segment_length, + int32 min_remainder, int32 overlap_length, + int32 segment_label, + Segmentation *segmentation); + +/** + * Split this segmentation into pieces of size segment_length, + * but only if possible by creating split points at the + * middle of the chunk where alignment == ali_label and + * the chunk is at least min_segment_length frames long + * + * min_remainder, segment_label serve the same purpose as in the + * above SplitSegments function. +**/ +void SplitSegmentsUsingAlignment(int32 segment_length, + int32 segment_label, + const std::vector &alignment, + int32 alignment_label, + int32 min_align_chunk_length, + Segmentation *segmentation); + +/** + * This function is a standard intersection of the set of times represented by + * the segmentation in_segmentation and the set of times of where + * alignment contains ali_label for at least min_align_chunk_length + * consecutive frames +**/ +void IntersectSegmentationAndAlignment(const Segmentation &in_segmentation, + const std::vector &alignment, + int32 ali_label, + int32 min_align_chunk_length, + Segmentation *out_segmentation); + +/** + * This function is a little complicated in what it does. But this is required + * for one of the applications. + * This function creates a new segmentation by sub-segmenting an arbitrary + * "primary_segmentation" and assign new label "subsegment_label" to regions + * where the "primary_segmentation" intersects the non-overlapping + * "secondary_segmentation" segments with label "secondary_label". + * This is similar to the function "IntersectSegments", but instead of keeping + * only the filtered subsegments, all the subsegments are kept, while only + * changing the class_id of the filtered sub-segments. + * The label for the newly created subsegments is determined as follows: + * if secondary segment's label == secondary_label: + * if subsegment_label > 0: + * label = subsegment_label + * else: + * label = secondary_label + * else: + * if unmatched_label > 0: + * label = unmatched_label + * else: + * label = primary_label +**/ +void SubSegmentUsingNonOverlappingSegments( + const Segmentation &primary_segmentation, + const Segmentation &secondary_segmentation, int32 secondary_label, + int32 subsegment_label, int32 unmatched_label, + Segmentation *out_segmentation); + +/** + * This function is used to merge segments next to each other in the SegmentList + * and within a distance of max_intersegment_length frames from each other, + * provided the segments are of the same label. + * This function requires the segmentation to be sorted before passing it. + **/ +void MergeAdjacentSegments(int32 max_intersegment_length, + Segmentation *segmentation); + +/** + * This function is used to pad segments of label "label" by "length" + * frames on either side of the segment. + * This is useful to pad segments of speech. +**/ +void PadSegments(int32 label, int32 length, Segmentation *segmentation); + +/** + * This function is used to widen segments of label "label" by "length" + * frames on either side of the segment. + * This is similar to PadSegments, but while widening, it also reduces the + * length of the segment adjacent to it. + * This may not be required in some applications, but it is ok for speech / + * silence. By this process, we are calling frames within a "length" number of + * frames near the speech segment as speech and hence we reduce the width of the + * silence segment before it. +**/ +void WidenSegments(int32 label, int32 length, Segmentation *segmentation); + +/** + * This function is used to shrink segments of class_id "label" by "length" + * frames on either side of the segment. + * If the whole segment is smaller than 2*length, then the segment is + * removed entirely. +**/ +void ShrinkSegments(int32 label, int32 length, Segmentation *segmentation); + +/** + * This function blends segments of label "label" that are shorter than + * "max_length" frames, provided the segments before and after it are of the + * same label "other_label" and the distance to the neighbor is less than + * "max_intersegment_distance". + * After blending, the three segments have the same label "other_label" and + * hence can be merged into a composite segment. + * An example where this is useful is when there is a short segment of silence + * with speech segments on either sides. Then the short segment of silence is + * removed and called speech instead. The three continguous segments of speech + * are merged into a single composite segment. +**/ +void BlendShortSegmentsWithNeighbors(int32 label, int32 max_length, + int32 max_intersegment_distance, + Segmentation *segmentation); + +/** + * This function is used to convert the segmentation into frame-level alignment + * with the label for each frame begin the class_id of segment the frame belongs + * to. + * The arguments are used to provided extended functionality that are required + * for most cases. + * default_label : the label that is used as filler in regions where the frame + * is not in any of the segments. In most applications, certain + * segments are removed, such as the ones that are silence. Then + * the segments would not span the entire duration of the file. + * e.g. + * 10 35 1 + * 41 190 2 + * ... + * Here there is no segment from 36-40. These frames are + * filled with default_label. + * length : the number of frames required in the alignment. + * If set to -1, then this length is ignored. + * In most applications, the length of the alignment required is + * known. Usually it must match the length of the features + * (obtained using feat-to-len). Then the alignment is resized + * to this length and filled with default_label. The segments + * are then read and the frames corresponding to the segments + * are relabeled with the class_id of the respective segments. + * tolerance : the tolerance in number of frames that we allow for the + * frame index corresponding to the end_frame of the last + * segment. Applicable when length != -1. + * Since, we use 25 ms widows with 10 ms frame shift, + * it is possible that the features length is 2 frames less than + * the end of the last segment. So the user can set the + * tolerance to 2 in order to avoid returning with error in this + * function. + * Function returns true is successful. +**/ +bool ConvertToAlignment(const Segmentation &segmentation, + int32 default_label, int32 length, + int32 tolerance, + std::vector *alignment); + +/** + * Insert segments created from alignment starting from frame index "start" + * until and excluding frame index "end". + * The inserted segments are shifted by "start_time_offset". + * "start_time_offset" is useful when the "alignment" is per-utterance, in which + * case the start time of the utterance can be provided as the + * "start_time_offset" + * The function returns the number of segments created. + * If "frame_counts_per_class" is provided, then the number of frames per class + * is accumulated there. +**/ +int32 InsertFromAlignment(const std::vector &alignment, + int32 start, int32 end, + int32 start_time_offset, + Segmentation *segmentation, + std::vector *frame_counts_per_class = NULL); + +/** + * Insert segments from in_segmentation, but shift them by + * start_time offset. + * If sort is true, then the final segmentation is sorted. + * It is useful in some applications to set sort to false. + * Returns number of segments inserted. +**/ +int32 InsertFromSegmentation(const Segmentation &in_segmentation, + int32 start_time_offset, bool sort, + Segmentation *segmentation, + std::vector *frame_counts_per_class = NULL); + +/** + * Extend a segmentation by adding another one. + * If "sort" is set to true, then resultant segmentation would be sorted. + * If its known that the other segmentation must all be after this segmentation, + * then the user may set "sort" false. +**/ +void ExtendSegmentation(const Segmentation &in_segmentation, bool sort, + Segmentation *segmentation); + +/** + * This function is used to get per-frame count of number of classes. + * The output is in the format of a vector of maps. + * class_counts_per_frame: A pointer to a vector of maps use to get the output. + * The size of the vector is the number of frames. + * For each frame, there is a map from the "class_id" + * to the number of segments where the label the + * corresponding "class_id". + * The size of the map gives the number of unique + * labels in this frame e.g. number of speakers. + * The count for each "class_id" is the number + * of segments with that "class_id" at that frame. + * length : the number of frames required in the output. + * In most applications, this length is known. + * Usually it must match the length of the features (obtained + * using feat-to-len). Then the output is resized to this + * length. The map is empty for frames where no segments are + * seen. + * tolerance : the tolerance in number of frames that we allow for the + * frame index corresponding to the end_frame of the last + * segment. Since, we use 25 ms widows with 10 ms frame shift, + * it is possible that the features length is 2 frames less than + * the end of the last segment. So the user can set the + * tolerance to 2 in order to avoid returning an error in this + * function. + * Function returns true is successful. +**/ +bool GetClassCountsPerFrame( + const Segmentation &segmentation, + int32 length, int32 tolerance, + std::vector > *class_counts_per_frame); + +// Checks if segmentation is non-overlapping +bool IsNonOverlapping(const Segmentation &segmentation); + +// Sorts segments on start frame. +void Sort(Segmentation *segmentation); + +// Truncate segmentation to "length". +// Removes any segments with "start_time" >= "length" +// and truncates any segments with "end_time" >= "length" +void TruncateToLength(int32 length, Segmentation *segmentation); + +} // end namespace segmenter +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_SEGMENTATION_UTILS_H_ diff --git a/src/segmenter/segmentation.cc b/src/segmenter/segmentation.cc new file mode 100644 index 00000000000..fb83ed5476b --- /dev/null +++ b/src/segmenter/segmentation.cc @@ -0,0 +1,201 @@ +// segmenter/segmentation.cc + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "segmenter/segmentation.h" +#include + +namespace kaldi { +namespace segmenter { + +void Segmentation::PushBack(const Segment &seg) { + dim_++; + segments_.push_back(seg); +} + +SegmentList::iterator Segmentation::Insert(SegmentList::iterator it, + const Segment &seg) { + dim_++; + return segments_.insert(it, seg); +} + +void Segmentation::EmplaceBack(int32 start_frame, int32 end_frame, + int32 class_id) { + dim_++; + Segment seg(start_frame, end_frame, class_id); + segments_.push_back(seg); +} + +SegmentList::iterator Segmentation::Emplace(SegmentList::iterator it, + int32 start_frame, int32 end_frame, + int32 class_id) { + dim_++; + Segment seg(start_frame, end_frame, class_id); + return segments_.insert(it, seg); +} + +SegmentList::iterator Segmentation::Erase(SegmentList::iterator it) { + dim_--; + return segments_.erase(it); +} + +void Segmentation::Clear() { + segments_.clear(); + dim_ = 0; +} + +void Segmentation::Read(std::istream &is, bool binary) { + Clear(); + + if (binary) { + int32 sz = is.peek(); + if (sz == Segment::SizeInBytes()) { + is.get(); + } else { + KALDI_ERR << "Segmentation::Read: expected to see Segment of size " + << Segment::SizeInBytes() << ", saw instead " << sz + << ", at file position " << is.tellg(); + } + + int32 segmentssz; + is.read(reinterpret_cast(&segmentssz), sizeof(segmentssz)); + if (is.fail() || segmentssz < 0) + KALDI_ERR << "Segmentation::Read: read failure at file position " + << is.tellg(); + + for (int32 i = 0; i < segmentssz; i++) { + Segment seg; + seg.Read(is, binary); + segments_.push_back(seg); + } + dim_ = segmentssz; + } else { + if (int c = is.peek() != static_cast('[')) { + KALDI_ERR << "Segmentation::Read: expected to see [, saw " + << static_cast(c) << ", at file position " << is.tellg(); + } + is.get(); // consume the '[' + is >> std::ws; + while (is.peek() != static_cast(']')) { + KALDI_ASSERT(!is.eof()); + Segment seg; + seg.Read(is, binary); + segments_.push_back(seg); + dim_++; + is >> std::ws; + } + is.get(); + KALDI_ASSERT(!is.eof()); + } +#ifdef KALDI_PARANOID + Check(); +#endif +} + +void Segmentation::Write(std::ostream &os, bool binary) const { +#ifdef KALDI_PARANOID + Check(); +#endif + + SegmentList::const_iterator it = Begin(); + if (binary) { + char sz = Segment::SizeInBytes(); + os.write(&sz, 1); + + int32 segmentssz = static_cast(Dim()); + KALDI_ASSERT((size_t)segmentssz == Dim()); + + os.write(reinterpret_cast(&segmentssz), sizeof(segmentssz)); + + for (; it != End(); ++it) { + it->Write(os, binary); + } + } else { + os << "[ "; + for (; it != End(); ++it) { + it->Write(os, binary); + os << std::endl; + } + os << "]" << std::endl; + } +} + +void Segmentation::Check() const { + int32 dim = 0; + for (SegmentList::const_iterator it = Begin(); it != End(); ++it, dim++) { + KALDI_ASSERT(it->start_frame >= 0); + KALDI_ASSERT(it->end_frame >= 0); + KALDI_ASSERT(it->Label() >= 0); + } + KALDI_ASSERT(dim == dim_); +} + +void Segmentation::Sort() { + segments_.sort(SegmentComparator()); +} + +void Segmentation::SortByLength() { + segments_.sort(SegmentLengthComparator()); +} + +SegmentList::iterator Segmentation::MinElement() { + return std::min_element(segments_.begin(), segments_.end(), + SegmentLengthComparator()); +} + +SegmentList::iterator Segmentation::MaxElement() { + return std::max_element(segments_.begin(), segments_.end(), + SegmentLengthComparator()); +} + +Segmentation::Segmentation() { + Clear(); +} + + +void Segmentation::GenRandomSegmentation(int32 max_length, + int32 max_segment_length, + int32 num_classes) { + Clear(); + int32 st = 0; + int32 end = 0; + + while (st > max_length) { + int32 segment_length = RandInt(0, max_segment_length); + + end = st + segment_length - 1; + + // Choose random class id + int32 k = RandInt(-1, num_classes - 1); + + if (k >= 0) { + Segment seg(st, end, k); + segments_.push_back(seg); + dim_++; + } + + // Choose random shift i.e. the distance between two adjacent segments + int32 shift = RandInt(0, max_segment_length); + st = end + shift; + } + + Check(); +} + +} // namespace segmenter +} // namespace kaldi diff --git a/src/segmenter/segmentation.h b/src/segmenter/segmentation.h new file mode 100644 index 00000000000..aa408374751 --- /dev/null +++ b/src/segmenter/segmentation.h @@ -0,0 +1,144 @@ +// segmenter/segmentation.h + +// Copyright 2016 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_SEGMENTER_SEGMENTATION_H_ +#define KALDI_SEGMENTER_SEGMENTATION_H_ + +#include +#include "base/kaldi-common.h" +#include "matrix/kaldi-matrix.h" +#include "util/kaldi-table.h" +#include "segmenter/segment.h" + +namespace kaldi { +namespace segmenter { + +// Segments are stored as a doubly-linked-list. This could be changed later +// if needed. Hence defining a typedef SegmentList. +typedef std::list SegmentList; + +// Declare class +class SegmentationPostProcessor; + +/** + * The main class to store segmentation and do operations on it. The segments + * are stored in the structure SegmentList, which is currently a doubly-linked + * list. + * See the .cc file for details of implementation of the different functions. + * This file gives only a small description of the functions. +**/ + +class Segmentation { + public: + // Inserts the segment at the back of the list. + void PushBack(const Segment &seg); + + // Inserts the segment before the segment at the position specified by the + // iterator "it". + SegmentList::iterator Insert(SegmentList::iterator it, + const Segment &seg); + + // The following function is a wrapper to the + // emplace_back functionality of a STL list of Segments + // and inserts a new segment to the back of the list. + void EmplaceBack(int32 start_frame, int32 end_frame, int32 class_id); + + // The following function is a wrapper to the + // emplace functionality of a STL list of segments + // and inserts a segment at the position specified by the iterator "it". + // Returns an iterator to the inserted segment. + SegmentList::iterator Emplace(SegmentList::iterator it, + int32 start_frame, int32 end_frame, + int32 class_id); + + // Call erase operation on the SegmentList and returns the iterator pointing + // to the next segment in the SegmentList and also decrements dim_. + SegmentList::iterator Erase(SegmentList::iterator it); + + // Reset segmentation i.e. clear all values + void Clear(); + + // Read segmentation object from input stream + void Read(std::istream &is, bool binary); + + // Write segmentation object to output stream + void Write(std::ostream &os, bool binary) const; + + // Check if all segments have class_id >=0 and if dim_ matches the number of + // segments. + void Check() const; + + // Sort the segments on the start_frame + void Sort(); + + // Sort the segments on the length + void SortByLength(); + + // Returns an iterator to the smallest segment akin to std::min_element + SegmentList::iterator MinElement(); + + // Returns an iterator to the largest segment akin to std::max_element + SegmentList::iterator MaxElement(); + + // Generate a random segmentation for debugging purposes. + // Arguments: + // max_length: The maximum length of the random segmentation to be + // generated. + // max_segment_length: Maximum length of a segment in the segmentation + // num_classes: Maximum number of classes in the generated segmentation + void GenRandomSegmentation(int32 max_length, int32 max_segment_length, + int32 num_classes); + + // Public accessors + inline int32 Dim() const { return dim_; } + SegmentList::iterator Begin() { return segments_.begin(); } + SegmentList::const_iterator Begin() const { return segments_.begin(); } + SegmentList::iterator End() { return segments_.end(); } + SegmentList::const_iterator End() const { return segments_.end(); } + + Segment& Back() { return segments_.back(); } + const Segment& Back() const { return segments_.back(); } + + const SegmentList* Data() const { return &segments_; } + + // Default constructor + Segmentation(); + + private: + // number of segments in the segmentation + int32 dim_; + + // list of segments in the segmentation + SegmentList segments_; + + friend class SegmentationPostProcessor; +}; + +typedef TableWriter > SegmentationWriter; +typedef SequentialTableReader > + SequentialSegmentationReader; +typedef RandomAccessTableReader > + RandomAccessSegmentationReader; +typedef RandomAccessTableReaderMapped > + RandomAccessSegmentationReaderMapped; + +} // end namespace segmenter +} // end namespace kaldi + +#endif // KALDI_SEGMENTER_SEGMENTATION_H_ diff --git a/src/segmenterbin/Makefile b/src/segmenterbin/Makefile new file mode 100644 index 00000000000..1f0efe71181 --- /dev/null +++ b/src/segmenterbin/Makefile @@ -0,0 +1,36 @@ + +all: + +EXTRA_CXXFLAGS = -Wno-sign-compare +include ../kaldi.mk + +BINFILES = segmentation-copy segmentation-get-stats \ + segmentation-init-from-ali segmentation-to-ali \ + segmentation-init-from-segments segmentation-to-segments \ + segmentation-combine-segments segmentation-merge-recordings \ + segmentation-create-subsegments segmentation-intersect-ali \ + segmentation-to-rttm segmentation-post-process \ + segmentation-merge segmentation-split-segments \ + segmentation-remove-segments \ + segmentation-init-from-lengths \ + segmentation-combine-segments-to-recordings \ + segmentation-create-overlapped-subsegments \ + segmentation-intersect-segments \ + segmentation-init-from-additive-signals-info #\ + gmm-acc-pdf-stats-segmentation \ + gmm-est-segmentation gmm-update-segmentation \ + segmentation-init-from-diarization \ + segmentation-compute-class-ctm-conf \ + combine-vector-segments + +OBJFILES = + + + +TESTFILES = + +ADDLIBS = ../hmm/kaldi-hmm.a ../gmm/kaldi-gmm.a ../segmenter/kaldi-segmenter.a ../tree/kaldi-tree.a \ + ../util/kaldi-util.a ../matrix/kaldi-matrix.a ../base/kaldi-base.a ../thread/kaldi-thread.a + +include ../makefiles/default_rules.mk + diff --git a/src/segmenterbin/segmentation-combine-segments-to-recordings.cc b/src/segmenterbin/segmentation-combine-segments-to-recordings.cc new file mode 100644 index 00000000000..acf71265577 --- /dev/null +++ b/src/segmenterbin/segmentation-combine-segments-to-recordings.cc @@ -0,0 +1,114 @@ +// segmenterbin/segmentation-combine-segments-to-recordings.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Combine kaldi segments in segmentation format to " + "recording-level segmentation\n" + "A reco2utt file is used to specify which utterances are contained " + "in a recording.\n" + "This program expects the input segmentation to be a kaldi segment " + "converted to segmentation using segmentation-init-from-segments. " + "For other segmentations, the user can use the binary " + "segmentation-combine-segments instead.\n" + "\n" + "Usage: segmentation-combine-segments-to-recording [options] " + " " + "\n" + " e.g.: segmentation-combine-segments-to-recording \\\n" + "'ark:segmentation-init-from-segments --shift-to-zero=false " + "data/dev/segments ark:- |' ark,t:data/dev/reco2utt ark:file.seg\n" + "See also: segmentation-combine-segments, " + "segmentation-merge, segmentation-merge-recordings, " + "segmentation-post-process --merge-adjacent-segments\n"; + + ParseOptions po(usage); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1), + reco2utt_rspecifier = po.GetArg(2), + segmentation_wspecifier = po.GetArg(3); + + SequentialTokenVectorReader reco2utt_reader(reco2utt_rspecifier); + RandomAccessSegmentationReader segmentation_reader( + segmentation_rspecifier); + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_done = 0, num_segmentations = 0, num_err = 0; + + for (; !reco2utt_reader.Done(); reco2utt_reader.Next()) { + const std::vector &utts = reco2utt_reader.Value(); + const std::string &reco_id = reco2utt_reader.Key(); + + Segmentation out_segmentation; + + for (std::vector::const_iterator it = utts.begin(); + it != utts.end(); ++it) { + if (!segmentation_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in " + << "segments segmentation " + << segmentation_rspecifier; + num_err++; + continue; + } + + const Segmentation &segmentation = segmentation_reader.Value(*it); + if (segmentation.Dim() != 1) { + KALDI_ERR << "Segments segmentation for utt " << *it << " is not " + << "kaldi segment converted to segmentation format " + << "in " << segmentation_rspecifier; + } + const Segment &segment = *(segmentation.Begin()); + + out_segmentation.PushBack(segment); + + num_done++; + } + + Sort(&out_segmentation); + segmentation_writer.Write(reco_id, out_segmentation); + num_segmentations++; + } + + KALDI_LOG << "Combined " << num_done << " utterance-level segments " + << "into " << num_segmentations + << " recording-level segmentations; failed with " + << num_err << " utterances."; + + return ((num_done > 0 && num_err < num_done) ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-combine-segments.cc b/src/segmenterbin/segmentation-combine-segments.cc new file mode 100644 index 00000000000..7034a8a1734 --- /dev/null +++ b/src/segmenterbin/segmentation-combine-segments.cc @@ -0,0 +1,128 @@ +// segmenterbin/segmentation-combine-segments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Combine utterance-level segmentations in an archive to " + "recording-level segmentations using the kaldi segments to map " + "utterances to their positions in the recordings.\n" + "A reco2utt file is used to specify which utterances belong to each " + "recording.\n" + "\n" + "Usage: segmentation-combine-segments [options] " + " " + " " + " \n" + " e.g.: segmentation-combine-segments ark:utt.seg " + "'ark:segmentation-init-from-segments --shift-to-zero=false " + "data/dev/segments ark:- |' ark,t:data/dev/reco2utt ark:file.seg\n" + "See also: segmentation-combine-segments-to-recording, " + "segmentation-merge, segmentatin-merge-recordings, " + "segmentation-post-process --merge-adjacent-segments\n"; + + ParseOptions po(usage); + + po.Read(argc, argv); + + if (po.NumArgs() != 4) { + po.PrintUsage(); + exit(1); + } + + std::string utt_segmentation_rspecifier = po.GetArg(1), + segments_segmentation_rspecifier = po.GetArg(2), + reco2utt_rspecifier = po.GetArg(3), + segmentation_wspecifier = po.GetArg(4); + + SequentialTokenVectorReader reco2utt_reader(reco2utt_rspecifier); + RandomAccessSegmentationReader segments_segmentation_reader( + segments_segmentation_rspecifier); + RandomAccessSegmentationReader utt_segmentation_reader( + utt_segmentation_rspecifier); + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_done = 0, num_segmentations = 0, num_err = 0; + int64 num_segments = 0; + + for (; !reco2utt_reader.Done(); reco2utt_reader.Next()) { + const std::vector &utts = reco2utt_reader.Value(); + const std::string &reco_id = reco2utt_reader.Key(); + + Segmentation out_segmentation; + + for (std::vector::const_iterator it = utts.begin(); + it != utts.end(); ++it) { + if (!segments_segmentation_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in " + << "segments segmentation " + << segments_segmentation_rspecifier; + num_err++; + continue; + } + + const Segmentation &segments_segmentation = + segments_segmentation_reader.Value(*it); + if (segments_segmentation.Dim() != 1) { + KALDI_ERR << "Segments segmentation for utt " << *it << " is not " + << "kaldi segment converted to segmentation format " + << "in " << segments_segmentation_rspecifier; + } + const Segment &segment = *(segments_segmentation.Begin()); + + if (!utt_segmentation_reader.HasKey(*it)) { + KALDI_WARN << "Could not find utterance " << *it << " in " + << "segmentation " << utt_segmentation_rspecifier; + num_err++; + continue; + } + const Segmentation &utt_segmentation + = utt_segmentation_reader.Value(*it); + + num_segments += InsertFromSegmentation(utt_segmentation, + segment.start_frame, false, + &out_segmentation, NULL); + num_done++; + } + + Sort(&out_segmentation); + segmentation_writer.Write(reco_id, out_segmentation); + num_segmentations++; + } + + KALDI_LOG << "Combined " << num_done << " utterance-level segmentations " + << "into " << num_segmentations + << " recording-level segmentations; failed with " + << num_err << " utterances; " + << "wrote a total of " << num_segments << " segments."; + + return ((num_done > 0 && num_err < num_done) ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-copy.cc b/src/segmenterbin/segmentation-copy.cc new file mode 100644 index 00000000000..26d0f47682d --- /dev/null +++ b/src/segmenterbin/segmentation-copy.cc @@ -0,0 +1,232 @@ +// segmenterbin/segmentation-copy.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Copy segmentation or archives of segmentation.\n" + "If label-map is supplied, then apply the mapping to the labels \n" + "when copying.\n" + "If utt2label-rspecifier is supplied, then ignore the \n" + "original labels, and map all the segments of an utterance using \n" + "the supplied utt2label map.\n" + "\n" + "Usage: segmentation-copy [options] " + "\n" + " e.g.: segmentation-copy ark:1.seg ark,t:-\n" + " or \n" + " segmentation-copy [options] " + "\n" + " e.g.: segmentation-copy --binary=false foo -\n"; + + bool binary = true; + std::string label_map_rxfilename, utt2label_rspecifier; + std::string include_rxfilename, exclude_rxfilename; + int32 keep_label = -1; + BaseFloat frame_subsampling_factor = 1; + + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("label-map", &label_map_rxfilename, + "File with mapping from old to new labels"); + po.Register("frame-subsampling-factor", &frame_subsampling_factor, + "Change frame rate by this factor"); + po.Register("utt2label-rspecifier", &utt2label_rspecifier, + "Mapping for each utterance to an integer label"); + po.Register("keep-label", &keep_label, + "If supplied, only segments of this label are written out"); + po.Register("include", &include_rxfilename, + "Text file, the first field of each" + " line being interpreted as an " + "utterance-id whose features will be included"); + po.Register("exclude", &exclude_rxfilename, + "Text file, the first field of each " + "line being interpreted as an utterance-id" + " whose features will be excluded"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + // all these "fn"'s are either rspecifiers or filenames. + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(2); + + // Read mapping from old to new labels + unordered_map label_map; + if (!label_map_rxfilename.empty()) { + Input ki(label_map_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector splits; + SplitStringToVector(line, " ", true, &splits); + + if (splits.size() != 2) + KALDI_ERR << "Invalid format of line " << line + << " in " << label_map_rxfilename; + + label_map[std::atoi(splits[0].c_str())] = std::atoi(splits[1].c_str()); + } + } + + unordered_set include_set; + if (include_rxfilename != "") { + if (exclude_rxfilename != "") { + KALDI_ERR << "should not have both --exclude and --include option!"; + } + Input ki(include_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + SplitStringToVector(line, " \t\r", true, &split_line); + KALDI_ASSERT(!split_line.empty() && + "Empty line encountered in input from --include option"); + include_set.insert(split_line[0]); + } + } + + unordered_set exclude_set; + if (exclude_rxfilename != "") { + if (include_rxfilename != "") { + KALDI_ERR << "should not have both --exclude and --include option!"; + } + Input ki(exclude_rxfilename); + std::string line; + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + SplitStringToVector(line, " \t\r", true, &split_line); + KALDI_ASSERT(!split_line.empty() && + "Empty line encountered in input from --exclude option"); + exclude_set.insert(split_line[0]); + } + } + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + + if (!label_map_rxfilename.empty()) + RelabelSegmentsUsingMap(label_map, &segmentation); + + if (keep_label != -1) + KeepSegments(keep_label, &segmentation); + + if (frame_subsampling_factor != 1.0) { + ScaleFrameShift(frame_subsampling_factor, &segmentation); + } + + if (!utt2label_rspecifier.empty()) + KALDI_ERR << "It makes no sense to specify utt2label-rspecifier " + << "when not reading segmentation archives."; + + Output ko(segmentation_out_fn, binary); + segmentation.Write(ko.Stream(), binary); + + KALDI_LOG << "Copied segmentation to " << segmentation_out_fn; + return 0; + } else { + RandomAccessInt32Reader utt2label_reader(utt2label_rspecifier); + + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + + for (; !reader.Done(); reader.Next()) { + const std::string &key = reader.Key(); + + if (include_rxfilename != "" && include_set.count(key) == 0) { + continue; + } + + if (exclude_rxfilename != "" && include_set.count(key) > 0) { + continue; + } + + if (label_map_rxfilename.empty() && + frame_subsampling_factor == 1.0 && + utt2label_rspecifier.empty() && + keep_label == -1) { + writer.Write(key, reader.Value()); + } else { + Segmentation segmentation = reader.Value(); + if (!label_map_rxfilename.empty()) + RelabelSegmentsUsingMap(label_map, &segmentation); + if (!utt2label_rspecifier.empty()) { + if (!utt2label_reader.HasKey(key)) { + KALDI_WARN << "Utterance " << key + << " not found in utt2label map " + << utt2label_rspecifier; + num_err++; + continue; + } + + RelabelAllSegments(utt2label_reader.Value(key), &segmentation); + } + if (keep_label != -1) + KeepSegments(keep_label, &segmentation); + + if (frame_subsampling_factor != 1.0) + ScaleFrameShift(frame_subsampling_factor, &segmentation); + + writer.Write(key, segmentation); + } + + num_done++; + } + + KALDI_LOG << "Copied " << num_done << " segmentation; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-create-subsegments.cc b/src/segmenterbin/segmentation-create-subsegments.cc new file mode 100644 index 00000000000..9d7f4c08b6d --- /dev/null +++ b/src/segmenterbin/segmentation-create-subsegments.cc @@ -0,0 +1,175 @@ +// segmenterbin/segmentation-create-subsegments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Create sub-segmentation of a segmentation by intersecting with " + "segments from a 'filter' segmentation. \n" + "The labels for the new subsegments are decided " + "depending on whether the label of 'filter' segment " + "matches the specified 'filter_label' or not:\n" + " if filter segment's label == filter_label: \n" + " if subsegment_label is specified:\n" + " label = subsegment_label\n" + " else: \n" + " label = filter_label \n" + " else: \n" + " if unmatched_label is specified:\n" + " label = unmatched_label\n" + " else\n:" + " label = primary_label\n" + "See the function SubSegmentUsingNonOverlappingSegments() " + "for more details.\n" + "\n" + "Usage: segmentation-create-subsegments [options] " + " " + " \n" + " or : segmentation-create-subsegments [options] " + " " + " \n" + " e.g.: segmentation-create-subsegments --binary=false " + "--filter-label=1 --subsegment-label=1000 foo bar -\n" + " segmentation-create-subsegments --filter-label=1 " + "--subsegment-label=1000 ark:1.foo ark:1.bar ark:-\n"; + + bool binary = true, ignore_missing = false; + int32 filter_label = -1, subsegment_label = -1, unmatched_label = -1; + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("filter-label", &filter_label, + "The label on which filtering is done."); + po.Register("subsegment-label", &subsegment_label, + "If non-negative, change the class-id of the matched regions " + "in the intersection of the two segmentations to this label."); + po.Register("unmatched-label", &unmatched_label, + "If non-negative, change the class-id of the unmatched " + "regions in the intersection of the two segmentations " + "to this label."); + po.Register("ignore-missing", &ignore_missing, "Ignore missing " + "segmentations in filter. If this is set true, then the " + "segmentations with missing key in filter are written " + "without any modification."); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + secondary_segmentation_in_fn = po.GetArg(2), + segmentation_out_fn = po.GetArg(3); + + // all these "fn"'s are either rspecifiers or filenames. + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + filter_is_rspecifier = + (ClassifyRspecifier(secondary_segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier || + in_is_rspecifier != filter_is_rspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + Segmentation secondary_segmentation; + { + bool binary_in; + Input ki(secondary_segmentation_in_fn, &binary_in); + secondary_segmentation.Read(ki.Stream(), binary_in); + } + + Segmentation new_segmentation; + SubSegmentUsingNonOverlappingSegments( + segmentation, secondary_segmentation, filter_label, subsegment_label, + unmatched_label, &new_segmentation); + Output ko(segmentation_out_fn, binary); + new_segmentation.Write(ko.Stream(), binary); + + KALDI_LOG << "Created subsegments of " << segmentation_in_fn + << " based on " << secondary_segmentation_in_fn + << " and wrote to " << segmentation_out_fn; + return 0; + } else { + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + RandomAccessSegmentationReader filter_reader( + secondary_segmentation_in_fn); + + for (; !reader.Done(); reader.Next(), num_done++) { + const Segmentation &segmentation = reader.Value(); + const std::string &key = reader.Key(); + + if (!filter_reader.HasKey(key)) { + KALDI_WARN << "Could not find filter segmentation for utterance " + << key; + if (!ignore_missing) + num_err++; + else + writer.Write(key, segmentation); + continue; + } + const Segmentation &secondary_segmentation = filter_reader.Value(key); + + Segmentation new_segmentation; + SubSegmentUsingNonOverlappingSegments(segmentation, + secondary_segmentation, + filter_label, subsegment_label, + unmatched_label, + &new_segmentation); + + writer.Write(key, new_segmentation); + } + + KALDI_LOG << "Created subsegments for " << num_done << " segmentations; " + << "failed with " << num_err << " segmentations"; + + return ((num_done != 0 && num_err < num_done) ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-get-stats.cc b/src/segmenterbin/segmentation-get-stats.cc new file mode 100644 index 00000000000..b25d6913f06 --- /dev/null +++ b/src/segmenterbin/segmentation-get-stats.cc @@ -0,0 +1,125 @@ +// segmenterbin/segmentation-get-per-frame-stats.cc + +// Copyright 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Get per-frame stats from segmentation. \n" + "Currently supported stats are \n" + " num-overlaps: Number of overlapping segments common to this frame\n" + " num-classes: Number of distinct classes common to this frame\n" + "\n" + "Usage: segmentation-get-stats [options] " + " \n" + " e.g.: segmentation-get-stats ark:1.seg ark:/dev/null " + "ark:num_classes.ark\n"; + + ParseOptions po(usage); + + std::string lengths_rspecifier; + int32 length_tolerance = 2; + + po.Register("lengths-rspecifier", &lengths_rspecifier, + "Archive of frame lengths of the utterances. " + "Fills up any extra length with zero stats."); + po.Register("length-tolerance", &length_tolerance, + "Tolerate shortage of this many frames in the specified " + "lengths file"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1), + num_overlaps_wspecifier = po.GetArg(2), + num_classes_wspecifier = po.GetArg(3); + + int64 num_done = 0, num_err = 0; + + SequentialSegmentationReader reader(segmentation_rspecifier); + Int32VectorWriter num_overlaps_writer(num_overlaps_wspecifier); + Int32VectorWriter num_classes_writer(num_classes_wspecifier); + + RandomAccessInt32Reader lengths_reader(lengths_rspecifier); + + for (; !reader.Done(); reader.Next(), num_done++) { + const Segmentation &segmentation = reader.Value(); + const std::string &key = reader.Key(); + + int32 length = -1; + if (!lengths_rspecifier.empty()) { + if (!lengths_reader.HasKey(key)) { + KALDI_WARN << "Could not find length for key " << key; + num_err++; + continue; + } + length = lengths_reader.Value(key); + } + + std::vector > class_counts_per_frame; + if (!GetClassCountsPerFrame(segmentation, length, + length_tolerance, + &class_counts_per_frame)) { + KALDI_WARN << "Failed getting stats for key " << key; + num_err++; + continue; + } + + if (length == -1) + length = class_counts_per_frame.size(); + + std::vector num_classes_per_frame(length, 0); + std::vector num_overlaps_per_frame(length, 0); + + for (int32 i = 0; i < class_counts_per_frame.size(); i++) { + std::map &class_counts = class_counts_per_frame[i]; + + for (std::map::const_iterator it = class_counts.begin(); + it != class_counts.end(); ++it) { + if (it->second > 0) + num_classes_per_frame[i]++; + num_overlaps_per_frame[i] += it->second; + } + } + + num_classes_writer.Write(key, num_classes_per_frame); + num_overlaps_writer.Write(key, num_overlaps_per_frame); + + num_done++; + } + + KALDI_LOG << "Got stats for " << num_done << " segmentations; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-init-from-ali.cc b/src/segmenterbin/segmentation-init-from-ali.cc new file mode 100644 index 00000000000..a98a54368c9 --- /dev/null +++ b/src/segmenterbin/segmentation-init-from-ali.cc @@ -0,0 +1,91 @@ +// segmenterbin/segmentation-init-from-ali.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Initialize utterance-level segmentations from alignments file. \n" + "The user can pass this to segmentation-combine-segments to " + "create recording-level segmentations." + "\n" + "Usage: segmentation-init-from-ali [options] " + " \n" + " e.g.: segmentation-init-from-ali ark:1.ali ark:-\n" + "See also: segmentation-init-from-segments, " + "segmentation-combine-segments\n"; + + ParseOptions po(usage); + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string ali_rspecifier = po.GetArg(1), + segmentation_wspecifier = po.GetArg(2); + + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_done = 0, num_segmentations = 0; + int64 num_segments = 0; + int64 num_err = 0; + + std::vector frame_counts_per_class; + + SequentialInt32VectorReader alignment_reader(ali_rspecifier); + + for (; !alignment_reader.Done(); alignment_reader.Next()) { + const std::string &key = alignment_reader.Key(); + const std::vector &alignment = alignment_reader.Value(); + + Segmentation segmentation; + + num_segments += InsertFromAlignment(alignment, 0, alignment.size(), + 0, &segmentation, + &frame_counts_per_class); + + Sort(&segmentation); + segmentation_writer.Write(key, segmentation); + + num_done++; + num_segmentations++; + } + + KALDI_LOG << "Processed " << num_done << " utterances; failed with " + << num_err << " utterances; " + << "wrote " << num_segmentations << " segmentations " + << "with a total of " << num_segments << " segments."; + KALDI_LOG << "Number of frames for the different classes are : "; + WriteIntegerVector(KALDI_LOG, false, frame_counts_per_class); + + return ((num_done > 0 && num_err < num_done) ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-init-from-lengths.cc b/src/segmenterbin/segmentation-init-from-lengths.cc new file mode 100644 index 00000000000..28c998c220b --- /dev/null +++ b/src/segmenterbin/segmentation-init-from-lengths.cc @@ -0,0 +1,82 @@ +// segmenterbin/segmentation-init-from-lengths.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Initialize segmentations from frame lengths file\n" + "\n" + "Usage: segmentation-init-from-lengths [options] " + " \n" + " e.g.: segmentation-init-from-lengths " + "\"ark:feat-to-len scp:feats.scp ark:- |\" ark:-\n" + "\n" + "See also: segmentation-init-from-ali, " + "segmentation-init-from-segments\n"; + + int32 label = 1; + + ParseOptions po(usage); + + po.Register("label", &label, "Label to assign to the created segments"); + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string lengths_rspecifier = po.GetArg(1), + segmentation_wspecifier = po.GetArg(2); + + SequentialInt32Reader lengths_reader(lengths_rspecifier); + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_done = 0; + + for (; !lengths_reader.Done(); lengths_reader.Next()) { + const std::string &key = lengths_reader.Key(); + const int32 &length = lengths_reader.Value(); + + Segmentation segmentation; + + if (length > 0) { + segmentation.EmplaceBack(0, length - 1, label); + } + + segmentation_writer.Write(key, segmentation); + num_done++; + } + + KALDI_LOG << "Created " << num_done << " segmentations."; + + return (num_done > 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-init-from-segments.cc b/src/segmenterbin/segmentation-init-from-segments.cc new file mode 100644 index 00000000000..c39996b5ef4 --- /dev/null +++ b/src/segmenterbin/segmentation-init-from-segments.cc @@ -0,0 +1,179 @@ +// segmenterbin/segmentation-init-from-segments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation.h" + +// If segments file contains +// Alpha-001 Alpha 0.00 0.16 +// Alpha-002 Alpha 1.50 4.10 +// Beta-001 Beta 0.50 2.66 +// Beta-002 Beta 3.50 5.20 +// the output segmentation will contain +// Alpha-001 [ 0 16 1 ] +// Alpha-002 [ 0 360 1 ] +// Beta-001 [ 0 216 1 ] +// Beta-002 [ 0 170 1 ] +// If --shift-to-zero=false is provided, then the output will contain +// Alpha-001 [ 0 16 1 ] +// Alpha-002 [ 150 410 1 ] +// Beta-001 [ 50 266 1 ] +// Beta-002 [ 350 520 1 ] +// +// If the following utt2label-rspecifier was provided: +// Alpha-001 2 +// Alpha-002 2 +// Beta-001 4 +// Beta-002 4 +// then the output segmentation will contain +// Alpha-001 [ 0 16 2 ] +// Alpha-002 [ 0 360 2 ] +// Beta-001 [ 0 216 4 ] +// Beta-002 [ 0 170 4 ] + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert segments from segments file into utterance-level " + "segmentation format. \n" + "The user can convert the segmenation to recording-level using " + "the binary segmentation-combine-segments-to-recording.\n" + "\n" + "Usage: segmentation-init-from-segments [options] " + " \n" + " e.g.: segmentation-init-from-segments segments ark:-\n"; + + int32 segment_label = 1; + BaseFloat frame_shift = 0.01, frame_overlap = 0.015; + std::string utt2label_rspecifier; + bool shift_to_zero = true; + + ParseOptions po(usage); + + po.Register("segment-label", &segment_label, + "Label for all the segments in the segmentations"); + po.Register("utt2label-rspecifier", &utt2label_rspecifier, + "Mapping for each utterance to an integer label. " + "If supplied, these labels will be used as the segment " + "labels"); + po.Register("shift-to-zero", &shift_to_zero, + "Shift all segments to 0th frame"); + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + po.Register("frame-overlap", &frame_overlap, "Frame overlap in seconds"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segments_rxfilename = po.GetArg(1), + segmentation_wspecifier = po.GetArg(2); + + SegmentationWriter writer(segmentation_wspecifier); + RandomAccessInt32Reader utt2label_reader(utt2label_rspecifier); + + Input ki(segments_rxfilename); + + int64 num_lines = 0, num_done = 0; + + std::string line; + + while (std::getline(ki.Stream(), line)) { + num_lines++; + + std::vector split_line; + // Split the line by space or tab and check the number of fields in each + // line. There must be 4 fields--segment name , reacording wav file name, + // start time, end time; 5th field (channel info) is optional. + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 4 && split_line.size() != 5) { + KALDI_WARN << "Invalid line in segments file: " << line; + continue; + } + std::string utt = split_line[0], + reco = split_line[1], + start_str = split_line[2], + end_str = split_line[3]; + + // Convert the start time and endtime to real from string. Segment is + // ignored if start or end time cannot be converted to real. + double start, end; + if (!ConvertStringToReal(start_str, &start)) { + KALDI_WARN << "Invalid line in segments file [bad start]: " << line; + continue; + } + if (!ConvertStringToReal(end_str, &end)) { + KALDI_WARN << "Invalid line in segments file [bad end]: " << line; + continue; + } + + // start time must not be negative; start time must not be greater than + // end time, except if end time is -1 + if (start < 0 || (end != -1.0 && end <= 0) || + ((start >= end) && (end > 0))) { + KALDI_WARN << "Invalid line in segments file " + << "[empty or invalid segment]: " << line; + continue; + } + + if (split_line.size() >= 5) + KALDI_ERR << "Not supporting channel in segments file"; + + Segmentation segmentation; + + if (!utt2label_rspecifier.empty()) { + if (!utt2label_reader.HasKey(utt)) { + KALDI_WARN << "Could not find utterance " << utt << " in " + << utt2label_rspecifier; + continue; + } + + segment_label = utt2label_reader.Value(utt); + } + + int32 length = round((end - frame_overlap)/ frame_shift) + - round(start / frame_shift); + + if (shift_to_zero) + segmentation.EmplaceBack(0, length, segment_label); + else + segmentation.EmplaceBack(round(start / frame_shift), + round((end-frame_overlap) / frame_shift) - 1, + segment_label); + + writer.Write(utt, segmentation); + num_done++; + } + + KALDI_LOG << "Successfully processed " << num_done << " lines out of " + << num_lines << " in the segments file"; + + return (num_done > num_lines / 2 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-intersect-ali.cc b/src/segmenterbin/segmentation-intersect-ali.cc new file mode 100644 index 00000000000..a551eee02ce --- /dev/null +++ b/src/segmenterbin/segmentation-intersect-ali.cc @@ -0,0 +1,99 @@ +// segmenterbin/segmentation-intersect-ali.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Intersect (like sets) segmentation with an alignment and retain \n" + "only segments where the alignment is the specified label. \n" + "\n" + "Usage: segmentation-intersect-alignment [options] " + " " + "\n" + " e.g.: segmentation-intersect-alignment --binary=false ark:foo.seg " + "ark:filter.ali ark,t:-\n" + "See also: segmentation-combine-segments, " + "segmentation-intersect-segments, segmentation-create-subsegments\n"; + + ParseOptions po(usage); + + int32 ali_label = 0, min_alignment_chunk_length = 0; + + po.Register("ali-label", &ali_label, + "Intersect only at this label of alignments"); + po.Register("min-alignment-chunk-length", &min_alignment_chunk_length, + "The minimmum number of consecutive frames of ali_label in " + "alignment at which the segments can be intersected."); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1), + ali_rspecifier = po.GetArg(2), + segmentation_wspecifier = po.GetArg(3); + + int32 num_done = 0, num_err = 0; + + SegmentationWriter writer(segmentation_wspecifier); + SequentialSegmentationReader segmentation_reader(segmentation_rspecifier); + RandomAccessInt32VectorReader alignment_reader(ali_rspecifier); + + for (; !segmentation_reader.Done(); segmentation_reader.Next()) { + const Segmentation &segmentation = segmentation_reader.Value(); + const std::string &key = segmentation_reader.Key(); + + if (!alignment_reader.HasKey(key)) { + KALDI_WARN << "Could not find segmentation for key " << key + << " in " << ali_rspecifier; + num_err++; + continue; + } + const std::vector &ali = alignment_reader.Value(key); + + Segmentation out_segmentation; + IntersectSegmentationAndAlignment(segmentation, ali, ali_label, + min_alignment_chunk_length, + &out_segmentation); + out_segmentation.Sort(); + + writer.Write(key, out_segmentation); + num_done++; + } + + KALDI_LOG << "Intersected " << num_done + << " segmentations with alignments; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-intersect-segments.cc b/src/segmenterbin/segmentation-intersect-segments.cc new file mode 100644 index 00000000000..1c9861ba453 --- /dev/null +++ b/src/segmenterbin/segmentation-intersect-segments.cc @@ -0,0 +1,145 @@ +// segmenterbin/segmentation-intersect-segments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +namespace kaldi { +namespace segmenter { + +void IntersectSegmentationsNonOverlapping( + const Segmentation &in_segmentation, + const Segmentation &secondary_segmentation, + int32 mismatch_label, + Segmentation *out_segmentation) { + KALDI_ASSERT(out_segmentation); + KALDI_ASSERT(secondary_segmentation.Dim() > 0); + + std::vector alignment; + ConvertToAlignment(secondary_segmentation, -1, -1, 0, &alignment); + + for (SegmentList::const_iterator it = in_segmentation.Begin(); + it != in_segmentation.End(); ++it) { + if (it->end_frame >= alignment.size()) { + alignment.resize(it->end_frame + 1, -1); + } + Segmentation filter_segmentation; + InsertFromAlignment(alignment, it->start_frame, it->end_frame + 1, + 0, &filter_segmentation, NULL); + + for (SegmentList::const_iterator f_it = filter_segmentation.Begin(); + f_it != filter_segmentation.End(); ++f_it) { + int32 label = it->Label(); + if (f_it->Label() != it->Label()) { + if (mismatch_label == -1) continue; + label = mismatch_label; + } + + out_segmentation->EmplaceBack(f_it->start_frame, f_it->end_frame, + label); + } + } +} + +} +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Intersect segments from two archives by retaining only regions .\n" + "where the primary and secondary segments match on label\n" + "\n" + "Usage: segmentation-intersect-segments [options] " + " " + "\n" + " e.g.: segmentation-intersect-segments ark:foo.seg ark:bar.seg " + "ark,t:-\n" + "See also: segmentation-create-subsegments, " + "segmentation-intersect-ali\n"; + + int32 mismatch_label = -1; + bool assume_non_overlapping_secondary = true; + + ParseOptions po(usage); + + po.Register("mismatch-label", &mismatch_label, + "Intersect only where secondary segment has this label"); + po.Register("assume-non-overlapping-secondary", & + assume_non_overlapping_secondary, + "Assume secondary segments are non-overlapping"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string primary_rspecifier = po.GetArg(1), + secondary_rspecifier = po.GetArg(2), + segmentation_writer = po.GetArg(3); + + if (!assume_non_overlapping_secondary) { + KALDI_ERR << "Secondary segment must be non-overlapping for now"; + } + + int64 num_done = 0, num_err = 0; + + SegmentationWriter writer(segmentation_writer); + SequentialSegmentationReader primary_reader(primary_rspecifier); + RandomAccessSegmentationReader secondary_reader(secondary_rspecifier); + + for (; !primary_reader.Done(); primary_reader.Next()) { + const Segmentation &segmentation = primary_reader.Value(); + const std::string &key = primary_reader.Key(); + + if (!secondary_reader.HasKey(key)) { + KALDI_WARN << "Could not find segmentation for key " << key + << " in " << secondary_rspecifier; + num_err++; + continue; + } + const Segmentation &secondary_segmentation = secondary_reader.Value(key); + + Segmentation out_segmentation; + IntersectSegmentationsNonOverlapping(segmentation, + secondary_segmentation, + mismatch_label, + &out_segmentation); + + Sort(&out_segmentation); + + writer.Write(key, out_segmentation); + num_done++; + } + + KALDI_LOG << "Intersected " << num_done << " segmentations; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-merge-recordings.cc b/src/segmenterbin/segmentation-merge-recordings.cc new file mode 100644 index 00000000000..85b5108be29 --- /dev/null +++ b/src/segmenterbin/segmentation-merge-recordings.cc @@ -0,0 +1,101 @@ +// segmenterbin/segmentation-merge-recordings.cc + +// Copyright 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Merge segmentations of different recordings into one segmentation " + "using a mapping from new to old recording name\n" + "\n" + "Usage: segmentation-merge-recordings [options] " + " \n" + " e.g.: segmentation-merge-recordings ark:sdm2ihm_reco.map " + "ark:ihm_seg.ark ark:sdm_seg.ark\n"; + + ParseOptions po(usage); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + std::string new2old_list_rspecifier = po.GetArg(1); + std::string segmentation_rspecifier = po.GetArg(2), + segmentation_wspecifier = po.GetArg(3); + + SequentialTokenVectorReader new2old_reader(new2old_list_rspecifier); + RandomAccessSegmentationReader segmentation_reader( + segmentation_rspecifier); + SegmentationWriter segmentation_writer(segmentation_wspecifier); + + int32 num_new_segmentations = 0, num_old_segmentations = 0; + int64 num_segments = 0, num_err = 0; + + for (; !new2old_reader.Done(); new2old_reader.Next()) { + const std::vector &old_key_list = new2old_reader.Value(); + const std::string &new_key = new2old_reader.Key(); + + KALDI_ASSERT(old_key_list.size() > 0); + + Segmentation segmentation; + + for (std::vector::const_iterator it = old_key_list.begin(); + it != old_key_list.end(); ++it) { + num_old_segmentations++; + + if (!segmentation_reader.HasKey(*it)) { + KALDI_WARN << "Could not find key " << *it << " in " + << "old segmentation " << segmentation_rspecifier; + num_err++; + continue; + } + + const Segmentation &this_segmentation = segmentation_reader.Value(*it); + + num_segments += InsertFromSegmentation(this_segmentation, 0, NULL, + &segmentation); + } + Sort(&segmentation); + + segmentation_writer.Write(new_key, segmentation); + + num_new_segmentations++; + } + + KALDI_LOG << "Merged " << num_old_segmentations << " old segmentations " + << "into " << num_new_segmentations << " new segmentations; " + << "created overall " << num_segments << " segments; " + << "failed to merge " << num_err << " old segmentations"; + + return (num_new_segmentations > 0 && num_err < num_old_segmentations / 2); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-merge.cc b/src/segmenterbin/segmentation-merge.cc new file mode 100644 index 00000000000..21e9a410e15 --- /dev/null +++ b/src/segmenterbin/segmentation-merge.cc @@ -0,0 +1,146 @@ +// segmenterbin/segmentation-merge.cc + +// Copyright 2015 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Merge corresponding segments from multiple archives or files.\n" + "i.e. for each utterance in the first segmentation, the segments " + "from all the supplied segmentations are merged and put in a single " + "segmentation." + "\n" + "Usage: segmentation-merge [options] " + " ... " + "\n" + " e.g.: segmentation-merge ark:foo.seg ark:bar.seg ark,t:-\n" + " or \n" + " segmentation-merge " + " ... " + "\n" + " e.g.: segmentation-merge --binary=false foo bar -\n" + "See also: segmentation-copy, segmentation-merge-recordings, " + "segmentation-post-process --merge-labels\n"; + + bool binary = true; + bool sort = true; + + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("sort", &sort, "Sort the segements after merging"); + + po.Read(argc, argv); + + if (po.NumArgs() <= 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(po.NumArgs()); + + // all these "fn"'s are either rspecifiers or filenames. + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + + for (int32 i = 2; i < po.NumArgs(); i++) { + bool binary_in; + Input ki(po.GetArg(i), &binary_in); + Segmentation other_segmentation; + other_segmentation.Read(ki.Stream(), binary_in); + ExtendSegmentation(other_segmentation, false, + &segmentation); + } + + Sort(&segmentation); + + Output ko(segmentation_out_fn, binary); + segmentation.Write(ko.Stream(), binary); + + KALDI_LOG << "Merged segmentations to " << segmentation_out_fn; + return 0; + } else { + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + std::vector other_readers( + po.NumArgs()-2, + static_cast(NULL)); + + for (size_t i = 0; i < po.NumArgs()-2; i++) { + other_readers[i] = new RandomAccessSegmentationReader(po.GetArg(i+2)); + } + + for (; !reader.Done(); reader.Next()) { + Segmentation segmentation(reader.Value()); + std::string key = reader.Key(); + + for (size_t i = 0; i < po.NumArgs()-2; i++) { + if (!other_readers[i]->HasKey(key)) { + KALDI_WARN << "Could not find segmentation for key " << key + << " in " << po.GetArg(i+2); + num_err++; + } + const Segmentation &other_segmentation = + other_readers[i]->Value(key); + ExtendSegmentation(other_segmentation, false, + &segmentation); + } + + Sort(&segmentation); + + writer.Write(key, segmentation); + num_done++; + } + + KALDI_LOG << "Merged " << num_done << " segmentation; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-post-process.cc b/src/segmenterbin/segmentation-post-process.cc new file mode 100644 index 00000000000..921ee5dc5d8 --- /dev/null +++ b/src/segmenterbin/segmentation-post-process.cc @@ -0,0 +1,142 @@ +// segmenterbin/segmentation-post-process.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-post-processor.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Post processing of segmentation that does the following operations " + "in order: \n" + "1) Merge labels: Merge labels specified in --merge-labels into a " + "single label specified by --merge-dst-label. \n" + "2) Padding segments: Pad segments of label specified by --pad-label " + "by a few frames as specified by --pad-length. \n" + "3) Shrink segments: Shrink segments of label specified by " + "--shrink-label by a few frames as specified by --shrink-length. \n" + "4) Blend segments with neighbors: Blend short segments of class-id " + "specified by --blend-short-segments-class that are " + "shorter than --max-blend-length frames with their " + "respective neighbors if both the neighbors are within " + "a distance of --max-intersegment-length frames.\n" + "5) Remove segments: Remove segments of class-ids contained " + "in --remove-labels.\n" + "6) Merge adjacent segments: Merge adjacent segments of the same " + "label if they are within a distance of --max-intersegment-length " + "frames.\n" + "7) Split segments: Split segments that are longer than " + "--max-segment-length frames into overlapping segments " + "with an overlap of --overlap-length frames. \n" + "Usage: segmentation-post-process [options] " + "\n" + " or : segmentation-post-process [options] " + "\n" + " e.g.: segmentation-post-process --binary=false foo -\n" + " segmentation-post-process ark:foo.seg ark,t:-\n" + "See also: segmentation-merge, segmentation-copy, " + "segmentation-remove-segments\n"; + + bool binary = true; + + ParseOptions po(usage); + + SegmentationPostProcessingOptions opts; + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + + opts.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + SegmentationPostProcessor post_processor(opts); + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(2); + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + if (post_processor.PostProcess(&segmentation)) { + Output ko(segmentation_out_fn, binary); + Sort(&segmentation); + segmentation.Write(ko.Stream(), binary); + KALDI_LOG << "Post-processed segmentation " << segmentation_in_fn + << " and wrote " << segmentation_out_fn; + return 0; + } + KALDI_LOG << "Failed post-processing segmentation " + << segmentation_in_fn; + return 1; + } + + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + for (; !reader.Done(); reader.Next()) { + Segmentation segmentation(reader.Value()); + const std::string &key = reader.Key(); + + if (!post_processor.PostProcess(&segmentation)) { + num_err++; + continue; + } + + Sort(&segmentation); + + writer.Write(key, segmentation); + num_done++; + } + + KALDI_LOG << "Successfully post-processed " << num_done + << " segmentations; " + << "failed with " << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-remove-segments.cc b/src/segmenterbin/segmentation-remove-segments.cc new file mode 100644 index 00000000000..ce3ef2de6fd --- /dev/null +++ b/src/segmenterbin/segmentation-remove-segments.cc @@ -0,0 +1,155 @@ +// segmenterbin/segmentation-remove-segments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Remove segments of particular class_id (e.g silence or noise) " + "or a set of class_ids.\n" + "The labels to removed can be made utterance-specific by passing " + "--remove-labels-rspecifier option.\n" + "\n" + "Usage: segmentation-remove-segments [options] " + " \n" + " or : segmentation-remove-segments [options] " + " \n" + "\n" + " e.g.: segmentation-remove-segments --remove-label=0 ark:foo.ark " + "ark:foo.speech.ark\n" + "See also: segmentation-post-process --remove-labels, " + "segmentation-post-process --max-blend-length, segmentation-copy\n"; + + bool binary = true; + + int32 remove_label = -1; + std::string remove_labels_rspecifier = ""; + + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("remove-label", &remove_label, "Remove segments of this label"); + po.Register("remove-labels-rspecifier", &remove_labels_rspecifier, + "Specify colon separated list of labels for each key"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(2); + + // all these "fn"'s are either rspecifiers or filenames. + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + int64 num_done = 0, num_missing = 0; + + if (!in_is_rspecifier) { + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + if (!remove_labels_rspecifier.empty()) { + KALDI_ERR << "It does not make sense to specify " + << "--remove-labels-rspecifier " + << "for single segmentation"; + } + + RemoveSegments(remove_label, &segmentation); + + { + Output ko(segmentation_out_fn, binary); + segmentation.Write(ko.Stream(), binary); + } + + KALDI_LOG << "Removed segments and wrote segmentation to " + << segmentation_out_fn; + + return 0; + } else { + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + + RandomAccessTokenReader remove_labels_reader(remove_labels_rspecifier); + + for (; !reader.Done(); reader.Next(), num_done++) { + Segmentation segmentation(reader.Value()); + std::string key = reader.Key(); + + if (!remove_labels_rspecifier.empty()) { + if (!remove_labels_reader.HasKey(key)) { + KALDI_WARN << "No remove-labels found for recording " << key; + num_missing++; + writer.Write(key, segmentation); + continue; + } + + std::vector remove_labels; + const std::string& remove_labels_str = + remove_labels_reader.Value(key); + + if (!SplitStringToIntegers(remove_labels_str, ":,", false, + &remove_labels)) { + KALDI_ERR << "Bad colon-separated list " + << remove_labels_str << " for key " << key + << " in " << remove_labels_rspecifier; + } + + remove_label = remove_labels[0]; + + RemoveSegments(remove_labels, &segmentation); + } else { + RemoveSegments(remove_label, &segmentation); + } + writer.Write(key, segmentation); + } + + KALDI_LOG << "Removed segments " << "from " << num_done + << " segmentations; " + << "remove-labels list missing for " << num_missing; + return (num_done != 0 ? 0 : 1); + } + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-split-segments.cc b/src/segmenterbin/segmentation-split-segments.cc new file mode 100644 index 00000000000..a45211b28ca --- /dev/null +++ b/src/segmenterbin/segmentation-split-segments.cc @@ -0,0 +1,194 @@ +// segmenterbin/segmentation-split-segments.cc + +// Copyright 2016 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Split long segments optionally using alignment.\n" + "The splitting works in two possible ways:\n" + " 1) If alignment is not provided: The segments are split if they\n" + " are longer than --max-segment-length frames into overlapping\n" + " segments with an overlap of --overlap-length frames.\n" + " 2) If alignment is provided: The segments are split if they\n" + " are longer than --max-segment-length frames at the region \n" + " where there is a contiguous segment of --ali-label in the \n" + " alignment that is at least --min-alignment-chunk-length frames \n" + " long.\n" + "Usage: segmentation-split-segments [options] " + " \n" + " or : segmentation-split-segments [options] " + " \n" + " e.g.: segmentation-split-segments --binary=false foo -\n" + " segmentation-split-segments ark:foo.seg ark,t:-\n" + "See also: segmentation-post-process\n"; + + bool binary = true; + int32 max_segment_length = -1; + int32 min_remainder = -1; + int32 overlap_length = 0; + int32 split_label = -1; + int32 ali_label = 0; + int32 min_alignment_chunk_length = 2; + + std::string alignments_in_fn; + + ParseOptions po(usage); + + po.Register("binary", &binary, + "Write in binary mode " + "(only relevant if output is a wxfilename)"); + po.Register("max-segment-length", &max_segment_length, + "If segment is longer than this length, split it into " + "pieces with less than these many frames. " + "Refer to the SplitSegments() code for details. " + "Used in conjunction with the option --overlap-length."); + po.Register("min-remainder", &min_remainder, + "The minimum remainder left after splitting that will " + "prevent a splitting from begin done. " + "Set to max-segment-length / 2, if not specified. " + "Applicable only when alignments is not specified."); + po.Register("overlap-length", &overlap_length, + "When splitting segments longer than max-segment-length, " + "have the pieces overlap by these many frames. " + "Refer to the SplitSegments() code for details. " + "Used in conjunction with the option --max-segment-length."); + po.Register("split-label", &split_label, + "If supplied, split only segments of these labels. " + "Otherwise, split all segments."); + po.Register("alignments", &alignments_in_fn, + "A single alignment file or archive of alignment used " + "for splitting, " + "depending on whether the input segmentation is single file " + "or archive"); + po.Register("ali-label", &ali_label, + "Split at this label of alignments"); + po.Register("min-alignment-chunk-length", &min_alignment_chunk_length, + "The minimum number of frames of alignment with ali_label " + "at which to split the segments"); + + po.Read(argc, argv); + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_in_fn = po.GetArg(1), + segmentation_out_fn = po.GetArg(2); + + bool in_is_rspecifier = + (ClassifyRspecifier(segmentation_in_fn, NULL, NULL) + != kNoRspecifier), + out_is_wspecifier = + (ClassifyWspecifier(segmentation_out_fn, NULL, NULL, NULL) + != kNoWspecifier); + + if (in_is_rspecifier != out_is_wspecifier) + KALDI_ERR << "Cannot mix regular files and archives"; + + if (min_remainder == -1) { + min_remainder = max_segment_length / 2; + } + + int64 num_done = 0, num_err = 0; + + if (!in_is_rspecifier) { + std::vector ali; + + Segmentation segmentation; + { + bool binary_in; + Input ki(segmentation_in_fn, &binary_in); + segmentation.Read(ki.Stream(), binary_in); + } + + if (!alignments_in_fn.empty()) { + { + bool binary_in; + Input ki(alignments_in_fn, &binary_in); + ReadIntegerVector(ki.Stream(), binary_in, &ali); + } + SplitSegmentsUsingAlignment(max_segment_length, + split_label, ali, ali_label, + min_alignment_chunk_length, + &segmentation); + } else { + SplitSegments(max_segment_length, min_remainder, + overlap_length, split_label, &segmentation); + } + + Sort(&segmentation); + + { + Output ko(segmentation_out_fn, binary); + segmentation.Write(ko.Stream(), binary); + } + + KALDI_LOG << "Split segmentation " << segmentation_in_fn + << " and wrote " << segmentation_out_fn; + return 0; + } + + SegmentationWriter writer(segmentation_out_fn); + SequentialSegmentationReader reader(segmentation_in_fn); + RandomAccessInt32VectorReader ali_reader(alignments_in_fn); + + for (; !reader.Done(); reader.Next()) { + Segmentation segmentation(reader.Value()); + const std::string &key = reader.Key(); + + if (!alignments_in_fn.empty()) { + if (!ali_reader.HasKey(key)) { + KALDI_WARN << "Could not find key " << key + << " in alignments " << alignments_in_fn; + num_err++; + continue; + } + SplitSegmentsUsingAlignment(max_segment_length, split_label, + ali_reader.Value(key), ali_label, + min_alignment_chunk_length, + &segmentation); + } else { + SplitSegments(max_segment_length, min_remainder, + overlap_length, split_label, + &segmentation); + } + + Sort(&segmentation); + + writer.Write(key, segmentation); + num_done++; + } + + KALDI_LOG << "Successfully split " << num_done + << " segmentations; " + << "failed with " << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-to-ali.cc b/src/segmenterbin/segmentation-to-ali.cc new file mode 100644 index 00000000000..9a618247a42 --- /dev/null +++ b/src/segmenterbin/segmentation-to-ali.cc @@ -0,0 +1,99 @@ +// segmenterbin/segmentation-to-ali.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation-utils.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert segmentation to alignment\n" + "\n" + "Usage: segmentation-to-ali [options] " + "\n" + " e.g.: segmentation-to-ali ark:1.seg ark:1.ali\n"; + + std::string lengths_rspecifier; + int32 default_label = 0, length_tolerance = 2; + + ParseOptions po(usage); + + po.Register("lengths-rspecifier", &lengths_rspecifier, + "Archive of frame lengths " + "of the utterances. Fills up any extra length with " + "the specified default-label"); + po.Register("default-label", &default_label, "Fill any extra length " + "with this label"); + po.Register("length-tolerance", &length_tolerance, "Tolerate shortage of " + "this many frames in the specified lengths file"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + std::string segmentation_rspecifier = po.GetArg(1); + std::string alignment_wspecifier = po.GetArg(2); + + RandomAccessInt32Reader lengths_reader(lengths_rspecifier); + + SequentialSegmentationReader segmentation_reader(segmentation_rspecifier); + Int32VectorWriter alignment_writer(alignment_wspecifier); + + int32 num_err = 0, num_done = 0; + for (; !segmentation_reader.Done(); segmentation_reader.Next()) { + const Segmentation &segmentation = segmentation_reader.Value(); + const std::string &key = segmentation_reader.Key(); + + int32 length = -1; + if (lengths_rspecifier != "") { + if (!lengths_reader.HasKey(key)) { + KALDI_WARN << "Could not find length for utterance " << key; + num_err++; + continue; + } + length = lengths_reader.Value(key); + } + + std::vector ali; + if (!ConvertToAlignment(segmentation, default_label, length, + length_tolerance, &ali)) { + KALDI_WARN << "Conversion failed for utterance " << key; + num_err++; + continue; + } + alignment_writer.Write(key, ali); + num_done++; + } + + KALDI_LOG << "Converted " << num_done << " segmentations into alignments; " + << "failed with " << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/segmenterbin/segmentation-to-rttm.cc b/src/segmenterbin/segmentation-to-rttm.cc new file mode 100644 index 00000000000..6ffd1a8b1e8 --- /dev/null +++ b/src/segmenterbin/segmentation-to-rttm.cc @@ -0,0 +1,255 @@ +// segmenterbin/segmentation-to-rttm.cc + +// Copyright 2015-16 Vimal Manohar + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation.h" + +namespace kaldi { +namespace segmenter { + +/** + * This function is used to write the segmentation in RTTM format. Each class is + * treated as a "SPEAKER". If map_to_speech_and_sil is true, then the class_id 0 + * is treated as SILENCE and every other class_id as SPEECH. The argument + * start_time is used to set what the time corresponding to the 0 frame in the + * segment. Each segment is converted into the following line, + * SPEAKER 1 + * ,where + * is the file_id supplied as an argument + * is the start time of the segment in seconds + * is the length of the segment in seconds + * is the class_id stored in the segment. If map_to_speech_and_sil is + * set true then is either SPEECH or SILENCE. + * The function retunns the largest class_id that it encounters. +**/ + +int32 WriteRttm(const Segmentation &segmentation, + std::ostream &os, const std::string &file_id, + const std::string &channel, + BaseFloat frame_shift, BaseFloat start_time, + bool map_to_speech_and_sil) { + SegmentList::const_iterator it = segmentation.Begin(); + int32 largest_class = 0; + for (; it != segmentation.End(); ++it) { + os << "SPEAKER " << file_id << " " << channel << " " + << it->start_frame * frame_shift + start_time << " " + << (it->Length()) * frame_shift << " "; + if (map_to_speech_and_sil) { + switch (it->Label()) { + case 1: + os << "SPEECH "; + break; + default: + os << "SILENCE "; + break; + } + largest_class = 1; + } else { + if (it->Label() >= 0) { + os << it->Label() << " "; + if (it->Label() > largest_class) + largest_class = it->Label(); + } + } + os << "" << std::endl; + } + return largest_class; +} + +} +} + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert segmentation into RTTM\n" + "\n" + "Usage: segmentation-to-rttm [options] \n" + " e.g.: segmentation-to-rttm ark:1.seg -\n"; + + bool map_to_speech_and_sil = true; + + BaseFloat frame_shift = 0.01; + std::string segments_rxfilename; + std::string reco2file_and_channel_rxfilename; + ParseOptions po(usage); + + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + po.Register("segments", &segments_rxfilename, "Segments file"); + po.Register("reco2file-and-channel", &reco2file_and_channel_rxfilename, "reco2file_and_channel file"); + po.Register("map-to-speech-and-sil", &map_to_speech_and_sil, "Map all classes to SPEECH and SILENCE"); + + po.Read(argc, argv); + + if (po.NumArgs() != 2) { + po.PrintUsage(); + exit(1); + } + + unordered_map utt2file; + unordered_map utt2start_time; + + if (!segments_rxfilename.empty()) { + Input ki(segments_rxfilename); // no binary argment: never binary. + int32 i = 0; + std::string line; + /* read each line from segments file */ + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + // Split the line by space or tab and check the number of fields in each + // line. There must be 4 fields--segment name , reacording wav file name, + // start time, end time; 5th field (channel info) is optional. + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 4 && split_line.size() != 5) { + KALDI_WARN << "Invalid line in segments file: " << line; + continue; + } + std::string segment = split_line[0], + utterance = split_line[1], + start_str = split_line[2], + end_str = split_line[3]; + + // Convert the start time and endtime to real from string. Segment is + // ignored if start or end time cannot be converted to real. + double start, end; + if (!ConvertStringToReal(start_str, &start)) { + KALDI_WARN << "Invalid line in segments file [bad start]: " << line; + continue; + } + if (!ConvertStringToReal(end_str, &end)) { + KALDI_WARN << "Invalid line in segments file [bad end]: " << line; + continue; + } + // start time must not be negative; start time must not be greater than + // end time, except if end time is -1 + if (start < 0 || end <= 0 || start >= end) { + KALDI_WARN << "Invalid line in segments file [empty or invalid segment]: " + << line; + continue; + } + int32 channel = -1; // means channel info is unspecified. + // if each line has 5 elements then 5th element must be channel identifier + if(split_line.size() == 5) { + if (!ConvertStringToInteger(split_line[4], &channel) || channel < 0) { + KALDI_WARN << "Invalid line in segments file [bad channel]: " << line; + continue; + } + } + + utt2file.insert(std::make_pair(segment, utterance)); + utt2start_time.insert(std::make_pair(segment, start)); + i++; + } + KALDI_LOG << "Read " << i << " lines from " << segments_rxfilename; + } + + unordered_map , StringHasher> reco2file_and_channel; + + if (!reco2file_and_channel_rxfilename.empty()) { + Input ki(reco2file_and_channel_rxfilename); // no binary argment: never binary. + + int32 i = 0; + std::string line; + /* read each line from reco2file_and_channel file */ + while (std::getline(ki.Stream(), line)) { + std::vector split_line; + SplitStringToVector(line, " \t\r", true, &split_line); + if (split_line.size() != 3) { + KALDI_WARN << "Invalid line in reco2file_and_channel file: " << line; + continue; + } + + const std::string &reco_id = split_line[0]; + const std::string &file_id = split_line[1]; + const std::string &channel = split_line[2]; + + reco2file_and_channel.insert(std::make_pair(reco_id, std::make_pair(file_id, channel))); + i++; + } + + KALDI_LOG << "Read " << i << " lines from " << reco2file_and_channel_rxfilename; + } + + unordered_set seen_files; + + std::string segmentation_rspecifier = po.GetArg(1), + rttm_out_wxfilename = po.GetArg(2); + + int64 num_done = 0, num_err = 0; + + Output ko(rttm_out_wxfilename, false); + SequentialSegmentationReader reader(segmentation_rspecifier); + for (; !reader.Done(); reader.Next(), num_done++) { + Segmentation segmentation(reader.Value()); + const std::string &key = reader.Key(); + + std::string reco_id = key; + BaseFloat start_time = 0.0; + if (!segments_rxfilename.empty()) { + if (utt2file.count(key) == 0 || utt2start_time.count(key) == 0) + KALDI_ERR << "Could not find key " << key << " in segments " + << segments_rxfilename; + KALDI_ASSERT(utt2file.count(key) > 0 && utt2start_time.count(key) > 0); + reco_id = utt2file[key]; + start_time = utt2start_time[key]; + } + + std::string file_id, channel; + if (!reco2file_and_channel_rxfilename.empty()) { + if (reco2file_and_channel.count(reco_id) == 0) + KALDI_ERR << "Could not find recording " << reco_id + << " in " << reco2file_and_channel_rxfilename; + file_id = reco2file_and_channel[reco_id].first; + channel = reco2file_and_channel[reco_id].second; + } else { + file_id = reco_id; + channel = "1"; + } + + int32 largest_class = WriteRttm(segmentation, ko.Stream(), file_id, channel, frame_shift, start_time, map_to_speech_and_sil); + + if (map_to_speech_and_sil) { + if (seen_files.count(reco_id) == 0) { + ko.Stream() << "SPKR-INFO " << file_id << " " << channel << " unknown SILENCE \n"; + ko.Stream() << "SPKR-INFO " << file_id << " " << channel << " unknown SPEECH \n"; + seen_files.insert(reco_id); + } + } else { + for (int32 i = 0; i < largest_class; i++) { + ko.Stream() << "SPKR-INFO " << file_id << " " << channel << " unknown " << i << " \n"; + } + } + } + + KALDI_LOG << "Copied " << num_done << " segmentation; failed with " + << num_err << " segmentations"; + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + + + + diff --git a/src/segmenterbin/segmentation-to-segments.cc b/src/segmenterbin/segmentation-to-segments.cc new file mode 100644 index 00000000000..c57aa827ead --- /dev/null +++ b/src/segmenterbin/segmentation-to-segments.cc @@ -0,0 +1,133 @@ +// segmenterbin/segmentation-to-segments.cc + +// Copyright 2015-16 Vimal Manohar (Johns Hopkins University) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "segmenter/segmentation.h" + +int main(int argc, char *argv[]) { + try { + using namespace kaldi; + using namespace segmenter; + + const char *usage = + "Convert segmentation to a segments file and utt2spk file." + "Assumes that the input segmentations are indexed by reco-id and " + "treats speakers from different recording as distinct speakers." + "\n" + "Usage: segmentation-to-segments [options] " + " \n" + " e.g.: segmentation-to-segments ark:foo.seg ark,t:utt2spk segments\n"; + + BaseFloat frame_shift = 0.01, frame_overlap = 0.015; + bool single_speaker = false, per_utt_speaker = false; + ParseOptions po(usage); + + po.Register("frame-shift", &frame_shift, "Frame shift in seconds"); + po.Register("frame-overlap", &frame_overlap, "Frame overlap in seconds"); + po.Register("single-speaker", &single_speaker, "If this is set true, " + "then all the utterances in a recording are mapped to the " + "same speaker"); + po.Register("per-utt-speaker", &per_utt_speaker, + "If this is set true, then each utterance is mapped to distint " + "speaker with spkr_id = utt_id"); + + po.Read(argc, argv); + + if (po.NumArgs() != 3) { + po.PrintUsage(); + exit(1); + } + + if (frame_shift < 0.001 || frame_shift > 1) { + KALDI_ERR << "Invalid frame-shift " << frame_shift << "; must be in " + << "the range [0.001,1]"; + } + + if (frame_overlap < 0 || frame_overlap > 1) { + KALDI_ERR << "Invalid frame-overlap " << frame_overlap << "; must be in " + << "the range [0,1]"; + } + + std::string segmentation_rspecifier = po.GetArg(1), + utt2spk_wspecifier = po.GetArg(2), + segments_wxfilename = po.GetArg(3); + + SequentialSegmentationReader reader(segmentation_rspecifier); + TokenWriter utt2spk_writer(utt2spk_wspecifier); + + Output ko(segments_wxfilename, false); + + int32 num_done = 0; + int64 num_segments = 0; + + for (; !reader.Done(); reader.Next(), num_done++) { + const Segmentation &segmentation = reader.Value(); + const std::string &key = reader.Key(); + + for (SegmentList::const_iterator it = segmentation.Begin(); + it != segmentation.End(); ++it) { + BaseFloat start_time = it->start_frame * frame_shift; + BaseFloat end_time = (it->end_frame + 1) * frame_shift + frame_overlap; + + std::ostringstream oss; + + if (!single_speaker) { + oss << key << "-" << it->Label(); + } else { + oss << key; + } + + std::string spk = oss.str(); + + oss << "-"; + oss << std::setw(6) << std::setfill('0') << it->start_frame; + oss << std::setw(1) << "-"; + oss << std::setw(6) << std::setfill('0') + << it->end_frame + 1 + + static_cast(frame_overlap / frame_shift); + + std::string utt = oss.str(); + + if (per_utt_speaker) + utt2spk_writer.Write(utt, utt); + else + utt2spk_writer.Write(utt, spk); + + ko.Stream() << utt << " " << key << " "; + ko.Stream() << std::fixed << std::setprecision(3) << start_time << " "; + ko.Stream() << std::setprecision(3) << end_time << "\n"; + + num_segments++; + } + } + + KALDI_LOG << "Converted " << num_done << " segmentations to segments; " + << "wrote " << num_segments << " segments"; + + return (num_done != 0 ? 0 : 1); + } catch(const std::exception &e) { + std::cerr << e.what(); + return -1; + } +} + diff --git a/src/util/kaldi-holder.cc b/src/util/kaldi-holder.cc index a26bdf2ce29..a86f09a2030 100644 --- a/src/util/kaldi-holder.cc +++ b/src/util/kaldi-holder.cc @@ -34,7 +34,7 @@ bool ExtractObjectRange(const Matrix &input, const std::string &range, SplitStringToVector(range, ",", false, &splits); if (!((splits.size() == 1 && !splits[0].empty()) || (splits.size() == 2 && !splits[0].empty() && !splits[1].empty()))) { - KALDI_ERR << "Invalid range specifier: " << range; + KALDI_ERR << "Invalid range specifier for matrix: " << range; return false; } std::vector row_range, col_range; @@ -75,6 +75,48 @@ template bool ExtractObjectRange(const Matrix &, const std::string &, template bool ExtractObjectRange(const Matrix &, const std::string &, Matrix *); +template +bool ExtractObjectRange(const Vector &input, const std::string &range, + Vector *output) { + if (range.empty()) { + KALDI_ERR << "Empty range specifier."; + return false; + } + std::vector splits; + SplitStringToVector(range, ",", false, &splits); + if (!((splits.size() == 1 && !splits[0].empty()))) { + KALDI_ERR << "Invalid range specifier for vector: " << range; + return false; + } + std::vector index_range; + bool status = true; + if (splits[0] != ":") + status = SplitStringToIntegers(splits[0], ":", false, &index_range); + + if (index_range.size() == 0) { + index_range.push_back(0); + index_range.push_back(input.Dim() - 1); + } + + if (!(status && index_range.size() == 2 && + index_range[0] >= 0 && index_range[0] <= index_range[1] && + index_range[1] < input.Dim())) { + KALDI_ERR << "Invalid range specifier: " << range + << " for vector of size " << input.Dim(); + return false; + } + int32 size = index_range[1] - index_range[0] + 1; + output->Resize(size, kUndefined); + output->CopyFromVec(input.Range(index_range[0], size)); + return true; +} + +// template instantiation +template bool ExtractObjectRange(const Vector &, const std::string &, + Vector *); +template bool ExtractObjectRange(const Vector &, const std::string &, + Vector *); + bool ExtractRangeSpecifier(const std::string &rxfilename_with_range, std::string *data_rxfilename, std::string *range) { diff --git a/src/util/kaldi-holder.h b/src/util/kaldi-holder.h index 06d7ec8e745..9ab148387ee 100644 --- a/src/util/kaldi-holder.h +++ b/src/util/kaldi-holder.h @@ -242,6 +242,11 @@ template bool ExtractObjectRange(const Matrix &input, const std::string &range, Matrix *output); +/// The template is specialized types Vector and Vector. +template +bool ExtractObjectRange(const Vector &input, const std::string &range, + Vector *output); + // In SequentialTableReaderScriptImpl and RandomAccessTableReaderScriptImpl, for // cases where the scp contained 'range specifiers' (things in square brackets diff --git a/tools/config/common_path.sh b/tools/config/common_path.sh index 3e2ea50d685..36b5350dd8e 100644 --- a/tools/config/common_path.sh +++ b/tools/config/common_path.sh @@ -20,4 +20,5 @@ ${KALDI_ROOT}/src/online2bin:\ ${KALDI_ROOT}/src/onlinebin:\ ${KALDI_ROOT}/src/sgmm2bin:\ ${KALDI_ROOT}/src/sgmmbin:\ +${KALDI_ROOT}/src/segmenterbin:\ $PATH