diff --git a/byteps/common/common.h b/byteps/common/common.h index efd9e04b2..9824488bd 100644 --- a/byteps/common/common.h +++ b/byteps/common/common.h @@ -45,7 +45,7 @@ namespace common { namespace compressor { struct BPSTensor; typedef BPSTensor tensor_t; -class BaseCompressor; +class Compressor; class ErrorFeedback; } // namespace compressor @@ -198,7 +198,7 @@ typedef struct BytePSContext { std::unordered_map>> part_comm_time; // Compressor list - std::vector> compressor_list; + std::vector> compressor_list; // kwargs std::unordered_map kwargs; } BPSContext; @@ -257,7 +257,7 @@ struct TensorTableEntry { // How many partitions unsigned int total_partnum = 0; // Compressor - std::shared_ptr compressor; + std::shared_ptr compressor; // Compressed std::shared_ptr compressed; }; diff --git a/byteps/common/compressor/base_compressor.h b/byteps/common/compressor/base_compressor.h deleted file mode 100644 index 6477c247d..000000000 --- a/byteps/common/compressor/base_compressor.h +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright 2019 Amazon Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#ifndef BYTEPS_COMPRESS_BASE_COMPRESSOR_H -#define BYTEPS_COMPRESS_BASE_COMPRESSOR_H - -#include -#include -#include -#include -#include - -#include "../cpu_reducer.h" - -namespace byteps { -namespace common { -namespace compressor { - -typedef char byte_t; -/*! - * \brief Tensor type - */ -typedef struct BPSTensor { - byte_t* data; - size_t size; - int dtype; - - BPSTensor(byte_t* data = nullptr, size_t size = 0, int dtype = 0) - : data(data), size(size), dtype(dtype) {} -} tensor_t; - -/*! - * \brief Compressor interface used in BytePS core. - */ -class BaseCompressor { - public: - BaseCompressor(); - virtual ~BaseCompressor(); - - /*! - * \brief Allocate encoding buffer for compression. - * \param aligned_size aligned size - */ - virtual void Init(size_t aligned_size); - - /*! - * \brief Compress function - * - * \param grad gradient tensor - * \param compressed compressed tensor - */ - virtual void Compress(tensor_t grad, tensor_t& compressed) = 0; - - /*! - * \brief Decompress function - * - * \param compressed compressed tensor - * \param decompressed decompressed tensor - */ - virtual void Decompress(tensor_t compressed, tensor_t& decompressed) = 0; - - /*! - * \brief help function for error feedback `UpdateError` - * - * \param corrected gradient corrected with error - * \param error error - * \param compressed compressed gradient - */ - virtual void FastUpdateError(tensor_t error, tensor_t corrected, - tensor_t compressed); - - protected: - /*! - * \brief buffer - */ - std::unique_ptr _buf; - - /*! - * \brief CPU reducer - */ - std::unique_ptr _cpu_reducer; -}; - -using kwargs_t = std::unordered_map; - -class CompressorRegistry { - public: - using ctor_t = - std::function(const kwargs_t& kwargs)>; - - using map_t = std::unordered_map; - - struct Register { - explicit Register(std::string name, ctor_t ctor); - }; - - static ctor_t Find(const std::string& name); - - static std::unique_ptr Create(const kwargs_t& kwargs); - - private: - static map_t _ctor_map; - - CompressorRegistry() = delete; - ~CompressorRegistry() = delete; -}; - -inline std::string Serialize(const kwargs_t& kwargs) { - std::ostringstream os; - os << kwargs.size(); - for (auto const& kwarg : kwargs) { - os << " " << kwarg.first << " " << kwarg.second; - } - return os.str(); -} - -inline kwargs_t Deserialize(const std::string& content) { - kwargs_t kwargs; - std::istringstream is(content); - size_t size = 0; - is >> size; - for (size_t i = 0; i < size; ++i) { - kwargs_t::key_type key; - kwargs_t::mapped_type val; - is >> key >> val; - kwargs[key] = val; - } - - return kwargs; -} -} // namespace compressor -} // namespace common -} // namespace byteps - -#endif // BYTEPS_COMPRESS_BASE_COMPRESSOR_H \ No newline at end of file diff --git a/byteps/common/compressor/common.h b/byteps/common/compressor/common.h new file mode 100644 index 000000000..3c929c992 --- /dev/null +++ b/byteps/common/compressor/common.h @@ -0,0 +1,44 @@ +// Copyright 2019 Amazon Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef BYTEPS_COMPRESSOR_COMMON_H +#define BYTEPS_COMPRESSOR_COMMON_H + +#include + +namespace byteps { +namespace common { +namespace compressor { +typedef char byte_t; +/*! + * \brief Tensor type + */ +typedef struct BPSTensor { + byte_t* data; + size_t size; + int dtype; + + BPSTensor() : data(nullptr), size(0), dtype(0) {} + BPSTensor(byte_t* data, size_t size=0, int dtype=0) + : data(data), size(size), dtype(dtype) {} +} tensor_t; + +using kwargs_t = std::unordered_map; + +} // namespace compressor +} // namespace common +} // namespace byteps + +#endif // BYTEPS_COMPRESSOR_COMMON_H \ No newline at end of file diff --git a/byteps/common/compressor/compressor.h b/byteps/common/compressor/compressor.h new file mode 100644 index 000000000..ac7400953 --- /dev/null +++ b/byteps/common/compressor/compressor.h @@ -0,0 +1,89 @@ +// Copyright 2019 Amazon Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef BYTEPS_COMPRESSOR_COMPRESSOR_H +#define BYTEPS_COMPRESSOR_COMPRESSOR_H + +#include + +#include "../cpu_reducer.h" +#include "common.h" + +namespace byteps { +namespace common { +namespace compressor { +/*! + * \brief Compressor interface used in BytePS core. + */ +class Compressor { + public: + Compressor(size_t size) + : _size(size), + _buf(new byte_t[size]), + _cpu_reducer(new CpuReducer(nullptr)){}; + virtual ~Compressor() = default; + + /*! + * \brief Compress function + * + * \param grad gradient tensor + * \param compressed compressed tensor + */ + virtual void Compress(tensor_t grad, tensor_t& compressed) = 0; + + /*! + * \brief Decompress function + * + * \param compressed compressed tensor + * \param decompressed decompressed tensor + */ + virtual void Decompress(tensor_t compressed, tensor_t& decompressed) = 0; + + /*! + * \brief help function for error feedback `UpdateError` + * + * \param corrected gradient corrected with error + * \param error error + * \param compressed compressed gradient + */ + virtual void FastUpdateError(tensor_t error, tensor_t corrected, + tensor_t compressed) { + BPS_LOG(FATAL) << "FastUpdateError is not implemented"; + }; + + size_t size() const { return _size; } + + protected: + /*! + * \brief buffer + */ + std::unique_ptr _buf; + + /*! + * \brief original size + */ + size_t _size; + + /*! + * \brief CPU reducer + */ + std::unique_ptr _cpu_reducer; +}; + +} // namespace compressor +} // namespace common +} // namespace byteps + +#endif // BYTEPS_COMPRESSOR_COMPRESSOR_H \ No newline at end of file diff --git a/byteps/common/compressor/base_compressor.cc b/byteps/common/compressor/compressor_registry.cc similarity index 77% rename from byteps/common/compressor/base_compressor.cc rename to byteps/common/compressor/compressor_registry.cc index dbdc0307d..c92c037b9 100644 --- a/byteps/common/compressor/base_compressor.cc +++ b/byteps/common/compressor/compressor_registry.cc @@ -13,9 +13,7 @@ // limitations under the License. // ============================================================================= -#include "base_compressor.h" - -#include "../logging.h" +#include "compressor_registry.h" namespace byteps { namespace common { @@ -38,8 +36,8 @@ CompressorRegistry::ctor_t CompressorRegistry::Find(const std::string& name) { return it->second; } -std::unique_ptr CompressorRegistry::Create( - const kwargs_t& kwargs) { +std::unique_ptr CompressorRegistry::Create(const kwargs_t& kwargs, + size_t size, int dtype) { #ifndef BYTEPS_BUILDING_SERVER const std::string types[] = {"momentum_type", "ef_type", "compressor_type"}; #else @@ -51,24 +49,13 @@ std::unique_ptr CompressorRegistry::Create( if (iter != kwargs.end()) { auto ctor = CompressorRegistry::Find(iter->second + "_" + type); BPS_CHECK_NE(ctor, nullptr); - return ctor(kwargs); + return ctor(kwargs, size, dtype); } } return nullptr; } -BaseCompressor::BaseCompressor() = default; - -BaseCompressor::~BaseCompressor() = default; - -void BaseCompressor::Init(size_t aligned_size) { - _buf.reset(new char[aligned_size]); - _cpu_reducer.reset(new CpuReducer(nullptr)); -} - -void BaseCompressor::FastUpdateError(tensor_t error, tensor_t corrected, - tensor_t compressed) {} } // namespace compressor } // namespace common } // namespace byteps \ No newline at end of file diff --git a/byteps/common/compressor/compressor_registry.h b/byteps/common/compressor/compressor_registry.h new file mode 100644 index 000000000..aafe941de --- /dev/null +++ b/byteps/common/compressor/compressor_registry.h @@ -0,0 +1,52 @@ +// Copyright 2019 Amazon Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef BYTEPS_COMPRESSOR_COMPRESSOR_REGISTRY_H +#define BYTEPS_COMPRESSOR_COMPRESSOR_REGISTRY_H + +#include "compressor.h" + +namespace byteps { +namespace common { +namespace compressor { + +class CompressorRegistry { + public: + using ctor_t = std::function( + const kwargs_t& kwargs, size_t size, int dtype)>; + + using map_t = std::unordered_map; + + struct Register { + Register(std::string name, ctor_t ctor); + }; + + static ctor_t Find(const std::string& name); + + static std::unique_ptr Create(const kwargs_t& kwargs, size_t size, + int dtype); + + private: + static map_t _ctor_map; + + CompressorRegistry() = delete; + ~CompressorRegistry() = delete; +}; + +} // namespace compressor +} // namespace common +} // namespace byteps + +#endif // BYTEPS_COMPRESSOR_COMPRESSOR_REGISTRY_H \ No newline at end of file diff --git a/byteps/common/compressor/error_feedback.cc b/byteps/common/compressor/error_feedback.cc index 3e3750e6e..52adbc0dc 100644 --- a/byteps/common/compressor/error_feedback.cc +++ b/byteps/common/compressor/error_feedback.cc @@ -18,18 +18,6 @@ namespace byteps { namespace common { namespace compressor { -ErrorFeedback::ErrorFeedback(std::unique_ptr compressor_ptr) - : _compressor_ptr(std::move(compressor_ptr)) {} - -ErrorFeedback::~ErrorFeedback() = default; - -void ErrorFeedback::Init(size_t aligned_size) { - _compressor_ptr->Init(aligned_size); - _error.reset(new char[aligned_size]); - memset(_error.get(), 0, aligned_size); - _cpu_reducer.reset(new CpuReducer(nullptr)); -} - void ErrorFeedback::Compress(tensor_t grad, tensor_t& compressed) { // before: grad += error UpdateGradient(grad); @@ -37,17 +25,18 @@ void ErrorFeedback::Compress(tensor_t grad, tensor_t& compressed) { // TODO: look strange compressed.data = _error.get(); // compress - _compressor_ptr->Compress(grad, compressed); + _cptr->Compress(grad, compressed); UpdateError(grad, compressed); } void ErrorFeedback::Decompress(tensor_t compressed, tensor_t& decompressed) { - _compressor_ptr->Decompress(compressed, decompressed); + _cptr->Decompress(compressed, decompressed); } void ErrorFeedback::UpdateError(tensor_t corrected, tensor_t compressed) { - _compressor_ptr->FastUpdateError({_error.get()}, corrected, compressed); + tensor_t error{_error.get(), _size, corrected.dtype}; + _cptr->FastUpdateError(error, corrected, compressed); } } // namespace compressor diff --git a/byteps/common/compressor/error_feedback.h b/byteps/common/compressor/error_feedback.h index 768949713..9acf9c14d 100644 --- a/byteps/common/compressor/error_feedback.h +++ b/byteps/common/compressor/error_feedback.h @@ -13,10 +13,10 @@ // limitations under the License. // ============================================================================= -#ifndef BYTEPS_COMPRESS_ERROR_FEEDBACK_H -#define BYTEPS_COMPRESS_ERROR_FEEDBACK_H +#ifndef BYTEPS_COMPRESSOR_ERROR_FEEDBACK_H +#define BYTEPS_COMPRESSOR_ERROR_FEEDBACK_H -#include "base_compressor.h" +#include "compressor.h" namespace byteps { namespace common { @@ -27,16 +27,11 @@ namespace compressor { * * add error feedback behavior to any compressor at run-time */ -class ErrorFeedback : public BaseCompressor { +class ErrorFeedback : public Compressor { public: - explicit ErrorFeedback(std::unique_ptr compressor_ptr); - virtual ~ErrorFeedback(); - - /*! - * \brief Allocate encoding buffer for compression. - * \param aligned_size aligned size - */ - virtual void Init(size_t aligned_size); + ErrorFeedback(size_t size, std::unique_ptr cptr) + : Compressor(size), _cptr(std::move(cptr)), _error(new byte_t[size]()) {} + virtual ~ErrorFeedback() = default; /*! * \brief Compress function @@ -76,16 +71,16 @@ class ErrorFeedback : public BaseCompressor { virtual void UpdateError(tensor_t corrected, tensor_t compressed); protected: - std::unique_ptr _error; + std::unique_ptr _error; private: /*! * \brief compressor */ - std::unique_ptr _compressor_ptr; + std::unique_ptr _cptr; }; } // namespace compressor } // namespace common } // namespace byteps -#endif // BYTEPS_COMPRESS_ERROR_FEEDBACK_H \ No newline at end of file +#endif // BYTEPS_COMPRESSOR_ERROR_FEEDBACK_H \ No newline at end of file diff --git a/byteps/common/compressor/momentum.cc b/byteps/common/compressor/momentum.cc index 865c6f4f9..f67761554 100644 --- a/byteps/common/compressor/momentum.cc +++ b/byteps/common/compressor/momentum.cc @@ -19,18 +19,6 @@ namespace byteps { namespace common { namespace compressor { -Momentum::Momentum(std::unique_ptr compressor_ptr, float mu) - : _compressor_ptr(std::move(compressor_ptr)), _mu(mu) {} - -Momentum::~Momentum() = default; - -void Momentum::Init(size_t aligned_size) { - _compressor_ptr->Init(aligned_size); - _mom.reset(new char[aligned_size]); - memset(_mom.get(), 0, aligned_size); - _cpu_reducer.reset(new CpuReducer(nullptr)); -} - void Momentum::Compress(tensor_t grad, tensor_t& compressed) { // m_t = \mu * m_{t-1} + g_t UpdateMom(grad); @@ -39,11 +27,11 @@ void Momentum::Compress(tensor_t grad, tensor_t& compressed) { UpdateGradient(grad); // compress - _compressor_ptr->Compress(grad, compressed); + _cptr->Compress(grad, compressed); } void Momentum::Decompress(tensor_t compressed, tensor_t& decompressed) { - _compressor_ptr->Decompress(compressed, decompressed); + _cptr->Decompress(compressed, decompressed); } } // namespace compressor diff --git a/byteps/common/compressor/momentum.h b/byteps/common/compressor/momentum.h index 377b66ce2..2e849ecc6 100644 --- a/byteps/common/compressor/momentum.h +++ b/byteps/common/compressor/momentum.h @@ -13,10 +13,10 @@ // limitations under the License. // ============================================================================= -#ifndef BYTEPS_COMPRESS_MOMENTUM_H -#define BYTEPS_COMPRESS_MOMENTUM_H +#ifndef BYTEPS_COMPRESSOR_MOMENTUM_H +#define BYTEPS_COMPRESSOR_MOMENTUM_H -#include "base_compressor.h" +#include "compressor.h" namespace byteps { namespace common { @@ -29,15 +29,14 @@ namespace compressor { * NOTE: This should not be used at the same time with the momentum implemented * in the framework such as MXNet, Tensorflow or PyTorch etc. */ -class Momentum : public BaseCompressor { +class Momentum : public Compressor { public: - Momentum(std::unique_ptr compressor_ptr, float mu); - virtual ~Momentum(); - /*! - * \brief Allocate encoding buffer for compression. - * \param aligned_size aligned size - */ - virtual void Init(size_t aligned_size) final; + Momentum(size_t size, std::unique_ptr cptr, float mu) + : Compressor(size), + _cptr(std::move(cptr)), + _mu(mu), + _mom(new byte_t[size]()){}; + virtual ~Momentum() = default; /*! * \brief Compress function @@ -75,7 +74,7 @@ class Momentum : public BaseCompressor { virtual void UpdateGradient(tensor_t grad) = 0; protected: - std::unique_ptr _mom; + std::unique_ptr _mom; float _mu; @@ -83,10 +82,10 @@ class Momentum : public BaseCompressor { /*! * \brief compressor */ - std::unique_ptr _compressor_ptr; + std::unique_ptr _cptr; }; } // namespace compressor } // namespace common } // namespace byteps -#endif \ No newline at end of file +#endif // BYTEPS_COMPRESSOR_MOMENTUM_H \ No newline at end of file diff --git a/byteps/common/compressor/strategy/multibit.cc b/byteps/common/compressor/strategy/multibit.cc index d3aef4e92..b7586d054 100644 --- a/byteps/common/compressor/strategy/multibit.cc +++ b/byteps/common/compressor/strategy/multibit.cc @@ -14,8 +14,7 @@ // ============================================================================= #include "multibit.h" - -#include "../../logging.h" +#include "../compressor_registry.h" namespace byteps { namespace common { @@ -23,7 +22,8 @@ namespace compressor { namespace { CompressorRegistry::Register reg("multibit_compressor", - [](const kwargs_t& kwargs) -> std::unique_ptr { + [](const kwargs_t& kwargs, size_t size, + int dtype) -> std::unique_ptr { auto iter = kwargs.find("compressor_k"); if (iter == kwargs.end()) { BPS_LOG(WARNING) @@ -33,14 +33,10 @@ CompressorRegistry::Register int k = std::stoi(iter->second); BPS_LOG(DEBUG) << "Register Multibit Compressor " << "k=" << k; - return std::unique_ptr(new MultibitCompressor(k)); + return std::unique_ptr(new MultibitCompressor(size, k)); }); } -MultibitCompressor::MultibitCompressor(int k) : _k(k){}; - -MultibitCompressor::~MultibitCompressor() = default; - void MultibitCompressor::Compress(tensor_t grad, tensor_t& compressed) { // TOOD } diff --git a/byteps/common/compressor/strategy/multibit.h b/byteps/common/compressor/strategy/multibit.h index 72e02530b..714bf1552 100644 --- a/byteps/common/compressor/strategy/multibit.h +++ b/byteps/common/compressor/strategy/multibit.h @@ -13,10 +13,10 @@ // limitations under the License. // ============================================================================= -#ifndef BYTEPS_COMPRESS_STRAT_MULTIBIT_H -#define BYTEPS_COMPRESS_STRAT_MULTIBIT_H +#ifndef BYTEPS_COMPRESSOR_STRATEGY_MULTIBIT_H +#define BYTEPS_COMPRESSOR_STRATEGY_MULTIBIT_H -#include "../base_compressor.h" +#include "../compressor.h" namespace byteps { namespace common { @@ -25,10 +25,10 @@ namespace compressor { /*! * \brief TODO */ -class MultibitCompressor : public BaseCompressor { +class MultibitCompressor : public Compressor { public: - explicit MultibitCompressor(int k); - virtual ~MultibitCompressor(); + MultibitCompressor(size_t size, int k) : Compressor(size), _k(k){}; + virtual ~MultibitCompressor() = default; void Compress(tensor_t grad, tensor_t& compressed) override; @@ -41,4 +41,4 @@ class MultibitCompressor : public BaseCompressor { } // namespace common } // namespace byteps -#endif // BYTEPS_COMPRESS_STRAT_MULTIBIT_H \ No newline at end of file +#endif // BYTEPS_COMPRESSOR_STRATEGY_MULTIBIT_H \ No newline at end of file diff --git a/byteps/common/compressor/strategy/nesterov_momentum.cc b/byteps/common/compressor/strategy/nesterov_momentum.cc index 1bbd8c7c0..b85723d5a 100644 --- a/byteps/common/compressor/strategy/nesterov_momentum.cc +++ b/byteps/common/compressor/strategy/nesterov_momentum.cc @@ -14,8 +14,7 @@ // ============================================================================= #include "nesterov_momentum.h" - -#include "../../logging.h" +#include "../compressor_registry.h" namespace byteps { namespace common { @@ -23,28 +22,24 @@ namespace compressor { namespace { CompressorRegistry::Register reg( "nesterov_momentum", - [](const kwargs_t& kwargs) -> std::unique_ptr { + [](const kwargs_t& kwargs, size_t size, + int dtype) -> std::unique_ptr { // register cpr auto kwargs_clone = kwargs; kwargs_clone.erase("momentum_type"); - auto compressor_ptr = CompressorRegistry::Create(kwargs_clone); - BPS_CHECK_NE(compressor_ptr, nullptr); + auto cptr = + CompressorRegistry::Create(kwargs_clone, size, dtype); + BPS_CHECK_NE(cptr, nullptr); // find \mu auto iter = kwargs.find("momentum_mu"); BPS_CHECK_NE(iter, kwargs.end()) << "momentum \mu is not defined"; float mu = std::stof(iter->second); BPS_LOG(DEBUG) << "with momentum"; return std::unique_ptr( - new NesterovMomentumCompressor(std::move(compressor_ptr), mu)); + new NesterovMomentumCompressor(size, std::move(cptr), mu)); }); } -NesterovMomentumCompressor::NesterovMomentumCompressor( - std::unique_ptr compressor_ptr, float mu) - : Momentum(std::move(compressor_ptr), mu){}; - -NesterovMomentumCompressor::~NesterovMomentumCompressor() = default; - void NesterovMomentumCompressor::UpdateMom(tensor_t grad) { // m_t = \mu * m_{t-1} + g_t this->_cpu_reducer->sum(_mom.get(), grad.data, _mom.get(), grad.size, diff --git a/byteps/common/compressor/strategy/nesterov_momentum.h b/byteps/common/compressor/strategy/nesterov_momentum.h index 00bdfee21..b612ca9ee 100644 --- a/byteps/common/compressor/strategy/nesterov_momentum.h +++ b/byteps/common/compressor/strategy/nesterov_momentum.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef BYTEPS_COMPRESS_NESTEROV_MOM_H -#define BYTEPS_COMPRESS_NESTEROV_MOM_H +#ifndef BYTEPS_COMPRESSOR_STRATEGY_NESTEROV_MOMENTUM_H +#define BYTEPS_COMPRESSOR_STRATEGY_NESTEROV_MOMENTUM_H #include "../momentum.h" @@ -24,9 +24,10 @@ namespace compressor { class NesterovMomentumCompressor : public Momentum { public: - NesterovMomentumCompressor(std::unique_ptr compressor_ptr, - float mu); - virtual ~NesterovMomentumCompressor(); + NesterovMomentumCompressor(size_t size, std::unique_ptr cptr, + float mu) + : Momentum(size, std::move(cptr), mu){}; + virtual ~NesterovMomentumCompressor() = default; protected: void UpdateMom(tensor_t grad) override; @@ -37,4 +38,4 @@ class NesterovMomentumCompressor : public Momentum { } // namespace common } // namespace byteps -#endif \ No newline at end of file +#endif // BYTEPS_COMPRESSOR_STRATEGY_NESTEROV_MOMENTUM_H \ No newline at end of file diff --git a/byteps/common/compressor/strategy/onebit.cc b/byteps/common/compressor/strategy/onebit.cc index 8f3e8386a..9826d75cf 100644 --- a/byteps/common/compressor/strategy/onebit.cc +++ b/byteps/common/compressor/strategy/onebit.cc @@ -14,32 +14,25 @@ // ============================================================================= #include "onebit.h" - -#include "../../logging.h" +#include "../compressor_registry.h" namespace byteps { namespace common { namespace compressor { namespace { -CompressorRegistry::Register reg( - "onebit_compressor", [](const kwargs_t& kwargs) { - BPS_LOG(DEBUG) << "Register Onebit Compressor"; - bool scaled = false; - auto iter = kwargs.find("compressor_onebit_scaling"); - if (iter != kwargs.end()) { - if (iter->second == "true" || iter->second == "True") scaled = true; - } - if (scaled) { - return std::unique_ptr(new OnebitCompressor(true)); - } - return std::unique_ptr(new OnebitCompressor()); - }); +CompressorRegistry::Register reg("onebit_compressor", [](const kwargs_t& kwargs, + size_t size, + int dtype) { + BPS_LOG(DEBUG) << "Register Onebit Compressor"; + bool scaled = false; + auto iter = kwargs.find("compressor_onebit_scaling"); + if (iter != kwargs.end()) { + if (iter->second == "true" || iter->second == "True") scaled = true; + } + return std::unique_ptr(new OnebitCompressor(size, scaled)); +}); } -OnebitCompressor::OnebitCompressor(bool use_scale) : _use_scale(use_scale){}; - -OnebitCompressor::~OnebitCompressor() = default; - template size_t OnebitCompressor::PackingImpl(index_t* dst, const scalar_t* src, size_t len) { diff --git a/byteps/common/compressor/strategy/onebit.h b/byteps/common/compressor/strategy/onebit.h index c0901a178..5326527c9 100644 --- a/byteps/common/compressor/strategy/onebit.h +++ b/byteps/common/compressor/strategy/onebit.h @@ -13,10 +13,10 @@ // limitations under the License. // ============================================================================= -#ifndef BYTEPS_COMPRESS_STRAT_ONEBIT_H -#define BYTEPS_COMPRESS_STRAT_ONEBIT_H +#ifndef BYTEPS_COMPRESSOR_STRATEGY_ONEBIT_H +#define BYTEPS_COMPRESSOR_STRATEGY_ONEBIT_H -#include "../base_compressor.h" +#include "../compressor.h" namespace byteps { namespace common { @@ -36,10 +36,11 @@ namespace compressor { * * \note 0 represents positive and 1 represents negative. */ -class OnebitCompressor : public BaseCompressor { +class OnebitCompressor : public Compressor { public: - OnebitCompressor(bool use_scale = false); - virtual ~OnebitCompressor(); + OnebitCompressor(size_t size, bool use_scale = false) + : Compressor(size), _use_scale(use_scale) {} + virtual ~OnebitCompressor() = default; /*! * \brief Compress function @@ -94,4 +95,4 @@ class OnebitCompressor : public BaseCompressor { } // namespace common } // namespace byteps -#endif // BYTEPS_COMPRESS_STRAT_ONEBIT_H \ No newline at end of file +#endif // BYTEPS_COMPRESSOR_STRATEGY_ONEBIT_H \ No newline at end of file diff --git a/byteps/common/compressor/strategy/randomk.cc b/byteps/common/compressor/strategy/randomk.cc index e134a67d6..3e8952114 100644 --- a/byteps/common/compressor/strategy/randomk.cc +++ b/byteps/common/compressor/strategy/randomk.cc @@ -14,8 +14,7 @@ // ============================================================================= #include "randomk.h" - -#include "../../logging.h" +#include "../compressor_registry.h" namespace byteps { namespace common { @@ -23,7 +22,8 @@ namespace compressor { namespace { CompressorRegistry::Register reg("randomk_compressor", - [](const kwargs_t& kwargs) -> std::unique_ptr { + [](const kwargs_t& kwargs, size_t size, + int dtype) -> std::unique_ptr { auto iter = kwargs.find("compressor_k"); if (iter == kwargs.end()) { BPS_LOG(WARNING) @@ -33,13 +33,10 @@ CompressorRegistry::Register int k = std::stoi(iter->second); BPS_LOG(DEBUG) << "Register Randomk Compressor " << "k=" << k; - return std::unique_ptr(new RandomkCompressor(k)); + return std::unique_ptr(new RandomkCompressor(size, k)); }); } -RandomkCompressor::RandomkCompressor(int k) : _k(k) { _gen.seed(_rd()); }; - -RandomkCompressor::~RandomkCompressor() = default; template size_t RandomkCompressor::PackingImpl(index_t* dst, const scalar_t* src, size_t len) { @@ -142,12 +139,11 @@ void RandomkCompressor::Unpacking(void* dst, const void* src, size_t size, void RandomkCompressor::Decompress(tensor_t compressed, tensor_t& decompressed) { - BPS_CHECK_GT(decompressed.size, 0); #ifdef BYTEPS_BUILDING_SERVER if (decompressed.data == nullptr) decompressed.data = _buf.get(); #endif Unpacking(decompressed.data, compressed.data, compressed.size, - decompressed.size, compressed.dtype); + _size, compressed.dtype); } template diff --git a/byteps/common/compressor/strategy/randomk.h b/byteps/common/compressor/strategy/randomk.h index a1c9cea47..0c791d9de 100644 --- a/byteps/common/compressor/strategy/randomk.h +++ b/byteps/common/compressor/strategy/randomk.h @@ -13,12 +13,12 @@ // limitations under the License. // ============================================================================= -#ifndef BYTEPS_COMPRESS_STRAT_RANDOMK_H -#define BYTEPS_COMPRESS_STRAT_RANDOMK_H +#ifndef BYTEPS_COMPRESSOR_STRATEGY_RANDOMK_H +#define BYTEPS_COMPRESSOR_STRATEGY_RANDOMK_H #include -#include "../base_compressor.h" +#include "../compressor.h" namespace byteps { namespace common { @@ -32,10 +32,12 @@ namespace compressor { * * randomly sending k entries of the stochastic gradient */ -class RandomkCompressor : public BaseCompressor { +class RandomkCompressor : public Compressor { public: - explicit RandomkCompressor(int k); - virtual ~RandomkCompressor(); + RandomkCompressor(size_t size, int k) : Compressor(size), _k(k) { + _gen.seed(_rd()); + }; + virtual ~RandomkCompressor() = default; /*! * \brief Compress function @@ -93,4 +95,4 @@ class RandomkCompressor : public BaseCompressor { } // namespace common } // namespace byteps -#endif // BYTEPS_COMPRESS_STRAT_RANDOMK_H \ No newline at end of file +#endif // BYTEPS_COMPRESSOR_STRATEGY_RANDOMK_H \ No newline at end of file diff --git a/byteps/common/compressor/strategy/topk.cc b/byteps/common/compressor/strategy/topk.cc index 3c2851be3..41b45c25c 100644 --- a/byteps/common/compressor/strategy/topk.cc +++ b/byteps/common/compressor/strategy/topk.cc @@ -13,11 +13,10 @@ // limitations under the License. // ============================================================================= -#include "topk.h" - #include -#include "../../logging.h" +#include "../compressor_registry.h" +#include "topk.h" namespace byteps { namespace common { @@ -25,7 +24,8 @@ namespace compressor { namespace { CompressorRegistry::Register reg( "topk_compressor", - [](const kwargs_t& kwargs) -> std::unique_ptr { + [](const kwargs_t& kwargs, size_t size, + int dtype) -> std::unique_ptr { auto iter = kwargs.find("compressor_k"); if (iter == kwargs.end()) { BPS_LOG(WARNING) << "Topk Compressor needs parameter \"compressor_k\""; @@ -34,14 +34,10 @@ CompressorRegistry::Register reg( int k = std::stoi(iter->second); BPS_LOG(DEBUG) << "Register Topk Compressor " << "k=" << k; - return std::unique_ptr(new TopkCompressor(k)); + return std::unique_ptr(new TopkCompressor(size, k)); }); } -TopkCompressor::TopkCompressor(int k) : _k(k){}; - -TopkCompressor::~TopkCompressor() = default; - template size_t TopkCompressor::PackingImpl(index_t* dst, const scalar_t* src, size_t len) { @@ -158,12 +154,11 @@ void TopkCompressor::Unpacking(void* dst, const void* src, size_t size, } void TopkCompressor::Decompress(tensor_t compressed, tensor_t& decompressed) { - BPS_CHECK_GT(decompressed.size, 0); #ifdef BYTEPS_BUILDING_SERVER if (decompressed.data == nullptr) decompressed.data = _buf.get(); #endif - Unpacking(decompressed.data, compressed.data, compressed.size, - decompressed.size, compressed.dtype); + Unpacking(decompressed.data, compressed.data, compressed.size, _size, + compressed.dtype); } template diff --git a/byteps/common/compressor/strategy/topk.h b/byteps/common/compressor/strategy/topk.h index b0e52731e..1874d9cb4 100644 --- a/byteps/common/compressor/strategy/topk.h +++ b/byteps/common/compressor/strategy/topk.h @@ -13,10 +13,10 @@ // limitations under the License. // ============================================================================= -#ifndef BYTEPS_COMPRESS_STRAT_TOPK_H -#define BYTEPS_COMPRESS_STRAT_TOPK_H +#ifndef BYTEPS_COMPRESSOR_STRATEGY_TOPK_H +#define BYTEPS_COMPRESSOR_STRATEGY_TOPK_H -#include "../base_compressor.h" +#include "../compressor.h" namespace byteps { namespace common { @@ -31,10 +31,10 @@ namespace compressor { * sending the most significant entries of the stochastic gradient * */ -class TopkCompressor : public BaseCompressor { +class TopkCompressor : public Compressor { public: - explicit TopkCompressor(int k); - virtual ~TopkCompressor(); + TopkCompressor(size_t size, int k) : Compressor(size), _k(k){}; + virtual ~TopkCompressor() = default; /*! * \brief Compress function @@ -92,4 +92,4 @@ class TopkCompressor : public BaseCompressor { } // namespace common } // namespace byteps -#endif // BYTEPS_COMPRESS_STRAT_MULTIBIT_H \ No newline at end of file +#endif // BYTEPS_COMPRESSOR_STRATEGY_TOPK_H \ No newline at end of file diff --git a/byteps/common/compressor/strategy/vanilla_error_feedback.cc b/byteps/common/compressor/strategy/vanilla_error_feedback.cc index 832bd73e9..0f36c654a 100644 --- a/byteps/common/compressor/strategy/vanilla_error_feedback.cc +++ b/byteps/common/compressor/strategy/vanilla_error_feedback.cc @@ -13,14 +13,13 @@ // limitations under the License. // ============================================================================= -#include "vanilla_error_feedback.h" - #include #include #include #include -#include "../../logging.h" +#include "../compressor_registry.h" +#include "vanilla_error_feedback.h" namespace byteps { namespace common { @@ -28,30 +27,23 @@ namespace compressor { namespace { CompressorRegistry::Register reg( "vanilla_ef", - [](const kwargs_t& kwargs) -> std::unique_ptr { + [](const kwargs_t& kwargs, size_t size, + int dtype) -> std::unique_ptr { // register cpr auto kwargs_clone = kwargs; kwargs_clone.erase("ef_type"); - auto compressor_ptr = CompressorRegistry::Create(kwargs_clone); - BPS_CHECK_NE(compressor_ptr, nullptr); - + auto cptr = + CompressorRegistry::Create(kwargs_clone, size, dtype); + BPS_CHECK_NE(cptr, nullptr); BPS_LOG(DEBUG) << "with Error feedback"; return std::unique_ptr( - new VanillaErrorFeedbackCompressor(std::move(compressor_ptr))); + new VanillaErrorFeedbackCompressor(size, std::move(cptr))); }); } VanillaErrorFeedbackCompressor::VanillaErrorFeedbackCompressor( - std::unique_ptr compressor_ptr) - : ErrorFeedback(std::move(compressor_ptr)) {} - -VanillaErrorFeedbackCompressor::~VanillaErrorFeedbackCompressor() { - munmap(_mm, 8); - close(_fd); -} - -void VanillaErrorFeedbackCompressor::Init(size_t aligned_size) { - ErrorFeedback::Init(aligned_size); + size_t size, std::unique_ptr cptr) + : ErrorFeedback(size, std::move(cptr)) { _fd = open("lr.s", O_RDONLY); BPS_CHECK(_fd > 0) << "open lr.s failed, errno=" << strerror(errno); void* ptr = mmap(0, 8, PROT_READ, MAP_SHARED, _fd, 0); @@ -60,6 +52,11 @@ void VanillaErrorFeedbackCompressor::Init(size_t aligned_size) { _pre_lr = _cur_lr = *reinterpret_cast(_mm); } +VanillaErrorFeedbackCompressor::~VanillaErrorFeedbackCompressor() { + munmap(_mm, 8); + close(_fd); +} + void VanillaErrorFeedbackCompressor::UpdateGradient(tensor_t grad) { _cur_lr = *reinterpret_cast(_mm); this->_cpu_reducer->sum(grad.data, _error.get(), grad.size, diff --git a/byteps/common/compressor/strategy/vanilla_error_feedback.h b/byteps/common/compressor/strategy/vanilla_error_feedback.h index 7863e8b86..e5247f053 100644 --- a/byteps/common/compressor/strategy/vanilla_error_feedback.h +++ b/byteps/common/compressor/strategy/vanilla_error_feedback.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef BYTEPS_COMPRESS_VANILLA_EF_H -#define BYTEPS_COMPRESS_VANILLA_EF_H +#ifndef BYTEPS_COMPRESSOR_STRATEGY_VANILLA_ERROR_FEEDBACK_H +#define BYTEPS_COMPRESSOR_STRATEGY_VANILLA_ERROR_FEEDBACK_H #include "../error_feedback.h" @@ -23,16 +23,14 @@ namespace common { namespace compressor { /*! - * \brief TODO + * \brief VanillaErrorFeedbackCompressor */ class VanillaErrorFeedbackCompressor : public ErrorFeedback { public: - explicit VanillaErrorFeedbackCompressor( - std::unique_ptr compressor_ptr); + VanillaErrorFeedbackCompressor(size_t size, + std::unique_ptr cptr); virtual ~VanillaErrorFeedbackCompressor(); - virtual void Init(size_t aligned_size); - protected: void UpdateGradient(tensor_t grad) override; @@ -45,4 +43,4 @@ class VanillaErrorFeedbackCompressor : public ErrorFeedback { } // namespace common } // namespace byteps -#endif // BYTEPS_COMPRESS_VANILLA_EF_H \ No newline at end of file +#endif // BYTEPS_COMPRESSOR_STRATEGY_VANILLA_ERROR_FEEDBACK_H \ No newline at end of file diff --git a/byteps/common/compressor/utils.h b/byteps/common/compressor/utils.h new file mode 100644 index 000000000..fe26f6cd7 --- /dev/null +++ b/byteps/common/compressor/utils.h @@ -0,0 +1,56 @@ +// Copyright 2019 Amazon Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef BYTEPS_COMPRESSOR_UTILS_H +#define BYTEPS_COMPRESSOR_UTILS_H + +#include +#include + +#include "common.h" + +namespace byteps { +namespace common { +namespace compressor { + +inline std::string Serialize(const kwargs_t& kwargs) { + std::ostringstream os; + os << kwargs.size(); + for (auto const& kwarg : kwargs) { + os << " " << kwarg.first << " " << kwarg.second; + } + return os.str(); +} + +inline kwargs_t Deserialize(const std::string& content) { + kwargs_t kwargs; + std::istringstream is(content); + size_t size = 0; + is >> size; + for (size_t i = 0; i < size; ++i) { + kwargs_t::key_type key; + kwargs_t::mapped_type val; + is >> key >> val; + kwargs[key] = val; + } + + return kwargs; +} + +} // namespace compressor +} // namespace common +} // namespace byteps + +#endif // BYTEPS_COMPRESSOR_UTILS_H \ No newline at end of file diff --git a/byteps/common/core_loops.cc b/byteps/common/core_loops.cc index 3c45ee5ea..fe791e048 100644 --- a/byteps/common/core_loops.cc +++ b/byteps/common/core_loops.cc @@ -13,15 +13,14 @@ // limitations under the License. // ============================================================================= -#include "core_loops.h" - #include #include #include #include "common.h" -#include "compressor/base_compressor.h" +#include "compressor/compressor.h" +#include "core_loops.h" #include "global.h" #include "logging.h" @@ -628,8 +627,7 @@ bool RunDecompressLoopOnce() { auto &pskv = BytePSGlobal::EncodeDefaultKey(task->key, 0); auto len = pskv.lens[0]; int dtype = task->tensor->dtype(); - compressor::tensor_t compressed(data, len, dtype), - decompressed(data, task->len); + compressor::tensor_t compressed(data, len, dtype), decompressed{data}; task->compressor->Decompress(compressed, decompressed); BPS_LOG(DEBUG) << "PULL with gradient compression. key=" << task->key; diff --git a/byteps/common/global.cc b/byteps/common/global.cc index f6e9076d5..73ddf8444 100644 --- a/byteps/common/global.cc +++ b/byteps/common/global.cc @@ -13,14 +13,13 @@ // limitations under the License. // ============================================================================= -#include "global.h" - #include #include #include -#include "compressor/base_compressor.h" +#include "compressor/compressor.h" +#include "global.h" namespace byteps { namespace common { @@ -41,7 +40,7 @@ bool BytePSGlobal::_is_root_device; bool BytePSGlobal::_is_distributed_job; bool BytePSGlobal::_is_cross_pcie_switch; uint32_t BytePSGlobal::_partition_bytes = 4096000; -uint32_t BytePSGlobal::_min_compress_bytes = (1<<16); +uint32_t BytePSGlobal::_min_compress_bytes = (1 << 16); int BytePSGlobal::_is_trace = 0; int BytePSGlobal::_start_step = 10; @@ -285,14 +284,13 @@ ps::KVWorker* BytePSGlobal::GetOrInitPS() { // we reuse _init_mutex, because BytePS should have been inited std::lock_guard lock(_init_mutex); if (!_ps && IsDistributed() && - _my_role == - BytePSRole::LOCAL_ROOT) { // only the root needs networking - // init low-level ps implementation - _ps = new ps::KVWorker(0, 0); - ps::StartAsync(0, "byteps\0"); - if (BytePSGlobal::IsResuming() || !ps::Postoffice::Get()->is_recovery()) { - ps::Postoffice::Get()->Barrier( - 0, ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler); + _my_role == BytePSRole::LOCAL_ROOT) { // only the root needs networking + // init low-level ps implementation + _ps = new ps::KVWorker(0, 0); + ps::StartAsync(0, "byteps\0"); + if (BytePSGlobal::IsResuming() || !ps::Postoffice::Get()->is_recovery()) { + ps::Postoffice::Get()->Barrier( + 0, ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler); } } return _ps; @@ -414,7 +412,8 @@ BPSContext& BytePSGlobal::GetContextFromName(const std::string& name) { bool BytePSGlobal::IsTensorDeclared(const std::string& name) { std::lock_guard lock(_context_mutex); if (_name_to_cxt.find(name) == _name_to_cxt.end()) { - if (std::find(_declared_tensors.begin(), _declared_tensors.end(), name) == _declared_tensors.end()) { + if (std::find(_declared_tensors.begin(), _declared_tensors.end(), name) == + _declared_tensors.end()) { _declared_tensors.push_back(name); } _name_to_cxt[name].initialized = false; @@ -430,14 +429,15 @@ bool BytePSGlobal::IsTensorDeclared(const std::string& name) { } void BytePSGlobal::ReDeclareTensor() { - for (auto name: _declared_tensors) { + for (auto name : _declared_tensors) { BPS_LOG(DEBUG) << "Redeclare tensor " << name; BytePSGlobal::IsTensorDeclared(name); } } -void BytePSGlobal::RegisterCompressor(const std::string& name, - std::unordered_map& kwargs) { +void BytePSGlobal::RegisterCompressor( + const std::string& name, + std::unordered_map& kwargs) { std::lock_guard lock(_context_mutex); BPS_CHECK(_name_to_cxt.find(name) != _name_to_cxt.end()) << name << " is not initialized"; diff --git a/byteps/common/operations.cc b/byteps/common/operations.cc index 569e09cf2..3b280cca1 100644 --- a/byteps/common/operations.cc +++ b/byteps/common/operations.cc @@ -13,8 +13,6 @@ // limitations under the License. // ============================================================================= -#include "operations.h" - #include #include @@ -22,10 +20,13 @@ #include #include -#include "compressor/base_compressor.h" +#include "compressor/compressor.h" +#include "compressor/compressor_registry.h" +#include "compressor/utils.h" #include "core_loops.h" #include "global.h" #include "logging.h" +#include "operations.h" namespace byteps { namespace common { @@ -94,8 +95,10 @@ void byteps_shutdown() { void byteps_resume(int num_workers, int num_servers) { // set ps, worker numbers - BPS_LOG(DEBUG) << "Resume worker number: " << num_workers << "DMLC_NUM_WORKER: " << getenv("DMLC_NUM_WORKER"); - BPS_LOG(DEBUG) << "Resume server number: " << num_workers << "DMLC_NUM_SERVER: " << getenv("DMLC_NUM_SERVER"); + BPS_LOG(DEBUG) << "Resume worker number: " << num_workers + << "DMLC_NUM_WORKER: " << getenv("DMLC_NUM_WORKER"); + BPS_LOG(DEBUG) << "Resume server number: " << num_workers + << "DMLC_NUM_SERVER: " << getenv("DMLC_NUM_SERVER"); BPS_LOG(DEBUG) << "Start resuming BytePS"; BytePSGlobal::SetResumingFlag(true); @@ -188,9 +191,9 @@ Status EnqueueTensor(BPSContext &context, std::shared_ptr input, // add queue if (BytePSGlobal::IsRootDevice() && !context.compressor_list.empty()) { auto it = std::find(queue_list->begin(), queue_list->end(), PUSH); - it = queue_list->insert(it, COMPRESS); // before PUSH + it = queue_list->insert(it, COMPRESS); // before PUSH it = std::find(queue_list->begin(), queue_list->end(), PULL); - queue_list->insert(it+1, DECOMPRESS); // after PULL + queue_list->insert(it + 1, DECOMPRESS); // after PULL } std::shared_ptr e(new TensorTableEntry); @@ -205,11 +208,14 @@ Status EnqueueTensor(BPSContext &context, std::shared_ptr input, e->callback = callback; if (device == CPU_DEVICE_ID) { - cudaError_t err = cudaHostRegister(const_cast(input->data()), input->size(), cudaHostRegisterMapped); + cudaError_t err = cudaHostRegister(const_cast(input->data()), + input->size(), cudaHostRegisterMapped); if (err == cudaSuccess) { - BPS_LOG(DEBUG) << name << " cpu address has changed, so it is pinned again."; + BPS_LOG(DEBUG) << name + << " cpu address has changed, so it is pinned again."; } - CUDA_CALL(cudaHostGetDevicePointer(&(context.gpu_ptr), const_cast(input->data()), 0)); + CUDA_CALL(cudaHostGetDevicePointer(&(context.gpu_ptr), + const_cast(input->data()), 0)); } e->cpubuff = context.cpubuff; @@ -319,7 +325,7 @@ void InitTensor(BPSContext &context, size_t size, int dtype, void *cpubuff) { BPS_LOG(DEBUG) << name << " is already on cpu, len=" << size; cudaError_t e = cudaHostRegister(cpubuff, size, cudaHostRegisterMapped); if (e != cudaSuccess) { - BPS_LOG(INFO) << cudaGetErrorString(e) + BPS_LOG(INFO) << cudaGetErrorString(e) << " (You may ignore this if your program continues)"; } CUDA_CALL(cudaHostGetDevicePointer(&(context.gpu_ptr), cpubuff, 0)); @@ -367,8 +373,7 @@ void InitTensor(BPSContext &context, size_t size, int dtype, void *cpubuff) { // register if (!context.kwargs.empty()) { auto compressor_ptr = - compressor::CompressorRegistry::Create(context.kwargs); - compressor_ptr->Init(Align(len, dtype)); + compressor::CompressorRegistry::Create(context.kwargs, Align(len, dtype), dtype); context.compressor_list.push_back(std::move(compressor_ptr)); } } diff --git a/byteps/server/server.cc b/byteps/server/server.cc index d2f0c31f2..45934a92e 100644 --- a/byteps/server/server.cc +++ b/byteps/server/server.cc @@ -14,7 +14,7 @@ // ============================================================================= #include "server.h" - +#include "../common/compressor/utils.h" #include "queue.h" namespace byteps { @@ -101,7 +101,7 @@ void BytePSServerEngineThread(int i) { CHECK_LE(compressed_len, msg.len); common::compressor::tensor_t compressed( reinterpret_cast(msg.src), compressed_len, msg.type.dtype), - decompressed(nullptr, msg.len); + decompressed; iter->second->Decompress(compressed, decompressed); msg.src = decompressed.data; } @@ -227,11 +227,12 @@ void BytePSHandler(const ps::KVMeta& req_meta, std::string content{reinterpret_cast(req_data.vals.data()), static_cast(req_data.lens[0])}; auto kwargs = byteps::common::compressor::Deserialize(content); + auto stored = GetStore(key); + size_t aligned_size = byteps::common::Align(stored->len, stored->dtype); auto compressor_ptr = - byteps::common::compressor::CompressorRegistry::Create(kwargs); + byteps::common::compressor::CompressorRegistry::Create( + kwargs, aligned_size, stored->dtype); CHECK_NE(compressor_ptr, nullptr); - auto stored = GetStore(key); - compressor_ptr->Init(byteps::common::Align(stored->len, stored->dtype)); compressor_map_[key] = std::move(compressor_ptr); if (log_key_info_) { LOG(INFO) << "register compressor for key=" << key; diff --git a/byteps/server/server.h b/byteps/server/server.h index e6d0e8298..f24412d90 100644 --- a/byteps/server/server.h +++ b/byteps/server/server.h @@ -23,7 +23,8 @@ #include #include "ps/ps.h" #include "../common/cpu_reducer.h" -#include "../common/compressor/base_compressor.h" +#include "../common/compressor/compressor.h" +#include "../common/compressor/compressor_registry.h" namespace byteps { namespace server { @@ -108,7 +109,7 @@ std::vector > pull_cnt_; // byteps handler std::mutex handle_mu_; std::unordered_map update_buf_; -std::unordered_map> compressor_map_; +std::unordered_map> compressor_map_; // address map std::mutex store_mu_; diff --git a/setup.py b/setup.py index b59058255..55a59c463 100644 --- a/setup.py +++ b/setup.py @@ -250,7 +250,7 @@ def get_common_options(build_ext): 'byteps/common/shared_memory.cc', 'byteps/common/nccl_manager.cc', 'byteps/common/cpu_reducer.cc'] + [ - 'byteps/common/compressor/base_compressor.cc', + 'byteps/common/compressor/compressor_registry.cc', 'byteps/common/compressor/error_feedback.cc', 'byteps/common/compressor/momentum.cc', 'byteps/common/compressor/strategy/multibit.cc', @@ -302,8 +302,8 @@ def build_server(build_ext, options): server_lib.sources = ['byteps/server/server.cc', 'byteps/common/cpu_reducer.cc', 'byteps/common/logging.cc', - 'byteps/common/common.cc']+ [ - 'byteps/common/compressor/base_compressor.cc', + 'byteps/common/common.cc'] + [ + 'byteps/common/compressor/compressor_registry.cc', 'byteps/common/compressor/error_feedback.cc', 'byteps/common/compressor/strategy/multibit.cc', 'byteps/common/compressor/strategy/onebit.cc',