From 2379e17a42d430e572000b639daa466c59febe21 Mon Sep 17 00:00:00 2001 From: Yuchen Zhong Date: Thu, 14 May 2020 10:02:46 +0800 Subject: [PATCH] compression: add topk compressor (#10) * topk: init commit * topk: update scripts * topk: fix some bugs * topk: fix pq ctor * topk: fix args type * topk: fix register * topk: fix typo * topk: add log * topk: fix bug * topk: fix bug * topk: fix bug * topk: fix const cast * topk: fix typo * topk: rm log * topk: add comments * topk: fix typo --- byteps/common/compressor/strategy/onebit.cc | 8 +- byteps/common/compressor/strategy/onebit.h | 3 +- byteps/common/compressor/strategy/topk.cc | 127 +++++++++++++++++- byteps/common/compressor/strategy/topk.h | 43 +++++- byteps/mxnet/__init__.py | 2 +- .../mxnet/train_gluon_imagenet_byteps_gc.py | 5 +- example/mxnet/train_gluon_mnist_byteps_gc.py | 7 +- 7 files changed, 181 insertions(+), 14 deletions(-) diff --git a/byteps/common/compressor/strategy/onebit.cc b/byteps/common/compressor/strategy/onebit.cc index e8efad223..7fe6f01c4 100644 --- a/byteps/common/compressor/strategy/onebit.cc +++ b/byteps/common/compressor/strategy/onebit.cc @@ -41,7 +41,7 @@ OnebitCompressor::OnebitCompressor(bool use_scale) : _use_scale(use_scale){}; OnebitCompressor::~OnebitCompressor() = default; template -size_t _Packing(T* data, size_t len) { +static size_t _Packing(T* data, size_t len) { constexpr int PACKING_SIZE = sizeof(T) * 8; size_t padding_len = (PACKING_SIZE - (len % PACKING_SIZE)) % PACKING_SIZE; size_t chunk_size = (len + padding_len) / PACKING_SIZE; @@ -56,7 +56,7 @@ size_t _Packing(T* data, size_t len) { return chunk_size * sizeof(T); } -size_t Packing(void* data, size_t len, int dtype) { +static size_t Packing(void* data, size_t len, int dtype) { switch (dtype) { case BYTEPS_INT8: case BYTEPS_UINT8: @@ -96,7 +96,7 @@ void OnebitCompressor::Compress(ByteBuf grad, int dtype, ByteBuf& compressed) { } template -size_t _Unpacking(T1* dst, const T2* src, size_t size) { +static size_t _Unpacking(T1* dst, const T2* src, size_t size) { static_assert(sizeof(T1) == sizeof(T2), "T1 should be the same size as T2"); constexpr int PACKING_SIZE = sizeof(T2) * 8; auto chunk_size = (size - sizeof(float)) / sizeof(T2); @@ -118,7 +118,7 @@ size_t _Unpacking(T1* dst, const T2* src, size_t size) { return chunk_size; } -size_t Unpacking(void* dst, const void* src, size_t len, int dtype) { +static size_t Unpacking(void* dst, const void* src, size_t len, int dtype) { switch (dtype) { case BYTEPS_INT8: return _Unpacking(reinterpret_cast(dst), diff --git a/byteps/common/compressor/strategy/onebit.h b/byteps/common/compressor/strategy/onebit.h index c4c54e7a8..8b935d612 100644 --- a/byteps/common/compressor/strategy/onebit.h +++ b/byteps/common/compressor/strategy/onebit.h @@ -55,13 +55,12 @@ class OnebitCompressor : public BaseCompressor { void Compress(ByteBuf grad, int dtype, ByteBuf& compressed) override; /*! - * \brief Decompress + * \brief Decompress function * * unpack from byte array to FP tensor * * \param compressed compressed tensor * \param dtype data type - * \param src_size uncompressed tensor size * \param decompressed decompressed tensor */ void Decompress(ByteBuf compressed, int dtype, diff --git a/byteps/common/compressor/strategy/topk.cc b/byteps/common/compressor/strategy/topk.cc index 1ae3daa65..a4c1264fe 100644 --- a/byteps/common/compressor/strategy/topk.cc +++ b/byteps/common/compressor/strategy/topk.cc @@ -15,6 +15,8 @@ #include "topk.h" +#include + #include "../../logging.h" namespace byteps { @@ -40,14 +42,135 @@ TopkCompressor::TopkCompressor(int k) : _k(k){}; TopkCompressor::~TopkCompressor() = default; +template +size_t TopkCompressor::_Packing(index_t* dst, const scalar_t* src, size_t len) { + static_assert(sizeof(index_t) == sizeof(scalar_t), + "index_t should be the same size as scalar_t"); + BPS_CHECK_LE(this->_k, len / 2); + using pair_t = std::pair; + using container_t = std::vector; + auto comp = [](const pair_t& lhs, const pair_t& rhs) { + return lhs.second > rhs.second; + }; + this->_src_len = len; + auto beg = reinterpret_cast(dst); + size_t size = 0; + for (index_t i = 0; i < len; ++i) { + if (i < this->_k) { + beg[size] = std::make_pair(i, src[i]); + size++; + std::push_heap(beg, beg + size, comp); + } else { + auto& top = *beg; + // note: compare absolute value + if (std::abs(src[i]) > std::abs(top.second)) { + std::pop_heap(beg, beg + size, comp); + beg[size - 1] = std::make_pair(i, src[i]); + std::push_heap(beg, beg + size, comp); + } + } + } + BPS_LOG(INFO) << "first=" << beg[0].first << " second=" << beg[0].second; + + return this->_k * sizeof(pair_t); +} + +size_t TopkCompressor::Packing(const void* src, size_t size, int dtype) { + switch (dtype) { + case BYTEPS_INT8: + return _Packing(reinterpret_cast(_buf.get()), + reinterpret_cast(src), + size / sizeof(int8_t)); + case BYTEPS_UINT8: + return _Packing(reinterpret_cast(_buf.get()), + reinterpret_cast(src), + size / sizeof(uint8_t)); + // case BYTEPS_FLOAT16: + // return _Packing(reinterpret_cast(_buf.get()), + // reinterpret_cast(src), size); + case BYTEPS_FLOAT32: + return _Packing(reinterpret_cast(_buf.get()), + reinterpret_cast(src), + size / sizeof(int32_t)); + case BYTEPS_FLOAT64: + return _Packing(reinterpret_cast(_buf.get()), + reinterpret_cast(src), + size / sizeof(int64_t)); + default: + BPS_CHECK(0) << "Unsupported data type: " << dtype; + } + return 0; +} + void TopkCompressor::Compress(ByteBuf grad, int dtype, ByteBuf& compressed) { - // TODO + compressed.size = Packing(grad.data, grad.size, dtype); + compressed.data = _buf.get(); } +template +size_t TopkCompressor::_Unpacking(scalar_t* dst, const index_t* src, + size_t len) { + static_assert(sizeof(index_t) == sizeof(scalar_t), + "index_t should be the same size as scalar_t"); + using pair_t = std::pair; + auto ptr = reinterpret_cast(src); + + if ((void*)dst == (void*)src) { + auto buf = reinterpret_cast(_buf.get()); + std::copy(ptr, ptr+len, buf); + ptr = const_cast(buf); + } + + // reset to zeros + std::fill(dst, dst + this->_src_len, 0); + for (auto i = 0; i < len; ++i) { + auto& pair = ptr[i]; + dst[pair.first] = pair.second; + } +} + +size_t TopkCompressor::Unpacking(void* dst, const void* src, size_t size, + int dtype) { + switch (dtype) { + case BYTEPS_INT8: + return _Unpacking(reinterpret_cast(dst), + reinterpret_cast(src), + size / sizeof(int8_t) / 2); + case BYTEPS_UINT8: + return _Unpacking(reinterpret_cast(dst), + reinterpret_cast(src), + size / sizeof(uint8_t) / 2); + // case BYTEPS_FLOAT16: + // return _Unpacking(reinterpret_cast(_buf.get()), + // reinterpret_cast(src), size); + case BYTEPS_FLOAT32: + return _Unpacking(reinterpret_cast(dst), + reinterpret_cast(src), + size / sizeof(float) / 2); + case BYTEPS_FLOAT64: + return _Unpacking(reinterpret_cast(dst), + reinterpret_cast(src), + size / sizeof(double) / 2); + default: + BPS_CHECK(0) << "Unsupported data type: " << dtype; + } + return 0; +} + +#ifndef BYTEPS_BUILDING_SERVER +// worker version decompressor +void TopkCompressor::Decompress(ByteBuf compressed, int dtype, + ByteBuf& decompressed) { + BPS_CHECK(decompressed.data); + Unpacking(decompressed.data, compressed.data, compressed.size, dtype); +} +#else void TopkCompressor::Decompress(ByteBuf compressed, int dtype, ByteBuf& decompressed) { - // TODO + if (decompressed.data == nullptr) decompressed.data = _buf.get(); + Unpacking(decompressed.data, compressed.data, compressed.size, dtype); } +#endif } // namespace compressor } // namespace common } // namespace byteps \ No newline at end of file diff --git a/byteps/common/compressor/strategy/topk.h b/byteps/common/compressor/strategy/topk.h index 442fb884d..8200e69d5 100644 --- a/byteps/common/compressor/strategy/topk.h +++ b/byteps/common/compressor/strategy/topk.h @@ -23,20 +23,59 @@ namespace common { namespace compressor { /*! - * \brief TODO + * \brief TopK Compressor + * + * paper: Sparsified SGD with Memory + * https://arxiv.org/pdf/1809.07599.pdf + * + * sending the most significant entries of the stochastic gradient + * + * \note this is a deterministic algorithm */ class TopkCompressor : public BaseCompressor { public: explicit TopkCompressor(int k); virtual ~TopkCompressor(); - + + /*! + * \brief Compress function + * + * select topk entries and corresponding indices + * + * \note compare with absolute values + * + * \param grad gradient tensor + * \param dtype data type + * \param compressed compressed tensor + */ void Compress(ByteBuf grad, int dtype, ByteBuf& compressed) override; + /*! + * \brief Decompress function + * + * fill a zero tensor with topk entries and corresponding indices + * + * \param compressed compressed tensor + * \param dtype data type + * \param decompressed decompressed tensor + */ void Decompress(ByteBuf compressed, int dtype, ByteBuf& decompressed) override; + private: + size_t Packing(const void* src, size_t size, int dtype); + + template + size_t _Packing(index_t* dst, const scalar_t* src, size_t len); + + size_t Unpacking(void* dst, const void* src, size_t size, int dtype); + + template + size_t _Unpacking(scalar_t* dst, const index_t* src, size_t len); + private: int _k; + int _src_len; }; } // namespace compressor } // namespace common diff --git a/byteps/mxnet/__init__.py b/byteps/mxnet/__init__.py index 379cc437b..f4ae1ed3b 100644 --- a/byteps/mxnet/__init__.py +++ b/byteps/mxnet/__init__.py @@ -253,7 +253,7 @@ def _register_compressor(self, params, optimizer_params, compression_params): for _, param in params.items(): # generic for item in check_list: - if item in compression_params: + if item in compression_params and compression_params[item]: if isinstance(compression_params[item], str): setattr(param, "byteps_%s_type" % item, compression_params[item]) diff --git a/example/mxnet/train_gluon_imagenet_byteps_gc.py b/example/mxnet/train_gluon_imagenet_byteps_gc.py index 2e8674b4d..1bbb5152a 100644 --- a/example/mxnet/train_gluon_imagenet_byteps_gc.py +++ b/example/mxnet/train_gluon_imagenet_byteps_gc.py @@ -122,6 +122,8 @@ def parse_args(): help='which compress momentum') parser.add_argument('--onebit-scaling', action='store_true', default=False, help='enable scaling for onebit compressor') + parser.add_argument('--k', default=1, type=int, + help='topk or randomk') parser.add_argument('--fp16-pushpull', action='store_true', default=False, help='use fp16 compression during pushpull') @@ -402,7 +404,8 @@ def train(ctx): "compressor": opt.compressor, "ef": opt.ef, "momentum": opt.compress_momentum, - "scaling": opt.onebit_scaling + "scaling": opt.onebit_scaling, + "k": opt.k } trainer = bps.DistributedTrainer( diff --git a/example/mxnet/train_gluon_mnist_byteps_gc.py b/example/mxnet/train_gluon_mnist_byteps_gc.py index a60deeb61..674e6e276 100644 --- a/example/mxnet/train_gluon_mnist_byteps_gc.py +++ b/example/mxnet/train_gluon_mnist_byteps_gc.py @@ -48,12 +48,14 @@ help='disable training on GPU (default: False)') parser.add_argument('--compressor', type=str, default='', help='which compressor') -parser.add_argument('--ef', type=str, default=None, +parser.add_argument('--ef', type=str, default='', help='which error feedback') parser.add_argument('--compress-momentum', type=str, default='', help='which compress momentum') parser.add_argument('--scaling', action='store_true', default=False, help='enable scaling for onebit compressor') +parser.add_argument('--k', type=int, default=1, + help='topk or randomk') parser.add_argument('--fp16-pushpull', action='store_true', default=False, help='use fp16 compression during pushpull') args = parser.parse_args() @@ -142,7 +144,8 @@ def evaluate(model, data_iter, context): "compressor": args.compressor, "ef": args.ef, "momentum": args.compress_momentum, - "scaling": args.scaling + "scaling": args.scaling, + "k": args.k } trainer = bps.DistributedTrainer(