Skip to content

Commit

Permalink
Merge pull request BVLC#1228 from longjon/solver-step
Browse files Browse the repository at this point in the history
Refactor Solver to allow interactive stepping

Conflicts:
	src/caffe/solver.cpp
  • Loading branch information
longjon committed Jan 2, 2015
2 parents d389945 + fda4809 commit f9fa540
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 37 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Caffe
This is Caffe with several unmerged PRs and no guarantees.

Caffe is a deep learning framework developed with cleanliness, readability, and speed in mind.<br />
Consult the [project website](http://caffe.berkeleyvision.org) for all documentation.
Everything here is subject to change, including the history of this branch.

See `future.sh` for details.
29 changes: 29 additions & 0 deletions future.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/bin/bash
git checkout dev
git branch -D future
git checkout -b future
# deconv layer, coord maps, net pointer, crop layer
hub merge https://github.com/BVLC/caffe/pull/1639
# reshaping data layer
hub merge https://github.com/BVLC/caffe/pull/1313
# softmax missing values
hub merge https://github.com/BVLC/caffe/pull/1654
# python testing
hub merge https://github.com/BVLC/caffe/pull/1473
# gradient accumulation
hub merge https://github.com/BVLC/caffe/pull/1663
# solver stepping
hub merge https://github.com/BVLC/caffe/pull/1228
git add future.sh
git commit -m 'add creation script'

cat << 'EOF' > README.md
This is Caffe with several unmerged PRs and no guarantees.
Everything here is subject to change, including the history of this branch.
See `future.sh` for details.
EOF

git add README.md
git commit -m 'update readme'
4 changes: 3 additions & 1 deletion include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Solver {
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
void Step(int iters);
virtual ~Solver() {}
inline shared_ptr<Net<Dtype> > net() { return net_; }
inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
Expand All @@ -36,7 +37,7 @@ class Solver {
protected:
// PreSolve is run before any solving iteration starts, allowing one to
// put up some scaffold.
virtual void PreSolve() {}
virtual void PreSolve();
// Get the update value for the current iteration.
virtual void ComputeUpdateValue() = 0;
// The Solver::Snapshot function implements the basic snapshotting utility
Expand All @@ -60,6 +61,7 @@ class Solver {
int current_step_;
shared_ptr<Net<Dtype> > net_;
vector<shared_ptr<Net<Dtype> > > test_nets_;
bool initialized_;

DISABLE_COPY_AND_ASSIGN(Solver);
};
Expand Down
3 changes: 2 additions & 1 deletion python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ BOOST_PYTHON_MODULE(_caffe) {
.add_property("test_nets", &PySGDSolver::test_nets)
.add_property("iter", &PySGDSolver::iter)
.def("solve", &PySGDSolver::Solve)
.def("solve", &PySGDSolver::SolveResume);
.def("solve", &PySGDSolver::SolveResume)
.def("step", &PySGDSolver::Step);

bp::class_<vector<shared_ptr<PyNet> > >("NetVec")
.def(bp::vector_indexing_suite<vector<shared_ptr<PyNet> >, true>());
Expand Down
1 change: 1 addition & 0 deletions python/caffe/_caffe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class PySGDSolver {
vector<shared_ptr<PyNet> > test_nets() { return test_nets_; }
int iter() { return solver_->iter(); }
void Solve() { return solver_->Solve(); }
void Step(int iters) { solver_->Step(iters); }
void SolveResume(const string& resume_file);

protected:
Expand Down
79 changes: 47 additions & 32 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ Solver<Dtype>::Solver(const string& param_file)

template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
initialized_ = false;
iter_ = 0;
LOG(INFO) << "Initializing solver from parameters: " << std::endl
<< param.DebugString();
param_ = param;
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
if (param_.random_seed() >= 0) {
Caffe::set_random_seed(param_.random_seed());
}
Expand Down Expand Up @@ -155,39 +158,19 @@ void Solver<Dtype>::InitTestNets() {
}

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
PreSolve();

iter_ = 0;
current_step_ = 0;
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
void Solver<Dtype>::Step(int iters) {
if (!initialized_) {
PreSolve();
}
// Remember the initial iter_ value; will be non-zero if we loaded from a
// resume_file above.
const int start_iter = iter_;

vector<Blob<Dtype>*> bottom_vec;
const int start_iter = iter_;
const int stop_iter = iter_ + iters;
int average_loss = this->param_.average_loss();

CHECK_GE(average_loss, 1) << "average_loss should be non-negative.";

vector<Dtype> losses;
Dtype smoothed_loss = 0;

// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
vector<Blob<Dtype>*> bottom_vec;
for (; iter_ < param_.max_iter(); ++iter_) {
// Save a snapshot if needed.
if (param_.snapshot() && iter_ > start_iter &&
iter_ % param_.snapshot() == 0) {
Snapshot();
}

for (; iter_ < stop_iter; ++iter_) {
// zero-init the params
for (int i = 0; i < net_->params().size(); ++i) {
shared_ptr<Blob<Dtype> > blob = net_->params()[i];
Expand Down Expand Up @@ -252,13 +235,44 @@ void Solver<Dtype>::Solve(const char* resume_file) {
}
}
}

ComputeUpdateValue();
net_->Update();

// Save a snapshot if needed.
if (param_.snapshot() && (iter_ + 1) % param_.snapshot() == 0) {
Snapshot();
}
}
}

template <typename Dtype>
void Solver<Dtype>::PreSolve() {
initialized_ = true;
iter_ = 0;
current_step_ = 0;
}

template <typename Dtype>
void Solver<Dtype>::Solve(const char* resume_file) {
Caffe::set_phase(Caffe::TRAIN);
LOG(INFO) << "Solving " << net_->name();
LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();

PreSolve();
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
}

// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
Step(param_.max_iter() - iter_);
// If we haven't already, save a snapshot after optimization, unless
// overridden by setting snapshot_after_train := false
if (param_.snapshot_after_train()
&& (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
Snapshot();
}
// Always save a snapshot after optimization, unless overridden by setting
// snapshot_after_train := false.
if (param_.snapshot_after_train()) { Snapshot(); }
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
Expand All @@ -267,7 +281,7 @@ void Solver<Dtype>::Solve(const char* resume_file) {
// display the loss, which is computed in the forward pass.
if (param_.display() && iter_ % param_.display() == 0) {
Dtype loss;
net_->Forward(bottom_vec, &loss);
net_->ForwardPrefilled(&loss);
LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss;
}
if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
Expand Down Expand Up @@ -439,6 +453,7 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {

template <typename Dtype>
void SGDSolver<Dtype>::PreSolve() {
Solver<Dtype>::PreSolve();
// Initialize the history
vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
history_.clear();
Expand Down

0 comments on commit f9fa540

Please sign in to comment.