diff --git a/egs/callhome_diarization/v1/diarization/cluster.sh b/egs/callhome_diarization/v1/diarization/cluster.sh index fa5ead5b6b9..1d92bef462a 100755 --- a/egs/callhome_diarization/v1/diarization/cluster.sh +++ b/egs/callhome_diarization/v1/diarization/cluster.sh @@ -14,6 +14,7 @@ stage=0 nj=10 cleanup=true threshold=0.5 +max_spk_fraction=1.0 rttm_channel=0 read_costs=false reco2num_spk= @@ -36,6 +37,10 @@ if [ $# != 2 ]; then echo " --threshold # Cluster stopping criterion. Clusters with scores greater" echo " # than this value will be merged until all clusters" echo " # exceed this value." + echo " --max-spk-fraction # Clusters with total fraction of utterances greater than" + echo " # this value will not be merged. This is active only when" + echo " # reco2num-spk is supplied and" + echo " # 1.0 / num-spk <= max-spk-fraction <= 1.0." echo " --rttm-channel # The value passed into the RTTM channel field. Only affects" echo " # the format of the RTTM file." echo " --read-costs # If true, interpret input scores as costs, i.e. similarity" @@ -78,8 +83,8 @@ if [ $stage -le 0 ]; then echo "$0: clustering scores" $cmd JOB=1:$nj $dir/log/agglomerative_cluster.JOB.log \ agglomerative-cluster --threshold=$threshold --read-costs=$read_costs \ - --reco2num-spk-rspecifier=$reco2num_spk scp:"$feats" \ - ark,t:$sdata/JOB/spk2utt ark,t:$dir/labels.JOB || exit 1; + --reco2num-spk-rspecifier=$reco2num_spk --max-spk-fraction=$max_spk_fraction \ + scp:"$feats" ark,t:$sdata/JOB/spk2utt ark,t:$dir/labels.JOB || exit 1; fi if [ $stage -le 1 ]; then diff --git a/src/ivector/agglomerative-clustering.cc b/src/ivector/agglomerative-clustering.cc index 30138e00637..bc59733fccc 100644 --- a/src/ivector/agglomerative-clustering.cc +++ b/src/ivector/agglomerative-clustering.cc @@ -36,8 +36,10 @@ void AgglomerativeClusterer::Cluster() { queue_.pop(); // check to make sure clusters have not already been merged if ((active_clusters_.find(i) != active_clusters_.end()) && - (active_clusters_.find(j) != active_clusters_.end())) - MergeClusters(i, j); + (active_clusters_.find(j) != active_clusters_.end())) { + if (clusters_map_[i]->size + clusters_map_[j]->size <= max_cluster_size_) + MergeClusters(i, j); + } } std::vector new_assignments(num_points_); @@ -123,9 +125,12 @@ void AgglomerativeCluster( const Matrix &costs, BaseFloat thresh, int32 min_clust, + BaseFloat max_cluster_fraction, std::vector *assignments_out) { KALDI_ASSERT(min_clust >= 0); - AgglomerativeClusterer ac(costs, thresh, min_clust, assignments_out); + KALDI_ASSERT(max_cluster_fraction >= 1.0 / min_clust); + AgglomerativeClusterer ac(costs, thresh, min_clust, max_cluster_fraction, + assignments_out); ac.Cluster(); } diff --git a/src/ivector/agglomerative-clustering.h b/src/ivector/agglomerative-clustering.h index 310a336f8b5..bf8b9c0f91b 100644 --- a/src/ivector/agglomerative-clustering.h +++ b/src/ivector/agglomerative-clustering.h @@ -57,11 +57,13 @@ class AgglomerativeClusterer { const Matrix &costs, BaseFloat thresh, int32 min_clust, + BaseFloat max_cluster_fraction, std::vector *assignments_out) : count_(0), costs_(costs), thresh_(thresh), min_clust_(min_clust), assignments_(assignments_out) { num_clusters_ = costs.NumRows(); num_points_ = costs.NumRows(); + max_cluster_size_ = ceil(num_points_ * max_cluster_fraction); } // Performs the clustering @@ -80,6 +82,7 @@ class AgglomerativeClusterer { const Matrix &costs_; // cost matrix BaseFloat thresh_; // stopping criterion threshold int32 min_clust_; // minimum number of clusters + int32 max_cluster_size_; // maximum number of points in a cluster std::vector *assignments_; // assignments out // Priority queue using greater (lowest costs are highest priority). @@ -107,6 +110,7 @@ class AgglomerativeClusterer { cost for pairing the utterances for its row and column * - A threshold which is used as the stopping criterion for the clusters * - A minimum number of clusters that will not be merged past + * - A maximum fraction of points that can be in a cluster * - A vector which will be filled with integer IDs corresponding to each * of the rows/columns of the score matrix. * @@ -131,6 +135,7 @@ void AgglomerativeCluster( const Matrix &costs, BaseFloat thresh, int32 min_clust, + BaseFloat max_cluster_fraction, std::vector *assignments_out); } // end namespace kaldi. diff --git a/src/ivectorbin/agglomerative-cluster.cc b/src/ivectorbin/agglomerative-cluster.cc index 9dca9bfeb83..dbfa2c25b69 100644 --- a/src/ivectorbin/agglomerative-cluster.cc +++ b/src/ivectorbin/agglomerative-cluster.cc @@ -47,7 +47,7 @@ int main(int argc, char *argv[]) { ParseOptions po(usage); std::string reco2num_spk_rspecifier; - BaseFloat threshold = 0.0; + BaseFloat threshold = 0.0, max_spk_fraction = 1.0; bool read_costs = false; po.Register("reco2num-spk-rspecifier", &reco2num_spk_rspecifier, @@ -55,6 +55,10 @@ int main(int argc, char *argv[]) { " recording and the option --threshold is ignored."); po.Register("threshold", &threshold, "Merge clusters if their distance" " is less than this threshold."); + po.Register("max-spk-fraction", &max_spk_fraction, "Merge clusters if the" + " total fraction of utterances in them is less than this threshold." + " This is active only when reco2num-spk-rspecifier is supplied and" + " 1.0 / num-spk <= max-spk-fraction <= 1.0."); po.Register("read-costs", &read_costs, "If true, the first" " argument is interpreted as a matrix of costs rather than a" " similarity matrix."); @@ -90,10 +94,14 @@ int main(int argc, char *argv[]) { std::vector spk_ids; if (reco2num_spk_rspecifier.size()) { int32 num_speakers = reco2num_spk_reader.Value(reco); - AgglomerativeCluster(costs, - std::numeric_limits::max(), num_speakers, &spk_ids); + if (1.0 / num_speakers <= max_spk_fraction && max_spk_fraction <= 1.0) + AgglomerativeCluster(costs, std::numeric_limits::max(), + num_speakers, max_spk_fraction, &spk_ids); + else + AgglomerativeCluster(costs, std::numeric_limits::max(), + num_speakers, 1.0, &spk_ids); } else { - AgglomerativeCluster(costs, threshold, 1, &spk_ids); + AgglomerativeCluster(costs, threshold, 1, 1.0, &spk_ids); } for (int32 i = 0; i < spk_ids.size(); i++) label_writer.Write(uttlist[i], spk_ids[i]);