Skip to content

Commit

Permalink
compression: add randomk compressor (#11)
Browse files Browse the repository at this point in the history
* randomk: init commit

* randomk: fix typo

* randomk: fix typo
  • Loading branch information
jasperzhong committed Jun 23, 2020
1 parent 2379e17 commit 720d7e9
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 6 deletions.
1 change: 0 additions & 1 deletion byteps/common/compressor/strategy/onebit.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ namespace compressor {
* server: majority vote
* sign(\sum_i c_i)
*
* \note this is a deterministic algorithm.
* \note 0 represents positive and 1 represents negative.
*/
class OnebitCompressor : public BaseCompressor {
Expand Down
111 changes: 108 additions & 3 deletions byteps/common/compressor/strategy/randomk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,123 @@ CompressorRegistry::Register
});
}

RandomkCompressor::RandomkCompressor(int k) : _k(k){};
RandomkCompressor::RandomkCompressor(int k) : _k(k) { _gen.seed(_rd()); };

RandomkCompressor::~RandomkCompressor() = default;
template <typename index_t, typename scalar_t>
size_t RandomkCompressor::_Packing(index_t* dst, const scalar_t* src,
size_t len) {
static_assert(sizeof(index_t) == sizeof(scalar_t),
"index_t should be the same size as scalar_t");
BPS_CHECK_LE(this->_k, len / 2);
using pair_t = std::pair<index_t, scalar_t>;
std::uniform_int_distribution<> dis(0, len-1);
auto ptr = reinterpret_cast<pair_t*>(dst);

for (size_t i = 0; i < this->_k; ++i) {
auto index = dis(_gen);
ptr[i] = std::make_pair(index, src[index]);
}

return this->_k * sizeof(pair_t);
}

size_t RandomkCompressor::Packing(const void* src, size_t size, int dtype) {
switch (dtype) {
case BYTEPS_INT8:
return _Packing(reinterpret_cast<int8_t*>(_buf.get()),
reinterpret_cast<const int8_t*>(src),
size / sizeof(int8_t));
case BYTEPS_UINT8:
return _Packing(reinterpret_cast<uint8_t*>(_buf.get()),
reinterpret_cast<const uint8_t*>(src),
size / sizeof(uint8_t));
// case BYTEPS_FLOAT16:
// return _Packing(reinterpret_cast<int8_t*>(_buf.get()),
// reinterpret_cast<const int8_t*>(src), size);
case BYTEPS_FLOAT32:
return _Packing(reinterpret_cast<int32_t*>(_buf.get()),
reinterpret_cast<const float*>(src),
size / sizeof(int32_t));
case BYTEPS_FLOAT64:
return _Packing(reinterpret_cast<int64_t*>(_buf.get()),
reinterpret_cast<const double*>(src),
size / sizeof(int64_t));
default:
BPS_CHECK(0) << "Unsupported data type: " << dtype;
}
return 0;
}

void RandomkCompressor::Compress(ByteBuf grad, int dtype, ByteBuf& compressed) {
// TODO
compressed.size = Packing(grad.data, grad.size, dtype);
compressed.data = _buf.get();
}

template <typename index_t, typename scalar_t>
size_t RandomkCompressor::_Unpacking(scalar_t* dst, const index_t* src,
size_t len) {
static_assert(sizeof(index_t) == sizeof(scalar_t),
"index_t should be the same size as scalar_t");
using pair_t = std::pair<index_t, scalar_t>;
auto ptr = reinterpret_cast<const pair_t*>(src);

if ((void*)dst == (void*)src) {
auto buf = reinterpret_cast<pair_t*>(_buf.get());
std::copy(ptr, ptr + len, buf);
ptr = const_cast<const pair_t*>(buf);
}

// reset to zeros
std::fill(dst, dst + this->_src_len, 0);
for (auto i = 0; i < len; ++i) {
auto& pair = ptr[i];
dst[pair.first] = pair.second;
}
}

