Skip to content

Commit

Permalink
refactor: format and rename (#24)
Browse files Browse the repository at this point in the history
* format: refactor

* format: fix typos

* format: update setup.py

* format: update format
  • Loading branch information
jasperzhong authored Jun 24, 2020
1 parent c507e14 commit 2a98d12
Show file tree
Hide file tree
Showing 29 changed files with 420 additions and 389 deletions.
6 changes: 3 additions & 3 deletions byteps/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ namespace common {
namespace compressor {
struct BPSTensor;
typedef BPSTensor tensor_t;
class BaseCompressor;
class Compressor;
class ErrorFeedback;
} // namespace compressor

Expand Down Expand Up @@ -198,7 +198,7 @@ typedef struct BytePSContext {
std::unordered_map<int, std::queue<BPSCommTime*>>>
part_comm_time;
// Compressor list
std::vector<std::shared_ptr<compressor::BaseCompressor>> compressor_list;
std::vector<std::shared_ptr<compressor::Compressor>> compressor_list;
// kwargs
std::unordered_map<std::string, std::string> kwargs;
} BPSContext;
Expand Down Expand Up @@ -257,7 +257,7 @@ struct TensorTableEntry {
// How many partitions
unsigned int total_partnum = 0;
// Compressor
std::shared_ptr<compressor::BaseCompressor> compressor;
std::shared_ptr<compressor::Compressor> compressor;
// Compressed
std::shared_ptr<compressor::tensor_t> compressed;
};
Expand Down
147 changes: 0 additions & 147 deletions byteps/common/compressor/base_compressor.h

This file was deleted.

44 changes: 44 additions & 0 deletions byteps/common/compressor/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright 2019 Amazon Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#ifndef BYTEPS_COMPRESSOR_COMMON_H
#define BYTEPS_COMPRESSOR_COMMON_H

#include <unordered_map>

namespace byteps {
namespace common {
namespace compressor {
typedef char byte_t;
/*!
* \brief Tensor type
*/
typedef struct BPSTensor {
byte_t* data;
size_t size;
int dtype;

BPSTensor() : data(nullptr), size(0), dtype(0) {}
BPSTensor(byte_t* data, size_t size=0, int dtype=0)
: data(data), size(size), dtype(dtype) {}
} tensor_t;

using kwargs_t = std::unordered_map<std::string, std::string>;

} // namespace compressor
} // namespace common
} // namespace byteps

#endif // BYTEPS_COMPRESSOR_COMMON_H
89 changes: 89 additions & 0 deletions byteps/common/compressor/compressor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright 2019 Amazon Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#ifndef BYTEPS_COMPRESSOR_COMPRESSOR_H
#define BYTEPS_COMPRESSOR_COMPRESSOR_H

#include <memory>

#include "../cpu_reducer.h"
#include "common.h"

namespace byteps {
namespace common {
namespace compressor {
/*!
* \brief Compressor interface used in BytePS core.
*/
class Compressor {
public:
Compressor(size_t size)
: _size(size),
_buf(new byte_t[size]),
_cpu_reducer(new CpuReducer(nullptr)){};
virtual ~Compressor() = default;

/*!
* \brief Compress function
*
* \param grad gradient tensor
* \param compressed compressed tensor
*/
virtual void Compress(tensor_t grad, tensor_t& compressed) = 0;

/*!
* \brief Decompress function
*
* \param compressed compressed tensor
* \param decompressed decompressed tensor
*/
virtual void Decompress(tensor_t compressed, tensor_t& decompressed) = 0;

/*!
* \brief help function for error feedback `UpdateError`
*
* \param corrected gradient corrected with error
* \param error error
* \param compressed compressed gradient
*/
virtual void FastUpdateError(tensor_t error, tensor_t corrected,
tensor_t compressed) {
BPS_LOG(FATAL) << "FastUpdateError is not implemented";
};

size_t size() const { return _size; }

protected:
/*!
* \brief buffer
*/
std::unique_ptr<byte_t[]> _buf;

/*!
* \brief original size
*/
size_t _size;

/*!
* \brief CPU reducer
*/
std::unique_ptr<CpuReducer> _cpu_reducer;
};

} // namespace compressor
} // namespace common
} // namespace byteps

#endif // BYTEPS_COMPRESSOR_COMPRESSOR_H
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
// limitations under the License.
// =============================================================================

#include "base_compressor.h"

#include "../logging.h"
#include "compressor_registry.h"

namespace byteps {
namespace common {
Expand All @@ -38,8 +36,8 @@ CompressorRegistry::ctor_t CompressorRegistry::Find(const std::string& name) {
return it->second;
}

std::unique_ptr<BaseCompressor> CompressorRegistry::Create(
const kwargs_t& kwargs) {
std::unique_ptr<Compressor> CompressorRegistry::Create(const kwargs_t& kwargs,
size_t size, int dtype) {
#ifndef BYTEPS_BUILDING_SERVER
const std::string types[] = {"momentum_type", "ef_type", "compressor_type"};
#else
Expand All @@ -51,24 +49,13 @@ std::unique_ptr<BaseCompressor> CompressorRegistry::Create(
if (iter != kwargs.end()) {
auto ctor = CompressorRegistry::Find(iter->second + "_" + type);
BPS_CHECK_NE(ctor, nullptr);
return ctor(kwargs);
return ctor(kwargs, size, dtype);
}
}

return nullptr;
}

BaseCompressor::BaseCompressor() = default;

BaseCompressor::~BaseCompressor() = default;

void BaseCompressor::Init(size_t aligned_size) {
_buf.reset(new char[aligned_size]);
_cpu_reducer.reset(new CpuReducer(nullptr));
}

void BaseCompressor::FastUpdateError(tensor_t error, tensor_t corrected,
tensor_t compressed) {}
} // namespace compressor
} // namespace common
} // namespace byteps
Loading

0 comments on commit 2a98d12

Please sign in to comment.