diff --git a/byteps/common/compressor/strategy/onebit.h b/byteps/common/compressor/strategy/onebit.h index 8b935d612..d917dd456 100644 --- a/byteps/common/compressor/strategy/onebit.h +++ b/byteps/common/compressor/strategy/onebit.h @@ -34,7 +34,6 @@ namespace compressor { * server: majority vote * sign(\sum_i c_i) * - * \note this is a deterministic algorithm. * \note 0 represents positive and 1 represents negative. */ class OnebitCompressor : public BaseCompressor { diff --git a/byteps/common/compressor/strategy/randomk.cc b/byteps/common/compressor/strategy/randomk.cc index d7b147e16..63827fbdd 100644 --- a/byteps/common/compressor/strategy/randomk.cc +++ b/byteps/common/compressor/strategy/randomk.cc @@ -37,18 +37,123 @@ CompressorRegistry::Register }); } -RandomkCompressor::RandomkCompressor(int k) : _k(k){}; +RandomkCompressor::RandomkCompressor(int k) : _k(k) { _gen.seed(_rd()); }; RandomkCompressor::~RandomkCompressor() = default; +template +size_t RandomkCompressor::_Packing(index_t* dst, const scalar_t* src, + size_t len) { + static_assert(sizeof(index_t) == sizeof(scalar_t), + "index_t should be the same size as scalar_t"); + BPS_CHECK_LE(this->_k, len / 2); + using pair_t = std::pair; + std::uniform_int_distribution<> dis(0, len-1); + auto ptr = reinterpret_cast(dst); + + for (size_t i = 0; i < this->_k; ++i) { + auto index = dis(_gen); + ptr[i] = std::make_pair(index, src[index]); + } + + return this->_k * sizeof(pair_t); +} + +size_t RandomkCompressor::Packing(const void* src, size_t size, int dtype) { + switch (dtype) { + case BYTEPS_INT8: + return _Packing(reinterpret_cast(_buf.get()), + reinterpret_cast(src), + size / sizeof(int8_t)); + case BYTEPS_UINT8: + return _Packing(reinterpret_cast(_buf.get()), + reinterpret_cast(src), + size / sizeof(uint8_t)); + // case BYTEPS_FLOAT16: + // return _Packing(reinterpret_cast(_buf.get()), + // reinterpret_cast(src), size); + case BYTEPS_FLOAT32: + return _Packing(reinterpret_cast(_buf.get()), + reinterpret_cast(src), + size / sizeof(int32_t)); + case BYTEPS_FLOAT64: + return _Packing(reinterpret_cast(_buf.get()), + reinterpret_cast(src), + size / sizeof(int64_t)); + default: + BPS_CHECK(0) << "Unsupported data type: " << dtype; + } + return 0; +} void RandomkCompressor::Compress(ByteBuf grad, int dtype, ByteBuf& compressed) { - // TODO + compressed.size = Packing(grad.data, grad.size, dtype); + compressed.data = _buf.get(); +} + +template +size_t RandomkCompressor::_Unpacking(scalar_t* dst, const index_t* src, + size_t len) { + static_assert(sizeof(index_t) == sizeof(scalar_t), + "index_t should be the same size as scalar_t"); + using pair_t = std::pair; + auto ptr = reinterpret_cast(src); + + if ((void*)dst == (void*)src) { + auto buf = reinterpret_cast(_buf.get()); + std::copy(ptr, ptr + len, buf); + ptr = const_cast(buf); + } + + // reset to zeros + std::fill(dst, dst + this->_src_len, 0); + for (auto i = 0; i < len; ++i) { + auto& pair = ptr[i]; + dst[pair.first] = pair.second; + } +} + +size_t RandomkCompressor::Unpacking(void* dst, const void* src, size_t size, + int dtype) { + switch (dtype) { + case BYTEPS_INT8: + return _Unpacking(reinterpret_cast(dst), + reinterpret_cast(src), + size / sizeof(int8_t) / 2); + case BYTEPS_UINT8: + return _Unpacking(reinterpret_cast(dst), + reinterpret_cast(src), + size / sizeof(uint8_t) / 2); + // case BYTEPS_FLOAT16: + // return _Unpacking(reinterpret_cast(_buf.get()), + // reinterpret_cast(src), size); + case BYTEPS_FLOAT32: + return _Unpacking(reinterpret_cast(dst), + reinterpret_cast(src), + size / sizeof(float) / 2); + case BYTEPS_FLOAT64: + return _Unpacking(reinterpret_cast(dst), + reinterpret_cast(src), + size / sizeof(double) / 2); + default: + BPS_CHECK(0) << "Unsupported data type: " << dtype; + } + return 0; } +#ifndef BYTEPS_BUILDING_SERVER +// worker version decompressor +void RandomkCompressor::Decompress(ByteBuf compressed, int dtype, + ByteBuf& decompressed) { + BPS_CHECK(decompressed.data); + Unpacking(decompressed.data, compressed.data, compressed.size, dtype); +} +#else void RandomkCompressor::Decompress(ByteBuf compressed, int dtype, ByteBuf& decompressed) { - // TODO + if (decompressed.data == nullptr) decompressed.data = _buf.get(); + Unpacking(decompressed.data, compressed.data, compressed.size, dtype); } +#endif } // namespace compressor } // namespace common } // namespace byteps \ No newline at end of file diff --git a/byteps/common/compressor/strategy/randomk.h b/byteps/common/compressor/strategy/randomk.h index 3ef1fc317..1e464663a 100644 --- a/byteps/common/compressor/strategy/randomk.h +++ b/byteps/common/compressor/strategy/randomk.h @@ -16,6 +16,8 @@ #ifndef BYTEPS_COMPRESS_STRAT_RANDOMK_H #define BYTEPS_COMPRESS_STRAT_RANDOMK_H +#include + #include "../base_compressor.h" namespace byteps { @@ -23,20 +25,57 @@ namespace common { namespace compressor { /*! - * \brief TODO + * \brief RandomK Compressor + * + * paper: Sparsified SGD with Memory + * https://arxiv.org/pdf/1809.07599.pdf + * + * randomly sending k entries of the stochastic gradient */ class RandomkCompressor : public BaseCompressor { public: explicit RandomkCompressor(int k); virtual ~RandomkCompressor(); + /*! + * \brief Compress function + * + * randomly select k entries and corresponding indices + * + * \param grad gradient tensor + * \param dtype data type + * \param compressed compressed tensor + */ void Compress(ByteBuf grad, int dtype, ByteBuf& compressed) override; + /*! + * \brief Decompress function + * + * fill a zero tensor with topk entries and corresponding indices + * + * \param compressed compressed tensor + * \param dtype data type + * \param decompressed decompressed tensor + */ void Decompress(ByteBuf compressed, int dtype, ByteBuf& decompressed) override; + private: + size_t Packing(const void* src, size_t size, int dtype); + + template + size_t _Packing(index_t* dst, const scalar_t* src, size_t len); + + size_t Unpacking(void* dst, const void* src, size_t size, int dtype); + + template + size_t _Unpacking(scalar_t* dst, const index_t* src, size_t len); + private: int _k; + int _src_len; + std::random_device _rd; + std::mt19937 _gen; }; } // namespace compressor } // namespace common diff --git a/byteps/common/compressor/strategy/topk.h b/byteps/common/compressor/strategy/topk.h index 8200e69d5..3c6d06956 100644 --- a/byteps/common/compressor/strategy/topk.h +++ b/byteps/common/compressor/strategy/topk.h @@ -30,7 +30,6 @@ namespace compressor { * * sending the most significant entries of the stochastic gradient * - * \note this is a deterministic algorithm */ class TopkCompressor : public BaseCompressor { public: