Skip to content

Commit

Permalink
compression: optimize implementation of compressors (#18)
Browse files Browse the repository at this point in the history
* compression: update cifar100 training script (#15)

* cifar: update cifar script

* cifar: update lr

* cifar: add warmup

* cifar: update parse

* cifar: update

* cifar: add log

* cifar: fix typo

* cifar: fix bug

* cifar: fix lr

* cifar: fix typo

* cifar: update num samples

* 1bit: update packing

* 1bit: fix compile bug

* 1bit: exp

* 1bit: exp

* 1bit: exp

* 1bit: exp

* 1bit: exp

* 1bit: exp

* 1bit: exp

* 1bit: exp

* 1bit: exp

* 1bit: exp

* 1bit: fix typo

* 1bit: fix compile bug

* 1bit: exp

* 1bit: test

* 1bit: exp

* 1bit: test

* 1bit: exp

* 1bit: exp

* 1bit: exp

* 1bit: fix typo

* 1bit: fix typo

* 1bit: fix typo

* 1bit: try5 final

* 1bit: exp rm decompress in ef

* 1bit: fix typo

* 1bit: fix typo

* 1bit: fix bug

* 1bit: fix typo

* 1bit: add log

* 1bit: debug

* 1bit: debug

* 1bit: debug

* 1bit: debug

* 1bit: debug

* 1bit: debug

* 1bit: debug

* 1bit: debug

* 1bit: fix

* 1bit: debug

* 1bit: fix

* 1bit: debug

* 1bit: fix typo

* 1bit: fix typo

* 1bit: debug

* 1bit: debug

* 1bit: debug

* 1bit: debug

* 1bit: debug

* 1bit: debug

* 1bit: fix bug

* 1bit: fix bug

* 1bit: fix typo

* 1bit: add test

* 1bit: update test

* 1bit: update test

* 1bit: update test script

* 1bit: fix test bug

* 1bit: fix test script

* 1bit: update script

* 1bit: update test

* refactor: update name and api

* refactor: fix indent

* refactor: add fastupdateerror

* refactor: fix link error

* topk: impl fastupdateerror

* topk: debug

* randomk: fix fatal bug

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
jasperzhong and Ubuntu committed Jun 23, 2020
1 parent 3844228 commit c17bfb6
Show file tree
Hide file tree
Showing 31 changed files with 586 additions and 557 deletions.
5 changes: 3 additions & 2 deletions byteps/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
namespace byteps {
namespace common {
namespace compressor {
struct ByteBuf;
struct BPSTensor;
typedef BPSTensor tensor_t;
class BaseCompressor;
class ErrorFeedback;
} // namespace compressor
Expand Down Expand Up @@ -258,7 +259,7 @@ struct TensorTableEntry {
// Compressor
std::shared_ptr<compressor::BaseCompressor> compressor;
// Compressed
std::shared_ptr<compressor::ByteBuf> compressed;
std::shared_ptr<compressor::tensor_t> compressed;
};
using TensorTable = std::unordered_map<std::string, TensorTableEntry>;

Expand Down
6 changes: 4 additions & 2 deletions byteps/common/compressor/base_compressor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ CompressorRegistry::ctor_t CompressorRegistry::Find(const std::string& name) {
std::unique_ptr<BaseCompressor> CompressorRegistry::Create(
const kwargs_t& kwargs) {
#ifndef BYTEPS_BUILDING_SERVER
const std::string types[] = {"momentum_type", "ef_type",
"compressor_type"};
const std::string types[] = {"momentum_type", "ef_type", "compressor_type"};
#else
// server do not need momentum
const std::string types[] = {"ef_type", "compressor_type"};
Expand All @@ -67,6 +66,9 @@ void BaseCompressor::Init(size_t aligned_size) {
_buf.reset(new char[aligned_size]);
_cpu_reducer.reset(new CpuReducer(nullptr));
}

void BaseCompressor::FastUpdateError(tensor_t error, tensor_t corrected,
tensor_t compressed) {}
} // namespace compressor
} // namespace common
} // namespace byteps
30 changes: 21 additions & 9 deletions byteps/common/compressor/base_compressor.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,18 @@ namespace byteps {
namespace common {
namespace compressor {

typedef char byte_t;
/*!
* \brief Byte buffer
* \brief Tensor type
*/
struct ByteBuf {
char* data;
typedef struct BPSTensor {
byte_t* data;
size_t size;
};
int dtype;

BPSTensor(byte_t* data = nullptr, size_t size = 0, int dtype = 0)
: data(data), size(size), dtype(dtype) {}
} tensor_t;

/*!
* \brief Compressor interface used in BytePS core.
Expand All @@ -54,20 +59,27 @@ class BaseCompressor {
* \brief Compress function
*
* \param grad gradient tensor
* \param dtype data type
* \param compressed compressed tensor
*/
virtual void Compress(ByteBuf grad, int dtype, ByteBuf& compressed) = 0;
virtual void Compress(tensor_t grad, tensor_t& compressed) = 0;

/*!
* \brief Decompress function
*
* \param compressed compressed tensor
* \param dtype data type
* \param decompressed decompressed tensor
*/
virtual void Decompress(ByteBuf compressed, int dtype,
ByteBuf& decompressed) = 0;
virtual void Decompress(tensor_t compressed, tensor_t& decompressed) = 0;

/*!
* \brief help function for error feedback `UpdateError`
*
* \param corrected gradient corrected with error
* \param error error
* \param compressed compressed gradient
*/
virtual void FastUpdateError(tensor_t error, tensor_t corrected,
tensor_t compressed);

protected:
/*!
Expand Down
21 changes: 14 additions & 7 deletions byteps/common/compressor/error_feedback.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,26 @@ void ErrorFeedback::Init(size_t aligned_size) {
_cpu_reducer.reset(new CpuReducer(nullptr));
}

void ErrorFeedback::Compress(ByteBuf grad, int dtype, ByteBuf& compressed) {
void ErrorFeedback::Compress(tensor_t grad, tensor_t& compressed) {
// before: grad += error
UpdateGradient(grad, dtype);
UpdateGradient(grad);

// TODO: look strange
compressed.data = _error.get();
// compress
_compressor_ptr->Compress(grad, dtype, compressed);
_compressor_ptr->Compress(grad, compressed);

UpdateError(grad, compressed);
}

UpdateError(grad, dtype, compressed);
void ErrorFeedback::Decompress(tensor_t compressed, tensor_t& decompressed) {
_compressor_ptr->Decompress(compressed, decompressed);
}

void ErrorFeedback::Decompress(ByteBuf compressed, int dtype,
ByteBuf& decompressed) {
_compressor_ptr->Decompress(compressed, dtype, decompressed);
void ErrorFeedback::UpdateError(tensor_t corrected, tensor_t compressed) {
_compressor_ptr->FastUpdateError({_error.get()}, corrected, compressed);
}

} // namespace compressor
} // namespace common
} // namespace byteps
13 changes: 4 additions & 9 deletions byteps/common/compressor/error_feedback.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,17 @@ class ErrorFeedback : public BaseCompressor {
* \brief Compress function
*
* \param grad gradient tensor
* \param dtype data type
* \param compressed compressed tensor
*/
virtual void Compress(ByteBuf grad, int dtype, ByteBuf& compressed);
virtual void Compress(tensor_t grad, tensor_t& compressed);

/*!
* \brief Decompress function
*
* \param compressed compressed tensor
* \param dtype data type
* \param decompressed decompressed tensor
*/
virtual void Decompress(ByteBuf compressed, int dtype,
ByteBuf& decompressed);
virtual void Decompress(tensor_t compressed, tensor_t& decompressed);

protected:
/*!
Expand All @@ -66,19 +63,17 @@ class ErrorFeedback : public BaseCompressor {
* \param grad input gradient to be updated inplace
* \param dtype type
*/
virtual void UpdateGradient(ByteBuf grad, int dtype) = 0;
virtual void UpdateGradient(tensor_t grad) = 0;

/*!
* \brief Update error
*
* error = corrected_grad - decompressed
*
* \param corrected refers to gradient + error
* \param dtype type
* \param compressed compressed tensor
*/
virtual void UpdateError(ByteBuf corrected, int dtype,
ByteBuf compressed) = 0;
virtual void UpdateError(tensor_t corrected, tensor_t compressed);

protected:
std::unique_ptr<char[]> _error;
Expand Down
21 changes: 10 additions & 11 deletions byteps/common/compressor/momentum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,25 @@ Momentum::Momentum(std::unique_ptr<BaseCompressor> compressor_ptr, float mu)
Momentum::~Momentum() = default;

void Momentum::Init(size_t aligned_size) {
_compressor_ptr->Init(aligned_size);
_mom.reset(new char[aligned_size]);
memset(_mom.get(), 0, aligned_size);
_cpu_reducer.reset(new CpuReducer(nullptr));
_compressor_ptr->Init(aligned_size);
_mom.reset(new char[aligned_size]);
memset(_mom.get(), 0, aligned_size);
_cpu_reducer.reset(new CpuReducer(nullptr));
}

void Momentum::Compress(ByteBuf grad, int dtype, ByteBuf& compressed) {
void Momentum::Compress(tensor_t grad, tensor_t& compressed) {
// m_t = \mu * m_{t-1} + g_t
UpdateMom(grad, dtype);
UpdateMom(grad);

// p_t = \mu m_t + g_t
UpdateGradient(grad, dtype);
UpdateGradient(grad);

// compress
_compressor_ptr->Compress(grad, dtype, compressed);
_compressor_ptr->Compress(grad, compressed);
}

void Momentum::Decompress(ByteBuf compressed, int dtype,
ByteBuf& decompressed) {
_compressor_ptr->Decompress(compressed, dtype, decompressed);
void Momentum::Decompress(tensor_t compressed, tensor_t& decompressed) {
_compressor_ptr->Decompress(compressed, decompressed);
}

} // namespace compressor
Expand Down
15 changes: 5 additions & 10 deletions byteps/common/compressor/momentum.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,17 @@ class Momentum : public BaseCompressor {
* \brief Compress function
*
* \param grad gradient tensor
* \param dtype data type
* \param compressed compressed tensor
*/
virtual void Compress(ByteBuf grad, int dtype, ByteBuf& compressed) final;
virtual void Compress(tensor_t grad, tensor_t& compressed) final;

/*!
* \brief Decompress function
*
* \param compressed compressed tensor
* \param dtype data type
* \param decompressed decompressed tensor
*/
virtual void Decompress(ByteBuf compressed, int dtype,
ByteBuf& decompressed) final;
virtual void Decompress(tensor_t compressed, tensor_t& decompressed) final;

protected:
/*!
Expand All @@ -65,19 +62,17 @@ class Momentum : public BaseCompressor {
* m_t = \mu * m_{t-1} + g_t
*
* \param grad refers to gradient
* \param dtype type
*/
virtual void UpdateMom(ByteBuf grad, int dtype) = 0;
virtual void UpdateMom(tensor_t grad) = 0;

/*!
/*!
* \brief Update gradient
*
* p_t = \mu m_t + g_t
*
* \param grad refers to gradient which adds momentum in place.
* \param dtype type
*/
virtual void UpdateGradient(ByteBuf grad, int dtype) = 0;
virtual void UpdateGradient(tensor_t grad) = 0;

protected:
std::unique_ptr<char[]> _mom;
Expand Down
7 changes: 3 additions & 4 deletions byteps/common/compressor/strategy/multibit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@ MultibitCompressor::MultibitCompressor(int k) : _k(k){};

MultibitCompressor::~MultibitCompressor() = default;

void MultibitCompressor::Compress(ByteBuf grad, int dtype,
ByteBuf& compressed) {
void MultibitCompressor::Compress(tensor_t grad, tensor_t& compressed) {
// TOOD
}

void MultibitCompressor::Decompress(ByteBuf compressed, int dtype,
ByteBuf& decompressed) {
void MultibitCompressor::Decompress(tensor_t compressed,
tensor_t& decompressed) {
// TODO
}
} // namespace compressor
Expand Down
5 changes: 2 additions & 3 deletions byteps/common/compressor/strategy/multibit.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ class MultibitCompressor : public BaseCompressor {
explicit MultibitCompressor(int k);
virtual ~MultibitCompressor();

void Compress(ByteBuf grad, int dtype, ByteBuf& compressed) override;
void Compress(tensor_t grad, tensor_t& compressed) override;

void Decompress(ByteBuf compressed, int dtype,
ByteBuf& decompressed) override;
void Decompress(tensor_t compressed, tensor_t& decompressed) override;

private:
int _k;
Expand Down
8 changes: 4 additions & 4 deletions byteps/common/compressor/strategy/nesterov_momentum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ NesterovMomentumCompressor::NesterovMomentumCompressor(

NesterovMomentumCompressor::~NesterovMomentumCompressor() = default;

void NesterovMomentumCompressor::UpdateMom(ByteBuf grad, int dtype) {
void NesterovMomentumCompressor::UpdateMom(tensor_t grad) {
// m_t = \mu * m_{t-1} + g_t
this->_cpu_reducer->sum(_mom.get(), grad.data, _mom.get(), grad.size,
static_cast<DataType>(dtype), _mu);
static_cast<DataType>(grad.dtype), _mu);
}

void NesterovMomentumCompressor::UpdateGradient(ByteBuf grad, int dtype) {
void NesterovMomentumCompressor::UpdateGradient(tensor_t grad) {
// p_t = \mu m_t + g_t
this->_cpu_reducer->sum(grad.data, _mom.get(), grad.size,
static_cast<DataType>(dtype), _mu);
static_cast<DataType>(grad.dtype), _mu);
}

} // namespace compressor
Expand Down
4 changes: 2 additions & 2 deletions byteps/common/compressor/strategy/nesterov_momentum.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class NesterovMomentumCompressor : public Momentum {
virtual ~NesterovMomentumCompressor();

protected:
void UpdateMom(ByteBuf grad, int dtype) override;
void UpdateGradient(ByteBuf grad, int dtype) override;
void UpdateMom(tensor_t grad) override;
void UpdateGradient(tensor_t grad) override;
};

} // namespace compressor
Expand Down
Loading

0 comments on commit c17bfb6

Please sign in to comment.