size_t RandomkCompressor::Unpacking(void* dst, const void* src, size_t size,
int dtype) {
switch (dtype) {
case BYTEPS_INT8:
return _Unpacking(reinterpret_cast<int8_t*>(dst),
reinterpret_cast<const int8_t*>(src),
size / sizeof(int8_t) / 2);
case BYTEPS_UINT8:
return _Unpacking(reinterpret_cast<uint8_t*>(dst),
reinterpret_cast<const uint8_t*>(src),
size / sizeof(uint8_t) / 2);
// case BYTEPS_FLOAT16:
// return _Unpacking(reinterpret_cast<int8_t*>(_buf.get()),
// reinterpret_cast<const int8_t*>(src), size);
case BYTEPS_FLOAT32:
return _Unpacking(reinterpret_cast<float*>(dst),
reinterpret_cast<const int32_t*>(src),
size / sizeof(float) / 2);
case BYTEPS_FLOAT64:
return _Unpacking(reinterpret_cast<double*>(dst),
reinterpret_cast<const int64_t*>(src),
size / sizeof(double) / 2);
default:
BPS_CHECK(0) << "Unsupported data type: " << dtype;
}
return 0;
}

#ifndef BYTEPS_BUILDING_SERVER
// worker version decompressor
void RandomkCompressor::Decompress(ByteBuf compressed, int dtype,
ByteBuf& decompressed) {
BPS_CHECK(decompressed.data);
Unpacking(decompressed.data, compressed.data, compressed.size, dtype);
}
#else
void RandomkCompressor::Decompress(ByteBuf compressed, int dtype,
ByteBuf& decompressed) {
// TODO
if (decompressed.data == nullptr) decompressed.data = _buf.get();
Unpacking(decompressed.data, compressed.data, compressed.size, dtype);
}
#endif
} // namespace compressor
} // namespace common
} // namespace byteps
41 changes: 40 additions & 1 deletion byteps/common/compressor/strategy/randomk.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,66 @@
#ifndef BYTEPS_COMPRESS_STRAT_RANDOMK_H
#define BYTEPS_COMPRESS_STRAT_RANDOMK_H

#include <random>

#include "../base_compressor.h"

namespace byteps {
namespace common {
namespace compressor {

/*!
* \brief TODO
* \brief RandomK Compressor
*
* paper: Sparsified SGD with Memory
* https://arxiv.org/pdf/1809.07599.pdf
*
* randomly sending k entries of the stochastic gradient
*/
class RandomkCompressor : public BaseCompressor {
public:
explicit RandomkCompressor(int k);
virtual ~RandomkCompressor();

/*!
* \brief Compress function
*
* randomly select k entries and corresponding indices
*
* \param grad gradient tensor
* \param dtype data type
* \param compressed compressed tensor
*/
void Compress(ByteBuf grad, int dtype, ByteBuf& compressed) override;

/*!
* \brief Decompress function
*
* fill a zero tensor with topk entries and corresponding indices
*
* \param compressed compressed tensor
* \param dtype data type
* \param decompressed decompressed tensor
*/
void Decompress(ByteBuf compressed, int dtype,
ByteBuf& decompressed) override;

private:
size_t Packing(const void* src, size_t size, int dtype);

template <typename index_t, typename scalar_t>
size_t _Packing(index_t* dst, const scalar_t* src, size_t len);

size_t Unpacking(void* dst, const void* src, size_t size, int dtype);

template <typename index_t, typename scalar_t>
size_t _Unpacking(scalar_t* dst, const index_t* src, size_t len);

private:
int _k;
int _src_len;
std::random_device _rd;
std::mt19937 _gen;
};
} // namespace compressor
} // namespace common
Expand Down
1 change: 0 additions & 1 deletion byteps/common/compressor/strategy/topk.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ namespace compressor {
*
* sending the most significant entries of the stochastic gradient
*
* \note this is a deterministic algorithm
*/
class TopkCompressor : public BaseCompressor {
public:
Expand Down

0 comments on commit 720d7e9

Please sign in to comment.