Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions egs/callhome_diarization/v1/diarization/cluster.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ stage=0
nj=10
cleanup=true
threshold=0.5
max_spk_fraction=1.0
rttm_channel=0
read_costs=false
reco2num_spk=
Expand All @@ -36,6 +37,10 @@ if [ $# != 2 ]; then
echo " --threshold <threshold|0> # Cluster stopping criterion. Clusters with scores greater"
echo " # than this value will be merged until all clusters"
echo " # exceed this value."
echo " --max-spk-fraction <max-spk-fraction|1.0> # Clusters with total fraction of utterances greater than"
echo " # this value will not be merged. This is active only when"
echo " # reco2num-spk is supplied and"
echo " # 1.0 / num-spk <= max-spk-fraction <= 1.0."
echo " --rttm-channel <rttm-channel|0> # The value passed into the RTTM channel field. Only affects"
echo " # the format of the RTTM file."
echo " --read-costs <read-costs|false> # If true, interpret input scores as costs, i.e. similarity"
Expand Down Expand Up @@ -78,8 +83,8 @@ if [ $stage -le 0 ]; then
echo "$0: clustering scores"
$cmd JOB=1:$nj $dir/log/agglomerative_cluster.JOB.log \
agglomerative-cluster --threshold=$threshold --read-costs=$read_costs \
--reco2num-spk-rspecifier=$reco2num_spk scp:"$feats" \
ark,t:$sdata/JOB/spk2utt ark,t:$dir/labels.JOB || exit 1;
--reco2num-spk-rspecifier=$reco2num_spk --max-spk-fraction=$max_spk_fraction \
scp:"$feats" ark,t:$sdata/JOB/spk2utt ark,t:$dir/labels.JOB || exit 1;
fi

if [ $stage -le 1 ]; then
Expand Down
11 changes: 8 additions & 3 deletions src/ivector/agglomerative-clustering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ void AgglomerativeClusterer::Cluster() {
queue_.pop();
// check to make sure clusters have not already been merged
if ((active_clusters_.find(i) != active_clusters_.end()) &&
(active_clusters_.find(j) != active_clusters_.end()))
MergeClusters(i, j);
(active_clusters_.find(j) != active_clusters_.end())) {
if (clusters_map_[i]->size + clusters_map_[j]->size <= max_cluster_size_)
MergeClusters(i, j);
}
}

std::vector<int32> new_assignments(num_points_);
Expand Down Expand Up @@ -123,9 +125,12 @@ void AgglomerativeCluster(
const Matrix<BaseFloat> &costs,
BaseFloat thresh,
int32 min_clust,
BaseFloat max_cluster_fraction,
std::vector<int32> *assignments_out) {
KALDI_ASSERT(min_clust >= 0);
AgglomerativeClusterer ac(costs, thresh, min_clust, assignments_out);
KALDI_ASSERT(max_cluster_fraction >= 1.0 / min_clust);
AgglomerativeClusterer ac(costs, thresh, min_clust, max_cluster_fraction,
assignments_out);
ac.Cluster();
}

Expand Down
5 changes: 5 additions & 0 deletions src/ivector/agglomerative-clustering.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ class AgglomerativeClusterer {
const Matrix<BaseFloat> &costs,
BaseFloat thresh,
int32 min_clust,
BaseFloat max_cluster_fraction,
std::vector<int32> *assignments_out)
: count_(0), costs_(costs), thresh_(thresh), min_clust_(min_clust),
assignments_(assignments_out) {
num_clusters_ = costs.NumRows();
num_points_ = costs.NumRows();
max_cluster_size_ = ceil(num_points_ * max_cluster_fraction);
}

// Performs the clustering
Expand All @@ -80,6 +82,7 @@ class AgglomerativeClusterer {
const Matrix<BaseFloat> &costs_; // cost matrix
BaseFloat thresh_; // stopping criterion threshold
int32 min_clust_; // minimum number of clusters
int32 max_cluster_size_; // maximum number of points in a cluster
std::vector<int32> *assignments_; // assignments out

// Priority queue using greater (lowest costs are highest priority).
Expand Down Expand Up @@ -107,6 +110,7 @@ class AgglomerativeClusterer {
cost for pairing the utterances for its row and column
* - A threshold which is used as the stopping criterion for the clusters
* - A minimum number of clusters that will not be merged past
* - A maximum fraction of points that can be in a cluster
* - A vector which will be filled with integer IDs corresponding to each
* of the rows/columns of the score matrix.
*
Expand All @@ -131,6 +135,7 @@ void AgglomerativeCluster(
const Matrix<BaseFloat> &costs,
BaseFloat thresh,
int32 min_clust,
BaseFloat max_cluster_fraction,
std::vector<int32> *assignments_out);

} // end namespace kaldi.
Expand Down
16 changes: 12 additions & 4 deletions src/ivectorbin/agglomerative-cluster.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,18 @@ int main(int argc, char *argv[]) {

ParseOptions po(usage);
std::string reco2num_spk_rspecifier;
BaseFloat threshold = 0.0;
BaseFloat threshold = 0.0, max_spk_fraction = 1.0;
bool read_costs = false;

po.Register("reco2num-spk-rspecifier", &reco2num_spk_rspecifier,
"If supplied, clustering creates exactly this many clusters for each"
" recording and the option --threshold is ignored.");
po.Register("threshold", &threshold, "Merge clusters if their distance"
" is less than this threshold.");
po.Register("max-spk-fraction", &max_spk_fraction, "Merge clusters if the"
" total fraction of utterances in them is less than this threshold."
" This is active only when reco2num-spk-rspecifier is supplied and"
" 1.0 / num-spk <= max-spk-fraction <= 1.0.");
po.Register("read-costs", &read_costs, "If true, the first"
" argument is interpreted as a matrix of costs rather than a"
" similarity matrix.");
Expand Down Expand Up @@ -90,10 +94,14 @@ int main(int argc, char *argv[]) {
std::vector<int32> spk_ids;
if (reco2num_spk_rspecifier.size()) {
int32 num_speakers = reco2num_spk_reader.Value(reco);
AgglomerativeCluster(costs,
std::numeric_limits<BaseFloat>::max(), num_speakers, &spk_ids);
if (1.0 / num_speakers <= max_spk_fraction && max_spk_fraction <= 1.0)
AgglomerativeCluster(costs, std::numeric_limits<BaseFloat>::max(),
num_speakers, max_spk_fraction, &spk_ids);
else
AgglomerativeCluster(costs, std::numeric_limits<BaseFloat>::max(),
num_speakers, 1.0, &spk_ids);
} else {
AgglomerativeCluster(costs, threshold, 1, &spk_ids);
AgglomerativeCluster(costs, threshold, 1, 1.0, &spk_ids);
}
for (int32 i = 0; i < spk_ids.size(); i++)
label_writer.Write(uttlist[i], spk_ids[i]);
Expand Down