Skip to content

Commit

Permalink
loss_barlow_twins: add get_eccm member function (davisking#2906)
Browse files Browse the repository at this point in the history
This allows us to greatly simplify the self supervised learning example:
- the computation in user code was a bit too distracting
- avoids duplicated computation/allocation of this matrix
- avoids edge case where net outputs are zero due to trainer synchronization
  • Loading branch information
arrufat authored Jan 9, 2024
1 parent 46e59a2 commit b0f6be8
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 33 deletions.
3 changes: 3 additions & 0 deletions dlib/dnn/loss.h
Original file line number Diff line number Diff line change
Expand Up @@ -4154,6 +4154,9 @@ namespace dlib

float get_lambda() const { return lambda; }

tensor& get_eccm() { return eccm; }
const tensor& get_eccm() const { return eccm; }

friend void serialize(const loss_barlow_twins_& item, std::ostream& out)
{
serialize("loss_barlow_twins_", out);
Expand Down
14 changes: 14 additions & 0 deletions dlib/dnn/loss_abstract.h
Original file line number Diff line number Diff line change
Expand Up @@ -2144,6 +2144,20 @@ namespace dlib
in WHAT THIS OBJECT REPRESENTS for details.
!*/

tensor& get_eccm();
/*!
ensures
- returns the empirical cross-correlation matrix computed by the loss.
- this is only meant to be used for visualization/debugging purposes.
!*/

const tensor& get_eccm() const;
/*!
ensures
- returns the empirical cross-correlation matrix computed by the loss.
- this is only meant to be used for visualization/debugging purposes.
!*/

template <
typename SUBNET
>
Expand Down
37 changes: 4 additions & 33 deletions examples/dnn_self_supervised_learning_ex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,20 +206,10 @@ try
trainer.be_verbose();
cout << trainer << endl;

// During the training, we will compute the empirical cross-correlation
// During the training, we will visualize the empirical cross-correlation
// matrix between the features of both versions of the augmented images.
// This matrix should be getting close to the identity matrix as the training
// progresses. Note that this step is already done in the loss layer, and it's
// not necessary to do it here for the example to work. However, it provides
// a nice visualization of the training progress: the closer to the identity
// matrix, the better.
resizable_tensor eccm;
eccm.set_size(dims, dims);
// Some tensors needed to perform batch normalization
resizable_tensor za_norm, zb_norm, means, invstds, rms, rvs, gamma, beta;
const double eps = DEFAULT_BATCH_NORM_EPS;
gamma.set_size(1, dims);
beta.set_size(1, dims);
// progresses. Note that this is done here for visualization purposes only.
image_window win;

std::vector<pair<matrix<rgb_pixel>, matrix<rgb_pixel>>> batch;
Expand All @@ -234,32 +224,13 @@ try
}
trainer.train_one_step(batch);

// Compute the empirical cross-correlation matrix every 100 steps. Again,
// Get the empirical cross-correlation matrix every 100 steps. Again,
// this is not needed for the training to work, but it's nice to visualize.
if (trainer.get_train_one_step_calls() % 100 == 0)
{
// Wait for threaded processing to stop in the trainer.
trainer.get_net(force_flush_to_disk::no);
// Get the output from the last fc layer
const auto& out = net.subnet().get_output();
// The trainer might have synchronized its state to the disk and cleaned
// the network state. If that happens, the output will be empty, in which
// case, we just skip the empirical cross-correlation matrix computation.
if (out.size() == 0)
continue;
// Separate both augmented versions of the images
alias_tensor split(out.num_samples() / 2, dims);
auto za = split(out);
auto zb = split(out, split.size());
gamma = 1;
beta = 0;
// Perform batch normalization on each feature representation, independently.
tt::batch_normalize(eps, za_norm, means, invstds, 1, rms, rvs, za, gamma, beta);
tt::batch_normalize(eps, zb_norm, means, invstds, 1, rms, rvs, zb, gamma, beta);
// Compute the empirical cross-correlation matrix between the features and
// visualize it.
tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false);
eccm /= batch_size;
const matrix<float> eccm = mat(net.loss_details().get_eccm());
win.set_image(round(abs(mat(eccm)) * 255));
win.set_title("Barlow Twins step#: " + to_string(trainer.get_train_one_step_calls()));
}
Expand Down

0 comments on commit b0f6be8

Please sign in to comment.