diff --git a/examples/mnist/lenet_adadelta_solver.prototxt b/examples/mnist/lenet_adadelta_solver.prototxt new file mode 100644 index 00000000000..b77b451d56a --- /dev/null +++ b/examples/mnist/lenet_adadelta_solver.prototxt @@ -0,0 +1,22 @@ +# The train/test net protocol buffer definition +net: "examples/mnist/lenet_train_test.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of MNIST, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 100 +# Carry out testing every 500 training iterations. +test_interval: 500 +# The base learning rate, momentum and the weight decay of the network. +momentum: 0.95 +weight_decay: 0.0005 +# Display every 100 iterations +display: 100 +# The maximum number of iterations +max_iter: 10000 +# snapshot intermediate results +snapshot: 5000 +snapshot_prefix: "examples/mnist/lenet_adadelta" +# solver mode: CPU or GPU +solver_mode: GPU +solver_type: ADADELTA +delta: 1e-6 diff --git a/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt new file mode 100644 index 00000000000..cc4f0bbb4a7 --- /dev/null +++ b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt @@ -0,0 +1,17 @@ +net: "examples/mnist/mnist_autoencoder.prototxt" +test_state: { stage: 'test-on-train' } +test_iter: 500 +test_state: { stage: 'test-on-test' } +test_iter: 100 +test_interval: 500 +test_compute_loss: true +momentum: 0.95 +display: 100 +max_iter: 65000 +weight_decay: 0.0005 +snapshot: 10000 +snapshot_prefix: "examples/mnist/mnist_autoencoder_adadelta_train" +# solver mode: CPU or GPU +solver_mode: GPU +solver_type: ADADELTA +delta: 1e-8 diff --git a/examples/mnist/train_mnist_autoencoder_adadelta.sh b/examples/mnist/train_mnist_autoencoder_adadelta.sh new file mode 100755 index 00000000000..4be0ebddedc --- /dev/null +++ b/examples/mnist/train_mnist_autoencoder_adadelta.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +./build/tools/caffe train \ + --solver=examples/mnist/mnist_autoencoder_solver_adadelta.prototxt diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 6fd159d0b98..3f22ae10494 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -124,6 +124,27 @@ class AdaGradSolver : public SGDSolver { DISABLE_COPY_AND_ASSIGN(AdaGradSolver); }; +template +class AdaDeltaSolver : public SGDSolver { + public: + explicit AdaDeltaSolver(const SolverParameter& param) + : SGDSolver(param) { constructor_sanity_check(); } + explicit AdaDeltaSolver(const string& param_file) + : SGDSolver(param_file) { constructor_sanity_check(); } + + protected: + virtual void PreSolve(); + virtual void ComputeUpdateValue(); + void constructor_sanity_check() { + CHECK_EQ(0, this->param_.base_lr()) + << "Learning rate cannot be used with AdaDelta."; + CHECK_EQ("", this->param_.lr_policy()) + << "Learning rate policy cannot be applied to AdaDelta."; + } + + DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); +}; + template Solver* GetSolver(const SolverParameter& param) { SolverParameter_SolverType type = param.solver_type(); @@ -135,6 +156,8 @@ Solver* GetSolver(const SolverParameter& param) { return new NesterovSolver(param); case SolverParameter_SolverType_ADAGRAD: return new AdaGradSolver(param); + case SolverParameter_SolverType_ADADELTA: + return new AdaDeltaSolver(param); default: LOG(FATAL) << "Unknown SolverType: " << type; } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 9395c38f3e9..83f4d6af9a3 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -146,6 +146,7 @@ message SolverParameter { SGD = 0; NESTEROV = 1; ADAGRAD = 2; + ADADELTA = 3; } optional SolverType solver_type = 30 [default = SGD]; // numerical stability for AdaGrad diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 0ea4edcf9b8..1d3f6f11594 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -753,9 +753,208 @@ void AdaGradSolver::ComputeUpdateValue() { } } +template +void AdaDeltaSolver::PreSolve() { + // Initialize the history + vector > >& net_params = this->net_->params(); + this->history_.clear(); + this->update_.clear(); + this->temp_.clear(); + for (int i = 0; i < net_params.size(); ++i) { + const Blob* net_param = net_params[i].get(); + this->history_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); + this->update_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); + this->temp_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); + } + for (int i = 0; i < net_params.size(); ++i) { + const Blob* net_param = net_params[i].get(); + this->history_.push_back(shared_ptr >(new Blob( + net_param->num(), net_param->channels(), net_param->height(), + net_param->width()))); + } +} + +template +void AdaDeltaSolver::ComputeUpdateValue() { + vector > >& net_params = this->net_->params(); + vector& net_params_weight_decay = this->net_->params_weight_decay(); + Dtype delta = this->param_.delta(); + Dtype momentum = this->param_.momentum(); + Dtype weight_decay = this->param_.weight_decay(); + string regularization_type = this->param_.regularization_type(); + size_t update_history_offset = net_params.size(); + switch (Caffe::mode()) { + case Caffe::CPU: + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else if (regularization_type == "L1") { + caffe_cpu_sign(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), + this->temp_[param_id]->mutable_cpu_data()); + caffe_axpy(net_params[param_id]->count(), + local_decay, + this->temp_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } + } + + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history of gradients + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->cpu_data(), momentum, + this->history_[param_id]->mutable_cpu_data()); + + // add delta to history to guard against dividing by zero later + caffe_set(net_params[param_id]->count(), delta, + this->temp_[param_id]->mutable_cpu_data()); + + caffe_add(net_params[param_id]->count(), + this->temp_[param_id]->cpu_data(), + this->history_[update_history_offset + param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + caffe_add(net_params[param_id]->count(), + this->temp_[param_id]->cpu_data(), + this->history_[param_id]->cpu_data(), + this->temp_[param_id]->mutable_cpu_data()); + + // divide history of updates by history of gradients + caffe_div(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + this->temp_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // jointly compute the RMS of both for update and gradient history + caffe_powx(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_cpu_data()); + + // compute the update + caffe_mul(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), + this->update_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + + // compute square of update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history of updates + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->cpu_data(), momentum, + this->history_[update_history_offset + param_id]->mutable_cpu_data()); + } + break; + case Caffe::GPU: +#ifndef CPU_ONLY + for (int param_id = 0; param_id < net_params.size(); ++param_id) { + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else if (regularization_type == "L1") { + caffe_gpu_sign(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), + this->temp_[param_id]->mutable_gpu_data()); + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + this->temp_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } + } + + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history of gradients + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->gpu_data(), momentum, + this->history_[param_id]->mutable_gpu_data()); + + // add delta to history to guard against dividing by zero later + caffe_gpu_set(net_params[param_id]->count(), delta, + this->temp_[param_id]->mutable_gpu_data()); + + caffe_gpu_add(net_params[param_id]->count(), + this->temp_[param_id]->gpu_data(), + this->history_[update_history_offset + param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_add(net_params[param_id]->count(), + this->temp_[param_id]->gpu_data(), + this->history_[param_id]->gpu_data(), + this->temp_[param_id]->mutable_gpu_data()); + + // divide history of updates by history of gradients + caffe_gpu_div(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + this->temp_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + // jointly compute the RMS of both for update and gradient history + caffe_gpu_powx(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_gpu_data()); + + // compute the update and copy to net_diff + caffe_gpu_mul(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), + this->update_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + + // compute square of update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history of updates + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->gpu_data(), momentum, + this->history_[update_history_offset + param_id]->mutable_gpu_data()); + } +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + INSTANTIATE_CLASS(Solver); INSTANTIATE_CLASS(SGDSolver); INSTANTIATE_CLASS(NesterovSolver); INSTANTIATE_CLASS(AdaGradSolver); +INSTANTIATE_CLASS(AdaDeltaSolver); } // namespace caffe diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 3040eb134a4..8314106ed5a 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -51,8 +51,9 @@ class GradientBasedSolverTest : public MultiDeviceTest { LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode(); } InitSolver(param); - delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD) ? - param.delta() : 0; + delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD) + || (solver_type() == SolverParameter_SolverType_ADADELTA) ? + param.delta() : 0; } void RunLeastSquaresSolver(const Dtype learning_rate, @@ -60,8 +61,6 @@ class GradientBasedSolverTest : public MultiDeviceTest { ostringstream proto; proto << "max_iter: " << num_iters << " " - "base_lr: " << learning_rate << " " - "lr_policy: 'fixed' " "net_param { " " name: 'TestNetwork' " " layers: { " @@ -107,6 +106,10 @@ class GradientBasedSolverTest : public MultiDeviceTest { " bottom: 'targets' " " } " "} "; + if (learning_rate != 0) { + proto << "base_lr: " << learning_rate << " "; + proto << "lr_policy: 'fixed' "; + } if (weight_decay != 0) { proto << "weight_decay: " << weight_decay << " "; } @@ -189,7 +192,11 @@ class GradientBasedSolverTest : public MultiDeviceTest { ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]); // Finally, compute update. const vector > >& history = solver_->history(); - ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias + if (solver_type() != SolverParameter_SolverType_ADADELTA) { + ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias + } else { + ASSERT_EQ(4, history.size()); // additional blobs for update history + } Dtype update_value = learning_rate * grad; const Dtype history_value = (i == D) ? history[1]->cpu_data()[0] : history[0]->cpu_data()[i]; @@ -206,6 +213,19 @@ class GradientBasedSolverTest : public MultiDeviceTest { case SolverParameter_SolverType_ADAGRAD: update_value /= std::sqrt(history_value + grad * grad) + delta_; break; + case SolverParameter_SolverType_ADADELTA: + { + const Dtype update_history_value = (i == D) ? + history[3]->cpu_data()[0] : history[2]->cpu_data()[i]; + const Dtype weighted_gradient_average = + momentum * history_value + (1 - momentum) * (grad * grad); + update_value = grad * std::sqrt((update_history_value + delta_) / + (weighted_gradient_average + delta_)); + // not actually needed, just here for illustrative purposes + // const Dtype weighted_update_average = + // momentum * update_history_value + (1 - momentum) * (update_value); + break; + } default: LOG(FATAL) << "Unknown solver type: " << solver_type(); } @@ -485,4 +505,79 @@ TYPED_TEST(NesterovSolverTest, TestNesterovLeastSquaresUpdateWithEverything) { } } + +template +class AdaDeltaSolverTest : public GradientBasedSolverTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolver(const SolverParameter& param) { + this->solver_.reset(new AdaDeltaSolver(param)); + } + + virtual SolverParameter_SolverType solver_type() { + return SolverParameter_SolverType_ADADELTA; + } +}; + +TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices); + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdate) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.0; + this->TestLeastSquaresUpdate(kLearningRate); +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.0; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.95; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithHalfMomentum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 1; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithMomentum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.95; + const int kNumIters = 1; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.95; + const int kNumIters = 500; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.0; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.95; + const int kNumIters = 500; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + } // namespace caffe