From 29c6bcbf31a2200e512e453558636db3b13a881f Mon Sep 17 00:00:00 2001 From: zhaocaibei123 <48509226+zhaocaibei123@users.noreply.github.com> Date: Mon, 1 Nov 2021 15:27:27 +0800 Subject: [PATCH] memory sparse table & brpc communication upgrade dependency (#36734) --- paddle/fluid/distributed/CMakeLists.txt | 1 + .../fluid/distributed/common/CMakeLists.txt | 4 + .../fluid/distributed/common/afs_warpper.cc | 89 ++++++++++ paddle/fluid/distributed/common/afs_warpper.h | 156 ++++++++++++++++++ paddle/fluid/distributed/common/cost_timer.h | 93 +++++++++++ paddle/fluid/distributed/common/utils.h | 15 ++ paddle/fluid/distributed/service/env.h | 7 +- paddle/fluid/distributed/service/ps_client.h | 62 ++++++- paddle/fluid/distributed/table/accessor.h | 9 +- .../fluid/distributed/table/depends/dense.h | 154 +++++++++++++++++ .../framework/distributed_strategy.proto | 66 ++++++++ 11 files changed, 640 insertions(+), 16 deletions(-) create mode 100644 paddle/fluid/distributed/common/CMakeLists.txt create mode 100644 paddle/fluid/distributed/common/afs_warpper.cc create mode 100644 paddle/fluid/distributed/common/afs_warpper.h create mode 100644 paddle/fluid/distributed/common/cost_timer.h diff --git a/paddle/fluid/distributed/CMakeLists.txt b/paddle/fluid/distributed/CMakeLists.txt index 905347d031b35..17e96243878bc 100644 --- a/paddle/fluid/distributed/CMakeLists.txt +++ b/paddle/fluid/distributed/CMakeLists.txt @@ -11,6 +11,7 @@ if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") endif() +add_subdirectory(common) add_subdirectory(service) add_subdirectory(table) add_subdirectory(test) diff --git a/paddle/fluid/distributed/common/CMakeLists.txt b/paddle/fluid/distributed/common/CMakeLists.txt new file mode 100644 index 0000000000000..eab6165ca689e --- /dev/null +++ b/paddle/fluid/distributed/common/CMakeLists.txt @@ -0,0 +1,4 @@ + +cc_library(afs_wrapper SRCS afs_warpper.cc DEPS fs ps_framework_proto) + +#set_property(GLOBAL PROPERTY COMMON_DEPS afs_warpper) diff --git a/paddle/fluid/distributed/common/afs_warpper.cc b/paddle/fluid/distributed/common/afs_warpper.cc new file mode 100644 index 0000000000000..d539ec6080469 --- /dev/null +++ b/paddle/fluid/distributed/common/afs_warpper.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/common/afs_warpper.h" +#include "paddle/fluid/framework/io/fs.h" + +namespace paddle { +namespace distributed { +// AfsClient impl +int AfsClient::initialize(const FsClientParameter& fs_client_param) { + // temporarily implemented with hdfs-client + return initialize(fs_client_param.hadoop_bin(), fs_client_param.uri(), + fs_client_param.user(), fs_client_param.passwd(), + fs_client_param.buffer_size()); +} +int AfsClient::initialize(const std::string& hadoop_bin, const std::string& uri, + const std::string& user, const std::string& passwd, + int buffer_size_param) { + return initialize(hadoop_bin, uri, paddle::string::format_string( + "%s,%s", user.c_str(), passwd.c_str()), + buffer_size_param); +} +int AfsClient::initialize(const std::string& hadoop_bin, const std::string& uri, + const std::string& ugi, int buffer_size_param) { + // temporarily implemented with hdfs-client + size_t buffer_size = 1L << 25; // 32MB + if (buffer_size_param > static_cast(buffer_size)) { + buffer_size = buffer_size_param; + } + paddle::framework::hdfs_set_buffer_size(buffer_size); + paddle::framework::hdfs_set_command(paddle::string::format_string( + "2>>./hdfs_err.log %s fs -Dfs.default.name=%s -Dhadoop.job.ugi=%s " + "-Ddfs.client.block.write.retries=15 -Ddfs.rpc.timeout=300000", + hadoop_bin.c_str(), uri.c_str(), ugi.c_str())); + return 0; +} + +// open file in 'w' or 'r' +std::shared_ptr AfsClient::open_r(const FsChannelConfig& config, + uint32_t buffer_size, + int* err_no) { + std::shared_ptr channel = + std::make_shared(buffer_size); + std::shared_ptr fp = + paddle::framework::fs_open_read(config.path, err_no, config.deconverter); + channel->open(fp, config); + return channel; +} +std::shared_ptr AfsClient::open_w(const FsChannelConfig& config, + uint32_t buffer_size, + int* err_no) { + std::shared_ptr channel = + std::make_shared(buffer_size); + std::shared_ptr fp = + paddle::framework::fs_open_write(config.path, err_no, config.converter); + channel->open(fp, config); + return channel; +} + +// remove file in path, path maybe a reg, such as 'part-000-*' +void AfsClient::remove(const std::string& path) { + return paddle::framework::fs_remove(path); +} +void AfsClient::remove_dir(const std::string& dir) { + return paddle::framework::fs_remove(dir); +} + +// list files in path, path maybe a dir with reg +std::vector AfsClient::list(const std::string& path) { + return paddle::framework::fs_list(path); +} + +// exist or not +bool AfsClient::exist(const std::string& dir) { + return paddle::framework::fs_exists(dir); +} +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/common/afs_warpper.h b/paddle/fluid/distributed/common/afs_warpper.h new file mode 100644 index 0000000000000..d10668046c0a7 --- /dev/null +++ b/paddle/fluid/distributed/common/afs_warpper.h @@ -0,0 +1,156 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/fluid/distributed/ps.pb.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace distributed { +struct FsDataConverter { + std::string converter; + std::string deconverter; +}; + +struct FsChannelConfig { + std::string path; // path of file + std::string converter; // data converter + std::string deconverter; +}; + +class FsReadChannel { + public: + FsReadChannel() : _buffer_size(0) {} + explicit FsReadChannel(uint32_t buffer_size) : _buffer_size(buffer_size) {} + virtual ~FsReadChannel() {} + FsReadChannel(FsReadChannel&&) = delete; + FsReadChannel(const FsReadChannel&) = delete; + int open(std::shared_ptr fp, const FsChannelConfig& config) { + _file = fp; + return 0; + } + inline int close() { + _file.reset(); + return 0; + } + + inline uint32_t read_line(std::string& line_data) { // NOLINT + line_data.clear(); + char buffer = '\0'; + size_t read_count = 0; + while (1 == fread(&buffer, 1, 1, _file.get()) && buffer != '\n') { + ++read_count; + line_data.append(&buffer, 1); + } + if (read_count == 0 && buffer != '\n') { + return -1; + } + return 0; + } + + private: + uint32_t _buffer_size; + FsChannelConfig _config; + std::shared_ptr _file; +}; +class FsWriteChannel { + public: + FsWriteChannel() : _buffer_size(0) {} + explicit FsWriteChannel(uint32_t buffer_size) : _buffer_size(buffer_size) {} + virtual ~FsWriteChannel() {} + FsWriteChannel(FsWriteChannel&&) = delete; + FsWriteChannel(const FsWriteChannel&) = delete; + + int open(std::shared_ptr fp, const FsChannelConfig& config) { + _file = fp; + + // the buffer has set in fs.cc + // if (_buffer_size != 0) { + // _buffer = std::shared_ptr(new char[_buffer_size]); + + // CHECK(0 == setvbuf(&*_file, _buffer.get(), _IOFBF, _buffer_size)); + //} + return 0; + } + + inline void flush() { return; } + + inline int close() { + flush(); + _file.reset(); + return 0; + } + + inline uint32_t write_line(const char* data, uint32_t size) { + size_t write_count = fwrite_unlocked(data, 1, size, _file.get()); + if (write_count != size) { + return -1; + } + write_count = fwrite_unlocked("\n", 1, 1, _file.get()); + if (write_count != 1) { + return -1; + } + return 0; + } + inline uint32_t write_line(const std::string& data) { + return write_line(data.c_str(), data.size()); + } + + private: + uint32_t _buffer_size; + FsChannelConfig _config; + std::shared_ptr _file; + std::shared_ptr _buffer; +}; + +class AfsClient { + public: + AfsClient() {} + virtual ~AfsClient() {} + AfsClient(AfsClient&&) = delete; + AfsClient(const AfsClient&) = delete; + + int initialize(const FsClientParameter& fs_client_param); + int initialize(const std::string& hadoop_bin, const std::string& uri, + const std::string& user, const std::string& passwd, + int buffer_size_param = (1L << 25)); + int initialize(const std::string& hadoop_bin, const std::string& uri, + const std::string& ugi, int buffer_size_param = (1L << 25)); + + // open file in 'w' or 'r' + std::shared_ptr open_r(const FsChannelConfig& config, + uint32_t buffer_size = 0, + int* err_no = nullptr); + std::shared_ptr open_w(const FsChannelConfig& config, + uint32_t buffer_size = 0, + int* err_no = nullptr); + + // remove file in path, path maybe a reg, such as 'part-000-*' + void remove(const std::string& path); + void remove_dir(const std::string& dir); + + // list files in path, path maybe a dir with reg + std::vector list(const std::string& path); + + // exist or not + bool exist(const std::string& dir); +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/common/cost_timer.h b/paddle/fluid/distributed/common/cost_timer.h new file mode 100644 index 0000000000000..d7bf4cc11e0a3 --- /dev/null +++ b/paddle/fluid/distributed/common/cost_timer.h @@ -0,0 +1,93 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "butil/time.h" +#include "bvar/latency_recorder.h" +#include "glog/logging.h" + +namespace paddle { +namespace distributed { + +struct CostProfilerNode { + std::shared_ptr recorder; +}; + +class CostProfiler { + public: + ~CostProfiler() {} + static CostProfiler& instance() { + static CostProfiler profiler; + return profiler; + } + + void register_profiler(const std::string& label) { + if (_cost_profiler_map.find(label) != _cost_profiler_map.end()) { + return; + } + auto profiler_node = std::make_shared(); + profiler_node->recorder.reset( + new bvar::LatencyRecorder("cost_profiler", label)); + _cost_profiler_map[label] = profiler_node; + } + + CostProfilerNode* profiler(const std::string& label) { + auto itr = _cost_profiler_map.find(label); + if (itr != _cost_profiler_map.end()) { + return itr->second.get(); + } + return NULL; + } + + private: + CostProfiler() {} + std::unordered_map> + _cost_profiler_map; +}; + +class CostTimer { + public: + explicit CostTimer(const std::string& label) { + _label = label; + auto& profiler = CostProfiler::instance(); + _profiler_node = profiler.profiler(label); + // 如果不在profiler中,则使用log输出耗时信息 + _is_print_cost = _profiler_node == NULL; + _start_time_ms = butil::gettimeofday_ms(); + } + explicit CostTimer(CostProfilerNode& profiler_node) { // NOLINT + _is_print_cost = false; + _profiler_node = &profiler_node; + _start_time_ms = butil::gettimeofday_ms(); + } + ~CostTimer() { + if (_is_print_cost) { + LOG(INFO) << "CostTimer label:" << _label + << ", cost:" << butil::gettimeofday_ms() - _start_time_ms + << "ms"; + } else { + *(_profiler_node->recorder) << butil::gettimeofday_ms() - _start_time_ms; + } + } + + private: + std::string _label; + bool _is_print_cost; + uint64_t _start_time_ms; + CostProfilerNode* _profiler_node; +}; +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/common/utils.h b/paddle/fluid/distributed/common/utils.h index 2305001ad6f8f..fb2189b8f5a1b 100644 --- a/paddle/fluid/distributed/common/utils.h +++ b/paddle/fluid/distributed/common/utils.h @@ -52,6 +52,20 @@ inline void ADD(int n, const T* x, const T y, T* z) { } } +template +inline void DIV(int n, const T x, const T* y, T* z) { + for (int i = 0; i < n; ++i) { + z[i] = x / y[i]; + } +} + +template +inline void ELE_MUL(int n, const T* x, const T* y, T* z) { + for (int i = 0; i < n; ++i) { + z[i] = x[i] * y[i]; + } +} + static bool StartWith(const std::string& str, const std::string& substr) { return str.find(substr) == 0; } @@ -91,5 +105,6 @@ inline double GetCurrentUS() { gettimeofday(&time, NULL); return 1e+6 * time.tv_sec + time.tv_usec; } + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/service/env.h b/paddle/fluid/distributed/service/env.h index ca395a776afd4..0cc57229b7a82 100644 --- a/paddle/fluid/distributed/service/env.h +++ b/paddle/fluid/distributed/service/env.h @@ -144,8 +144,8 @@ class PSEnvironment { virtual std::vector get_client_info() { std::vector client_info; - for (auto &i : _ps_client_sign_set) { - client_info.push_back(i); + for (auto &i : _ps_client_list) { + client_info.push_back(i.serialize_to_uint64()); } return client_info; } @@ -250,7 +250,7 @@ class PaddlePSEnvironment : public PSEnvironment { return 0; } - virtual int32_t set_ps_clients(std::vector *host_sign_list, + virtual int32_t set_ps_clients(const std::vector *host_sign_list, int node_num) { _ps_client_list.clear(); _ps_client_sign_set.clear(); @@ -265,6 +265,7 @@ class PaddlePSEnvironment : public PSEnvironment { std::sort( _ps_client_list.begin(), _ps_client_list.end(), [](const PSHost &h1, const PSHost &h2) { return h1.rank < h2.rank; }); + VLOG(1) << "env.set_ps_clients done\n"; return 0; } diff --git a/paddle/fluid/distributed/service/ps_client.h b/paddle/fluid/distributed/service/ps_client.h index 74a1e0dde71fc..3be83436cec34 100644 --- a/paddle/fluid/distributed/service/ps_client.h +++ b/paddle/fluid/distributed/service/ps_client.h @@ -20,11 +20,13 @@ #include #include #include +#include "paddle/fluid/distributed/common/cost_timer.h" #include "paddle/fluid/distributed/ps.pb.h" #include "paddle/fluid/distributed/service/env.h" #include "paddle/fluid/distributed/service/sendrecv.pb.h" #include "paddle/fluid/distributed/table/accessor.h" #include "paddle/fluid/distributed/table/graph/graph_node.h" +#include "paddle/fluid/platform/timer.h" namespace paddle { namespace distributed { @@ -35,7 +37,7 @@ using paddle::distributed::PsResponseMessage; typedef std::function PSClientCallBack; class PSClientClosure : public google::protobuf::Closure { public: - PSClientClosure(PSClientCallBack callback) : _callback(callback) {} + explicit PSClientClosure(PSClientCallBack callback) : _callback(callback) {} virtual ~PSClientClosure() {} virtual void set_promise_value(int value) { for (auto &promise : _promises) { @@ -43,12 +45,17 @@ class PSClientClosure : public google::protobuf::Closure { } } - void add_promise(std::shared_ptr> &promise) { + void add_promise(std::shared_ptr> &promise) { // NOLINT _promises.push_back(promise); } + void add_timer(std::shared_ptr &timer) { // NOLINT + _timers.push_back(timer); + } + protected: PSClientCallBack _callback; + std::vector> _timers; std::vector>> _promises; }; @@ -59,11 +66,11 @@ class PSClient { PSClient(PSClient &&) = delete; PSClient(const PSClient &) = delete; - virtual int32_t configure( + virtual int32_t configure( // NOLINT const PSParameter &config, const std::map> ®ions, - PSEnvironment &_env, size_t client_id) final; + PSEnvironment &_env, size_t client_id) final; // NOLINT virtual int32_t create_client2client_connection( int pserver_timeout_ms, int pserver_connect_timeout_ms, @@ -86,7 +93,7 @@ class PSClient { virtual std::future save(uint32_t table_id, const std::string &epoch, const std::string &mode) = 0; - //清空table数据 + // 清空table数据 virtual std::future clear() = 0; virtual std::future clear(uint32_t table_id) = 0; @@ -98,7 +105,7 @@ class PSClient { // server将参数区块中配置的某一维提取返回 // 返回数据解包后填充到累计的多个buffer中 virtual std::future pull_dense(Region *regions, size_t region_num, - size_t table_id) = 0; //保留 + size_t table_id) = 0; // 保留 // firstly push dense param for parameter server // this is neccessary because dense weight initialized in trainer on cold @@ -107,6 +114,9 @@ class PSClient { size_t region_num, size_t table_id) = 0; + // virtual std::future push_dense(const Region *regions, + // size_t region_num, + // size_t table_id) = 0; // 使用keys进行pull请求,结果填充values // keys和values的个数均为num个,每个value占用select_size空间 // future结束前keys和values缓冲区不能再次使用 @@ -212,6 +222,10 @@ class PSClient { const uint64_t *keys, const float **update_values, size_t num, void *done) = 0; + // virtual std::future push_sparse(size_t table_id, + // const uint64_t *keys, + // const float **update_values, + // size_t num) = 0; protected: virtual int32_t initialize() = 0; @@ -222,8 +236,42 @@ class PSClient { PSEnvironment *_env; std::unordered_map> _table_accessors; std::unordered_map - _msg_handler_map; //处理client2client消息 + _msg_handler_map; // 处理client2client消息 +}; + +template +class AsyncRequestTask { + public: + AsyncRequestTask() : _promise(std::make_shared>()) {} + AsyncRequestTask(T &data, size_t table_id, std::shared_ptr &timer) + : _table_id(table_id), + _timer(timer), + _promise(std::make_shared>()) { + _data = std::move(data); + } + + AsyncRequestTask(AsyncRequestTask &data) // NOLINT + : _table_id(data.table_id()), + _timer(data.timer()), + _promise(data.promise()) { + _data = std::move(data.data()); + } + + ~AsyncRequestTask() {} + + inline T &data() { return _data; } + inline size_t table_id() { return _table_id; } + inline std::shared_ptr &timer() { return _timer; } + inline std::future get_future() { return _promise->get_future(); } + inline std::shared_ptr> &promise() { return _promise; } + + private: + T _data; + size_t _table_id; + std::shared_ptr _timer; + std::shared_ptr> _promise; }; + REGISTER_PSCORE_REGISTERER(PSClient); class PSClientFactory { diff --git a/paddle/fluid/distributed/table/accessor.h b/paddle/fluid/distributed/table/accessor.h index 7cc92ce98ba69..8929e8cd64e84 100644 --- a/paddle/fluid/distributed/table/accessor.h +++ b/paddle/fluid/distributed/table/accessor.h @@ -17,15 +17,12 @@ #include #include #include +#include "paddle/fluid/distributed/common/afs_warpper.h" #include "paddle/fluid/distributed/common/registerer.h" #include "paddle/fluid/distributed/ps.pb.h" namespace paddle { namespace distributed { -struct FsDataConverter { - std::string converter; - std::string deconverter; -}; struct Region { Region() : data(NULL), size(0) {} @@ -50,8 +47,8 @@ struct DataConverter { class ValueAccessor { public: - explicit ValueAccessor(){}; - virtual ~ValueAccessor(){}; + ValueAccessor() {} + virtual ~ValueAccessor() {} virtual int configure(const TableAccessorParameter& parameter) { _config = parameter; diff --git a/paddle/fluid/distributed/table/depends/dense.h b/paddle/fluid/distributed/table/depends/dense.h index 8079003d1bf8f..d6b9ba0754550 100644 --- a/paddle/fluid/distributed/table/depends/dense.h +++ b/paddle/fluid/distributed/table/depends/dense.h @@ -183,5 +183,159 @@ class DAdam : public DenseOptimizer { float epsilon; }; +// adam optimizer for dense tensor +class DAdamD2Sum : public DenseOptimizer { + public: + explicit DAdamD2Sum(const CommonAccessorParameter& accessor, + std::vector>* values) { + lr_hardcode = 5e-6; + auto& names = accessor.params(); + for (int x = 0; x < static_cast(names.size()); ++x) { + if (names[x] == "LearningRate") { + learning_rate = (*values)[x].data(); + } + if (names[x] == "Param") { + param = (*values)[x].data(); + } + if (names[x] == "Moment") { + mom_velocity = (*values)[x].data(); + } + if (names[x] == "G2Sum") { + ada_g2sum = (*values)[x].data(); + } + if (names[x] == "D2Sum") { + ada_d2sum = (*values)[x].data(); + } + if (names[x] == "MomentDecayRate") { + mom_decay_rate = (*values)[x].data(); + } + if (names[x] == "AdaDecayRate") { + ada_decay_rate = (*values)[x].data(); + } + if (names[x] == "AdaEpsilon") { + ada_epsilon = (*values)[x].data(); + } + } + } + + void update(const float* update_values, size_t num, int begin, + int end) override { + auto update_numel = end - begin; + + /* + // for debug + std::cout << "before update:\n"; + for (int i = 0; i < 3; ++ i) { + std::cout << "param: " << i << " " << *(param+begin+i) << + "grad: " << *(update_values+begin+i) << "\n"; + }*/ + + std::vector grad, grad2, scale; + grad.resize(update_numel); + grad2.resize(update_numel); + scale.resize(update_numel); + + auto blas = GetBlas(); + // copy grad + blas.VCOPY(update_numel, update_values + begin, grad.data()); + blas.VCOPY(update_numel, update_values + begin, grad2.data()); + + /* + for (int i = 0; i < end-begin; ++ i) { + std::cout << "copy grad: " << i << " " << *(grad.data()+begin+i) << + "copy grad2: " << *(grad2.data()+begin+i) << "\n"; + } + for (int i = 0; i < 3; ++ i) { + std::cout << "d2sum before: " << i << " " << *(ada_d2sum+begin+i) << "\n"; + }*/ + + // d2sum + blas.SCAL(update_numel, ada_decay_rate[0], ada_d2sum + begin); + ADD(update_numel, ada_d2sum + begin, 1, ada_d2sum + begin); + + /* + for (int i = 0; i < end-begin; ++ i) { + std::cout << "d2sum update: " << i << " " << *(ada_d2sum+begin+i) << "\n"; + } + for (int i = 0; i < 3; ++ i) { + std::cout << "g2sum before: " << i << " " << *(ada_g2sum+begin+i) << "\n"; + }*/ + + // g2sum + blas.SCAL(update_numel, ada_decay_rate[0], ada_g2sum + begin); + blas.VSQUARE(update_numel, grad2.data(), grad2.data()); + blas.VADD(update_numel, ada_g2sum + begin, grad2.data(), ada_g2sum + begin); + + /* + for (int i = 0; i < end-begin; ++ i) { + std::cout << "g2sum update: " << i << " " << *(ada_g2sum+begin+i) << "\n"; + } + for (int i = 0; i < 3; ++ i) { + std::cout << "mom before: " << i << " " << *(mom_velocity+begin+i) << + "\n"; + }*/ + + // mom + blas.SCAL(update_numel, mom_decay_rate[0], mom_velocity + begin); + blas.SCAL(update_numel, 1 - mom_decay_rate[0], grad.data()); + blas.VADD(update_numel, mom_velocity + begin, grad.data(), + mom_velocity + begin); + + /* + for (int i = 0; i < end-begin; ++ i) { + std::cout << "mom update: " << i << " " << *(mom_velocity+begin+i) << + "\n"; + } + for (int i = 0; i < 3; ++ i) { + std::cout << "scale before: " << i << " " << *(scale.data()+begin+i) << + "\n"; + }*/ + + // scale + float* scale_ = scale.data(); + blas.VDIV(update_numel, ada_g2sum + begin, ada_d2sum + begin, scale_); + ADD(update_numel, scale_, ada_epsilon[0], scale_); + DIV(update_numel, 1 + ada_epsilon[0], scale_, scale_); + SQRT(update_numel, scale_, scale_); + + /* + for (int i = 0; i < 3; ++ i) { + std::cout << "scale update: " << i << " " << *(scale.data()+begin+i) << + "\n"; + }*/ + + blas.SCAL(update_numel, learning_rate[0], scale_); + + // TODO(zhaocaibei123): check if there exists elementwise_multiply in blas + // TODO(zhaocaibei123): blas.VMUL + ELE_MUL(update_numel, scale_, mom_velocity + begin, scale_); + + /* + for (int i = 0; i < 3; ++ i) { + std::cout << "scale update2: " << i << " " << *(scale.data()+begin+i) << + "\n"; + }*/ + + blas.VSUB(update_numel, param + begin, scale_, param + begin); + + /* + for (int i = 0; i < end-begin; ++ i) { + std::cout << "param update " << i << " " << *(param+begin+i) << "\n"; + }*/ + } + + float* learning_rate; + float lr_hardcode; + + float* param; + float* mom_velocity; + float* ada_g2sum; + float* ada_d2sum; + + float* mom_decay_rate; + float* ada_decay_rate; + float* ada_epsilon; +}; + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 28eebeb4d9bdc..bd84471e63ef7 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -173,6 +173,68 @@ message TensorParallelConfig { optional int32 tensor_init_seed = 2 [ default = -1 ]; } +enum TableType { + PS_SPARSE_TABLE = 0; + PS_DENSE_TABLE = 1; +} + +message TableParameter { + optional uint64 table_id = 1; + optional string table_class = 2; + optional uint64 shard_num = 3; + optional TableType type = 4; + optional TableAccessorParameter accessor = 5; +} + +message TableAccessorParameter { + optional string accessor_class = 1; + optional SGDParameter embed_sgd_param = 2; + optional SGDParameter embedx_sgd_param = 3; + optional uint32 fea_dim = 4; // for sparse table, this means field size of one + // value; for dense table, this means total value + // num + optional uint32 embedx_dim = 5; // embedx feature size + optional uint32 embedx_threshold = 6; // embedx feature create threshold + optional CtrAccessorParameter ctr_accessor_param = 7; +} + +// TODO(guanqun): add NaiveSGD/Adam... +message SGDParameter { + optional string name = 1; + optional SGDRuleParameter adagrad = 2; +} + +message SGDRuleParameter { + optional double learning_rate = 1; + optional double initial_g2sum = 2; + optional double initial_range = 3 [ default = 0 ]; + repeated float weight_bounds = 4; +} + +message CtrAccessorParameter { + optional float nonclk_coeff = 1; // to calculate show_click_score + optional float click_coeff = 2; // to calculate show_click_score + optional float base_threshold = + 3; // show_click_score > base_threshold, this feature can be saved + optional float delta_threshold = + 4; // delta_score > delta_threshold, this feature can be saved + optional float delta_keep_days = + 5; // unseen_day < delta_keep_days, this feature can be saved + optional float show_click_decay_rate = 6; // show/click will update to + // show/click * + // show_click_decay_rate after a day + optional float delete_threshold = 7; // threshold to shrink a feasign + optional float delete_after_unseen_days = 8; + optional int32 ssd_unseenday_threshold = 9; +} + +message FsClientParameter { + optional string uri = 1; + optional string user = 2; + optional string passwd = 3; + optional string hadoop_bin = 4; +} + message DistributedStrategy { // bool options optional Mode mode = 1 [ default = COLLECTIVE ]; @@ -210,6 +272,7 @@ message DistributedStrategy { optional bool asp = 33 [ default = false ]; optional bool fuse_grad_merge = 34 [ default = false ]; optional bool semi_auto = 35 [ default = false ]; + optional bool adam_d2sum = 36 [ default = true ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; @@ -225,6 +288,9 @@ message DistributedStrategy { optional HybridConfig hybrid_configs = 112; optional TensorParallelConfig tensor_parallel_configs = 113; optional TrainerDescConfig trainer_desc_configs = 114; + optional TableParameter downpour_table_param = 115; + optional FsClientParameter fs_client_param = 116; + optional BuildStrategy build_strategy = 201; optional ExecutionStrategy execution_strategy = 202; optional GradientScaleConfig gradient_scale_configs = 203;