From 37ac0ddaef7e637990ca351fae2926660dad600f Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Tue, 26 Oct 2021 10:01:07 +0800 Subject: [PATCH 1/9] feat: Add TRT support for 3D(batch_norm_op and elementwise_add_op) (#36446) (#36702) --- paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc | 7 ++++--- paddle/fluid/inference/tensorrt/convert/elementwise_op.cc | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc index 7ea41839cb939..71a2fa68f1749 100644 --- a/paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/batch_norm_op.cc @@ -147,9 +147,10 @@ class BatchNormOpConverter : public OpConverter { X = expand_layer->getOutput(0); } - layer = TRT_ENGINE_ADD_LAYER( - engine_, Scale, *X, nvinfer1::ScaleMode::kCHANNEL, shift_weights.get(), - scale_weights.get(), power_weights.get()); + layer = TRT_ENGINE_ADD_LAYER(engine_, ScaleNd, *X, + nvinfer1::ScaleMode::kCHANNEL, + shift_weights.get(), scale_weights.get(), + power_weights.get(), dynamic_shape_offset); auto output_name = op_desc.Output("Y").front(); engine_->SetWeights(op_desc.Input("Bias").front(), diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 2f802ea8d181e..8569dd6347852 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -83,8 +83,8 @@ class ElementwiseWeightOpConverter : public OpConverter { } if (op_type_ == "add") { nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( - engine_, Scale, *X, scale_mode, shift_weights.get(), - scale_weights.get(), power_weights.get()); + engine_, ScaleNd, *X, scale_mode, shift_weights.get(), + scale_weights.get(), power_weights.get(), dynamic_shape_offset); layer = scale_layer; } else if (op_type_ == "mul") { nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( From beb920cd332217c542021999b88881959053853c Mon Sep 17 00:00:00 2001 From: xiongkun <807377414@qq.com> Date: Tue, 26 Oct 2021 10:25:13 +0800 Subject: [PATCH 2/9] [cherry-pick] Support CPU Parallel in DataParallel Interface by GLOO to speed up training (#35745) (#36605) * User specified backend (#35745) * remove tensordot --- paddle/fluid/framework/fleet/gloo_wrapper.h | 18 ++ paddle/fluid/imperative/gloo_context.cc | 115 ++++++++++- paddle/fluid/imperative/gloo_context.h | 8 + python/paddle/distributed/fleet/launch.py | 51 ++++- .../paddle/distributed/fleet/launch_utils.py | 63 +++++- python/paddle/distributed/parallel.py | 27 +-- python/paddle/distributed/spawn.py | 88 +++++++-- python/paddle/distributed/utils.py | 22 ++- .../fluid/tests/unittests/CMakeLists.txt | 18 ++ .../parallel_dygraph_gradient_check.py | 3 +- .../unittests/parallel_dygraph_se_resnext.py | 1 + .../tests/unittests/test_cpuonly_launch.sh | 42 ++++ .../tests/unittests/test_cpuonly_spawn.py | 72 +++++++ .../fluid/tests/unittests/test_dist_base.py | 179 +++++++++++++++++- .../test_parallel_dygraph_dataparallel.py | 65 +++++++ ..._parallel_dygraph_sparse_embedding_gloo.py | 59 ++++++ ...graph_sparse_embedding_over_height_gloo.py | 44 +++++ .../test_parallel_dygraph_transformer_gloo.py | 61 ++++++ ..._parallel_dygraph_unused_variables_gloo.py | 72 +++++++ .../test_spawn_and_init_parallel_env.py | 5 +- 20 files changed, 948 insertions(+), 65 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_cpuonly_launch.sh create mode 100644 python/paddle/fluid/tests/unittests/test_cpuonly_spawn.py create mode 100644 python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_gloo.py create mode 100644 python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height_gloo.py create mode 100644 python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer_gloo.py create mode 100644 python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables_gloo.py diff --git a/paddle/fluid/framework/fleet/gloo_wrapper.h b/paddle/fluid/framework/fleet/gloo_wrapper.h index e69439892ca57..3686507043191 100644 --- a/paddle/fluid/framework/fleet/gloo_wrapper.h +++ b/paddle/fluid/framework/fleet/gloo_wrapper.h @@ -218,6 +218,24 @@ class GlooWrapper { return std::move(ret); } + // TODO(xiongkun03): support all gather array of + // numbers with different length + // can use AllgathervOptions, may be work in different + // occasion. Need some survey. + template + void AllGatherVector(T* input_ptr, T* output_ptr, + size_t element_num) { // NOLINT + CHECK_EQ(is_initialized_, true); +#ifdef PADDLE_WITH_GLOO + gloo::AllgatherOptions opts(context_); + opts.setInput(input_ptr, element_num); + opts.setOutput(output_ptr, element_num * size_); + gloo::allgather(opts); +#else + LOG(WARNING) << "AllGather does nothing when WITH_GLOO=OFF"; +#endif + } + protected: bool is_initialized_ = false; #ifdef PADDLE_WITH_GLOO diff --git a/paddle/fluid/imperative/gloo_context.cc b/paddle/fluid/imperative/gloo_context.cc index d7df6ec3c1164..0d93cdf57932f 100644 --- a/paddle/fluid/imperative/gloo_context.cc +++ b/paddle/fluid/imperative/gloo_context.cc @@ -18,6 +18,7 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/split.h" +#include "paddle/fluid/string/string_helper.h" namespace paddle { namespace framework { @@ -67,8 +68,36 @@ void GLOOParallelContext::AllReduceByStream(const framework::Variable &src, framework::Variable *dst, int ring_id, bool use_calc_stream) { // AllReduce(src, dst, strategy_, ring_id, use_calc_stream); - auto src_tensor = src.Get(); - auto *dst_tensor = dst->GetMutable(); + if (src.IsType()) { + if (!dst->IsType()) { + dst->Clear(); + } + AllReduce(src.Get(), + dst->GetMutable()); + } else if (src.IsType()) { + if (&src != dst) { + if (!dst->IsType()) { + dst->Clear(); + } + AllReduce(src.Get(), + dst->GetMutable()); + } else { + // SelectedRows cannot be allreduce in-place + framework::Variable tmp_dst; + AllReduce(src.Get(), + tmp_dst.GetMutable()); + *dst = std::move(tmp_dst); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported variable type %s for imperative allreduce, only " + "LoDTensor and SelectedRows are supported.", + platform::demangle(framework::ToTypeName(src.Type())))); + } +} + +void GLOOParallelContext::AllReduce(const framework::Tensor &src_tensor, + framework::Tensor *dst_tensor) { auto gloo_wrapper = framework::GlooWrapper::GetInstance(); dst_tensor->Resize(src_tensor.dims()); switch (src_tensor.type()) { @@ -84,6 +113,88 @@ void GLOOParallelContext::AllReduceByStream(const framework::Variable &src, gloo_wrapper->Barrier(); } +#define GLOO_ALL_GATHER_CASE(type, T, gw) \ + case type: { \ + const auto *src_tensor_ptr = src_tensor.data(); \ + gw->AllGatherVector(const_cast(src_tensor_ptr), \ + reinterpret_cast(dst_tensor_ptr), \ + value_sendcount); \ + break; \ + } + +void GLOOParallelContext::AllReduce(const framework::SelectedRows &src, + framework::SelectedRows *dst) { + // auto ; + // int local_rank = strategy_.local_rank_; + int nranks = strategy_.nranks_; + VLOG(3) << "SelectedRows AllReduce start"; + const auto &src_tensor = src.value(); + const auto &place = src_tensor.place(); + auto dtype = src_tensor.type(); + // 1. Gather rows number from all workers. Here use ncclAllGather to do this, + // but we can use other ways to implement is in the future + const auto &src_rows = src.rows(); + auto gloo_wrapper = framework::GlooWrapper::GetInstance(); + size_t local_row_num = src_rows.size(); + std::vector rows_num_vector = + gloo_wrapper->AllGather(local_row_num); + const auto *cpu_rows_num_ptr = rows_num_vector.data(); + auto rows_num = std::accumulate(cpu_rows_num_ptr, cpu_rows_num_ptr + nranks, + static_cast(0)); + dst->set_height(src.height()); + VLOG(3) << "Gather rows: " << string::join_strings(rows_num_vector, ',') + << ", total rows number: " << rows_num + << ", height: " << src.height(); + auto *dst_rows = dst->mutable_rows(); + dst_rows->resize(rows_num); + auto *dst_rows_ptr = dst_rows->MutableData(place); + const int64_t *src_rows_ptr = src_rows.Data(place); + + // VLOG(3) << "Selected Rows of src:" << string::join_strings(dst_rows, ',') + + auto *dst_tensor = dst->mutable_value(); + auto dims = src_tensor.dims(); + dims[0] = rows_num; + auto feature_size = framework::product(dims) / dims[0]; + dst_tensor->Resize(dims); + if (std::all_of(cpu_rows_num_ptr, cpu_rows_num_ptr + nranks, + [&](size_t row) { return row == cpu_rows_num_ptr[0]; })) { + // During sparse communication, the number of each card is same. + // Because gloo wrapper utility class currently don't support + // broadcast, so we only deal the-same case. + VLOG(3) << "Use the gloo all reduce to sync. SRC:" << src_tensor; + // framework::SerializeToStream(VLOG(4), src); + VLOG(3) << "allgather replaces broadcast to speed up in sparse allreduce"; + auto value_sendcount = cpu_rows_num_ptr[0] * feature_size; + auto *dst_tensor_ptr = dst_tensor->mutable_data(place, dtype); + + gloo_wrapper->AllGatherVector(const_cast(src_rows_ptr), + static_cast(dst_rows_ptr), + rows_num_vector[0]); + + switch (dtype) { + GLOO_ALL_GATHER_CASE(framework::proto::VarType::FP32, float, + gloo_wrapper); + GLOO_ALL_GATHER_CASE(framework::proto::VarType::FP64, double, + gloo_wrapper); + GLOO_ALL_GATHER_CASE(framework::proto::VarType::INT32, int, gloo_wrapper); + GLOO_ALL_GATHER_CASE(framework::proto::VarType::INT64, int64_t, + gloo_wrapper); + default: { + PADDLE_THROW(platform::errors::InvalidArgument( + "Invalid datatype for allreduce")); + } + } + VLOG(3) << "Selected Row DST:" << *dst_tensor; + VLOG(3) << "Selected Rows of DST:" + << string::join_strings(std::vector(*dst_rows), ','); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The number of each card is not the same, gloo only support the-same" + "batch division")); + } +} + paddle::platform::DeviceContext *GLOOParallelContext::GetDeviceContext( int ring_id) { // return the CPUDeviceContext diff --git a/paddle/fluid/imperative/gloo_context.h b/paddle/fluid/imperative/gloo_context.h index f54dc1a406a92..305a75a881153 100644 --- a/paddle/fluid/imperative/gloo_context.h +++ b/paddle/fluid/imperative/gloo_context.h @@ -16,6 +16,9 @@ #include #include #include +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/imperative/parallel_context.h" #include "paddle/fluid/platform/device_context.h" @@ -52,6 +55,11 @@ class GLOOParallelContext : public ParallelContext { void SynchronizeCompute() override; + private: + void AllReduce(const framework::Tensor& src, framework::Tensor* dst); + void AllReduce(const framework::SelectedRows& src, + framework::SelectedRows* dst); + private: std::unique_ptr device_; }; diff --git a/python/paddle/distributed/fleet/launch.py b/python/paddle/distributed/fleet/launch.py index c0a1c359d17c6..16b39e0fc8e45 100644 --- a/python/paddle/distributed/fleet/launch.py +++ b/python/paddle/distributed/fleet/launch.py @@ -103,7 +103,12 @@ def _parse_args(): type=str, default="log", help="The path for each process's log. Default --log_dir=log/") - + base_group.add_argument( + "--backend", + type=str, + default="auto", + help="Specifize the backend, can be gloo|nccl|bkcl|auto. Default value is auto which perfers nccl or bkcl." + ) base_group.add_argument( "--nproc_per_node", type=int, @@ -230,8 +235,21 @@ def get_cluster_from_args(args, device_mode, devices_per_proc): devices_per_proc) +def cpuonly_check(args): + if args.ips and len(args.ips.split(',')) > 1: + raise RuntimeError( + "CPUONLY launch only support single trainer, that is len(ips)=1, but got %s." + % args.ips) + if args.run_mode: + assert args.run_mode == 'cpuonly', "CPUONLY launch only support run mode is CPUONLY" + if args.servers: + raise RuntimeError("CPUONLY launch can't have --servers as arguments.") + return True + + def launch_collective(args): # parse arguments, used for cloud-single-machine and local + if args.backend == 'gloo': cpuonly_check(args) (device_mode, devices_per_proc) = launch_utils.get_device_proc_info(args) trainers_num = cloud_utils.get_trainers_num() logger.debug("parsed from args trainerss_num:{} mode:{} devices:{}".format( @@ -265,6 +283,7 @@ def launch_collective(args): global_envs["PADDLE_WITH_GLOO"] = str(os.getenv("PADDLE_WITH_GLOO", "0")) global_envs["PADDLE_GLOO_RENDEZVOUS"] = "3" global_envs["PADDLE_GLOO_FS_PATH"] = gloo_rendezvous_dir + global_envs["PADDLE_DISTRI_BACKEND"] = args.backend procs = start_local_trainers( cluster, @@ -349,9 +368,12 @@ def which_distributed_mode(args): if fluid.core.is_compiled_with_cuda(): accelerators = fluid.core.get_cuda_device_count() + args.backend = 'nccl' elif fluid.core.is_compiled_with_npu(): + args.backend = 'unknown' accelerators = fluid.core.get_npu_device_count() elif fluid.core.is_compiled_with_xpu(): + args.backend = 'bkcl' accelerators = fluid.core.get_xpu_device_count() else: accelerators = 0 @@ -372,10 +394,14 @@ def which_distributed_mode(args): else: if not fluid.core.is_compiled_with_cuda( ) and not fluid.core.is_compiled_with_xpu(): - logger.warning( - "Not found distinct arguments and not compiled with cuda or xpu. Default use ps mode" - ) - return DistributeMode.PS + if args.servers: + logger.warning( + "Not found distinct arguments and not compiled with cuda or xpu. \ +But found args.servers not empty, default use ps mode") + return DistributeMode.PS + else: + args.backend = "gloo" + return DistributeMode.COLLECTIVE else: logger.warning( "Not found distinct arguments and compiled with cuda or xpu. Default use collective mode" @@ -556,7 +582,20 @@ def launch(): logger = get_logger() _print_arguments(args) - distribute_mode = which_distributed_mode(args) + if args.backend == 'auto': + distribute_mode = which_distributed_mode(args) + assert args.backend in [ + 'gloo', 'nccl', 'bkcl', 'unknown' + ] # which_distributed_mode must modify args.backend + else: + assert args.run_mode == 'collective' or args.run_mode == None, "When backend is not 'auto', run mode must be collective" + check_backend(args.backend) + distribute_mode = DistributeMode.COLLECTIVE + + block_windows_and_macos( + args.backend) # raise error when using gloo on windows or macos + if args.backend == 'gloo': + logger.warning("launch start with CPUONLY mode") if enable_elastic(args, distribute_mode): launch_elastic(args, distribute_mode) diff --git a/python/paddle/distributed/fleet/launch_utils.py b/python/paddle/distributed/fleet/launch_utils.py index e114670440c06..3aced0ab996cb 100644 --- a/python/paddle/distributed/fleet/launch_utils.py +++ b/python/paddle/distributed/fleet/launch_utils.py @@ -22,6 +22,7 @@ import tempfile import shutil from contextlib import closing +import multiprocessing import socket import warnings import six @@ -30,6 +31,7 @@ import paddle import paddle.fluid as fluid from distutils.util import strtobool +import paddle.utils.cpp_extension.extension_utils as utils logger = logging.getLogger("root") logger.propagate = False @@ -669,29 +671,31 @@ def get_xpus(xpus): return res_xpus -def get_device_mode(): +def get_device_mode(backend): if fluid.core.is_compiled_with_npu() and \ fluid.core.get_npu_device_count() > 0: print("launch train in ascend npu mode!") return DeviceMode.ASCEND_NPU - if fluid.core.is_compiled_with_cuda() and \ + if backend == 'nccl' and \ fluid.core.get_cuda_device_count() > 0: print("launch train in GPU mode!") return DeviceMode.GPU - if fluid.core.is_compiled_with_xpu() and fluid.core.get_xpu_device_count( - ) > 0: + if backend == 'bkcl' and fluid.core.get_xpu_device_count() > 0: print("launch train in XPU mode") return DeviceMode.XPU - print("launch train in CPU mode") - return DeviceMode.CPU + if backend == 'gloo': + print("launch train in CPU mode") + return DeviceMode.CPU + + raise RuntimeError("Don't supported devices") def get_device_proc_info(args): # device_mode - device_mode = get_device_mode() + device_mode = get_device_mode(args.backend) # devices devices_per_proc = [] @@ -722,6 +726,9 @@ def get_device_proc_info(args): else: devices_per_proc = xpus elif device_mode == DeviceMode.CPU: + if hasattr(args, "paddle_cpuonly") and args.nproc_per_node is None: + #NOTE (xiongkun03) set it to cpu core number + args.nproc_per_node = multiprocessing.cpu_count() if args.nproc_per_node is None: devices_per_proc = [0] else: @@ -1237,3 +1244,45 @@ def start_pod_heter_worker(self, args, pod): tp.cmd = cmd self.procs["heter_worker"].append(tp) + + +def check_backend(backend): + if backend not in ['nccl', 'gloo', 'bkcl', 'auto']: + raise ValueError( + "paddle.distributed initialize error, " + "backend argument can only be one of 'nccl', 'gloo', 'bkcl', 'auto', but got %s" + % backend) + + if backend == 'nccl' and not fluid.core.is_compiled_with_cuda(): + raise ValueError( + "paddle.distributed initialize error, " + "your paddle is not compiled with cuda but you assign 'nccl' as backend." + ) + + if backend == 'bkcl' and not fluid.core.is_compiled_with_xpu(): + raise ValueError( + "paddle.distributed initialize error, " + "your paddle is not compiled with xpu but you assign 'bkcl' as backend." + ) + + +def block_windows_and_macos(backend): + if backend != 'gloo': return + if utils.OS_NAME.startswith('darwin'): # MACOS , block + raise ValueError( + "You are going to using gloo on macos, but currently is not supported" + ) + if utils.IS_WINDOWS: # MACOS , block + raise ValueError( + "You are going to using gloo on windows, but currently is not supported" + ) + + +def get_backend_by_compile_flag(): + if fluid.core.is_compiled_with_cuda(): + return 'nccl' + + if fluid.core.is_compiled_with_xpu(): + return 'bkcl' + + return 'gloo' diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index 7789b17429c4e..34c74ad30679e 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -26,6 +26,7 @@ from paddle.fluid import core from paddle.fluid.framework import _set_expected_place from paddle.fluid.dygraph import parallel_helper +from paddle.distributed.fleet.launch_utils import check_backend from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.distributed.fleet.base.private_helper_function import wait_server_ready # noqa: F401 @@ -55,25 +56,8 @@ def _start_kv_server(port, http_server_d, size): http_server.stop() -def _check_backend(backend): - if backend not in ['nccl', 'gloo', 'bkcl', 'auto']: - raise ValueError( - "paddle.distributed initialize error, " - "backend argument can only be one of 'nccl', 'gloo', 'bkcl', 'auto', but got %s" - % backend) - - if backend == 'nccl' and not core.is_compiled_with_cuda(): - raise ValueError( - "paddle.distributed initialize error, " - "your paddle is not compiled with cuda but you assign 'nccl' as backend." - ) - - if backend == 'bkcl' and not core.is_compiled_with_xpu(): - raise ValueError( - "paddle.distributed initialize error, " - "your paddle is not compiled with xpu but you assign 'bkcl' as backend." - ) - +def _is_cpuonly(backend): + check_backend(backend) if backend in ['auto', 'nccl', 'bkcl'] and (core.is_compiled_with_cuda() or core.is_compiled_with_xpu()): # passes 'auto' and can use cuda or xpu, use the default logics. so return False @@ -82,7 +66,7 @@ def _check_backend(backend): return True -def init_parallel_env(backend='auto'): +def init_parallel_env(): """ Initialize parallel training environment in dynamic graph mode. @@ -154,7 +138,8 @@ def train(): return # NOTE(xiongkun): support cpu gloo only, add this environment variable to # enable cpu only gloo prarllel training) - is_cpu_only = _check_backend(backend) + backend = os.environ.get('PADDLE_DISTRI_BACKEND', 'auto') + is_cpu_only = _is_cpuonly(backend) # 1. gpu xpu check, must be gpu or xpu, if not (is_cpu_only or core.is_compiled_with_cuda() or core.is_compiled_with_xpu()): diff --git a/python/paddle/distributed/spawn.py b/python/paddle/distributed/spawn.py index a60e4642e494d..cea831d9d90b5 100644 --- a/python/paddle/distributed/spawn.py +++ b/python/paddle/distributed/spawn.py @@ -24,8 +24,10 @@ from paddle.distributed.utils import _print_arguments from paddle.distributed.utils import _prepare_trainer_env from paddle.distributed.utils import get_host_name_ip -from paddle.distributed.cloud_utils import get_cluster_and_pod +from paddle.distributed.cloud_utils import get_cluster_and_pod, _get_trainers_num +from paddle.distributed.fleet.launch import get_cluster_from_args from paddle.distributed.fleet.cloud_utils import use_paddlecloud +from paddle.distributed.fleet.launch_utils import DeviceMode, check_backend, block_windows_and_macos from paddle.device import get_device # deprecated module import @@ -71,7 +73,9 @@ def _py_supported_check(): def _options_valid_check(options): # `print_config` keeped as a debug options, not show to users - supported_options = ['start_method', 'ips', 'gpus', 'xpus', 'print_config'] + supported_options = [ + 'start_method', 'ips', 'gpus', 'xpus', 'print_config', 'backend' + ] deprecated_options = [ 'selected_devices', 'started_port', 'cluster_node_ips', 'node_ip', 'use_paddlecloud' @@ -95,6 +99,22 @@ def _get_default_nprocs(): return core.get_cuda_device_count() elif 'xpu' in device: return core.get_xpu_device_count() + elif 'cpu' in device: + return multiprocessing.cpu_count() + else: + raise RuntimeError( + "`paddle.distributed.spawn` does not support parallel training on device `{}` now.". + format(device)) + + +def _get_default_backend(): + device = get_device() + if 'gpu' in device: + return 'nccl' + elif 'xpu' in device: + return 'bkcl' + elif 'cpu' in device: + return 'gloo' else: raise RuntimeError( "`paddle.distributed.spawn` does not support parallel training on device `{}` now.". @@ -112,6 +132,16 @@ def _get_node_ip(ips): def _get_subprocess_env_list(nprocs, options): + # NOTE (xiongkun03) Why put backend deduction here ? + # Becase _get_subprocess_env_list is used by many testcases. + # So for campability, we put backend deduction here + + # logic for handle backend option + if 'backend' not in options or options['backend'] == 'auto': + options['backend'] = _get_default_backend() + check_backend(options['backend']) + block_windows_and_macos(options['backend']) + # contruct processes env list processes_env_list = [] @@ -133,7 +163,7 @@ def _get_subprocess_env_list(nprocs, options): # if we set FLAGS_selected_gpus or FLAGS_selected_xpus to be `0,1,2,3`, it may cause error # when using `ParallelEnv` # NOTE(chenweihang): use absolute gpu or xpu card id - if core.is_compiled_with_cuda(): + if options['backend'] == 'nccl': args.selected_devices = options.get('gpus', None) if args.selected_devices is None: args.selected_devices = options.get('selected_devices', None) @@ -168,7 +198,7 @@ def _get_subprocess_env_list(nprocs, options): "CUDA_VISIBLE_DEVICES (%s)." % (card_id, ",".join(env_devices_list))) - elif core.is_compiled_with_xpu(): + elif options['backend'] == 'bkcl': args.selected_devices = options.get('xpus', None) if args.selected_devices is None: args.selected_devices = options.get('selected_devices', None) @@ -202,6 +232,23 @@ def _get_subprocess_env_list(nprocs, options): raise ValueError("The selected xpu card %s cannot found in " "XPU_VISIBLE_DEVICES (%s)." % (card_id, ",".join(env_devices_list))) + elif options['backend'] == 'gloo': + # TODO check gpu / xpu flag must not exist + warnings.warn( + "Your model will be trained under CPUONLY mode by using GLOO," + "because CPUPlace is specified manually or your installed PaddlePaddle only support CPU Device." + ) + args.paddle_cpuonly = True + args.selected_devices = None + args.ips = args.cluster_node_ips + assert options.get( + 'use_paddlecloud', + None) is None, "CPUONLY spawn doesn't support use paddle cloud" + assert len( + args.cluster_node_ips.split(',') + ) <= 1, "CPUONLY spawn only support single trainer, that is len(ips)=1, but got %s." + assert _get_trainers_num( + ) == 1, "CPUONLY spawn doesn't support multi-trainer" # set other inner args args.node_ip = options.get('node_ip', None) @@ -215,11 +262,17 @@ def _get_subprocess_env_list(nprocs, options): args.use_paddlecloud = use_paddlecloud() # get cluster and pod config - cluster, pod = get_cluster_and_pod(args) + if options['backend'] == 'gloo': + devices_per_proc = [x for x in range(0, nprocs)] + cluster, pod = get_cluster_from_args(args, DeviceMode.CPU, + devices_per_proc) + else: + cluster, pod = get_cluster_and_pod(args) # prepare subprocess env list for trainer in pod.trainers: - processes_env_list.append(_prepare_trainer_env(cluster, trainer)) + processes_env_list.append( + _prepare_trainer_env(cluster, trainer, options['backend'])) # [Debug] print config args.print_config = options.get('print_config', False) @@ -236,27 +289,35 @@ def _remove_risky_env(): os.environ.pop("https_proxy", None) -def _set_trainer_env(env_dict): +def _set_trainer_env(env_dict, backend): # NOTE(chenweihang): [ Why need set FLAGS_selected_gpus or FLAGS_selected_xpus here? ] # When the child process starts, it will inherit the configuration of the # main process and set the FLAGS once, but the environment variable has # not been set at this time, which leads to the FLAGS_selected_gpus or FLAGS_selected_xpus # is keep same with mainprocess(usually empty), so manually update the flags here - if core.is_compiled_with_cuda(): + + # NOTE(xiongkun): why put backend here? because if gloo, we shouldn't set FLAGS_selectedXXX + # + + if backend == 'nccl': set_flags({'FLAGS_selected_gpus': env_dict['FLAGS_selected_gpus']}) - elif core.is_compiled_with_xpu(): + elif backend == 'bkcl': set_flags({'FLAGS_selected_xpus': env_dict['FLAGS_selected_xpus']}) else: - raise ValueError("PaddlePaddle should be compiled with XPU or CUDA.") + #NOTE(xiongkun) why not raise Error ? + # So far, we added support for CPU parallel, and will be applied when paddle is not + # compiled with cuda or xp. just do nothing. + pass + for var_name in env_dict: os.environ[var_name] = env_dict[var_name] -def _func_wrapper(func, args, error_queue, return_queue, env_dict): +def _func_wrapper(func, args, error_queue, return_queue, env_dict, backend): try: # config subprocess environment variables _remove_risky_env() - _set_trainer_env(env_dict) + _set_trainer_env(env_dict, backend) # execute function result = func(*args) # record function return value @@ -487,7 +548,8 @@ def train(print_result=False): return_queue = mp.SimpleQueue() process = mp.Process( target=_func_wrapper, - args=(func, args, error_queue, return_queue, procs_env_list[i])) + args=(func, args, error_queue, return_queue, procs_env_list[i], + options['backend'])) process.daemon = daemon process.start() error_queues.append(error_queue) diff --git a/python/paddle/distributed/utils.py b/python/paddle/distributed/utils.py index 6d14b30d18c7f..15b728f25a99d 100644 --- a/python/paddle/distributed/utils.py +++ b/python/paddle/distributed/utils.py @@ -25,6 +25,7 @@ from contextlib import closing import socket from paddle.fluid import core +from paddle.distributed.fleet.launch_utils import get_backend_by_compile_flag from distutils.util import strtobool from paddle.fluid.layer_helper import LayerHelper @@ -622,8 +623,10 @@ def __free_port(): return None -def _prepare_trainer_env(cluster, trainer): - if core.is_compiled_with_xpu(): +def _prepare_trainer_env(cluster, trainer, backend=None): + if backend is None: + backend = get_backend_by_compile_flag() # for compatibility + if backend == 'bkcl': proc_env = { "FLAGS_selected_xpus": "%s" % ",".join([str(g) for g in trainer.gpus]), @@ -632,7 +635,7 @@ def _prepare_trainer_env(cluster, trainer): "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()) } - elif core.is_compiled_with_cuda(): + elif backend == 'nccl': proc_env = { "FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in trainer.gpus]), @@ -641,6 +644,19 @@ def _prepare_trainer_env(cluster, trainer): "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()) } + elif backend == 'gloo': + # NOTE (xiongkun) default fall back into cpu only + proc_env = { + "PADDLE_TRAINER_ID": "%d" % trainer.rank, + "PADDLE_CURRENT_ENDPOINT": "%s" % trainer.endpoint, + "PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(), + "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()), + "PADDLE_DISTRI_BACKEND": + backend, # only add here, other will be auto + } + else: + raise ValueError("backend must be one of 'gloo, nccl, bkcl'") + return proc_env diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 49a23890bbfd0..1ec0812a6661e 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -193,8 +193,14 @@ endif() list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_hybrid_parallel) +LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_transformer_gloo) # NOTE: @xiongkun03, cpu is too slow, fix it in next PR + if (NOT WITH_GLOO) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel_cpuonly) + + LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_unused_variables_gloo) + LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding_over_height_gloo) + LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding_gloo) endif() if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) @@ -477,6 +483,10 @@ if (APPLE OR WIN32) list(REMOVE_ITEM TEST_OPS test_multiprocess_dataloader_dataset) endif() +if (NOT WITH_GLOO) + LIST(REMOVE_ITEM TEST_OPS test_cpuonly_spawn) +endif() + if(NOT WITH_GPU OR WIN32 OR APPLE) list(REMOVE_ITEM TEST_OPS test_build_strategy_fusion_group_pass) endif() @@ -635,6 +645,9 @@ if(WITH_DISTRIBUTE) endforeach(TEST_OP) # solve it later. bash_test_modules(test_fleet_launch_ps START_BASH test_fleet_launch_ps.sh SERIAL LABELS "RUN_TYPE=EXCLUSIVE" ENVS "PADDLE_DIST_UT_PORT=${dist_ut_port}" PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR} ) + if (WITH_GLOO) + bash_test_modules(test_cpuonly_launch START_BASH test_cpuonly_launch.sh SERIAL LABELS "RUN_TYPE=EXCLUSIVE" ENVS "PADDLE_DIST_UT_PORT=${dist_ut_port}" PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR} ) + endif() bash_test_modules(test_new_group START_BASH test_new_group.sh SERIAL LABELS "RUN_TYPE=EXCLUSIVE" ENVS "PADDLE_DIST_UT_PORT=${dist_ut_port}+20" PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR} ) endif(NOT APPLE) endif() @@ -1043,4 +1056,9 @@ if(WITH_GPU OR WITH_ROCM) endif() set_tests_properties(test_inplace_addto_strategy PROPERTIES TIMEOUT 120) set_tests_properties(test_eigvals_op PROPERTIES TIMEOUT 400) +if (WITH_GLOO) + set_tests_properties(test_parallel_dygraph_unused_variables_gloo PROPERTIES TIMEOUT 120) + set_tests_properties(test_parallel_dygraph_sparse_embedding_gloo PROPERTIES TIMEOUT 120) + set_tests_properties(test_parallel_dygraph_sparse_embedding_over_height_gloo PROPERTIES TIMEOUT 120) +endif() set_tests_properties(test_tensordot PROPERTIES TIMEOUT 1000) diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py index 048c9b399d804..781d606f33b8f 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_gradient_check.py @@ -66,8 +66,7 @@ def forward(self, x): class TestDistTraning(unittest.TestCase): def test_multiple_gpus(self): - backend = os.environ.get('PADDLE_DISTRI_BACKEND', 'auto') - dist.init_parallel_env(backend) + dist.init_parallel_env() self.trainer_id = dist.get_rank() model_a = SimpleNet(self.trainer_id) diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_se_resnext.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_se_resnext.py index 4ce67676c3e85..0387de32c9145 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_se_resnext.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_se_resnext.py @@ -324,6 +324,7 @@ def run_one_loop(self, model, opt, data): bs = len(data) dy_x_data = np.array([x[0].reshape(3, 224, 224) for x in data]).astype('float32') + dy_x_data = dy_x_data / 255.0 y_data = np.array([x[1] for x in data]).astype('int64').reshape(bs, 1) img = to_variable(dy_x_data) label = to_variable(y_data) diff --git a/python/paddle/fluid/tests/unittests/test_cpuonly_launch.sh b/python/paddle/fluid/tests/unittests/test_cpuonly_launch.sh new file mode 100644 index 0000000000000..1c35166cf4434 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cpuonly_launch.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +function test_launch_cpuonly(){ + python -m paddle.distributed.launch --nproc_per_node=4 --backend=gloo \ + parallel_dygraph_gradient_check.py 2>ut.elog + if grep -q "ABORT" ut.elog; then + echo "test cpu only failed" + exit -1 + else + if grep -q "CPUONLY" ut.elog; then + echo "test_launch_cpuonly successfully" + else + echo "test_launch_cpuonly failed" + exit -1 + fi + fi +} +function test_launch_error_case1(){ + python -m paddle.distributed.launch --nproc_per_node=4 --backend=random_str \ + parallel_dygraph_gradient_check.py 2>ut.elog + if grep -q "ValueError" ut.elog; then + echo "test_launch_error_case1 successfully" + else + exit -1 + fi +} + +test_launch_cpuonly +test_launch_error_case1 diff --git a/python/paddle/fluid/tests/unittests/test_cpuonly_spawn.py b/python/paddle/fluid/tests/unittests/test_cpuonly_spawn.py new file mode 100644 index 0000000000000..1def2ffd82ad7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cpuonly_spawn.py @@ -0,0 +1,72 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest + +import paddle +import paddle.nn as nn +import paddle.optimizer as opt +import paddle.distributed as dist + + +class LinearNet(nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear1 = nn.Linear(10, 10) + self._linear2 = nn.Linear(10, 1) + + def forward(self, x): + return self._linear2(self._linear1(x)) + + +def train(print_result=False): + # 1. initialize parallel environment + dist.init_parallel_env() + + # 2. create data parallel layer & optimizer + layer = LinearNet() + dp_layer = paddle.DataParallel(layer) + + loss_fn = nn.MSELoss() + adam = opt.Adam(learning_rate=0.001, parameters=dp_layer.parameters()) + + # 3. run layer + inputs = paddle.randn([10, 10], 'float32') + outputs = dp_layer(inputs) + labels = paddle.randn([10, 1], 'float32') + loss = loss_fn(outputs, labels) + + if print_result is True: + print("loss:", loss.numpy()) + + loss.backward() + print("Grad is", layer._linear1.weight.grad) + adam.step() + adam.clear_grad() + + +class TestSpawn(unittest.TestCase): + def test_spawn(self): + dist.spawn(train, backend='gloo', nprocs=4) + + def test_wrong_backend(self): + try: + dist.spawn(train, backend='something', nprocs=4) + except ValueError as e: + self.assertEqual(type(e), ValueError) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index eceb484a0184c..63985415c51f6 100755 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -209,7 +209,11 @@ def run_use_fleet_api_20_trainer(self, args): def get_data(): origin_batch = next(reader_generator) - if args.update_method != "local" and args.use_reader_alloc: + if paddle.distributed.get_world_size( + ) == 1 and args.update_method == 'gloo': # Gloo single mode + return origin_batch + + elif args.update_method != "local" and args.use_reader_alloc: new_batch = [] for offset, item in enumerate(origin_batch): if offset % 2 == args.trainer_id: @@ -506,7 +510,10 @@ def run_one_loop(self, model, opt, data): "train_one_loop should be implemented by the child classes.") def _get_data(self, batch, args): - if args.update_method != "local": + if paddle.distributed.get_world_size( + ) == 1 and args.update_method == 'gloo': # Gloo single mode + return batch + elif args.update_method != "local": new_batch = [] for offset, item in enumerate(batch): if offset % 2 == args.trainer_id: @@ -518,14 +525,16 @@ def _get_data(self, batch, args): def run_trainer(self, args): seed = 90 - if fluid.core.is_compiled_with_cuda(): + if args.update_method == 'gloo': + place = fluid.CPUPlace() + elif fluid.core.is_compiled_with_cuda(): device_id = int(os.getenv("FLAGS_selected_gpus", "0")) place = fluid.CUDAPlace(device_id) elif fluid.core.is_compiled_with_xpu(): device_id = int(os.getenv("FLAGS_selected_xpus", "0")) place = fluid.XPUPlace(device_id) else: - assert ("Only support CUDAPlace or XPUPlace for now.") + assert ("Only support CUDAPlace or XPUPlace or CPU(Gloo) for now.") with fluid.dygraph.guard(place): fluid.default_startup_program().random_seed = seed @@ -554,6 +563,16 @@ def run_trainer(self, args): model = dygraph.parallel.DataParallel( model, strategy, find_unused_parameters=True) print_to_err(type(self).__name__, "model built in dygraph") + + elif args.update_method == "gloo": + paddle.distributed.init_parallel_env() + if not args.find_unused_parameters: + model = dygraph.parallel.DataParallel( + model, find_unused_parameters=False) + else: + model = dygraph.parallel.DataParallel( + model, find_unused_parameters=True) + out_losses = [] print_to_err(type(self).__name__, "begin to run dygraph training") for step_id, data in enumerate(train_reader()): @@ -588,12 +607,12 @@ def run_trainer_with_spawn(self, args): args.trainer_id = paddle.distributed.get_rank() # 3. init parallel env - if args.update_method == "nccl2": + if args.update_method in ["nccl2", "gloo"]: paddle.distributed.init_parallel_env() # 4. train model model, train_reader, opt = self.get_model() - if args.update_method == "nccl2": + if args.update_method in ["nccl2", "gloo"]: if args.find_unused_parameters: model = paddle.DataParallel(model, find_unused_parameters=True) else: @@ -668,7 +687,9 @@ def runtime_main(test_class): '--update_method', type=str, default="local", - choices=["pserver", "nccl2", "bkcl", "local", "nccl2_reduce_layer"]) + choices=[ + "pserver", "nccl2", "bkcl", "local", "nccl2_reduce_layer", "gloo" + ]) parser.add_argument('--trainer_id', type=int, required=False, default=0) parser.add_argument('--trainers', type=int, required=False, default=1) parser.add_argument('--nccl_comm_num', type=int, required=False, default=1) @@ -685,6 +706,7 @@ def runtime_main(test_class): '--current_endpoint', type=str, required=False, default="") parser.add_argument('--sync_mode', action='store_true') parser.add_argument('--use_cuda', action='store_true') + parser.add_argument('--use_cpu', action='store_true') parser.add_argument('--use_xpu', action='store_true') parser.add_argument('--use_dgc', action='store_true') parser.add_argument('--accumulate_gradient', action='store_true') @@ -713,6 +735,9 @@ def runtime_main(test_class): args = parser.parse_args() + if args.update_method == 'gloo': + paddle.set_device("cpu") + model = test_class() if args.role == "pserver" and args.update_method == "pserver": model.run_pserver(args) @@ -770,6 +795,7 @@ def setUp(self): self._use_reader_alloc = True self._nccl2_mode = False self._bkcl_mode = False + self._gloo_mode = False # now, support gloo backend self._pipeline_mode = False self._mp_mode = False # FIXME(typhoonzero): I added this stupid argument to enable @@ -875,7 +901,7 @@ def _run_local(self, batch_size=DEFAULT_BATCH_SIZE, batch_merge_repeat=1, log_name="", - devices="0"): + devices="1"): cmd = self._python_interp @@ -947,6 +973,21 @@ def _run_local(self, return pickle.loads(local_out) + def _run_local_gloo(self, + model, + envs, + check_error_log=False, + batch_size=DEFAULT_BATCH_SIZE, + batch_merge_repeat=1, + log_name="", + devices="0"): + saved_endpoints = self._ps_endpoints + self._ps_endpoints = self._ps_endpoints.split(',')[0] + result = self._run_cluster_gloo(model, envs, 'gloo', check_error_log, + log_name) + self._ps_endpoints = saved_endpoints + return result + def _run_cluster(self, model, envs, check_error_log, log_name): # Run dist train to compare with local results ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver( @@ -1037,6 +1078,62 @@ def _run_cluster(self, model, envs, check_error_log, log_name): return pickle.loads(tr0_out), pickle.loads(tr1_out) + def _get_gloo_trainer_cmd(self, model, ep, update_method, trainer_id, + trainer_num): + env = {} + tr_cmd = "%s -u" + + if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': + tr_cmd += " -m coverage run --branch -p" + + tr_cmd += " %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method %s --lr %f" + + tr_cmd = tr_cmd % \ + (self._python_interp, model, self._ps_endpoints, + trainer_id, ep, update_method, self._lr) + + if self._use_reduce: + tr_cmd += " --use_reduce" + if self._use_reader_alloc: + tr_cmd += " --use_reader_alloc" + #assert self._use_reduce == False, "gloo not support _use_reduce" + #assert self._use_reader_alloc == False, "gloo not support _use_reduce" + if self._save_model: + tr_cmd += " --save_model" + self.__use_cuda = False + self.__use_xpu = False + assert self.__use_cuda == False, "gloo not support use cuda" + assert self.__use_xpu == False, "gloo not support use xpu" + tr_cmd += " --use_cpu" + env.update({ + "PADDLE_TRAINERS_NUM": "{}".format(trainer_num), + "PADDLE_TRAINER_ID": "{}".format(trainer_id), + "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, + "PADDLE_CURRENT_ENDPOINT": ep, + "PADDLE_CURRENT_ENDPOINT": ep, + "PADDLE_DISTRI_BACKEND": "gloo", + "GLOG_v": "2", + }) + + assert self._use_dgc == False, "gloo not support use dgc" + if self._accumulate_gradient: + tr_cmd += " --accumulate_gradient" + + if self._find_unused_parameters: + tr_cmd += " --find_unused_parameters" + + assert self._pipeline_mode == False, "gloo not support use pipeline" + + if self._enable_backward_deps: # build strategy, save it + tr_cmd += " --enable_backward_deps" + + if self._fuse_all_reduce is not None: + tr_cmd += " --fuse_all_reduce {}".format(self._fuse_all_reduce) + + assert self._use_fleet_api == False, "gloo not support use fleet api" + assert self._use_fleet_api_20 == False, "gloo not support use fleet api" + return tr_cmd, env + def _get_nccl2_trainer_cmd(self, model, ep, update_method, trainer_id, trainer_num): env = {} @@ -1123,6 +1220,57 @@ def _get_nccl2_trainer_cmd(self, model, ep, update_method, trainer_id, return tr_cmd, env + def _run_cluster_gloo(self, model, envs, update_method, check_error_log, + log_name): + assert update_method == "gloo", "_run_cluster_gloo must have update_method: gloo, but get %s" % update_method + assert not self._use_hallreduce, "_run_cluster_gloo must have _use_hallreduce = false" + + worker_endpoints = self._ps_endpoints.split(",") + + trainer_num = len(worker_endpoints) + + procs = [] + pipes = [] + for i in range(0, trainer_num): + tr_cmd, tr_env = self._get_gloo_trainer_cmd( + model, worker_endpoints[i], update_method, i, trainer_num) + tr_env.update(envs) + tr_env["GLOG_vmodule"] = 'gloo_context=4' + tr_env["GLOG_v"] = '3' + print("use_hallreduce:{} tr_cmd:{}, env: {}".format( + self._use_hallreduce, tr_cmd, tr_env)) + + tr_pipe = open(log_name + "_tr{}_err.log".format(i), "wb") + + print_to_err( + type(self).__name__, + "going to start process {} with nccl2".format(i)) + tr_proc = subprocess.Popen( + tr_cmd.strip().split(" "), + stdout=subprocess.PIPE, + stderr=tr_pipe, + env=tr_env) + + procs.append(tr_proc) + pipes.append(tr_pipe) + + outs = [] + for i in range(0, trainer_num): + tr_out, tr_err = procs[i].communicate() + outs.append(tr_out) + pipes[i].close() + sys.stderr.write('trainer {} stderr: {}\n'.format(i, tr_err)) + + if trainer_num == 1: + if check_error_log: print("outs[0]:", outs[0]) + return pickle.loads(outs[0]) + + else: + if check_error_log: + print("outs[0]:", outs[0]) + print("outs[1]:", outs[1]) + return pickle.loads(outs[0]), pickle.loads(outs[1]) + def _run_cluster_nccl2(self, model, envs, update_method, check_error_log, log_name): if self._use_hallreduce: @@ -1262,7 +1410,12 @@ def check_with_place(self, required_envs = self._get_required_envs(check_error_log, need_envs) - local_losses \ + if self._gloo_mode: + local_losses \ + = self._run_local_gloo(model_file, required_envs, + check_error_log, log_name=log_name) + else: + local_losses \ = self._run_local(model_file, required_envs, check_error_log, log_name=log_name) @@ -1288,6 +1441,14 @@ def check_with_place(self, update_method='bkcl', check_error_log=check_error_log, log_name=log_name) + elif self._gloo_mode: + # gloo mode, cpu only parallel train @xiongkun03 + tr0_losses, tr1_losses = self._run_cluster_gloo( + model_file, + required_envs, + update_method='gloo', + check_error_log=check_error_log, + log_name=log_name) elif self._pipeline_mode: tr0_losses, tr1_losses = self._run_pipeline( diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py index c97cd56e8a7a4..edf9aed04f5e0 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_dataparallel.py @@ -49,6 +49,51 @@ def get_gpus(selected_gpus): return selected_gpus +def start_local_trainers_cpu(trainer_endpoints, + training_script, + training_script_args, + log_dir=None): + current_env = copy.copy(os.environ.copy()) + current_env.pop("http_proxy", None) + current_env.pop("https_proxy", None) + + procs = [] + n_rank = len(trainer_endpoints) + print(trainer_endpoints) + for rank_id, endpoint in enumerate(trainer_endpoints): + proc_env = { + "PADDLE_DISTRI_BACKEND": "gloo", + "PADDLE_TRAINER_ID": "%d" % rank_id, + "PADDLE_CURRENT_ENDPOINT": "%s" % endpoint, + "PADDLE_TRAINERS_NUM": "%d" % n_rank, + "PADDLE_TRAINER_ENDPOINTS": ",".join(trainer_endpoints) + } + + current_env.update(proc_env) + + print("trainer proc env:{}".format(current_env)) + + assert os.getenv('WITH_COVERAGE', + 'OFF') == 'OFF', "Gloo don't support WITH_COVERAGE." + cmd = "python -u " + training_script + + print("start trainer proc:{} env:{}".format(cmd, proc_env)) + + fn = None + + proc = subprocess.Popen(cmd.split(" "), env=current_env) + + tp = TrainerProc() + tp.proc = proc + tp.rank = rank_id + tp.log_fn = fn + tp.cmd = cmd + + procs.append(tp) + + return procs + + def start_local_trainers(cluster, pod, training_script, @@ -116,6 +161,26 @@ def run_mnist_2gpu(self, target_file_name): training_script=target_file_name, training_script_args=[]) + while True: + alive = watch_local_trainers(procs, cluster.trainers_endpoints()) + + if not alive: + print("Local procs complete, POD info:{}".format(pod)) + break + time.sleep(3) + + +class TestMultipleWithGloo(unittest.TestCase): + def run_mnist_2cpu(self, target_file_name): + + cluster, pod = get_cluster_from_args( + [0, 1]) #tmp use. for getting trainer_nranks() + + procs = start_local_trainers_cpu( + cluster.trainers_endpoints(), + training_script=target_file_name, + training_script_args=[]) + while True: alive = watch_local_trainers(procs, cluster.trainers_nranks()) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_gloo.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_gloo.py new file mode 100644 index 0000000000000..56fcf806c4717 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_gloo.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import sys +import unittest + +import paddle.fluid as fluid +from test_dist_base import TestDistBase +from spawn_runner_base import TestDistSpawnRunner +from parallel_dygraph_sparse_embedding import TestSparseEmbedding +from parallel_dygraph_sparse_embedding_fp64 import TestSparseEmbeddingFP64 + +flag_name = os.path.splitext(__file__)[0] + + +class TestParallelDygraphSparseEmdedding_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._gloo_mode = True + self._dygraph = True + + def test_sparse_embedding(self): + self.check_with_place( + "parallel_dygraph_sparse_embedding.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphSparseEmdeddingFP64_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._gloo_mode = True + self._dygraph = True + + def test_sparse_embedding_fp64(self): + self.check_with_place( + "parallel_dygraph_sparse_embedding_fp64.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height_gloo.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height_gloo.py new file mode 100644 index 0000000000000..ba43e26e23a4e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_over_height_gloo.py @@ -0,0 +1,44 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import sys +import unittest + +import paddle.fluid as fluid +from test_dist_base import TestDistBase +from spawn_runner_base import TestDistSpawnRunner +from parallel_dygraph_sparse_embedding_over_height import TestSparseEmbeddingOverHeight + +flag_name = os.path.splitext(__file__)[0] + + +class TestParallelDygraphSparseEmdeddingOverHeight_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._gloo_mode = True + self._dygraph = True + + def test_sparse_embedding(self): + self.check_with_place( + "parallel_dygraph_sparse_embedding_over_height.py", + delta=1e-7, + check_error_log=True, + log_name=flag_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer_gloo.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer_gloo.py new file mode 100644 index 0000000000000..d3619cc1b9a00 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_transformer_gloo.py @@ -0,0 +1,61 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import sys +import unittest + +import paddle.fluid as fluid +from test_dist_base import TestDistBase +from spawn_runner_base import TestDistSpawnRunner +from parallel_dygraph_transformer import TestTransformer + +flag_name = os.path.splitext(__file__)[0] + + +class TestParallelDygraphTransformer_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._gloo_mode = True + self._dygraph = True + + def test_transformer(self): + self.check_with_place( + "parallel_dygraph_transformer.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphTransformerAccGrad_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._gloo_mode = True + self._dygraph = True + self._accumulate_gradient = True + self._find_unused_parameters = False + + def test_transformer(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_transformer.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables_gloo.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables_gloo.py new file mode 100644 index 0000000000000..89373fcb6eebc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_unused_variables_gloo.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import sys +import unittest + +import paddle.fluid as fluid +from test_dist_base import TestDistBase +from spawn_runner_base import TestDistSpawnRunner +from parallel_dygraph_unused_variables import TestSparseEmbeddingUnusedVars + +flag_name = os.path.splitext(__file__)[0] + + +class TestParallelDygraphUnusedVar_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._gloo_mode = True + self._dygraph = True + + def test_net(self): + self.check_with_place( + "parallel_dygraph_unused_variables.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphNoVar_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._gloo_mode = True + self._dygraph = True + + def test_net(self): + self.check_with_place( + "parallel_dygraph_none_var.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +class TestParallelDygraphSharedUnusedVariables_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._gloo_mode = True + self._dygraph = True + + def test_mnist(self): + self.check_with_place( + "parallel_dygraph_shared_unused_var.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_spawn_and_init_parallel_env.py b/python/paddle/fluid/tests/unittests/test_spawn_and_init_parallel_env.py index 14547eca5aca2..dccc117f6bc15 100644 --- a/python/paddle/fluid/tests/unittests/test_spawn_and_init_parallel_env.py +++ b/python/paddle/fluid/tests/unittests/test_spawn_and_init_parallel_env.py @@ -24,6 +24,7 @@ from paddle.fluid import core from paddle.fluid.dygraph import parallel_helper +import multiprocessing # NOTE(chenweihang): Coverage CI is currently not able to count python3 # unittest, so the unittests here covers some cases that will only be @@ -89,8 +90,8 @@ def test_options_valid_check(self): def test_get_default_nprocs(self): paddle.set_device('cpu') - with self.assertRaises(RuntimeError): - nprocs = _get_default_nprocs() + nprocs = _get_default_nprocs() + self.assertEqual(nprocs, multiprocessing.cpu_count()) paddle.set_device('gpu') nprocs = _get_default_nprocs() From 3fbb6644cefad36b8328f33ab713a70bbd1a0b53 Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Tue, 26 Oct 2021 10:53:53 +0800 Subject: [PATCH 3/9] add slot record dataset (#36200) (#36710) --- paddle/fluid/framework/channel.h | 20 +- paddle/fluid/framework/data_feed.cc | 112 +++++++- paddle/fluid/framework/data_feed.h | 317 +++++++++++++++++++++- paddle/fluid/framework/data_set.cc | 166 +++++++++-- paddle/fluid/framework/data_set.h | 40 ++- paddle/fluid/framework/dataset_factory.cc | 3 +- paddle/fluid/platform/flags.cc | 8 + paddle/fluid/pybind/data_set_py.cc | 2 - 8 files changed, 622 insertions(+), 46 deletions(-) diff --git a/paddle/fluid/framework/channel.h b/paddle/fluid/framework/channel.h index 503f1513aad20..80fee94f1c85d 100644 --- a/paddle/fluid/framework/channel.h +++ b/paddle/fluid/framework/channel.h @@ -157,7 +157,19 @@ class ChannelObject { p.resize(finished); return finished; } + // read once only + size_t ReadOnce(std::vector& p, size_t size) { // NOLINT + if (size == 0) { + return 0; + } + std::unique_lock lock(mutex_); + p.resize(size); + size_t finished = Read(size, &p[0], lock, true); + p.resize(finished); + Notify(); + return finished; + } size_t ReadAll(std::vector& p) { // NOLINT p.clear(); size_t finished = 0; @@ -241,17 +253,21 @@ class ChannelObject { return !closed_; } - size_t Read(size_t n, T* p, std::unique_lock& lock) { // NOLINT + size_t Read(size_t n, T* p, std::unique_lock& lock, // NOLINT + bool once = false) { // NOLINT size_t finished = 0; CHECK(n <= MaxCapacity() - reading_count_); reading_count_ += n; while (finished < n && WaitForRead(lock)) { - size_t m = std::min(n - finished, data_.size()); + size_t m = (std::min)(n - finished, data_.size()); for (size_t i = 0; i < m; i++) { p[finished++] = std::move(data_.front()); data_.pop_front(); } reading_count_ -= m; + if (once && m > 0) { + break; + } } reading_count_ -= n - finished; return finished; diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index fdb24ee18eca7..4463fd9fd5340 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -36,6 +36,107 @@ DLManager& global_dlmanager_pool() { return manager; } +class BufferedLineFileReader { + typedef std::function SampleFunc; + static const int MAX_FILE_BUFF_SIZE = 4 * 1024 * 1024; + class FILEReader { + public: + explicit FILEReader(FILE* fp) : fp_(fp) {} + int read(char* buf, int len) { return fread(buf, sizeof(char), len, fp_); } + + private: + FILE* fp_; + }; + + public: + typedef std::function LineFunc; + + private: + template + int read_lines(T* reader, LineFunc func, int skip_lines) { + int lines = 0; + size_t ret = 0; + char* ptr = NULL; + char* eol = NULL; + total_len_ = 0; + error_line_ = 0; + + SampleFunc spfunc = get_sample_func(); + std::string x; + while (!is_error() && (ret = reader->read(buff_, MAX_FILE_BUFF_SIZE)) > 0) { + total_len_ += ret; + ptr = buff_; + eol = reinterpret_cast(memchr(ptr, '\n', ret)); + while (eol != NULL) { + int size = static_cast((eol - ptr) + 1); + x.append(ptr, size - 1); + ++lines; + if (lines > skip_lines && spfunc()) { + if (!func(x)) { + ++error_line_; + } + } + + x.clear(); + ptr += size; + ret -= size; + eol = reinterpret_cast(memchr(ptr, '\n', ret)); + } + if (ret > 0) { + x.append(ptr, ret); + } + } + if (!is_error() && !x.empty()) { + ++lines; + if (lines > skip_lines && spfunc()) { + if (!func(x)) { + ++error_line_; + } + } + } + return lines; + } + + public: + BufferedLineFileReader() + : random_engine_(std::random_device()()), + uniform_distribution_(0.0f, 1.0f) { + total_len_ = 0; + sample_line_ = 0; + buff_ = + reinterpret_cast(calloc(MAX_FILE_BUFF_SIZE + 1, sizeof(char))); + } + ~BufferedLineFileReader() { free(buff_); } + + int read_file(FILE* fp, LineFunc func, int skip_lines) { + FILEReader reader(fp); + return read_lines(&reader, func, skip_lines); + } + uint64_t file_size(void) { return total_len_; } + void set_sample_rate(float r) { sample_rate_ = r; } + size_t get_sample_line() { return sample_line_; } + bool is_error(void) { return (error_line_ > 10); } + + private: + SampleFunc get_sample_func() { + if (std::abs(sample_rate_ - 1.0f) < 1e-5f) { + return [this](void) { return true; }; + } + return [this](void) { + return (uniform_distribution_(random_engine_) < sample_rate_); + }; + } + + private: + char* buff_ = nullptr; + uint64_t total_len_ = 0; + + std::default_random_engine random_engine_; + std::uniform_real_distribution uniform_distribution_; + float sample_rate_ = 1.0f; + size_t sample_line_ = 0; + size_t error_line_ = 0; +}; void RecordCandidateList::ReSize(size_t length) { mutex_.lock(); capacity_ = length; @@ -301,7 +402,7 @@ int InMemoryDataFeed::Next() { << ", thread_id=" << thread_id_; } } else { - VLOG(3) << "enable heter NEXT: " << offset_index_ + VLOG(3) << "enable heter next: " << offset_index_ << " batch_offsets: " << batch_offsets_.size(); if (offset_index_ >= batch_offsets_.size()) { VLOG(3) << "offset_index: " << offset_index_ @@ -318,14 +419,7 @@ int InMemoryDataFeed::Next() { VLOG(3) << "finish reading for heterps, batch size zero, thread_id=" << thread_id_; } - /* - if (offset_index_ == batch_offsets_.size() - 1) { - std::vector data; - output_channel_->ReadAll(data); - consume_channel_->Write(std::move(data)); - } - */ - VLOG(3) << "#15 enable heter NEXT: " << offset_index_ + VLOG(3) << "enable heter next: " << offset_index_ << " batch_offsets: " << batch_offsets_.size() << " baych_size: " << this->batch_size_; } diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 198bc51463af3..5527eaf1f6fa4 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -39,8 +39,14 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/string/string_helper.h" +DECLARE_int32(record_pool_max_size); +DECLARE_int32(slotpool_thread_num); +DECLARE_bool(enable_slotpool_wait_release); +DECLARE_bool(enable_slotrecord_reset_shrink); + namespace paddle { namespace framework { class DataFeedDesc; @@ -69,6 +75,50 @@ namespace framework { // while (reader->Next()) { // // trainer do something // } + +template +struct SlotValues { + std::vector slot_values; + std::vector slot_offsets; + + void add_values(const T* values, uint32_t num) { + if (slot_offsets.empty()) { + slot_offsets.push_back(0); + } + if (num > 0) { + slot_values.insert(slot_values.end(), values, values + num); + } + slot_offsets.push_back(static_cast(slot_values.size())); + } + T* get_values(int idx, size_t* size) { + uint32_t& offset = slot_offsets[idx]; + (*size) = slot_offsets[idx + 1] - offset; + return &slot_values[offset]; + } + void add_slot_feasigns(const std::vector>& slot_feasigns, + uint32_t fea_num) { + slot_values.reserve(fea_num); + int slot_num = static_cast(slot_feasigns.size()); + slot_offsets.resize(slot_num + 1); + for (int i = 0; i < slot_num; ++i) { + auto& slot_val = slot_feasigns[i]; + slot_offsets[i] = static_cast(slot_values.size()); + uint32_t num = static_cast(slot_val.size()); + if (num > 0) { + slot_values.insert(slot_values.end(), slot_val.begin(), slot_val.end()); + } + } + slot_offsets[slot_num] = slot_values.size(); + } + void clear(bool shrink) { + slot_offsets.clear(); + slot_values.clear(); + if (shrink) { + slot_values.shrink_to_fit(); + slot_offsets.shrink_to_fit(); + } + } +}; union FeatureFeasign { uint64_t uint64_feasign_; float float_feasign_; @@ -97,6 +147,38 @@ struct FeatureItem { uint16_t slot_; }; +struct AllSlotInfo { + std::string slot; + std::string type; + int used_idx; + int slot_value_idx; +}; +struct UsedSlotInfo { + int idx; + int slot_value_idx; + std::string slot; + std::string type; + bool dense; + std::vector local_shape; + int total_dims_without_inductive; + int inductive_shape_index; +}; +struct SlotRecordObject { + uint64_t search_id; + uint32_t rank; + uint32_t cmatch; + std::string ins_id_; + SlotValues slot_uint64_feasigns_; + SlotValues slot_float_feasigns_; + + ~SlotRecordObject() { clear(true); } + void reset(void) { clear(FLAGS_enable_slotrecord_reset_shrink); } + void clear(bool shrink) { + slot_uint64_feasigns_.clear(shrink); + slot_float_feasigns_.clear(shrink); + } +}; +using SlotRecord = SlotRecordObject*; // sizeof Record is much less than std::vector struct Record { std::vector uint64_feasigns_; @@ -108,6 +190,179 @@ struct Record { uint32_t cmatch; }; +inline SlotRecord make_slotrecord() { + static const size_t slot_record_byte_size = sizeof(SlotRecordObject); + void* p = malloc(slot_record_byte_size); + new (p) SlotRecordObject; + return reinterpret_cast(p); +} + +inline void free_slotrecord(SlotRecordObject* p) { + p->~SlotRecordObject(); + free(p); +} + +template +class SlotObjAllocator { + public: + explicit SlotObjAllocator(std::function deleter) + : free_nodes_(NULL), capacity_(0), deleter_(deleter) {} + ~SlotObjAllocator() { clear(); } + + void clear() { + T* tmp = NULL; + while (free_nodes_ != NULL) { + tmp = reinterpret_cast(reinterpret_cast(free_nodes_)); + free_nodes_ = free_nodes_->next; + deleter_(tmp); + --capacity_; + } + CHECK_EQ(capacity_, static_cast(0)); + } + T* acquire(void) { + T* x = NULL; + x = reinterpret_cast(reinterpret_cast(free_nodes_)); + free_nodes_ = free_nodes_->next; + --capacity_; + return x; + } + void release(T* x) { + Node* node = reinterpret_cast(reinterpret_cast(x)); + node->next = free_nodes_; + free_nodes_ = node; + ++capacity_; + } + size_t capacity(void) { return capacity_; } + + private: + struct alignas(T) Node { + union { + Node* next; + char data[sizeof(T)]; + }; + }; + Node* free_nodes_; // a list + size_t capacity_; + std::function deleter_ = nullptr; +}; +static const int OBJPOOL_BLOCK_SIZE = 10000; +class SlotObjPool { + public: + SlotObjPool() + : max_capacity_(FLAGS_record_pool_max_size), alloc_(free_slotrecord) { + ins_chan_ = MakeChannel(); + ins_chan_->SetBlockSize(OBJPOOL_BLOCK_SIZE); + for (int i = 0; i < FLAGS_slotpool_thread_num; ++i) { + threads_.push_back(std::thread([this]() { run(); })); + } + disable_pool_ = false; + count_ = 0; + } + ~SlotObjPool() { + ins_chan_->Close(); + for (auto& t : threads_) { + t.join(); + } + } + void disable_pool(bool disable) { disable_pool_ = disable; } + void set_max_capacity(size_t max_capacity) { max_capacity_ = max_capacity; } + void get(std::vector* output, int n) { + output->resize(n); + return get(&(*output)[0], n); + } + void get(SlotRecord* output, int n) { + int size = 0; + mutex_.lock(); + int left = static_cast(alloc_.capacity()); + if (left > 0) { + size = (left >= n) ? n : left; + for (int i = 0; i < size; ++i) { + output[i] = alloc_.acquire(); + } + } + mutex_.unlock(); + count_ += n; + if (size == n) { + return; + } + for (int i = size; i < n; ++i) { + output[i] = make_slotrecord(); + } + } + void put(std::vector* input) { + size_t size = input->size(); + if (size == 0) { + return; + } + put(&(*input)[0], size); + input->clear(); + } + void put(SlotRecord* input, size_t size) { + CHECK(ins_chan_->WriteMove(size, input) == size); + } + void run(void) { + std::vector input; + while (ins_chan_->ReadOnce(input, OBJPOOL_BLOCK_SIZE)) { + if (input.empty()) { + continue; + } + // over max capacity + size_t n = input.size(); + count_ -= n; + if (disable_pool_ || n + capacity() > max_capacity_) { + for (auto& t : input) { + free_slotrecord(t); + } + } else { + for (auto& t : input) { + t->reset(); + } + mutex_.lock(); + for (auto& t : input) { + alloc_.release(t); + } + mutex_.unlock(); + } + input.clear(); + } + } + void clear(void) { + platform::Timer timeline; + timeline.Start(); + mutex_.lock(); + alloc_.clear(); + mutex_.unlock(); + // wait release channel data + if (FLAGS_enable_slotpool_wait_release) { + while (!ins_chan_->Empty()) { + sleep(1); + } + } + timeline.Pause(); + VLOG(3) << "clear slot pool data size=" << count_.load() + << ", span=" << timeline.ElapsedSec(); + } + size_t capacity(void) { + mutex_.lock(); + size_t total = alloc_.capacity(); + mutex_.unlock(); + return total; + } + + private: + size_t max_capacity_; + Channel ins_chan_; + std::vector threads_; + std::mutex mutex_; + SlotObjAllocator alloc_; + bool disable_pool_; + std::atomic count_; // NOLINT +}; + +inline SlotObjPool& SlotRecordPool() { + static SlotObjPool pool; + return pool; +} struct PvInstanceObject { std::vector ads; void merge_instance(Record* ins) { ads.push_back(ins); } @@ -129,7 +384,21 @@ class CustomParser { CustomParser() {} virtual ~CustomParser() {} virtual void Init(const std::vector& slots) = 0; + virtual bool Init(const std::vector& slots) = 0; virtual void ParseOneInstance(const char* str, Record* instance) = 0; + virtual bool ParseOneInstance( + const std::string& line, + std::function&, int)> + GetInsFunc) { // NOLINT + return true; + } + virtual bool ParseFileInstance( + std::function ReadBuffFunc, + std::function&, int, int)> + PullRecordsFunc, // NOLINT + int& lines) { // NOLINT + return false; + } }; typedef paddle::framework::CustomParser* (*CreateParserObjectFunc)(); @@ -194,6 +463,34 @@ class DLManager { return nullptr; } + paddle::framework::CustomParser* Load(const std::string& name, + const std::vector& conf) { +#ifdef _LINUX + std::lock_guard lock(mutex_); + DLHandle handle; + std::map::iterator it = handle_map_.find(name); + if (it != handle_map_.end()) { + return it->second.parser; + } + handle.module = dlopen(name.c_str(), RTLD_NOW); + if (handle.module == nullptr) { + VLOG(0) << "Create so of " << name << " fail"; + exit(-1); + return nullptr; + } + + CreateParserObjectFunc create_parser_func = + (CreateParserObjectFunc)dlsym(handle.module, "CreateParserObject"); + handle.parser = create_parser_func(); + handle.parser->Init(conf); + handle_map_.insert({name, handle}); + + return handle.parser; +#endif + VLOG(0) << "Not implement in windows"; + return nullptr; + } + paddle::framework::CustomParser* ReLoad(const std::string& name, const std::vector& conf) { Close(name); @@ -415,6 +712,11 @@ class InMemoryDataFeed : public DataFeed { virtual void SetCurrentPhase(int current_phase); virtual void LoadIntoMemory(); virtual void LoadIntoMemoryFromSo(); + virtual void SetRecord(T* records) { records_ = records; } + int GetDefaultBatchSize() { return default_batch_size_; } + void AddBatchOffset(const std::pair& offset) { + batch_offsets_.push_back(offset); + } protected: virtual bool ParseOneInstance(T* instance) = 0; @@ -424,6 +726,11 @@ class InMemoryDataFeed : public DataFeed { virtual void PutToFeedVec(const std::vector& ins_vec) = 0; virtual void PutToFeedVec(const T* ins_vec, int num) = 0; + std::vector> batch_float_feasigns_; + std::vector> batch_uint64_feasigns_; + std::vector> offset_; + std::vector visit_; + int thread_id_; int thread_num_; bool parse_ins_id_; @@ -783,11 +1090,7 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed { MultiSlotInMemoryDataFeed() {} virtual ~MultiSlotInMemoryDataFeed() {} virtual void Init(const DataFeedDesc& data_feed_desc); - void SetRecord(Record* records) { records_ = records; } - int GetDefaultBatchSize() { return default_batch_size_; } - void AddBatchOffset(const std::pair& offset) { - batch_offsets_.push_back(offset); - } + // void SetRecord(Record* records) { records_ = records; } protected: virtual bool ParseOneInstance(Record* instance); @@ -798,10 +1101,6 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed { virtual void GetMsgFromLogKey(const std::string& log_key, uint64_t* search_id, uint32_t* cmatch, uint32_t* rank); virtual void PutToFeedVec(const Record* ins_vec, int num); - std::vector> batch_float_feasigns_; - std::vector> batch_uint64_feasigns_; - std::vector> offset_; - std::vector visit_; }; class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed { diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 08c42a93d1fcb..82a39b206e6bd 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -351,10 +351,8 @@ static int compute_thread_batch_nccl( return thread_avg_batch_num; } -template -void DatasetImpl::SetHeterPs(bool enable_heterps) { +void MultiSlotDataset::PrepareTrain() { #ifdef PADDLE_WITH_GLOO - enable_heterps_ = enable_heterps; if (enable_heterps_) { if (input_records_.size() == 0 && input_channel_ != nullptr && input_channel_->Size() != 0) { @@ -541,22 +539,21 @@ void DatasetImpl::LocalShuffle() { << timeline.ElapsedSec() << " seconds"; } -template -void DatasetImpl::GlobalShuffle(int thread_num) { +void MultiSlotDataset::GlobalShuffle(int thread_num) { #ifdef PADDLE_WITH_PSLIB - VLOG(3) << "DatasetImpl::GlobalShuffle() begin"; + VLOG(3) << "MultiSlotDataset::GlobalShuffle() begin"; platform::Timer timeline; timeline.Start(); auto fleet_ptr = FleetWrapper::GetInstance(); if (!input_channel_ || input_channel_->Size() == 0) { - VLOG(3) << "DatasetImpl::GlobalShuffle() end, no data to shuffle"; + VLOG(3) << "MultiSlotDataset::GlobalShuffle() end, no data to shuffle"; return; } // local shuffle input_channel_->Close(); - std::vector data; + std::vector data; input_channel_->ReadAll(data); std::shuffle(data.begin(), data.end(), fleet_ptr->LocalRandomEngine()); input_channel_->Open(); @@ -566,10 +563,10 @@ void DatasetImpl::GlobalShuffle(int thread_num) { input_channel_->Close(); input_channel_->SetBlockSize(fleet_send_batch_size_); - VLOG(3) << "DatasetImpl::GlobalShuffle() input_channel_ size " + VLOG(3) << "MultiSlotDataset::GlobalShuffle() input_channel_ size " << input_channel_->Size(); - auto get_client_id = [this, fleet_ptr](const T& data) -> size_t { + auto get_client_id = [this, fleet_ptr](const Record& data) -> size_t { if (!this->merge_by_insid_) { return fleet_ptr->LocalRandomEngine()() % this->trainer_num_; } else { @@ -580,7 +577,7 @@ void DatasetImpl::GlobalShuffle(int thread_num) { auto global_shuffle_func = [this, get_client_id]() { auto fleet_ptr = FleetWrapper::GetInstance(); - std::vector data; + std::vector data; while (this->input_channel_->Read(data)) { std::vector ars(this->trainer_num_); for (auto& t : data) { @@ -835,9 +832,6 @@ void DatasetImpl::CreateReaders() { channel_idx = 0; } } - if (enable_heterps_) { - SetHeterPs(true); - } VLOG(3) << "readers size: " << readers_.size(); } @@ -923,9 +917,8 @@ int64_t DatasetImpl::GetShuffleDataSize() { return sum; } -template -int DatasetImpl::ReceiveFromClient(int msg_type, int client_id, - const std::string& msg) { +int MultiSlotDataset::ReceiveFromClient(int msg_type, int client_id, + const std::string& msg) { #ifdef _LINUX VLOG(3) << "ReceiveFromClient msg_type=" << msg_type << ", client_id=" << client_id << ", msg length=" << msg.length(); @@ -937,9 +930,9 @@ int DatasetImpl::ReceiveFromClient(int msg_type, int client_id, if (ar.Cursor() == ar.Finish()) { return 0; } - std::vector data; + std::vector data; while (ar.Cursor() < ar.Finish()) { - data.push_back(ar.Get()); + data.push_back(ar.Get()); } CHECK(ar.Cursor() == ar.Finish()); @@ -966,6 +959,20 @@ int DatasetImpl::ReceiveFromClient(int msg_type, int client_id, // explicit instantiation template class DatasetImpl; +void MultiSlotDataset::DynamicAdjustReadersNum(int thread_num) { + if (thread_num_ == thread_num) { + VLOG(3) << "DatasetImpl::DynamicAdjustReadersNum thread_num_=" + << thread_num_ << ", thread_num_=thread_num, no need to adjust"; + return; + } + VLOG(3) << "adjust readers num from " << thread_num_ << " to " << thread_num; + thread_num_ = thread_num; + std::vector>().swap(readers_); + CreateReaders(); + VLOG(3) << "adjust readers num done"; + PrepareTrain(); +} + void MultiSlotDataset::PostprocessInstance() { // divide pv instance, and merge to input_channel_ if (enable_pv_merge_) { @@ -1503,5 +1510,126 @@ void MultiSlotDataset::SlotsShuffle( << ", cost time=" << timeline.ElapsedSec() << " seconds"; } +template class DatasetImpl; +void SlotRecordDataset::CreateChannel() { + if (input_channel_ == nullptr) { + input_channel_ = paddle::framework::MakeChannel(); + } +} +void SlotRecordDataset::CreateReaders() { + VLOG(3) << "Calling CreateReaders()"; + VLOG(3) << "thread num in Dataset: " << thread_num_; + VLOG(3) << "Filelist size in Dataset: " << filelist_.size(); + VLOG(3) << "channel num in Dataset: " << channel_num_; + CHECK(thread_num_ > 0) << "thread num should > 0"; + CHECK(channel_num_ > 0) << "channel num should > 0"; + CHECK(channel_num_ <= thread_num_) << "channel num should <= thread num"; + VLOG(3) << "readers size: " << readers_.size(); + if (readers_.size() != 0) { + VLOG(3) << "readers_.size() = " << readers_.size() + << ", will not create again"; + return; + } + VLOG(3) << "data feed class name: " << data_feed_desc_.name(); + for (int i = 0; i < thread_num_; ++i) { + readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name())); + readers_[i]->Init(data_feed_desc_); + readers_[i]->SetThreadId(i); + readers_[i]->SetThreadNum(thread_num_); + readers_[i]->SetFileListMutex(&mutex_for_pick_file_); + readers_[i]->SetFileListIndex(&file_idx_); + readers_[i]->SetFeaNumMutex(&mutex_for_fea_num_); + readers_[i]->SetFeaNum(&total_fea_num_); + readers_[i]->SetFileList(filelist_); + readers_[i]->SetParseInsId(parse_ins_id_); + readers_[i]->SetParseContent(parse_content_); + readers_[i]->SetParseLogKey(parse_logkey_); + readers_[i]->SetEnablePvMerge(enable_pv_merge_); + readers_[i]->SetCurrentPhase(current_phase_); + if (input_channel_ != nullptr) { + readers_[i]->SetInputChannel(input_channel_.get()); + } + } + VLOG(3) << "readers size: " << readers_.size(); +} + +void SlotRecordDataset::ReleaseMemory() { + VLOG(3) << "SlotRecordDataset::ReleaseMemory() begin"; + platform::Timer timeline; + timeline.Start(); + + if (input_channel_) { + input_channel_->Clear(); + input_channel_ = nullptr; + } + if (enable_heterps_) { + VLOG(3) << "put pool records size: " << input_records_.size(); + SlotRecordPool().put(&input_records_); + input_records_.clear(); + input_records_.shrink_to_fit(); + VLOG(3) << "release heterps input records records size: " + << input_records_.size(); + } + + readers_.clear(); + readers_.shrink_to_fit(); + + std::vector>().swap(readers_); + + VLOG(3) << "SlotRecordDataset::ReleaseMemory() end"; + VLOG(3) << "total_feasign_num_(" << STAT_GET(STAT_total_feasign_num_in_mem) + << ") - current_fea_num_(" << total_fea_num_ << ") = (" + << STAT_GET(STAT_total_feasign_num_in_mem) - total_fea_num_ << ")" + << " object pool size=" << SlotRecordPool().capacity(); // For Debug + STAT_SUB(STAT_total_feasign_num_in_mem, total_fea_num_); +} +void SlotRecordDataset::GlobalShuffle(int thread_num) { + // TODO(yaoxuefeng) + return; +} + +void SlotRecordDataset::DynamicAdjustChannelNum(int channel_num, + bool discard_remaining_ins) { + if (channel_num_ == channel_num) { + VLOG(3) << "DatasetImpl::DynamicAdjustChannelNum channel_num_=" + << channel_num_ << ", channel_num_=channel_num, no need to adjust"; + return; + } + VLOG(3) << "adjust channel num from " << channel_num_ << " to " + << channel_num; + channel_num_ = channel_num; + + if (static_cast(input_channel_->Size()) >= channel_num) { + input_channel_->SetBlockSize(input_channel_->Size() / channel_num + + (discard_remaining_ins ? 0 : 1)); + } + + VLOG(3) << "adjust channel num done"; +} + +void SlotRecordDataset::PrepareTrain() { +#ifdef PADDLE_WITH_GLOO + return; +#else + PADDLE_THROW(platform::errors::Unavailable( + "dataset set heterps need compile with GLOO")); +#endif + return; +} + +void SlotRecordDataset::DynamicAdjustReadersNum(int thread_num) { + if (thread_num_ == thread_num) { + VLOG(3) << "DatasetImpl::DynamicAdjustReadersNum thread_num_=" + << thread_num_ << ", thread_num_=thread_num, no need to adjust"; + return; + } + VLOG(3) << "adjust readers num from " << thread_num_ << " to " << thread_num; + thread_num_ = thread_num; + std::vector>().swap(readers_); + CreateReaders(); + VLOG(3) << "adjust readers num done"; + PrepareTrain(); +} + } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index f3ee96fab8297..981fb694e0fec 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -149,7 +149,6 @@ class Dataset { virtual void DynamicAdjustReadersNum(int thread_num) = 0; // set fleet send sleep seconds virtual void SetFleetSendSleepSeconds(int seconds) = 0; - virtual void SetHeterPs(bool enable_heterps) = 0; protected: virtual int ReceiveFromClient(int msg_type, int client_id, @@ -207,7 +206,7 @@ class DatasetImpl : public Dataset { virtual void WaitPreLoadDone(); virtual void ReleaseMemory(); virtual void LocalShuffle(); - virtual void GlobalShuffle(int thread_num = -1); + virtual void GlobalShuffle(int thread_num = -1) {} virtual void SlotsShuffle(const std::set& slots_to_replace) {} virtual const std::vector& GetSlotsOriginalData() { return slots_shuffle_original_data_; @@ -233,7 +232,11 @@ class DatasetImpl : public Dataset { bool discard_remaining_ins = false); virtual void DynamicAdjustReadersNum(int thread_num); virtual void SetFleetSendSleepSeconds(int seconds); - virtual void SetHeterPs(bool enable_heterps); + /* for enable_heterps_ + virtual void EnableHeterps(bool enable_heterps) { + enable_heterps_ = enable_heterps; + } + */ std::vector>& GetMultiOutputChannel() { return multi_output_channel_; @@ -251,7 +254,10 @@ class DatasetImpl : public Dataset { protected: virtual int ReceiveFromClient(int msg_type, int client_id, - const std::string& msg); + const std::string& msg) { + // TODO(yaoxuefeng) for SlotRecordDataset + return -1; + } std::vector> readers_; std::vector> preload_readers_; paddle::framework::Channel input_channel_; @@ -327,6 +333,32 @@ class MultiSlotDataset : public DatasetImpl { const std::unordered_set& slots_to_replace, std::vector* result); virtual ~MultiSlotDataset() {} + virtual void GlobalShuffle(int thread_num = -1); + virtual void DynamicAdjustReadersNum(int thread_num); + virtual void PrepareTrain(); + + protected: + virtual int ReceiveFromClient(int msg_type, int client_id, + const std::string& msg); +}; +class SlotRecordDataset : public DatasetImpl { + public: + SlotRecordDataset() { SlotRecordPool(); } + virtual ~SlotRecordDataset() {} + // create input channel + virtual void CreateChannel(); + // create readers + virtual void CreateReaders(); + // release memory + virtual void ReleaseMemory(); + virtual void GlobalShuffle(int thread_num = -1); + virtual void DynamicAdjustChannelNum(int channel_num, + bool discard_remaining_ins); + virtual void PrepareTrain(); + virtual void DynamicAdjustReadersNum(int thread_num); + + protected: + bool enable_heterps_ = true; }; } // end namespace framework diff --git a/paddle/fluid/framework/dataset_factory.cc b/paddle/fluid/framework/dataset_factory.cc index aeaf961185323..38200927c5586 100644 --- a/paddle/fluid/framework/dataset_factory.cc +++ b/paddle/fluid/framework/dataset_factory.cc @@ -53,7 +53,7 @@ std::unique_ptr DatasetFactory::CreateDataset( std::string dataset_class) { if (g_dataset_map.count(dataset_class) < 1) { LOG(WARNING) << "Your Dataset " << dataset_class - << "is not supported currently"; + << " is not supported currently"; LOG(WARNING) << "Supported Dataset: " << DatasetTypeList(); exit(-1); } @@ -61,5 +61,6 @@ std::unique_ptr DatasetFactory::CreateDataset( } REGISTER_DATASET_CLASS(MultiSlotDataset); +REGISTER_DATASET_CLASS(SlotRecordDataset); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index b97c3106439be..1a19bb3aa97f0 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -673,3 +673,11 @@ PADDLE_DEFINE_EXPORTED_int32(get_host_by_name_time, 120, PADDLE_DEFINE_EXPORTED_bool( apply_pass_to_program, false, "It controls whether to apply IR pass to program when using Fleet APIs"); + +DEFINE_int32(record_pool_max_size, 2000000, + "SlotRecordDataset slot record pool max size"); +DEFINE_int32(slotpool_thread_num, 1, "SlotRecordDataset slot pool thread num"); +DEFINE_bool(enable_slotpool_wait_release, false, + "enable slotrecord obejct wait release, default false"); +DEFINE_bool(enable_slotrecord_reset_shrink, false, + "enable slotrecord obejct reset shrink memory, default false"); \ No newline at end of file diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index 41cf0189d3d9d..7a32d8729fc6c 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -309,8 +309,6 @@ void BindDataset(py::module *m) { &framework::Dataset::SetFleetSendSleepSeconds, py::call_guard()) .def("enable_pv_merge", &framework::Dataset::EnablePvMerge, - py::call_guard()) - .def("set_heter_ps", &framework::Dataset::SetHeterPs, py::call_guard()); py::class_(*m, "IterableDatasetWrapper") From d2be870a49144987eec5a3b1b18d14a8eec03858 Mon Sep 17 00:00:00 2001 From: Li Min <11663212+limin2021@users.noreply.github.com> Date: Tue, 26 Oct 2021 11:27:55 +0800 Subject: [PATCH 4/9] [cherry-pick-2.2] Fused attention op forward (#35905) (#36708) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 功能:本PR的目标是提高attention模块的计算性能。 为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op; 为了减少防存开销,本PR采取了两种优化方法: (1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次; (2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据; --- cmake/operators.cmake | 2 +- paddle/fluid/operators/dropout_impl_util.h | 3 + paddle/fluid/operators/fused/CMakeLists.txt | 4 + .../operators/fused/fused_attention_op.cc | 336 ++++++++++++++++++ .../operators/fused/fused_attention_op.cu | 209 +++++++++++ .../operators/fused/fused_dropout_helper.h | 2 +- paddle/fluid/pybind/op_function_generator.cc | 8 + .../fluid/tests/unittests/CMakeLists.txt | 4 + .../unittests/test_fused_attention_op.py | 235 ++++++++++++ python/paddle/nn/functional/__init__.py | 2 + .../paddle/nn/functional/fused_transformer.py | 127 +++++++ python/paddle/nn/layer/transformer.py | 2 +- 12 files changed, 931 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/operators/fused/fused_attention_op.cc create mode 100644 paddle/fluid/operators/fused/fused_attention_op.cu create mode 100644 python/paddle/fluid/tests/unittests/test_fused_attention_op.py create mode 100644 python/paddle/nn/functional/fused_transformer.py diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 7541b234ceaa6..1f25dfd8a9f4b 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -216,7 +216,7 @@ function(op_library TARGET) "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "sparse_attention_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op" -"fused_bn_add_activation_op") +"fused_bn_add_activation_op" "fused_attention_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/operators/dropout_impl_util.h b/paddle/fluid/operators/dropout_impl_util.h index f2038d12528c4..e11640d070625 100644 --- a/paddle/fluid/operators/dropout_impl_util.h +++ b/paddle/fluid/operators/dropout_impl_util.h @@ -34,6 +34,9 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx, TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor); *seed_data = static_cast(seed_cpu_tensor.data()[0]); *increment = offset; + } else if (seed && platform::is_cpu_place(seed->place())) { + *seed_data = *(seed->data()); + *increment = offset; } else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) { auto seed_offset = gen_cuda->IncrementOffset(offset); *seed_data = seed_offset.first; diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index e3dcff949f43c..b993645031054 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -16,6 +16,7 @@ register_operators(EXCLUDES fusion_gru_op fusion_lstm_op fused_bn_add_activation_op + fused_attention_op fused_transformer_op) # fusion_gru_op does not have CUDA kernel @@ -77,5 +78,8 @@ if (WITH_GPU OR WITH_ROCM) nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) + # fused_attention_op + op_library(fused_attention_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_attention);\n") endif() endif() diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc new file mode 100644 index 0000000000000..a286c39f7f8db --- /dev/null +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -0,0 +1,336 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class FusedAttentionOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", + "FusedAttentionOp"); + + OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut", + "FusedAttentionOp"); + // qkv_out: [batch_size, seq_len, 3, num_head, dim_head] + OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("TransposeOut2"), "Output", "TransposeOut2", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("QKOut"), "Output", "QKOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("SoftmaxOut"), "Output", "SoftmaxOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutMaskOut"), "Output", + "AttnDropoutMaskOut", "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutOut"), "Output", "AttnDropoutOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("FMHAOut"), "Output", "FMHAOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output", + "BiasDropoutResidualOut", "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp"); + + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputDim("QKVW"); + PADDLE_ENFORCE_EQ(x_dim.size(), 3, platform::errors::InvalidArgument( + "The dimensions of x must be 3" + "(batch_size, seq_len, dim_embed)," + "but received dimensions of" + "Input is [%d]", + x_dim.size())); + PADDLE_ENFORCE_EQ(y_dim.size(), 4, + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "but received dimensions of" + "Input is [%d]", + y_dim.size())); + PADDLE_ENFORCE_EQ(x_dim[2], y_dim[3], + platform::errors::InvalidArgument( + "ShapeError: the dimension of x_dim[2] and y_dim[3]" + "must be equal. But received: the shape " + "of input x = [%s], and the shape of " + "input qkv_weight = [%s]", + x_dim, y_dim)); + + ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("LnOut", ctx->GetInputDim("X")); + // [batch_size, seq_len, 3, num_head, head_size] + ctx->SetOutputDim("QKVOut", + {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); + ctx->SetOutputDim("QKVBiasOut", + {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); + // [3, batch_size, num_head, seq_len, head_size] + ctx->SetOutputDim("TransposeOut2", + {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); + // [batch, num_head, seq_len, seq_len] + ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + // the same as QKOut's shape. + ctx->SetOutputDim("AttnDropoutOut", + {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + if (ctx->Attrs().Get("attn_dropout_is_test") == false) { + ctx->SetOutputDim("AttnDropoutMaskOut", + {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + } + ctx->SetOutputDim("SoftmaxOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + // [batch_size, num_heads, seq_len, head_dim] + ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); + // [batch_size, seq_len, number of heads*head size] + ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]}); + ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X")); + + ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]}); + if (ctx->Attrs().Get("dropout_is_test") == false) { + ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X")); + } + ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X")); + ctx->SetOutputDim("Y", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input = ctx.Input("X"); + auto input_data_type = input->type(); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor."); + AddInput("LnScale", + "(optional) Scale is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDispensable(); + AddInput("LnBias", + "(optional) Bias is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDispensable(); + AddInput("QKVW", "The qkv weight tensor."); + AddInput("QKVBias", "The qkv bias tensor."); + AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") + .AsDispensable(); + AddInput("OutLinearW", "The out_linear weight tensor."); + AddInput("OutLinearBias", "The out_linear bias tensor."); + AddInput("Ln2Scale", + "(optional) Scale is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDispensable(); + AddInput("Ln2Bias", + "(optional) Bias is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDispensable(); + AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate(); + AddOutput("LnVariance", "Variance of the current mini batch.") + .AsIntermediate(); + AddOutput("LnOut", "The output of pre layer_norm.").AsIntermediate(); + AddOutput("QKVOut", "Result after qkv.").AsIntermediate(); + AddOutput("QKVBiasOut", "Result after qkv and bias op.").AsIntermediate(); + AddOutput("TransposeOut2", "Result in fmha.").AsIntermediate(); + AddOutput("QKOut", "Result in fmha.").AsIntermediate(); + AddOutput("QKTVOut", "Result in fmha.").AsIntermediate(); + AddOutput("SoftmaxOut", "Result in fmha.").AsIntermediate(); + AddOutput("AttnDropoutMaskOut", "Result in fmha.").AsIntermediate(); + AddOutput("AttnDropoutOut", "Result in fmha.").AsIntermediate(); + AddOutput("SrcMaskOut", "Result in fmha.").AsIntermediate(); + AddOutput("FMHAOut", "Result after fmha.").AsIntermediate(); + AddOutput("OutLinearOut", "Result after out_linear.").AsIntermediate(); + AddOutput("DropoutMaskOut", "The random sampled dropout mask.") + .AsIntermediate(); + AddOutput("Ln2Mean", "Mean of the current mini batch.").AsIntermediate(); + AddOutput("Ln2Variance", "Variance of the current mini batch.") + .AsIntermediate(); + AddOutput("BiasDropoutResidualOut", + "Result of residual + dropout(src + bias).") + .AsIntermediate(); + AddOutput("Y", "Result after attention."); + + AddAttr("pre_layer_norm", + "if true, the attention op uses pre_layer_norm architecure, " + "else, uses post_layer_norm architecuture. " + "[default false].") + .SetDefault(false); + AddAttr("epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true, + platform::errors::InvalidArgument( + "'epsilon' in Op(LayerNorm) should be between" + "0.0 and 0.001, But received [%s].", + epsilon)); + }); + + // for dropout in fmha. + AddAttr("attn_dropout_rate", "Probability of setting units to zero.") + .SetDefault(.5f) + .AddCustomChecker([](const float &drop_p) { + PADDLE_ENFORCE_EQ( + drop_p >= 0.0f && drop_p <= 1.0f, true, + platform::errors::InvalidArgument( + "'attn_dropout_rate' must be between 0.0 and 1.0.")); + }); + AddAttr("attn_dropout_is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr("attn_dropout_fix_seed", + "A flag indicating whether to use a fixed seed to generate " + "random mask. NOTE: DO NOT set this flag to true in " + "training. Setting this flag to true is only useful in " + "unittest or for debug that always the same output units " + "will be dropped.") + .SetDefault(true); + AddAttr("attn_dropout_seed", "Dropout random seed.").SetDefault(0); + AddAttr( + "attn_dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "There are two kinds of ways to implement dropout" + "(the mask below is a tensor have the same shape with input" + "the value of mask is 0 or 1, the ratio of 0 is dropout_rate)" + "1. downgrade_in_infer(default), downgrade the outcome at inference " + "time" + " train: out = input * mask" + " inference: out = input * (1.0 - dropout_rate)" + "2. upscale_in_train, upscale the outcome at training time, do nothing " + "in inference" + " train: out = input * mask / ( 1.0 - dropout_rate )" + " inference: out = input" + " dropout op can be removed from the program. the program will be " + "efficient") + .SetDefault("upscale_in_train") + .AddCustomChecker([](const std::string &type) { + PADDLE_ENFORCE_EQ( + type == "downgrade_in_infer" || type == "upscale_in_train", true, + platform::errors::InvalidArgument( + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train")); + }); + + AddAttr("dropout_rate", "Probability of setting units to zero.") + .SetDefault(.5f) + .AddCustomChecker([](const float &drop_p) { + PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, true, + platform::errors::InvalidArgument( + "'dropout_rate' must be between 0.0 and 1.0.")); + }); + + AddAttr("dropout_is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr("dropout_fix_seed", + "A flag indicating whether to use a fixed seed to generate " + "random mask. NOTE: DO NOT set this flag to true in " + "training. Setting this flag to true is only useful in " + "unittest or for debug that always the same output units " + "will be dropped.") + .SetDefault(true); + AddAttr("dropout_seed", "Dropout random seed.").SetDefault(0); + AddAttr( + "dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "The meaning is the same as 'attn_dropout_implementation'.") + .SetDefault("downgrade_in_infer") + .AddCustomChecker([](const std::string &type) { + PADDLE_ENFORCE_EQ( + type == "downgrade_in_infer" || type == "upscale_in_train", true, + platform::errors::InvalidArgument( + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train")); + }); + AddAttr("ln_epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &ln_epsilon) { + PADDLE_ENFORCE_EQ(ln_epsilon >= 0.0f && ln_epsilon <= 0.001f, true, + platform::errors::InvalidArgument( + "'epsilon' of the second LayerNorm in Fused " + "attention op should be between" + "0.0 and 0.001, But received [%s].", + ln_epsilon)); + }); + + AddComment(R"DOC( + Add fused attention op whose logic is as follows: + // @input: [batch_size, seq_len, 3, num_head, head_dim] + // @final_out: [batch_size, seq_len, num_heads, head_dim] + if (pre_layernorm) + out = layer_norm(input); + out = compute_qkv(out) + bias; + // fmha module + { + out = transpose(out, perm=[2, 0, 3, 1, 4]); + out = q * k^t; + out = attn_mark + out; + out = softmax(out); + out = dropout(out); + out = out * v; + out = transpose(out, perm=[0, 2, 1, 3]); + + } + out = out_linear(out); + final_out = layer_norm(residual + dropout(bias + out)); + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp, + ops::FusedAttentionOpMaker); diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu new file mode 100644 index 0000000000000..18a42b5c2cee2 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -0,0 +1,209 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/fluid/operators/math/math_function.h" + +#include "paddle/fluid/operators/fused/attention_layer_norm.h" +#include "paddle/fluid/operators/fused/attn_gemm.h" +#include "paddle/fluid/operators/fused/fmha_ref.h" +#include "paddle/fluid/operators/fused/fused_dropout_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class FusedAttentionOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + auto *input_x = ctx.Input("X"); + + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + const float epsilon = ctx.Attr("epsilon"); + auto *ln_scale = ctx.Input("LnScale"); + auto *ln_bias = ctx.Input("LnBias"); + auto *ln_mean = ctx.Output("LnMean"); + auto *ln_var = ctx.Output("LnVariance"); + auto *ln_out = ctx.Output("LnOut"); + + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto *qkv_weight = ctx.Input("QKVW"); + auto *qkv_bias = ctx.Input("QKVBias"); + auto *qkv_out = ctx.Output("QKVOut"); + auto *qkv_bias_out = ctx.Output("QKVBiasOut"); + + auto *src_mask = ctx.Input("SrcMask"); + auto *transpose_out_2 = ctx.Output("TransposeOut2"); + auto *qk_out = ctx.Output("QKOut"); + auto *qktv_out = ctx.Output("QKTVOut"); + auto *softmax_out = ctx.Output("SoftmaxOut"); + auto *attn_dropout_mask_out = ctx.Output("AttnDropoutMaskOut"); + auto *attn_dropout_out = ctx.Output("AttnDropoutOut"); + auto *src_mask_out = ctx.Output("SrcMaskOut"); + auto *fmha_out = ctx.Output("FMHAOut"); + + auto *out_linear_weight = ctx.Input("OutLinearW"); + auto *out_linear_bias = ctx.Input("OutLinearBias"); + auto *out_linear_out = ctx.Output("OutLinearOut"); + + auto *ln_scale_2 = ctx.Input("Ln2Scale"); + auto *ln_bias_2 = ctx.Input("Ln2Bias"); + auto *dropout_mask_out = ctx.Output("DropoutMaskOut"); + auto *bias_dropout_residual_out = + ctx.Output("BiasDropoutResidualOut"); + auto *ln_mean_2 = ctx.Output("Ln2Mean"); + auto *ln_var_2 = ctx.Output("Ln2Variance"); + const float ln_epsilon = ctx.Attr("ln_epsilon"); + + float attn_dropout_rate = ctx.Attr("attn_dropout_rate"); + bool is_test_1 = ctx.Attr("attn_dropout_is_test"); + auto &dropout_implementation_1 = + ctx.Attr("attn_dropout_implementation"); + bool is_upscale_in_train_1 = + (dropout_implementation_1 == "upscale_in_train"); + auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; + bool is_fix_seed_1 = ctx.Attr("attn_dropout_fix_seed"); + int seed_val_1 = ctx.Attr("attn_dropout_seed"); + + // final output. + auto *out = ctx.Output("Y"); + + // get data ptr for qkv part. + const auto input_x_dims = input_x->dims(); + const auto qkv_w_dims = qkv_weight->dims(); + + auto *x_data = input_x->data(); + auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); + auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); + auto *ln_mean_data = ln_mean->mutable_data(ctx.GetPlace()); + auto *ln_var_data = ln_var->mutable_data(ctx.GetPlace()); + auto *ln_out_data = ln_out->mutable_data(ctx.GetPlace()); + + auto *qkv_weight_data = qkv_weight->data(); + auto *qkv_bias_data = qkv_bias->data(); + auto *qkv_out_data = qkv_out->mutable_data(ctx.GetPlace()); + auto *qkv_bias_out_data = qkv_bias_out->mutable_data(ctx.GetPlace()); + + // get data ptr for FMHA. + auto *transpose_out_2_data = + transpose_out_2->mutable_data(ctx.GetPlace()); + auto *qk_out_data = qk_out->mutable_data(ctx.GetPlace()); + auto *qktv_out_data = qktv_out->mutable_data(ctx.GetPlace()); + auto *src_mask_out_data = src_mask_out->mutable_data(ctx.GetPlace()); + auto *softmax_out_data = softmax_out->mutable_data(ctx.GetPlace()); + auto *attn_dropout_mask_out_data = + attn_dropout_mask_out->mutable_data(ctx.GetPlace()); + auto *attn_dropout_out_data = + attn_dropout_out->mutable_data(ctx.GetPlace()); + auto *fmha_out_data = fmha_out->mutable_data(ctx.GetPlace()); + + // get data ptr for out_linear. + auto *out_linear_weight_data = out_linear_weight->data(); + auto *out_linear_bias_data = out_linear_bias->data(); + auto *out_linear_out_data = out_linear_out->mutable_data(ctx.GetPlace()); + + // get data ptr for bias+dropout+residual+layernorm + auto *ln_scale_2_data = + (ln_scale_2 == nullptr ? nullptr : ln_scale_2->data()); + auto *ln_bias_2_data = + (ln_bias_2 == nullptr ? nullptr : ln_bias_2->data()); + auto *dropout_mask_out_data = + dropout_mask_out->mutable_data(ctx.GetPlace()); + auto *bias_dropout_residual_out_data = + bias_dropout_residual_out->mutable_data(ctx.GetPlace()); + auto *ln_mean_2_data = ln_mean_2->mutable_data(ctx.GetPlace()); + auto *ln_var_2_data = ln_var_2->mutable_data(ctx.GetPlace()); + auto *final_out_data = out->mutable_data(ctx.GetPlace()); + + int batch_size = input_x_dims[0]; + int max_seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + + int num_head = qkv_w_dims[1]; + int dim_head = qkv_w_dims[2]; + + int bsz_seq = batch_size * max_seq_len; + int hidden_size = num_head * dim_head; + int output_size = 3 * hidden_size; + int input_size = dim_embed; + + auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), + epsilon, bsz_seq, dim_embed); + // (transA, transB, compute_bias) = (false, true, true) + auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), false, true, + bsz_seq, output_size, input_size, true); + + AttnDropoutParam attn_dropout_param( + is_test_1, dropout_implementation_1, attn_dropout_rate, + is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); + auto fmha_ref_compute = + FMHARef(ctx.cuda_device_context(), batch_size, max_seq_len, num_head, + dim_head, attn_dropout_param); + + output_size = hidden_size; + // (transA, transB, compute_bias) = (false, false, false) + auto out_linear_compute = + AttnMatMul(ctx.cuda_device_context(), false, false, bsz_seq, + output_size, input_size, false); + DropoutParam dropout_param2(ctx, 0); + FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( + ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, + ln_epsilon); + + if (pre_layer_norm) { + layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, + ln_out_data, ln_mean_data, ln_var_data); + qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data, + qkv_out_data, qkv_bias_out_data); + } else { + qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data, + qkv_out_data, qkv_bias_out_data); + } + fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2, + qk_out, src_mask_out, softmax_out, + attn_dropout_mask_out, attn_dropout_out, + qktv_out, fmha_out); + // fmha_out: [batch_size, seq_len, num_head, head_dim] + // weight: [embed_dim, embed_dim] + // out_linear_out: [batch_size, seq_len, embed_dim] + out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data, + nullptr, out_linear_out_data, nullptr); + // output = layernorm(residual + dropout(input + bias)) + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + ctx.cuda_device_context(), out_linear_out_data, x_data, + out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, + bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, + ln_mean_2_data, ln_var_2_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(fused_attention, ops::FusedAttentionOpKernel, + ops::FusedAttentionOpKernel, + ops::FusedAttentionOpKernel); diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index fcfa405a52f9b..33fde64164d12 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -66,7 +66,7 @@ struct DropoutParam { } else { pre_fix = pre_fix + "_"; } - dropout_prob = context.Attr(pre_fix + "prob"); + dropout_prob = context.Attr(pre_fix + "rate"); auto& dropout_implementation = context.Attr(pre_fix + "implementation"); is_upscale_in_train = (dropout_implementation == "upscale_in_train"); diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 01d101909b549..53c7e165d8433 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -40,6 +40,9 @@ // need to manually specify them in this map. std::map> op_ins_map = { {"layer_norm", {"X", "Scale", "Bias"}}, + {"fused_attention", + {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "SrcMask", "OutLinearW", + "OutLinearBias", "Ln2Scale", "Ln2Bias"}}, {"instance_norm", {"X", "Scale", "Bias"}}, {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, {"label_smooth", {"X", "PriorDist"}}, @@ -91,6 +94,11 @@ std::map> op_outs_map = { {"batch_norm", {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, + {"fused_attention", + {"LnMean", "LnVariance", "LnOut", "QKVOut", "QKVBiasOut", "TransposeOut2", + "QKOut", "QKTVOut", "SoftmaxOut", "AttnDropoutMaskOut", "AttnDropoutOut", + "SrcMaskOut", "FMHAOut", "OutLinearOut", "DropoutMaskOut", "Ln2Mean", + "Ln2Variance", "BiasDropoutResidualOut", "Y"}}, {"sync_batch_norm", {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 1ec0812a6661e..7dca567b64886 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -90,6 +90,10 @@ foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) endforeach() +if(NOT WITH_GPU) + LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op) +endif() + if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_c_comm_init_all_op) LIST(REMOVE_ITEM TEST_OPS test_c_concat) diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py new file mode 100644 index 0000000000000..a5578d71c5cd0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -0,0 +1,235 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.fluid.core as core +import paddle.nn.functional as F +from paddle.nn.layer.norm import LayerNorm +from paddle.nn.layer.common import Linear, Dropout +from paddle.nn.layer.transformer import _convert_attention_mask +from paddle import tensor +from paddle.fluid import layers +import unittest +from op_test import OpTest + + +class TestFusedAttentionOp(OpTest): + def setUp(self): + self.config() + self.generate_input_data() + paddle.set_default_dtype(self.x_type) + self.__class__.op_type = "fused_attention" + self.q_proj = Linear( + self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.k_proj = Linear( + self.kdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.v_proj = Linear( + self.vdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.out_proj = Linear( + self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + paddle.set_default_dtype(np.float32) + self.norm1 = LayerNorm(self.embed_dim) + self.norm2 = LayerNorm(self.embed_dim) + paddle.set_default_dtype(self.x_type) + self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train") + + def config(self): + self.x_type = np.float32 + self.attn_mask_type = np.float64 + self.pre_layer_norm = True + self.training = True + + self.batch_size = 8 + self.query_length = 128 + self.head_dim = 64 + self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = None + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + def generate_input_data(self): + self.query = np.random.rand(self.batch_size, self.query_length, + self.embed_dim).astype(self.x_type) + self.attn_mask = np.ones( + (self.batch_size, self.num_heads, self.query_length, + self.key_length), + dtype=self.attn_mask_type) + if self.attn_mask_type == np.int64: + self.attn_mask = np.tril(self.attn_mask) + elif self.attn_mask_type == np.float64: + self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 + else: + raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") + self.key, self.value = self.query, self.query + + self.dout = np.random.random((self.batch_size, self.query_length, + self.embed_dim)).astype(self.x_type) + + def GetBaselineOut(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + tensor_query = paddle.to_tensor(self.query, stop_gradient=False) + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + residual = tensor_query + + ln1_out = tensor_query + if self.pre_layer_norm: + ln1_out = self.norm1(tensor_query) + + q = self.q_proj(ln1_out) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + k = self.k_proj(ln1_out) + v = self.v_proj(ln1_out) + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + qk_out = layers.matmul( + x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5) + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) + attn_mask_out = qk_out + attn_mask + softmax_out = F.softmax(attn_mask_out) + else: + softmax_out = F.softmax(qk_out) + + if self.dropout_prob: + dropout_out = F.dropout( + softmax_out, + self.dropout_prob, + training=self.training, + mode="upscale_in_train") + qktv_out = tensor.matmul(dropout_out, v_out) + else: + qktv_out = tensor.matmul(softmax_out, v_out) + + fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) + out_linear_in = tensor.reshape( + x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) + out = self.out_proj(out_linear_in) + + residual_out = residual + self.dropout(out) + if not self.pre_layer_norm: + final_out = self.norm1(residual_out) + if self.pre_layer_norm: + final_out = self.norm2(residual_out) + return final_out + + def GetFusedAttentionOut(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + q_proj_weight = paddle.to_tensor( + self.q_proj.weight, stop_gradient=False) + q_proj_bias = paddle.to_tensor(self.q_proj.bias, stop_gradient=False) + k_proj_weight = paddle.to_tensor( + self.k_proj.weight, stop_gradient=False) + k_proj_bias = paddle.to_tensor(self.k_proj.bias, stop_gradient=False) + v_proj_weight = paddle.to_tensor( + self.v_proj.weight, stop_gradient=False) + v_proj_bias = paddle.to_tensor(self.v_proj.bias, stop_gradient=False) + out_linear_weight = paddle.to_tensor( + self.out_proj.weight, stop_gradient=False) + out_linear_bias = paddle.to_tensor( + self.out_proj.bias, stop_gradient=False) + + ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False) + ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False) + ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False) + ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False) + + q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) + k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) + v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) + qkv_weight = np.concatenate( + (q_proj_weight, k_proj_weight, v_proj_weight)) + qkv_weight = qkv_weight.reshape( + (3, self.num_heads, self.head_dim, self.embed_dim)) + + qkv_bias = np.concatenate( + (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) + qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) + + x = paddle.to_tensor(self.query, stop_gradient=False) + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) + qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) + epsilon = 1e-05 + ln2_epsilon = 1e-05 + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, x.dtype) + final_out = F.fused_multi_head_attention( + x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm, + ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor, + out_linear_bias, attn_mask, self.dropout_prob, + self.attn_dropout_prob, ln2_epsilon) + return final_out + + def test_fused_attention_op(self): + final_out_ref = self.GetBaselineOut() + final_out = self.GetFusedAttentionOut() + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5) + + +class TestFusedAttentionOpFp16(TestFusedAttentionOp): + def config(self): + self.x_type = np.float16 + self.attn_mask_type = np.float64 + self.pre_layer_norm = True + self.training = True + + self.batch_size = 8 + self.query_length = 128 + self.head_dim = 64 + self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = None + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + def test_fused_attention_op(self): + final_out_ref = self.GetBaselineOut() + final_out = self.GetFusedAttentionOut() + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 4151f25b94aff..642e3606cab4f 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -60,6 +60,7 @@ from .conv import conv1d # noqa: F401 from .conv import conv1d_transpose # noqa: F401 from .common import linear # noqa: F401 +from .fused_transformer import fused_multi_head_attention # noqa: F401 from .conv import conv2d # noqa: F401 from .conv import conv2d_transpose # noqa: F401 from .conv import conv3d # noqa: F401 @@ -209,5 +210,6 @@ 'layer_norm', 'instance_norm', 'class_center_sample', + 'fused_multi_head_attention', 'sparse_attention', ] diff --git a/python/paddle/nn/functional/fused_transformer.py b/python/paddle/nn/functional/fused_transformer.py new file mode 100644 index 0000000000000..565ef223a96cb --- /dev/null +++ b/python/paddle/nn/functional/fused_transformer.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from ...fluid.framework import in_dygraph_mode +from paddle import _C_ops + +__all__ = [] + + +def fused_multi_head_attention(x, + qkv_weight, + linear_weight, + pre_layer_norm=False, + pre_ln_scale=None, + pre_ln_bias=None, + ln_scale=None, + ln_bias=None, + pre_ln_epsilon=1e-05, + qkv_bias=None, + linear_bias=None, + attn_mask=None, + dropout_rate=0.5, + attn_dropout_rate=0.5, + ln_epsilon=1e-05, + name=None): + """ + Attention mapps queries and a set of key-value pairs to outputs, and + Multi-Head Attention performs multiple parallel attention to jointly attending + to information from different representation subspaces. This API only + support self_attention. The pseudo code is as follows: + if pre_layer_norm: + out = layer_norm(x); + out = linear(out) + qkv)bias + else: + out = linear(x) + bias; + out = transpose(out, perm=[2, 0, 3, 1, 4]); + # extract q, k and v from out. + q = out[0:1,::] + k = out[1:2,::] + v = out[2:3,::] + out = q * k^t; + out = attn_mask + out; + out = softmax(out); + out = dropout(out); + out = out * v; + out = transpose(out, perm=[0, 2, 1, 3]); + out = out_linear(out); + out = layer_norm(x + dropout(linear_bias + out)); + + Parameters: + x (Tensor): The input tensor of fused_multi_head_attention. The shape is + `[batch\_size, sequence\_len, embed\_dim]`. + qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`. + linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`. + pre_layer_norm (bool, optional): whether it is pre_layer_norm or post_layer_norm architecture. + Default False. + pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None. + pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None. + ln_scale (Tensor, optional): The weight tensor of layernorm. Default None. + ln_bias (Tensor, optional): The bias tensor of layernorm. Default None. + pre_ln_epsilon (float, optional): Small float value added to denominator of the pre layer_norm + to avoid dividing by zero. Default is 1e-5. + qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`. + Default None. + linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None. + attn_mask (Tensor, optional): + dropout_rate (float, optional): The dropout probability used on attention + weights to drop some attention targets for the dropout after attention. + 0 for no dropout. Default 0. + attn_dropout_rate (float, optional): The dropout probability used on attention + weights to drop some attention targets for the dropout in attention. + 0 for no dropout. Default 0. + ln_epsilon (float, optional): Small float value added to denominator of layer_norm + to avoid dividing by zero. Default is 1e-5. + + Examples: + + .. code-block:: python + + # required: gpu + import paddle + import paddle.nn.functional as F + + # input: [batch_size, seq_len, embed_dim] + x = paddle.rand(shape=(2, 4, 128), dtype="float32") + # qkv_weight: [3, num_head, dim_head, dim_embed] + qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") + # qkv_bias: [3, num_head, dim_head] + qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") + # linear_weight: [embed_dim, embed_dim] + linear_weight = paddle.rand(shape=(128, 128), dtype="float32") + # linear_bias: [embed_dim] + linear_bias = paddle.rand(shape=[128], dtype="float32") + # self attention mask: [batch_size, num_heads, seq_len, seq_len] + attn_mask = paddle.rand(shape=(2, 4, 4, 4), dtype="float32") + + # output: [batch_size, seq_len, embed_dim] + output = F.fused_multi_head_attention( + x, qkv_weight, linear_weight, False, + None, None, None, None, 1e-5, qkv_bias, + linear_bias, attn_mask) + # [2, 4, 128] + print(output.shape) + """ + if in_dygraph_mode(): + # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, + # qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out, + # linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out + _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention( + x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask, + linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm', + pre_layer_norm, 'epsilon', pre_ln_epsilon, 'dropout_rate', + dropout_rate, 'attn_dropout_rate', attn_dropout_rate, 'ln_epsilon', + ln_epsilon) + return final_out diff --git a/python/paddle/nn/layer/transformer.py b/python/paddle/nn/layer/transformer.py index eacf5aac9daa9..36bc83647965e 100644 --- a/python/paddle/nn/layer/transformer.py +++ b/python/paddle/nn/layer/transformer.py @@ -26,7 +26,7 @@ from ...fluid import layers from .. import Layer, LayerList from ...framework import ParamAttr -from ...fluid.data_feeder import convert_dtype +from paddle.fluid.data_feeder import convert_dtype __all__ = [] From 32fe5a49925fc13028c16bc305defd44b72c148e Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 26 Oct 2021 11:32:28 +0800 Subject: [PATCH 5/9] cherry pick CrossEntropy's bug fix (#36647) --- .../unittests/test_cross_entropy_loss.py | 50 +++++++++++++++++ python/paddle/nn/functional/loss.py | 55 ++++++++++++------- 2 files changed, 86 insertions(+), 19 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index d2eae1cce5bcb..d3ed76e34a614 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -1175,6 +1175,56 @@ def test_cross_entropy_loss_2d_with_weight_none(self): self.assertTrue(np.allclose(static_ret, expected)) self.assertTrue(np.allclose(dy_ret_value, expected)) + def test_cross_entropy_loss_2d_with_weight_axis_change_mean(self): + input_np = np.random.random(size=(2, 3, 2, 2)).astype(self.dtype) #NCHW + label_np = np.random.randint( + 0, 3, size=(2, 2, 2)).astype(np.int64) #NHW + weight_np = np.random.random(size=(3, )).astype(self.dtype) #C + + paddle.enable_static() + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[2, 3, 2, 2], dtype=self.dtype) + label = fluid.data(name='label', shape=[2, 2, 2], dtype='int64') + weight = fluid.data(name='weight', shape=[3], dtype=self.dtype) + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=weight, reduction='mean', axis=1) + # specify the class channels to axis 1 + ret = cross_entropy_loss(input, label) + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + "weight": weight_np + }, + fetch_list=[ret]) + + self.assertIsNotNone(static_ret) + with fluid.dygraph.guard(): + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=fluid.dygraph.to_variable(weight_np), + reduction='mean', + axis=1) + dy_ret = cross_entropy_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_2d( + np.transpose(input_np, [0, 2, 3, 1]), + label_np, + weight=weight_np, + reduction='mean')[0] + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + def test_cross_entropy_loss_2d_with_weight_mean_ignore_exceedlabel(self): N = 4 C = 3 diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index da2d010c323b5..b1db45ad50669 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1668,12 +1668,13 @@ def cross_entropy(input, format(invalid_label[0], 0)) # TODO: Temporarily use paddle.nonzero instead of paddle.max # to detect and find out possible illegal label values - if len(paddle.nonzero(valid_label >= input.shape[-1])) > 0: + if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0: invalid_label = paddle.gather_nd( - valid_label, paddle.nonzero(valid_label >= input.shape[-1])) + valid_label, + paddle.nonzero(valid_label >= input.shape[axis])) raise ValueError( "Target({}) is out of class_dimension's upper bound({})". - format(invalid_label[0], input.shape[-1] - 1)) + format(invalid_label[0], input.shape[axis] - 1)) _, out = _C_ops.softmax_with_cross_entropy( input, label, 'soft_label', soft_label, 'ignore_index', @@ -1700,19 +1701,28 @@ def cross_entropy(input, out = _C_ops.elementwise_mul(out, weight_gather_reshape) else: - if input.shape[-1] != weight.shape[-1]: + if input.shape[axis] != weight.shape[-1]: raise ValueError( - "input's class_dimension({}) must equal to \ - weight's class_dimension({}) \ - when weight is provided" - .format(input.shape[-1], weight.shape[-1])) + "input's class_dimension({}) must equal to " + "weight's class_dimension({}) " + "when weight is provided"\ + .format(input.shape[axis], weight.shape[-1])) ignore_weight_mask = paddle.cast((label != ignore_index), out.dtype) if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ - -1] == 1: - ignore_weight_mask.squeeze_(-1) - weight_gather = _C_ops.gather_nd(weight, valid_label) + axis] == 1: + # TODO: Temporarily use squeeze instead of squeeze_ + ignore_weight_mask = paddle.squeeze(ignore_weight_mask, + axis) + if axis != -1 and axis != valid_label.ndim - 1: + temp_perm = list(range(axis % valid_label.ndim)) \ + + list(range((axis % valid_label.ndim + 1) , valid_label.ndim)) \ + + [axis % valid_label.ndim] + weight_gather = _C_ops.gather_nd( + weight, valid_label.transpose(temp_perm)) + else: + weight_gather = _C_ops.gather_nd(weight, valid_label) weight_gather = _C_ops.elementwise_mul(weight_gather, ignore_weight_mask) input_shape = list(label.shape) @@ -1807,20 +1817,27 @@ def cross_entropy(input, weight_gather_reshape = reshape(weight_gather, shape=out_shape) out = paddle.cast(out, weight_gather_reshape.dtype) else: - if input.shape[-1] != weight.shape[-1]: - raise ValueError("input's class_dimension({}) must equal to "\ - "weight's class_dimension({}) "\ - "when weight is provided" - .format(input.shape[-1], weight.shape[-1])) + if input.shape[axis] != weight.shape[-1]: + raise ValueError("input's class_dimension({}) must equal to " + "weight's class_dimension({}) " + "when weight is provided"\ + .format(input.shape[axis], weight.shape[-1])) valid_label = paddle.where(label == ignore_index, paddle.zeros_like(label), label) ignore_weight_mask = paddle.cast((label != ignore_index), input.dtype) if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ - -1] == 1: - ignore_weight_mask = paddle.squeeze(ignore_weight_mask, -1) - weight_gather = paddle.gather_nd(weight, valid_label) + axis] == 1: + ignore_weight_mask = paddle.squeeze(ignore_weight_mask, axis) + if axis != -1 and axis != valid_label.ndim - 1: + temp_perm = list(range(axis % valid_label.ndim)) \ + + list(range((axis % valid_label.ndim + 1), valid_label.ndim)) \ + + [axis % valid_label.ndim] + weight_gather = paddle.gather_nd( + weight, paddle.transpose(valid_label, temp_perm)) + else: + weight_gather = paddle.gather_nd(weight, valid_label) weight_gather = paddle.multiply(weight_gather, ignore_weight_mask) input_shape = list(label.shape) From 53480c9c3f986265629d804a9dfaf5feca2abe1f Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Tue, 26 Oct 2021 13:20:17 +0800 Subject: [PATCH 6/9] add slot record support for GpuPS (#36723) * add slotrecord datafeed (#36099) * fix multi-node (#36329) --- paddle/fluid/framework/data_feed.cc | 642 ++++++++++++++++++ paddle/fluid/framework/data_feed.h | 38 +- paddle/fluid/framework/data_feed_factory.cc | 5 +- paddle/fluid/framework/data_set.cc | 30 +- .../fluid/framework/fleet/ps_gpu_wrapper.cc | 100 ++- paddle/fluid/framework/fleet/ps_gpu_wrapper.h | 10 +- paddle/fluid/platform/collective_helper.cc | 8 +- paddle/fluid/platform/flags.cc | 4 +- python/paddle/fluid/dataset.py | 2 + 9 files changed, 802 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 4463fd9fd5340..2d089b4721b82 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -28,6 +28,7 @@ limitations under the License. */ #include "paddle/fluid/platform/timer.h" USE_INT_STAT(STAT_total_feasign_num_in_mem); +DECLARE_bool(enable_ins_parser_file); namespace paddle { namespace framework { @@ -1929,5 +1930,646 @@ void PaddleBoxDataFeed::PutToFeedVec(const std::vector& ins_vec) { #endif } +template class InMemoryDataFeed; +void SlotRecordInMemoryDataFeed::Init(const DataFeedDesc& data_feed_desc) { + finish_init_ = false; + finish_set_filelist_ = false; + finish_start_ = false; + PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(), + platform::errors::PreconditionNotMet( + "Multi_slot_desc has not been set in data_feed_desc")); + paddle::framework::MultiSlotDesc multi_slot_desc = + data_feed_desc.multi_slot_desc(); + SetBatchSize(data_feed_desc.batch_size()); + size_t all_slot_num = multi_slot_desc.slots_size(); + + all_slots_.resize(all_slot_num); + all_slots_info_.resize(all_slot_num); + used_slots_info_.resize(all_slot_num); + use_slot_size_ = 0; + use_slots_.clear(); + + float_total_dims_size_ = 0; + float_total_dims_without_inductives_.clear(); + for (size_t i = 0; i < all_slot_num; ++i) { + const auto& slot = multi_slot_desc.slots(i); + all_slots_[i] = slot.name(); + + AllSlotInfo& all_slot = all_slots_info_[i]; + all_slot.slot = slot.name(); + all_slot.type = slot.type(); + all_slot.used_idx = slot.is_used() ? use_slot_size_ : -1; + all_slot.slot_value_idx = -1; + + if (slot.is_used()) { + UsedSlotInfo& info = used_slots_info_[use_slot_size_]; + info.idx = i; + info.slot = slot.name(); + info.type = slot.type(); + info.dense = slot.is_dense(); + info.total_dims_without_inductive = 1; + info.inductive_shape_index = -1; + + // record float value and uint64_t value pos + if (info.type[0] == 'u') { + info.slot_value_idx = uint64_use_slot_size_; + all_slot.slot_value_idx = uint64_use_slot_size_; + ++uint64_use_slot_size_; + } else if (info.type[0] == 'f') { + info.slot_value_idx = float_use_slot_size_; + all_slot.slot_value_idx = float_use_slot_size_; + ++float_use_slot_size_; + } + + use_slots_.push_back(slot.name()); + + if (slot.is_dense()) { + for (int j = 0; j < slot.shape_size(); ++j) { + if (slot.shape(j) > 0) { + info.total_dims_without_inductive *= slot.shape(j); + } + if (slot.shape(j) == -1) { + info.inductive_shape_index = j; + } + } + } + if (info.type[0] == 'f') { + float_total_dims_without_inductives_.push_back( + info.total_dims_without_inductive); + float_total_dims_size_ += info.total_dims_without_inductive; + } + info.local_shape.clear(); + for (int j = 0; j < slot.shape_size(); ++j) { + info.local_shape.push_back(slot.shape(j)); + } + ++use_slot_size_; + } + } + used_slots_info_.resize(use_slot_size_); + + feed_vec_.resize(used_slots_info_.size()); + const int kEstimatedFeasignNumPerSlot = 5; // Magic Number + for (size_t i = 0; i < all_slot_num; i++) { + batch_float_feasigns_.push_back(std::vector()); + batch_uint64_feasigns_.push_back(std::vector()); + batch_float_feasigns_[i].reserve(default_batch_size_ * + kEstimatedFeasignNumPerSlot); + batch_uint64_feasigns_[i].reserve(default_batch_size_ * + kEstimatedFeasignNumPerSlot); + offset_.push_back(std::vector()); + offset_[i].reserve(default_batch_size_ + + 1); // Each lod info will prepend a zero + } + visit_.resize(all_slot_num, false); + pipe_command_ = data_feed_desc.pipe_command(); + finish_init_ = true; + input_type_ = data_feed_desc.input_type(); + size_t pos = pipe_command_.find(".so"); + if (pos != std::string::npos) { + pos = pipe_command_.rfind('|'); + if (pos == std::string::npos) { + so_parser_name_ = pipe_command_; + pipe_command_.clear(); + } else { + so_parser_name_ = pipe_command_.substr(pos + 1); + pipe_command_ = pipe_command_.substr(0, pos); + } + so_parser_name_ = paddle::string::erase_spaces(so_parser_name_); + } else { + so_parser_name_.clear(); + } +} + +void SlotRecordInMemoryDataFeed::LoadIntoMemory() { + VLOG(3) << "SlotRecord LoadIntoMemory() begin, thread_id=" << thread_id_; + if (!so_parser_name_.empty()) { + LoadIntoMemoryByLib(); + } else { + LoadIntoMemoryByCommand(); + } +} +void SlotRecordInMemoryDataFeed::LoadIntoMemoryByLib(void) { + if (true) { + // user defined file format analysis + LoadIntoMemoryByFile(); + } else { + LoadIntoMemoryByLine(); + } +} + +void SlotRecordInMemoryDataFeed::LoadIntoMemoryByFile(void) { +#ifdef _LINUX + paddle::framework::CustomParser* parser = + global_dlmanager_pool().Load(so_parser_name_, all_slots_info_); + CHECK(parser != nullptr); + // get slotrecord object + auto pull_record_func = [this](std::vector& record_vec, + int max_fetch_num, int offset) { + if (offset > 0) { + input_channel_->WriteMove(offset, &record_vec[0]); + if (max_fetch_num > 0) { + SlotRecordPool().get(&record_vec[0], offset); + } else { // free all + max_fetch_num = static_cast(record_vec.size()); + if (max_fetch_num > offset) { + SlotRecordPool().put(&record_vec[offset], (max_fetch_num - offset)); + } + } + } else if (max_fetch_num > 0) { + SlotRecordPool().get(&record_vec, max_fetch_num); + } else { + SlotRecordPool().put(&record_vec); + } + }; + + std::string filename; + while (this->PickOneFile(&filename)) { + VLOG(3) << "PickOneFile, filename=" << filename + << ", thread_id=" << thread_id_; + platform::Timer timeline; + timeline.Start(); + + int lines = 0; + bool is_ok = true; + do { + int err_no = 0; + this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_); + + CHECK(this->fp_ != nullptr); + __fsetlocking(&*(this->fp_), FSETLOCKING_BYCALLER); + is_ok = parser->ParseFileInstance( + [this](char* buf, int len) { + return fread(buf, sizeof(char), len, this->fp_.get()); + }, + pull_record_func, lines); + + if (!is_ok) { + LOG(WARNING) << "parser error, filename=" << filename + << ", lines=" << lines; + } + } while (!is_ok); + timeline.Pause(); + VLOG(3) << "LoadIntoMemoryByLib() read all file, file=" << filename + << ", cost time=" << timeline.ElapsedSec() + << " seconds, thread_id=" << thread_id_ << ", lines=" << lines; + } +#endif +} + +void SlotRecordInMemoryDataFeed::LoadIntoMemoryByLine(void) { +#ifdef _LINUX + paddle::framework::CustomParser* parser = + global_dlmanager_pool().Load(so_parser_name_, all_slots_info_); + std::string filename; + BufferedLineFileReader line_reader; + line_reader.set_sample_rate(sample_rate_); + BufferedLineFileReader::LineFunc line_func = nullptr; + + while (this->PickOneFile(&filename)) { + VLOG(3) << "PickOneFile, filename=" << filename + << ", thread_id=" << thread_id_; + std::vector record_vec; + platform::Timer timeline; + timeline.Start(); + int offset = 0; + int old_offset = 0; + + SlotRecordPool().get(&record_vec, OBJPOOL_BLOCK_SIZE); + // get slotrecord object function + auto record_func = [this, &offset, &record_vec, &old_offset]( + std::vector& vec, int num) { + vec.resize(num); + if (offset + num > OBJPOOL_BLOCK_SIZE) { + input_channel_->WriteMove(offset, &record_vec[0]); + SlotRecordPool().get(&record_vec[0], offset); + record_vec.resize(OBJPOOL_BLOCK_SIZE); + offset = 0; + old_offset = 0; + } + for (int i = 0; i < num; ++i) { + auto& ins = record_vec[offset + i]; + ins->reset(); + vec[i] = ins; + } + offset = offset + num; + }; + + line_func = [this, &parser, &record_vec, &offset, &filename, &record_func, + &old_offset](const std::string& line) { + old_offset = offset; + if (!parser->ParseOneInstance(line, record_func)) { + offset = old_offset; + LOG(WARNING) << "read file:[" << filename << "] item error, line:[" + << line << "]"; + return false; + } + if (offset >= OBJPOOL_BLOCK_SIZE) { + input_channel_->Write(std::move(record_vec)); + record_vec.clear(); + SlotRecordPool().get(&record_vec, OBJPOOL_BLOCK_SIZE); + offset = 0; + } + return true; + }; + + int lines = 0; + + do { + int err_no = 0; + this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_); + CHECK(this->fp_ != nullptr); + __fsetlocking(&*(this->fp_), FSETLOCKING_BYCALLER); + lines = line_reader.read_file(this->fp_.get(), line_func, lines); + } while (line_reader.is_error()); + + if (offset > 0) { + input_channel_->WriteMove(offset, &record_vec[0]); + if (offset < OBJPOOL_BLOCK_SIZE) { + SlotRecordPool().put(&record_vec[offset], + (OBJPOOL_BLOCK_SIZE - offset)); + } + } else { + SlotRecordPool().put(&record_vec); + } + record_vec.clear(); + record_vec.shrink_to_fit(); + timeline.Pause(); + VLOG(3) << "LoadIntoMemoryByLib() read all lines, file=" << filename + << ", cost time=" << timeline.ElapsedSec() + << " seconds, thread_id=" << thread_id_ << ", lines=" << lines + << ", sample lines=" << line_reader.get_sample_line() + << ", filesize=" << line_reader.file_size() / 1024.0 / 1024.0 + << "MB"; + } + + VLOG(3) << "LoadIntoMemoryByLib() end, thread_id=" << thread_id_ + << ", total size: " << line_reader.file_size(); +#endif +} + +void SlotRecordInMemoryDataFeed::LoadIntoMemoryByCommand(void) { +#ifdef _LINUX + std::string filename; + BufferedLineFileReader line_reader; + line_reader.set_sample_rate(sample_rate_); + + while (this->PickOneFile(&filename)) { + VLOG(3) << "PickOneFile, filename=" << filename + << ", thread_id=" << thread_id_; + int lines = 0; + std::vector record_vec; + platform::Timer timeline; + timeline.Start(); + SlotRecordPool().get(&record_vec, OBJPOOL_BLOCK_SIZE); + int offset = 0; + + do { + int err_no = 0; + this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_); + CHECK(this->fp_ != nullptr); + __fsetlocking(&*(this->fp_), FSETLOCKING_BYCALLER); + + lines = line_reader.read_file( + this->fp_.get(), + [this, &record_vec, &offset, &filename](const std::string& line) { + if (ParseOneInstance(line, &record_vec[offset])) { + ++offset; + } else { + LOG(WARNING) << "read file:[" << filename + << "] item error, line:[" << line << "]"; + return false; + } + if (offset >= OBJPOOL_BLOCK_SIZE) { + input_channel_->Write(std::move(record_vec)); + record_vec.clear(); + SlotRecordPool().get(&record_vec, OBJPOOL_BLOCK_SIZE); + offset = 0; + } + return true; + }, + lines); + } while (line_reader.is_error()); + if (offset > 0) { + input_channel_->WriteMove(offset, &record_vec[0]); + if (offset < OBJPOOL_BLOCK_SIZE) { + SlotRecordPool().put(&record_vec[offset], + (OBJPOOL_BLOCK_SIZE - offset)); + } + } else { + SlotRecordPool().put(&record_vec); + } + record_vec.clear(); + record_vec.shrink_to_fit(); + timeline.Pause(); + VLOG(3) << "LoadIntoMemory() read all lines, file=" << filename + << ", lines=" << lines + << ", sample lines=" << line_reader.get_sample_line() + << ", cost time=" << timeline.ElapsedSec() + << " seconds, thread_id=" << thread_id_; + } + VLOG(3) << "LoadIntoMemory() end, thread_id=" << thread_id_ + << ", total size: " << line_reader.file_size(); +#endif +} + +static void parser_log_key(const std::string& log_key, uint64_t* search_id, + uint32_t* cmatch, uint32_t* rank) { + std::string searchid_str = log_key.substr(16, 16); + *search_id = static_cast(strtoull(searchid_str.c_str(), NULL, 16)); + std::string cmatch_str = log_key.substr(11, 3); + *cmatch = static_cast(strtoul(cmatch_str.c_str(), NULL, 16)); + std::string rank_str = log_key.substr(14, 2); + *rank = static_cast(strtoul(rank_str.c_str(), NULL, 16)); +} + +bool SlotRecordInMemoryDataFeed::ParseOneInstance(const std::string& line, + SlotRecord* ins) { + SlotRecord& rec = (*ins); + // parse line + const char* str = line.c_str(); + char* endptr = const_cast(str); + int pos = 0; + + thread_local std::vector> slot_float_feasigns; + thread_local std::vector> slot_uint64_feasigns; + slot_float_feasigns.resize(float_use_slot_size_); + slot_uint64_feasigns.resize(uint64_use_slot_size_); + + if (parse_ins_id_) { + int num = strtol(&str[pos], &endptr, 10); + CHECK(num == 1); // NOLINT + pos = endptr - str + 1; + size_t len = 0; + while (str[pos + len] != ' ') { + ++len; + } + rec->ins_id_ = std::string(str + pos, len); + pos += len + 1; + } + if (parse_logkey_) { + int num = strtol(&str[pos], &endptr, 10); + CHECK(num == 1); // NOLINT + pos = endptr - str + 1; + size_t len = 0; + while (str[pos + len] != ' ') { + ++len; + } + // parse_logkey + std::string log_key = std::string(str + pos, len); + uint64_t search_id; + uint32_t cmatch; + uint32_t rank; + parser_log_key(log_key, &search_id, &cmatch, &rank); + + rec->ins_id_ = log_key; + rec->search_id = search_id; + rec->cmatch = cmatch; + rec->rank = rank; + pos += len + 1; + } + + int float_total_slot_num = 0; + int uint64_total_slot_num = 0; + + for (size_t i = 0; i < all_slots_info_.size(); ++i) { + auto& info = all_slots_info_[i]; + int num = strtol(&str[pos], &endptr, 10); + PADDLE_ENFORCE(num, + "The number of ids can not be zero, you need padding " + "it in data generator; or if there is something wrong with " + "the data, please check if the data contains unresolvable " + "characters.\nplease check this error line: %s", + str); + if (info.used_idx != -1) { + if (info.type[0] == 'f') { // float + auto& slot_fea = slot_float_feasigns[info.slot_value_idx]; + slot_fea.clear(); + for (int j = 0; j < num; ++j) { + float feasign = strtof(endptr, &endptr); + if (fabs(feasign) < 1e-6 && !used_slots_info_[info.used_idx].dense) { + continue; + } + slot_fea.push_back(feasign); + ++float_total_slot_num; + } + } else if (info.type[0] == 'u') { // uint64 + auto& slot_fea = slot_uint64_feasigns[info.slot_value_idx]; + slot_fea.clear(); + for (int j = 0; j < num; ++j) { + uint64_t feasign = + static_cast(strtoull(endptr, &endptr, 10)); + if (feasign == 0 && !used_slots_info_[info.used_idx].dense) { + continue; + } + slot_fea.push_back(feasign); + ++uint64_total_slot_num; + } + } + pos = endptr - str; + } else { + for (int j = 0; j <= num; ++j) { + // pos = line.find_first_of(' ', pos + 1); + while (line[pos + 1] != ' ') { + pos++; + } + } + } + } + rec->slot_float_feasigns_.add_slot_feasigns(slot_float_feasigns, + float_total_slot_num); + rec->slot_uint64_feasigns_.add_slot_feasigns(slot_uint64_feasigns, + uint64_total_slot_num); + + return (uint64_total_slot_num > 0); +} + +void SlotRecordInMemoryDataFeed::PutToFeedVec(const SlotRecord* ins_vec, + int num) { + for (int j = 0; j < use_slot_size_; ++j) { + auto& feed = feed_vec_[j]; + if (feed == nullptr) { + continue; + } + + auto& slot_offset = offset_[j]; + slot_offset.clear(); + slot_offset.reserve(num + 1); + slot_offset.push_back(0); + + int total_instance = 0; + auto& info = used_slots_info_[j]; + // fill slot value with default value 0 + if (info.type[0] == 'f') { // float + auto& batch_fea = batch_float_feasigns_[j]; + batch_fea.clear(); + + for (int i = 0; i < num; ++i) { + auto r = ins_vec[i]; + size_t fea_num = 0; + float* slot_values = + r->slot_float_feasigns_.get_values(info.slot_value_idx, &fea_num); + batch_fea.resize(total_instance + fea_num); + memcpy(&batch_fea[total_instance], slot_values, + sizeof(float) * fea_num); + total_instance += fea_num; + slot_offset.push_back(total_instance); + } + + float* feasign = batch_fea.data(); + float* tensor_ptr = + feed->mutable_data({total_instance, 1}, this->place_); + CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(float)); + + } else if (info.type[0] == 'u') { // uint64 + auto& batch_fea = batch_uint64_feasigns_[j]; + batch_fea.clear(); + + for (int i = 0; i < num; ++i) { + auto r = ins_vec[i]; + size_t fea_num = 0; + uint64_t* slot_values = + r->slot_uint64_feasigns_.get_values(info.slot_value_idx, &fea_num); + if (fea_num > 0) { + batch_fea.resize(total_instance + fea_num); + memcpy(&batch_fea[total_instance], slot_values, + sizeof(uint64_t) * fea_num); + total_instance += fea_num; + } + if (fea_num == 0) { + batch_fea.resize(total_instance + fea_num); + batch_fea[total_instance] = 0; + total_instance += 1; + } + slot_offset.push_back(total_instance); + } + + // no uint64_t type in paddlepaddle + uint64_t* feasign = batch_fea.data(); + int64_t* tensor_ptr = + feed->mutable_data({total_instance, 1}, this->place_); + CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(int64_t)); + } + + if (info.dense) { + if (info.inductive_shape_index != -1) { + info.local_shape[info.inductive_shape_index] = + total_instance / info.total_dims_without_inductive; + } + feed->Resize(framework::make_ddim(info.local_shape)); + } else { + LoD data_lod{slot_offset}; + feed_vec_[j]->set_lod(data_lod); + } + } +} + +void SlotRecordInMemoryDataFeed::ExpandSlotRecord(SlotRecord* rec) { + SlotRecord& ins = (*rec); + if (ins->slot_float_feasigns_.slot_offsets.empty()) { + return; + } + size_t total_value_size = ins->slot_float_feasigns_.slot_values.size(); + if (float_total_dims_size_ == total_value_size) { + return; + } + int float_slot_num = + static_cast(float_total_dims_without_inductives_.size()); + CHECK(float_slot_num == float_use_slot_size_); + std::vector old_values; + std::vector old_offsets; + old_values.swap(ins->slot_float_feasigns_.slot_values); + old_offsets.swap(ins->slot_float_feasigns_.slot_offsets); + + ins->slot_float_feasigns_.slot_values.resize(float_total_dims_size_); + ins->slot_float_feasigns_.slot_offsets.assign(float_slot_num + 1, 0); + + auto& slot_offsets = ins->slot_float_feasigns_.slot_offsets; + auto& slot_values = ins->slot_float_feasigns_.slot_values; + + uint32_t offset = 0; + int num = 0; + uint32_t old_off = 0; + int dim = 0; + + for (int i = 0; i < float_slot_num; ++i) { + dim = float_total_dims_without_inductives_[i]; + old_off = old_offsets[i]; + num = static_cast(old_offsets[i + 1] - old_off); + if (num == 0) { + // fill slot value with default value 0 + for (int k = 0; k < dim; ++k) { + slot_values[k + offset] = 0.0; + } + } else { + if (num == dim) { + memcpy(&slot_values[offset], &old_values[old_off], dim * sizeof(float)); + } else { + // position fea + // record position index need fix values + int pos_idx = static_cast(old_values[old_off]); + for (int k = 0; k < dim; ++k) { + if (k == pos_idx) { + slot_values[k + offset] = 1.0; + } else { + slot_values[k + offset] = 0.0; + } + } + } + } + slot_offsets[i] = offset; + offset += dim; + } + slot_offsets[float_slot_num] = offset; + CHECK(float_total_dims_size_ == static_cast(offset)); +} + +bool SlotRecordInMemoryDataFeed::Start() { +#ifdef _LINUX + this->CheckSetFileList(); + if (input_channel_->Size() != 0) { + std::vector data; + input_channel_->Read(data); + } +#endif + if (batch_offsets_.size() > 0) { + VLOG(3) << "batch_size offsets: " << batch_offsets_.size(); + enable_heterps_ = true; + this->offset_index_ = 0; + } + this->finish_start_ = true; + return true; +} + +int SlotRecordInMemoryDataFeed::Next() { +#ifdef _LINUX + this->CheckStart(); + + VLOG(3) << "enable heter next: " << offset_index_ + << " batch_offsets: " << batch_offsets_.size(); + if (offset_index_ >= batch_offsets_.size()) { + VLOG(3) << "offset_index: " << offset_index_ + << " batch_offsets: " << batch_offsets_.size(); + return 0; + } + auto& batch = batch_offsets_[offset_index_++]; + this->batch_size_ = batch.second; + VLOG(3) << "batch_size_=" << this->batch_size_ + << ", thread_id=" << thread_id_; + if (this->batch_size_ != 0) { + PutToFeedVec(&records_[batch.first], this->batch_size_); + } else { + VLOG(3) << "finish reading for heterps, batch size zero, thread_id=" + << thread_id_; + } + VLOG(3) << "enable heter next: " << offset_index_ + << " batch_offsets: " << batch_offsets_.size() + << " baych_size: " << this->batch_size_; + + return this->batch_size_; +#else + return 0; +#endif +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 5527eaf1f6fa4..a4100e66e7285 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -384,7 +384,7 @@ class CustomParser { CustomParser() {} virtual ~CustomParser() {} virtual void Init(const std::vector& slots) = 0; - virtual bool Init(const std::vector& slots) = 0; + virtual bool Init(const std::vector& slots); virtual void ParseOneInstance(const char* str, Record* instance) = 0; virtual bool ParseOneInstance( const std::string& line, @@ -1103,6 +1103,42 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed { virtual void PutToFeedVec(const Record* ins_vec, int num); }; +class SlotRecordInMemoryDataFeed : public InMemoryDataFeed { + public: + SlotRecordInMemoryDataFeed() {} + virtual ~SlotRecordInMemoryDataFeed() {} + virtual void Init(const DataFeedDesc& data_feed_desc); + virtual void LoadIntoMemory(); + void ExpandSlotRecord(SlotRecord* ins); + + protected: + virtual bool Start(); + virtual int Next(); + virtual bool ParseOneInstance(SlotRecord* instance) { return false; } + virtual bool ParseOneInstanceFromPipe(SlotRecord* instance) { return false; } + // virtual void ParseOneInstanceFromSo(const char* str, T* instance, + // CustomParser* parser) {} + virtual void PutToFeedVec(const std::vector& ins_vec) {} + + virtual void LoadIntoMemoryByCommand(void); + virtual void LoadIntoMemoryByLib(void); + virtual void LoadIntoMemoryByLine(void); + virtual void LoadIntoMemoryByFile(void); + virtual void SetInputChannel(void* channel) { + input_channel_ = static_cast*>(channel); + } + bool ParseOneInstance(const std::string& line, SlotRecord* rec); + virtual void PutToFeedVec(const SlotRecord* ins_vec, int num); + float sample_rate_ = 1.0f; + int use_slot_size_ = 0; + int float_use_slot_size_ = 0; + int uint64_use_slot_size_ = 0; + std::vector all_slots_info_; + std::vector used_slots_info_; + size_t float_total_dims_size_ = 0; + std::vector float_total_dims_without_inductives_; +}; + class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed { public: PaddleBoxDataFeed() {} diff --git a/paddle/fluid/framework/data_feed_factory.cc b/paddle/fluid/framework/data_feed_factory.cc index ec1b8ec773fa6..e46e4aeb0124c 100644 --- a/paddle/fluid/framework/data_feed_factory.cc +++ b/paddle/fluid/framework/data_feed_factory.cc @@ -58,8 +58,8 @@ std::shared_ptr DataFeedFactory::CreateDataFeed( std::string data_feed_class) { if (g_data_feed_map.count(data_feed_class) < 1) { LOG(WARNING) << "Your DataFeed " << data_feed_class - << "is not supported currently"; - LOG(WARNING) << "Supported DataFeed: " << DataFeedTypeList(); + << " is not supported currently"; + LOG(WARNING) << " Supported DataFeed: " << DataFeedTypeList(); exit(-1); } return g_data_feed_map[data_feed_class](); @@ -68,6 +68,7 @@ std::shared_ptr DataFeedFactory::CreateDataFeed( REGISTER_DATAFEED_CLASS(MultiSlotDataFeed); REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed); REGISTER_DATAFEED_CLASS(PaddleBoxDataFeed); +REGISTER_DATAFEED_CLASS(SlotRecordInMemoryDataFeed); #if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32) REGISTER_DATAFEED_CLASS(MultiSlotFileInstantDataFeed); #endif diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 82a39b206e6bd..2a071665b263c 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -1609,7 +1609,35 @@ void SlotRecordDataset::DynamicAdjustChannelNum(int channel_num, void SlotRecordDataset::PrepareTrain() { #ifdef PADDLE_WITH_GLOO - return; + if (enable_heterps_) { + if (input_records_.size() == 0 && input_channel_ != nullptr && + input_channel_->Size() != 0) { + input_channel_->ReadAll(input_records_); + VLOG(3) << "read from channel to records with records size: " + << input_records_.size(); + } + VLOG(3) << "input records size: " << input_records_.size(); + int64_t total_ins_num = input_records_.size(); + std::vector> offset; + int default_batch_size = + reinterpret_cast(readers_[0].get()) + ->GetDefaultBatchSize(); + VLOG(3) << "thread_num: " << thread_num_ + << " memory size: " << total_ins_num + << " default batch_size: " << default_batch_size; + compute_thread_batch_nccl(thread_num_, total_ins_num, default_batch_size, + &offset); + VLOG(3) << "offset size: " << offset.size(); + for (int i = 0; i < thread_num_; i++) { + reinterpret_cast(readers_[i].get()) + ->SetRecord(&input_records_[0]); + } + for (size_t i = 0; i < offset.size(); i++) { + reinterpret_cast( + readers_[i % thread_num_].get()) + ->AddBatchOffset(offset[i]); + } + } #else PADDLE_THROW(platform::errors::Unavailable( "dataset set heterps need compile with GLOO")); diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc index 784cbc3d90b86..d1e98a711dc9d 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.cc @@ -45,9 +45,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr gpu_task) { platform::Timer timeline; timeline.Start(); int device_num = heter_devices_.size(); - MultiSlotDataset* dataset = dynamic_cast(dataset_); gpu_task->init(thread_keys_shard_num_, device_num); - auto input_channel = dataset->GetInputChannel(); auto& local_keys = gpu_task->feature_keys_; auto& local_ptr = gpu_task->value_ptr_; @@ -68,35 +66,83 @@ void PSGPUWrapper::BuildTask(std::shared_ptr gpu_task) { for (int i = 0; i < thread_keys_thread_num_; i++) { thread_keys_[i].resize(thread_keys_shard_num_); } - const std::deque& vec_data = input_channel->GetData(); - size_t total_len = vec_data.size(); - size_t len_per_thread = total_len / thread_keys_thread_num_; - int remain = total_len % thread_keys_thread_num_; + + size_t total_len = 0; + size_t len_per_thread = 0; + int remain = 0; size_t begin = 0; - auto gen_func = [this](const std::deque& total_data, int begin_index, - int end_index, int i) { - for (auto iter = total_data.begin() + begin_index; - iter != total_data.begin() + end_index; iter++) { - const auto& ins = *iter; - const auto& feasign_v = ins.uint64_feasigns_; - for (const auto feasign : feasign_v) { - uint64_t cur_key = feasign.sign().uint64_feasign_; - int shard_id = cur_key % thread_keys_shard_num_; - this->thread_keys_[i][shard_id].insert(cur_key); + + std::string data_set_name = std::string(typeid(*dataset_).name()); + + if (data_set_name.find("SlotRecordDataset") != std::string::npos) { + VLOG(0) << "ps_gpu_wrapper use SlotRecordDataset"; + SlotRecordDataset* dataset = dynamic_cast(dataset_); + auto input_channel = dataset->GetInputChannel(); + VLOG(0) << "yxf::buildtask::inputslotchannle size: " + << input_channel->Size(); + const std::deque& vec_data = input_channel->GetData(); + total_len = vec_data.size(); + len_per_thread = total_len / thread_keys_thread_num_; + remain = total_len % thread_keys_thread_num_; + VLOG(0) << "total len: " << total_len; + auto gen_func = [this](const std::deque& total_data, + int begin_index, int end_index, int i) { + for (auto iter = total_data.begin() + begin_index; + iter != total_data.begin() + end_index; iter++) { + const auto& ins = *iter; + const auto& feasign_v = ins->slot_uint64_feasigns_.slot_values; + for (const auto feasign : feasign_v) { + int shard_id = feasign % thread_keys_shard_num_; + this->thread_keys_[i][shard_id].insert(feasign); + } } + }; + for (int i = 0; i < thread_keys_thread_num_; i++) { + threads.push_back( + std::thread(gen_func, std::ref(vec_data), begin, + begin + len_per_thread + (i < remain ? 1 : 0), i)); + begin += len_per_thread + (i < remain ? 1 : 0); } - }; - for (int i = 0; i < thread_keys_thread_num_; i++) { - threads.push_back(std::thread(gen_func, std::ref(vec_data), begin, - begin + len_per_thread + (i < remain ? 1 : 0), - i)); - begin += len_per_thread + (i < remain ? 1 : 0); - } - for (std::thread& t : threads) { - t.join(); + for (std::thread& t : threads) { + t.join(); + } + timeline.Pause(); + VLOG(1) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds."; + } else { + CHECK(data_set_name.find("MultiSlotDataset") != std::string::npos); + VLOG(0) << "ps_gpu_wrapper use MultiSlotDataset"; + MultiSlotDataset* dataset = dynamic_cast(dataset_); + auto input_channel = dataset->GetInputChannel(); + + const std::deque& vec_data = input_channel->GetData(); + total_len = vec_data.size(); + len_per_thread = total_len / thread_keys_thread_num_; + remain = total_len % thread_keys_thread_num_; + auto gen_func = [this](const std::deque& total_data, + int begin_index, int end_index, int i) { + for (auto iter = total_data.begin() + begin_index; + iter != total_data.begin() + end_index; iter++) { + const auto& ins = *iter; + const auto& feasign_v = ins.uint64_feasigns_; + for (const auto feasign : feasign_v) { + uint64_t cur_key = feasign.sign().uint64_feasign_; + int shard_id = cur_key % thread_keys_shard_num_; + this->thread_keys_[i][shard_id].insert(cur_key); + } + } + }; + for (int i = 0; i < thread_keys_thread_num_; i++) { + threads.push_back( + std::thread(gen_func, std::ref(vec_data), begin, + begin + len_per_thread + (i < remain ? 1 : 0), i)); + begin += len_per_thread + (i < remain ? 1 : 0); + } + for (std::thread& t : threads) { + t.join(); + } + timeline.Pause(); + VLOG(1) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds."; } - timeline.Pause(); - VLOG(1) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds."; timeline.Start(); diff --git a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h index b7e8bbb369492..fa2ff6cbdb8c7 100644 --- a/paddle/fluid/framework/fleet/ps_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/ps_gpu_wrapper.h @@ -117,6 +117,15 @@ class PSGPUWrapper { resource_ = std::make_shared(dev_ids); resource_->enable_p2p(); keys_tensor.resize(resource_->total_gpu()); +#ifdef PADDLE_WITH_GLOO + auto gloo = paddle::framework::GlooWrapper::GetInstance(); + if (gloo->Size() > 1) { + multi_node_ = 1; + } +#else + PADDLE_THROW( + platform::errors::Unavailable("heter ps need compile with GLOO")); +#endif if (multi_node_) { int dev_size = dev_ids.size(); // init inner comm @@ -127,7 +136,6 @@ class PSGPUWrapper { // init inter comm #ifdef PADDLE_WITH_GLOO inter_comms_.resize(dev_size); - auto gloo = paddle::framework::GlooWrapper::GetInstance(); if (gloo->Rank() == 0) { for (int i = 0; i < dev_size; ++i) { platform::dynload::ncclGetUniqueId(&inter_ncclids_[i]); diff --git a/paddle/fluid/platform/collective_helper.cc b/paddle/fluid/platform/collective_helper.cc index a765f344daf8a..03359d932b5ab 100644 --- a/paddle/fluid/platform/collective_helper.cc +++ b/paddle/fluid/platform/collective_helper.cc @@ -148,7 +148,7 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer( paddle::platform::errors::InvalidArgument( "dev ids = [%d], it should greater than 0.", dev_ids.size())); const int kDevices = dev_ids.size(); - VLOG(3) << "Begin CreateNCCLCommMultiTrainer. device number: " << kDevices + VLOG(1) << "Begin CreateNCCLCommMultiTrainer. device number: " << kDevices << ", ntrainers: " << ntrainers << ", train_id: " << train_id << ", rind_id: " << ring_id; ncclComm_t comms[kDevices]; @@ -162,10 +162,10 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer( #endif platform::dynload::ncclCommInitRank(comms + i, kDevices * ntrainers, *nccl_id, train_id * kDevices + i); - VLOG(3) << "ncclCommInitRank: " << i; + VLOG(1) << "ncclCommInitRank: " << i; } PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclGroupEnd()); - VLOG(3) << "nccl group end seccessss"; + VLOG(1) << "nccl group end seccessss"; } PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0, platform::errors::InvalidArgument( @@ -174,7 +174,7 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer( for (int i = 0; i < kDevices; ++i) { AssignNCCLComm(comms[i], kDevices * ntrainers, train_id * kDevices + i, dev_ids[i], ring_id); - VLOG(3) << "nccl communicator of train_id " << train_id * kDevices + i + VLOG(1) << "nccl communicator of train_id " << train_id * kDevices + i << " in ring " << ring_id << " has been created on device " << dev_ids[i]; } diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 1a19bb3aa97f0..e33baa521630a 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -680,4 +680,6 @@ DEFINE_int32(slotpool_thread_num, 1, "SlotRecordDataset slot pool thread num"); DEFINE_bool(enable_slotpool_wait_release, false, "enable slotrecord obejct wait release, default false"); DEFINE_bool(enable_slotrecord_reset_shrink, false, - "enable slotrecord obejct reset shrink memory, default false"); \ No newline at end of file + "enable slotrecord obejct reset shrink memory, default false"); +DEFINE_bool(enable_ins_parser_file, false, + "enable parser ins file , default false"); diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 438831208b66a..d683e36fbe5ab 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -396,6 +396,8 @@ def set_feed_type(self, data_feed_type): Set data_feed_desc """ self.proto_desc.name = data_feed_type + if (self.proto_desc.name == "SlotRecordInMemoryDataFeed"): + self.dataset = core.Dataset("SlotRecordDataset") @deprecated( since="2.0.0", From 1ee4fc320601c07257bfae758da7525c13f456ed Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Tue, 26 Oct 2021 14:03:06 +0800 Subject: [PATCH 7/9] [Amp] refine code of amp level (#36362) (#36726) * refine amp level * fix typo * update tracer._amp_level --- paddle/fluid/imperative/amp_auto_cast.cc | 13 +++++++++- paddle/fluid/imperative/amp_auto_cast.h | 24 +++++++++---------- paddle/fluid/imperative/tracer.cc | 4 ++-- paddle/fluid/imperative/tracer.h | 9 ++++--- paddle/fluid/pybind/imperative.cc | 11 +++++++-- .../fleet/meta_parallel/pp_utils/utils.py | 2 +- .../distributed/fleet/utils/recompute.py | 2 +- python/paddle/fluid/dygraph/amp/auto_cast.py | 10 ++++---- 8 files changed, 49 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index 48e5e430b136a..b0d86f6db9f96 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -24,6 +24,17 @@ namespace imperative { class VarBase; +AutoCastGuard::AutoCastGuard(std::shared_ptr tracer, AmpLevel level) + : tracer_(tracer) { + pre_amp_level_ = tracer_->GetAmpLevel(); + + if (pre_amp_level_ != level) { + tracer_->SetAmpLevel(level); + } +} + +AutoCastGuard::~AutoCastGuard() { tracer_->SetAmpLevel(pre_amp_level_); } + AmpOperators::AmpOperators() : allow_ops_(new std::unordered_set()), block_ops_(new std::unordered_set()), @@ -117,7 +128,7 @@ static inline std::shared_ptr CastToType( imperative::NameVarBaseMap outs = {{"Out", {out}}}; { - AutoCastGuard guard(tracer, 0); + AutoCastGuard guard(tracer, AmpLevel::O0); tracer->TraceOp("cast", ins, outs, std::move(attrs)); } diff --git a/paddle/fluid/imperative/amp_auto_cast.h b/paddle/fluid/imperative/amp_auto_cast.h index 79bc83a777aa9..903e2652888d8 100644 --- a/paddle/fluid/imperative/amp_auto_cast.h +++ b/paddle/fluid/imperative/amp_auto_cast.h @@ -19,15 +19,22 @@ #include #include -#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/type_defs.h" namespace paddle { namespace imperative { -// Singleton implementation with C++ 11 +// NOTE(zhiqiu): only O1 and O2 are valid now +enum class AmpLevel { + O0 = 0, // fp32 + O1, // amp, mixed fp32-fp16 + O2, // almost fp16 + O3, // fp16 +}; + class Tracer; +// Singleton implementation with C++ 11 class AmpOperators { public: ~AmpOperators(); @@ -63,16 +70,9 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops); // NOTE(zhiqiu): AutoCastGuard is used for RAII. class AutoCastGuard { public: - AutoCastGuard(std::shared_ptr tracer, int guard_level) - : tracer_(tracer) { - pre_amp_level_ = tracer_->AMPLevel(); - - if (pre_amp_level_ != guard_level) { - tracer_->SetAMPLevel(guard_level); - } - } + AutoCastGuard(std::shared_ptr tracer, AmpLevel guard_level); - ~AutoCastGuard() { tracer_->SetAMPLevel(pre_amp_level_); } + ~AutoCastGuard(); // forbid copy and operator= AutoCastGuard(const AutoCastGuard& guard) = delete; @@ -80,7 +80,7 @@ class AutoCastGuard { private: std::shared_ptr tracer_; - int pre_amp_level_; + AmpLevel pre_amp_level_; }; NameVarBaseMap AutoCastInputs(const std::string& op_type, diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 49e079c58caf3..0f363d0ea1bff 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -176,10 +176,10 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, : attr_checker->GetDefaultAttrMap(); NameVarBaseMap new_ins = ins; - if (amp_level_ == 1) { + if (amp_level_ == AmpLevel::O1) { VLOG(5) << "Auto mixed precision run operator: " << type; new_ins = AutoCastInputs(type, ins); - } else if (amp_level_ == 2) { + } else if (amp_level_ == AmpLevel::O2) { VLOG(5) << "Pure fp16 run operator: " << type; new_ins = CastPureFp16Inputs(type, ins); } diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index e77623d7a4609..418b2069b5bb6 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -23,6 +23,7 @@ #include #include "ThreadPool.h" #include "paddle/fluid/framework/garbage_collector.h" +#include "paddle/fluid/imperative/amp_auto_cast.h" #include "paddle/fluid/imperative/basic_engine.h" #include "paddle/fluid/imperative/jit/program_desc_tracer.h" #include "paddle/fluid/imperative/layer.h" @@ -31,6 +32,8 @@ namespace paddle { namespace imperative { +enum class AmpLevel; + using GarbageCollectorMap = std::map>; @@ -105,9 +108,9 @@ class Tracer { void SetHasGrad(bool has_grad) { has_grad_ = has_grad; } - void SetAMPLevel(int level) { amp_level_ = level; } + void SetAmpLevel(AmpLevel level) { amp_level_ = level; } - int AMPLevel() const { return amp_level_; } + AmpLevel GetAmpLevel() const { return amp_level_; } paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists( const platform::Place& place); @@ -120,7 +123,7 @@ class Tracer { platform::Place expected_place_; GarbageCollectorMap gcs_; static thread_local bool has_grad_; - int amp_level_{0}; + AmpLevel amp_level_{AmpLevel::O0}; }; // To access static variable current_tracer diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index db84c1142ae23..f94afaa56b8df 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -1940,6 +1940,13 @@ void BindImperative(py::module *m_ptr) { &imperative::jit::ProgramDescTracer::CreateProgramDesc) .def("reset", &imperative::jit::ProgramDescTracer::Reset); + py::enum_(m, "AmpLevel", py::arithmetic()) + .value("O0", paddle::imperative::AmpLevel::O0) + .value("O1", paddle::imperative::AmpLevel::O1) + .value("O2", paddle::imperative::AmpLevel::O2) + .value("O3", paddle::imperative::AmpLevel::O3) + .export_values(); + py::class_>( m, "Tracer", R"DOC()DOC") .def("__init__", @@ -1947,8 +1954,8 @@ void BindImperative(py::module *m_ptr) { .def_property("_enable_program_desc_tracing", &imperative::Tracer::IsProgramDescTracingEnabled, &imperative::Tracer::SetEnableProgramDescTracing) - .def_property("_amp_level", &imperative::Tracer::AMPLevel, - &imperative::Tracer::SetAMPLevel) + .def_property("_amp_level", &imperative::Tracer::GetAmpLevel, + &imperative::Tracer::SetAmpLevel) .def_property("_has_grad", &imperative::Tracer::HasGrad, &imperative::Tracer::SetHasGrad) .def_property( diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index b29b0b3e27557..08266096548c4 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -198,7 +198,7 @@ def forward(ctx, run_function, all_outputs, *args): # TODO support AMP tracer = framework._dygraph_tracer() - if tracer._amp_level == 0: + if tracer._amp_level == core.AmpLevel.O0: ctx.is_fw_autocast = False else: ctx.is_fw_autocast = True diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index 302877e51fe01..56a64049b16e1 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -98,7 +98,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): # TODO support AMP tracer = framework._dygraph_tracer() - if tracer._amp_level == 0: + if tracer._amp_level == core.AmpLevel.O0: ctx.is_fw_autocast = False else: ctx.is_fw_autocast = True diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 0d02a383c1bb8..d218e6b7490d9 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -24,6 +24,8 @@ import operator import types +AMP_LEVEL = core.AmpLevel + __all__ = ['amp_guard', 'amp_decorate'] # The set of ops that support fp16 calculation and are considered numerically- @@ -108,7 +110,7 @@ def _in_amp_guard(): """ tracer = _dygraph_tracer() if tracer: - if tracer._amp_level == 1: + if tracer._amp_level == core.AmpLevel.O1: return True else: return False @@ -251,11 +253,11 @@ def amp_guard(enable=True, enable = False if level == 'O1': - amp_level = 1 + amp_level = AMP_LEVEL.O1 _white_list = WHITE_LIST _black_list = BLACK_LIST else: - amp_level = 2 + amp_level = AMP_LEVEL.O2 _white_list = PURE_FP16_WHITE_LIST _black_list = PURE_FP16_BLACK_LIST @@ -264,7 +266,7 @@ def amp_guard(enable=True, custom_black_list, level) if not enable: - amp_level = 0 + amp_level = AMP_LEVEL.O0 if tracer: # enable auto_cast From fced11bd89becf3344705391ad221714f250e0c3 Mon Sep 17 00:00:00 2001 From: xiongkun <807377414@qq.com> Date: Tue, 26 Oct 2021 14:07:02 +0800 Subject: [PATCH 8/9] Support various length support for SelectedRows in GLOO::AllGather (#36637) (#36722) Support various length support for SelectedRows in GLOO::AllGather (#36637) In cpu parallel using gloo, add various length support for SelectedRows --- paddle/fluid/framework/fleet/gloo_wrapper.h | 22 +++++- paddle/fluid/imperative/gloo_context.cc | 73 +++++++------------ .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_dist_base.py | 30 +++++++- ...graph_sparse_embedding_diff_length_gloo.py | 46 ++++++++++++ 5 files changed, 119 insertions(+), 53 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_diff_length_gloo.py diff --git a/paddle/fluid/framework/fleet/gloo_wrapper.h b/paddle/fluid/framework/fleet/gloo_wrapper.h index 3686507043191..fd11710538d58 100644 --- a/paddle/fluid/framework/fleet/gloo_wrapper.h +++ b/paddle/fluid/framework/fleet/gloo_wrapper.h @@ -27,6 +27,7 @@ limitations under the License. */ #include #ifdef PADDLE_WITH_GLOO #include +#include #include #include #include @@ -218,10 +219,25 @@ class GlooWrapper { return std::move(ret); } - // TODO(xiongkun03): support all gather array of + // NOTE(@xiongkun03): support all gather array of // numbers with different length - // can use AllgathervOptions, may be work in different - // occasion. Need some survey. + // if the third argument is int, use allgather, + // if it is vector, use AllgathervOptions, + // which works in different length occasion. + template + void AllGatherVector(T* input_ptr, T* output_ptr, + std::vector& element_nums) { // NOLINT + CHECK_EQ(is_initialized_, true); +#ifdef PADDLE_WITH_GLOO + gloo::AllgathervOptions opts(context_); + opts.setInput(input_ptr, element_nums[rank_]); + opts.setOutput(output_ptr, element_nums); + gloo::allgatherv(opts); +#else + LOG(WARNING) << "AllGather does nothing when WITH_GLOO=OFF"; +#endif + } + template void AllGatherVector(T* input_ptr, T* output_ptr, size_t element_num) { // NOLINT diff --git a/paddle/fluid/imperative/gloo_context.cc b/paddle/fluid/imperative/gloo_context.cc index 0d93cdf57932f..ef1bf0d158787 100644 --- a/paddle/fluid/imperative/gloo_context.cc +++ b/paddle/fluid/imperative/gloo_context.cc @@ -53,15 +53,13 @@ void GLOOParallelContext::InitWithRingID(int ring_id) { platform::errors::OutOfRange("Still not implement InitWithRingID")); } -#define GLOO_CASE(type, T, gw) \ - case type: { \ - VLOG(4) << "Use the gloo all reduce to sync. SRC:" << src_tensor; \ - std::vector send_vector##T; \ - framework::TensorToVector(src_tensor, &send_vector##T); \ - auto recv_vector##T = gw->AllReduce(send_vector##T); \ - framework::TensorFromVector(recv_vector##T, dst_tensor); \ - VLOG(4) << "DST:" << *dst_tensor; \ - break; \ +#define GLOO_CASE(type, T, gw) \ + case type: { \ + std::vector send_vector##T; \ + framework::TensorToVector(src_tensor, &send_vector##T); \ + auto recv_vector##T = gw->AllReduce(send_vector##T); \ + framework::TensorFromVector(recv_vector##T, dst_tensor); \ + break; \ } void GLOOParallelContext::AllReduceByStream(const framework::Variable &src, @@ -118,7 +116,7 @@ void GLOOParallelContext::AllReduce(const framework::Tensor &src_tensor, const auto *src_tensor_ptr = src_tensor.data(); \ gw->AllGatherVector(const_cast(src_tensor_ptr), \ reinterpret_cast(dst_tensor_ptr), \ - value_sendcount); \ + element_nums); \ break; \ } @@ -150,48 +148,31 @@ void GLOOParallelContext::AllReduce(const framework::SelectedRows &src, auto *dst_rows_ptr = dst_rows->MutableData(place); const int64_t *src_rows_ptr = src_rows.Data(place); - // VLOG(3) << "Selected Rows of src:" << string::join_strings(dst_rows, ',') - auto *dst_tensor = dst->mutable_value(); auto dims = src_tensor.dims(); dims[0] = rows_num; auto feature_size = framework::product(dims) / dims[0]; dst_tensor->Resize(dims); - if (std::all_of(cpu_rows_num_ptr, cpu_rows_num_ptr + nranks, - [&](size_t row) { return row == cpu_rows_num_ptr[0]; })) { - // During sparse communication, the number of each card is same. - // Because gloo wrapper utility class currently don't support - // broadcast, so we only deal the-same case. - VLOG(3) << "Use the gloo all reduce to sync. SRC:" << src_tensor; - // framework::SerializeToStream(VLOG(4), src); - VLOG(3) << "allgather replaces broadcast to speed up in sparse allreduce"; - auto value_sendcount = cpu_rows_num_ptr[0] * feature_size; - auto *dst_tensor_ptr = dst_tensor->mutable_data(place, dtype); - - gloo_wrapper->AllGatherVector(const_cast(src_rows_ptr), - static_cast(dst_rows_ptr), - rows_num_vector[0]); - - switch (dtype) { - GLOO_ALL_GATHER_CASE(framework::proto::VarType::FP32, float, - gloo_wrapper); - GLOO_ALL_GATHER_CASE(framework::proto::VarType::FP64, double, - gloo_wrapper); - GLOO_ALL_GATHER_CASE(framework::proto::VarType::INT32, int, gloo_wrapper); - GLOO_ALL_GATHER_CASE(framework::proto::VarType::INT64, int64_t, - gloo_wrapper); - default: { - PADDLE_THROW(platform::errors::InvalidArgument( - "Invalid datatype for allreduce")); - } + + std::vector element_nums = rows_num_vector; + std::for_each(element_nums.begin(), element_nums.end(), + [feature_size](size_t &x) { x = x * feature_size; }); + + auto *dst_tensor_ptr = dst_tensor->mutable_data(place, dtype); + gloo_wrapper->AllGatherVector(const_cast(src_rows_ptr), + static_cast(dst_rows_ptr), + rows_num_vector); + + switch (dtype) { + GLOO_ALL_GATHER_CASE(framework::proto::VarType::FP32, float, gloo_wrapper); + GLOO_ALL_GATHER_CASE(framework::proto::VarType::FP64, double, gloo_wrapper); + GLOO_ALL_GATHER_CASE(framework::proto::VarType::INT32, int, gloo_wrapper); + GLOO_ALL_GATHER_CASE(framework::proto::VarType::INT64, int64_t, + gloo_wrapper); + default: { + PADDLE_THROW( + platform::errors::InvalidArgument("Invalid datatype for allreduce")); } - VLOG(3) << "Selected Row DST:" << *dst_tensor; - VLOG(3) << "Selected Rows of DST:" - << string::join_strings(std::vector(*dst_rows), ','); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "The number of each card is not the same, gloo only support the-same" - "batch division")); } } diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 7dca567b64886..9020a77e49dbf 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -205,6 +205,7 @@ if (NOT WITH_GLOO) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_unused_variables_gloo) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding_over_height_gloo) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding_gloo) + LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding_diff_length_gloo) endif() if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 63985415c51f6..0b8a80f0c837a 100755 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -515,10 +515,28 @@ def _get_data(self, batch, args): return batch elif args.update_method != "local": new_batch = [] - for offset, item in enumerate(batch): - if offset % 2 == args.trainer_id: - new_batch.append(item) - return new_batch + + # NOTE(@xiongkun03) args.diff_batch means batch length is different: + # such as : batch = [2,3,4,5], then the first rank will get [2] and + # the second rank will get [3,4,5]. + # this function is for test sparse_embedding_differ_length + if hasattr(args, "diff_batch") and args.diff_batch: + assert len( + batch) > 2, "in differ_batch mode, len(batch) must > 2." + if paddle.distributed.get_rank() == 0: + new_batch.append(batch[0]) + elif paddle.distributed.get_rank() == 1: + new_batch.extend([_ for _ in batch[1:]]) + else: + raise NotImplementedError( + "Current TestParallelDyGraphRunnerBase don't support world_size > 2" + ) + return new_batch + else: + for offset, item in enumerate(batch): + if offset % 2 == args.trainer_id: + new_batch.append(item) + return new_batch else: return batch @@ -699,6 +717,7 @@ def runtime_main(test_class): parser.add_argument('--use_fleet_api', action='store_true') parser.add_argument('--use_fleet_api_20', action='store_true') parser.add_argument('--use_local_sgd', action='store_true') + parser.add_argument('--diff_batch', action='store_true') parser.add_argument('--ut4grad_allreduce', action='store_true') parser.add_argument( '--hallreduce_inter_nranks', type=int, required=False, default=2) @@ -798,6 +817,7 @@ def setUp(self): self._gloo_mode = False # now, support gloo backend self._pipeline_mode = False self._mp_mode = False + self._diff_batch = False # FIXME(typhoonzero): I added this stupid argument to enable # testing allreduce layers, which users can call layers.allreduce # to accumulate tensors at anywhere. Find a better way to do this @@ -1100,6 +1120,8 @@ def _get_gloo_trainer_cmd(self, model, ep, update_method, trainer_id, #assert self._use_reader_alloc == False, "gloo not support _use_reduce" if self._save_model: tr_cmd += " --save_model" + if self._diff_batch: + tr_cmd += " --diff_batch" self.__use_cuda = False self.__use_xpu = False assert self.__use_cuda == False, "gloo not support use cuda" diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_diff_length_gloo.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_diff_length_gloo.py new file mode 100644 index 0000000000000..1c425a40a9b39 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sparse_embedding_diff_length_gloo.py @@ -0,0 +1,46 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import sys +import unittest + +import paddle.fluid as fluid +from test_dist_base import TestDistBase +from spawn_runner_base import TestDistSpawnRunner +from parallel_dygraph_sparse_embedding import TestSparseEmbedding +from parallel_dygraph_sparse_embedding_fp64 import TestSparseEmbeddingFP64 + +flag_name = os.path.splitext(__file__)[0] + + +class TestParallelDygraphSparseEmdedding_GLOO(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._gloo_mode = True + self._dygraph = True + self._diff_batch = True + + def test_sparse_embedding(self): + self.check_with_place( + "parallel_dygraph_sparse_embedding.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +if __name__ == "__main__": + unittest.main() From 616ce203e2a3d55d540cdff1ca929f6db499b88a Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Tue, 26 Oct 2021 14:22:21 +0800 Subject: [PATCH 9/9] [Cherry-pick] Add the forward QR operator (#36627) --- cmake/operators.cmake | 1 + paddle/fluid/operators/qr_op.cc | 152 +++++++++ paddle/fluid/operators/qr_op.cu | 309 ++++++++++++++++++ paddle/fluid/operators/qr_op.h | 135 ++++++++ paddle/fluid/operators/svd_helper.h | 13 + paddle/fluid/platform/dynload/cusolver.h | 18 +- .../fluid/tests/unittests/test_qr_op.py | 168 ++++++++++ python/paddle/linalg.py | 2 + python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/linalg.py | 66 +++- 10 files changed, 864 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/qr_op.cc create mode 100644 paddle/fluid/operators/qr_op.cu create mode 100644 paddle/fluid/operators/qr_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_qr_op.py diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 1f25dfd8a9f4b..2b1c52fca8a0e 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -185,6 +185,7 @@ function(op_library TARGET) list(REMOVE_ITEM hip_srcs "cholesky_op.cu") list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu") list(REMOVE_ITEM hip_srcs "svd_op.cu") + list(REMOVE_ITEM hip_srcs "qr_op.cu") list(REMOVE_ITEM hip_srcs "eigh_op.cu") list(REMOVE_ITEM hip_srcs "multinomial_op.cu") list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu") diff --git a/paddle/fluid/operators/qr_op.cc b/paddle/fluid/operators/qr_op.cc new file mode 100644 index 0000000000000..f612bb9e31f93 --- /dev/null +++ b/paddle/fluid/operators/qr_op.cc @@ -0,0 +1,152 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/qr_op.h" +#include +#include +#include +#include +#include "paddle/fluid/framework/ddim.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { +using DDim = framework::DDim; + +class QrOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "qr"); + OP_INOUT_CHECK(ctx->HasOutput("Q"), "Output", "Q", "qr"); + OP_INOUT_CHECK(ctx->HasOutput("R"), "Output", "R", "qr"); + + auto x_dims = ctx->GetInputDim("X"); + int x_rank = x_dims.size(); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + platform::errors::InvalidArgument( + "the rank of input must greater than 2")); + bool compute_q; + bool reduced_mode; + int m = x_dims[x_rank - 2]; + int n = x_dims[x_rank - 1]; + int min_mn = std::min(m, n); + std::string mode = ctx->Attrs().Get("mode"); + std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode); + + if (compute_q) { + int k = reduced_mode ? min_mn : m; + auto q_dims_vec = framework::vectorize(x_dims); + q_dims_vec[q_dims_vec.size() - 1] = k; + ctx->SetOutputDim("Q", framework::make_ddim(q_dims_vec)); + } else { + ctx->SetOutputDim("Q", framework::make_ddim({0})); + } + + int k = reduced_mode ? min_mn : m; + auto r_dims_vec = framework::vectorize(x_dims); + r_dims_vec[r_dims_vec.size() - 2] = k; + r_dims_vec[r_dims_vec.size() - 1] = n; + ctx->SetOutputDim("R", framework::make_ddim(r_dims_vec)); + + ctx->ShareLoD("X", /*->*/ "Q"); + ctx->ShareLoD("X", /*->*/ "R"); + } +}; + +class QrOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of qr op."); + AddOutput("Q", "(Tensor), The output Q tensor of qr op."); + AddOutput("R", "(Tensor), The output R tensor of qr op."); + AddAttr( + "mode", + "(string, default \"reduced\"). " + "If mode is \"reduced\", Qr op will return reduced Q and R matrices. " + "If mode is \"complete\", Qr op will return complete Q and R matrices. " + "If mode is \"r\", Qr op will only return reduced R matrix.") + .SetDefault("reduced"); + AddComment(R"DOC( +Qr Operator. + +This operator is used to perform QR operation for batched matrics $X$. +$$Q, R = qr(X)$$ + +)DOC"); + } +}; + +class QrGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Q")), "Input", + "Q@Grad", "QrGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("R")), "Input", + "R@Grad", "QrGrad"); + OP_INOUT_CHECK(ctx->HasInput("Q"), "Input", "Q", "QrGrad"); + OP_INOUT_CHECK(ctx->HasInput("R"), "Input", "R", "QrGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + "X@Grad", "QrGrad"); + + auto x_dims = ctx->GetInputDim(("X")); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(dtype, ctx.GetPlace()); + } +}; + +template +class QrGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("qr_grad"); + retv->SetInput(framework::GradVarName("Q"), this->OutputGrad("Q")); + retv->SetInput(framework::GradVarName("R"), this->OutputGrad("R")); + retv->SetInput("Q", this->Output("Q")); + retv->SetInput("R", this->Output("R")); + retv->SetInput("X", this->Input("X")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(qr, ops::QrOp, ops::QrOpMaker, + ops::QrGradMaker, + ops::QrGradMaker); + +REGISTER_OPERATOR(qr_grad, ops::QrGradOp); + +REGISTER_OP_CPU_KERNEL(qr, ops::QrCPUKernel, ops::QrCPUKernel); + +REGISTER_OP_CPU_KERNEL( + qr_grad, ops::QrGradKernel, + ops::QrGradKernel); diff --git a/paddle/fluid/operators/qr_op.cu b/paddle/fluid/operators/qr_op.cu new file mode 100644 index 0000000000000..992df172ace0c --- /dev/null +++ b/paddle/fluid/operators/qr_op.cu @@ -0,0 +1,309 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver + +#include +#include +#include +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/qr_op.h" +#include "paddle/fluid/platform/dynload/cusolver.h" + +// Reuse some helper functions from svd +#include "paddle/fluid/operators/svd_helper.h" + +namespace paddle { +namespace operators { + +template +class QrGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + bool compute_q; + bool reduced_mode; + auto& dev_ctx = + context.template device_context(); + const Tensor& x = *context.Input("X"); + Tensor& q = *context.Output("Q"); + Tensor& r = *context.Output("R"); + const std::string mode = context.Attr("mode"); + std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode); + + auto numel = x.numel(); + PADDLE_ENFORCE_GT(numel, 0, platform::errors::PreconditionNotMet( + "The input of QR is empty.")); + auto x_dims = x.dims(); + int x_rank = x_dims.size(); + int m = x_dims[x_rank - 2]; + int n = x_dims[x_rank - 1]; + int min_mn = std::min(m, n); + int k = reduced_mode ? min_mn : m; + int batch_size = numel / (m * n); + int qr_stride = m * n; + int tau_stride = min_mn; + + if (compute_q) { + q.mutable_data>( + context.GetPlace(), + size_t(batch_size * m * k * sizeof(math::Real))); + } + r.mutable_data>( + context.GetPlace(), size_t(batch_size * k * n * sizeof(math::Real))); + + auto dito = + math::DeviceIndependenceTensorOperations(context); + + // Note: allocate temporary tensors because of lacking in-place operatios. + // Prepare qr + Tensor qr; + qr.mutable_data>( + context.GetPlace(), size_t(batch_size * m * n * sizeof(math::Real))); + // BatchedGeqrf performs computation in-place and 'qr' must be a copy of + // input + TensorCopy(x, context.GetPlace(), &qr); + + // Prepare tau + auto tau_dims_vec = framework::vectorize(x_dims); + tau_dims_vec.pop_back(); + tau_dims_vec[tau_dims_vec.size() - 1] = min_mn; + Tensor tau = dito.Fill(tau_dims_vec, 0); + + // Transpose 'qr' to conform the column-major order + auto tmp_qr = dito.Transpose(qr); + framework::TensorCopy(tmp_qr, qr.place(), &qr); + auto qr_data = qr.mutable_data(context.GetPlace()); + auto tau_data = tau.mutable_data(context.GetPlace()); + + BatchedGeqrf(dev_ctx, batch_size, m, n, qr_data, m, tau_data, qr_stride, + tau_stride); + + if (reduced_mode) { + auto trans_qr = dito.Transpose(qr); + auto sliced_qr = dito.Slice(trans_qr, {-2}, {0}, {min_mn}); + auto tmp_r = dito.TrilTriu(sliced_qr, 0, false); + // Transpose 'tmp_r' to retore the original row-major order + framework::TensorCopy(tmp_r, r.place(), &r); + } else { + auto trans_qr = dito.Transpose(qr); + auto tmp_r = dito.TrilTriu(trans_qr, 0, false); + // Transpose 'tmp_r' to retore the original row-major order + framework::TensorCopy(tmp_r, r.place(), &r); + } + + if (compute_q) { + // Perform QRGQR for Q using the result from GEQRF + // Transpose 'q' to retore the original row-major order + if (reduced_mode) { + BatchedOrgqr(dev_ctx, batch_size, m, min_mn, min_mn, qr_data, m, + tau_data, qr_stride, tau_stride); + auto trans_q = dito.Transpose(qr); + auto sliced_q = dito.Slice(trans_q, {-1}, {0}, {min_mn}); + framework::TensorCopy(sliced_q, q.place(), &q); + } else { + if (m > n) { + auto new_qr_dims_vec = framework::vectorize(x_dims); + new_qr_dims_vec[new_qr_dims_vec.size() - 1] = m; + Tensor new_qr = dito.Fill(new_qr_dims_vec, 0); + auto new_qr_data = new_qr.mutable_data(context.GetPlace()); + auto new_qr_stride = m * m; + for (int i = 0; i < batch_size; ++i) { + memory::Copy( + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + (new_qr_data + i * new_qr_stride), + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + (qr_data + i * qr_stride), qr_stride * sizeof(math::Real), + dev_ctx.stream()); + } + BatchedOrgqr(dev_ctx, batch_size, m, m, min_mn, new_qr_data, m, + tau_data, new_qr_stride, tau_stride); + auto trans_q = dito.Transpose(new_qr); + framework::TensorCopy(trans_q, q.place(), &q); + } else { + BatchedOrgqr(dev_ctx, batch_size, m, m, min_mn, qr_data, m, tau_data, + qr_stride, tau_stride); + auto trans_q = dito.Transpose(qr); + auto sliced_q = dito.Slice(trans_q, {-1}, {0}, {m}); + framework::TensorCopy(sliced_q, q.place(), &q); + } + } + } + } + + void BatchedGeqrf(const platform::CUDADeviceContext& dev_ctx, int batch_size, + int m, int n, float* a, int lda, float* tau, int a_stride, + int tau_stride) const; + + void BatchedGeqrf(const platform::CUDADeviceContext& dev_ctx, int batch_size, + int m, int n, double* a, int lda, double* tau, int a_stride, + int tau_stride) const; + + void BatchedOrgqr(const platform::CUDADeviceContext& dev_ctx, int batch_size, + int m, int n, int k, float* a, int lda, float* tau, + int a_stride, int tau_stride) const; + + void BatchedOrgqr(const platform::CUDADeviceContext& dev_ctx, int batch_size, + int m, int n, int k, double* a, int lda, double* tau, + int a_stride, int tau_stride) const; +}; + +template <> +void QrGPUKernel::BatchedGeqrf( + const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, + float* a, int lda, float* tau, int a_stride, int tau_stride) const { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSgeqrf_bufferSize( + handle, m, n, a, lda, &lwork)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); + float* workspace_ptr = reinterpret_cast(workspace->ptr()); + auto info = memory::Alloc(dev_ctx, sizeof(int)); + int* info_d = reinterpret_cast(info->ptr()); + + for (int i = 0; i < batch_size; ++i) { + float* a_working_ptr = &a[i * a_stride]; + float* tau_working_ptr = &tau[i * tau_stride]; + // compute geqrf + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSgeqrf( + handle, m, n, a_working_ptr, lda, tau_working_ptr, workspace_ptr, lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + memory::Copy(platform::CPUPlace(), &info_h, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + info_d, sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); + } +} + +template <> +void QrGPUKernel::BatchedGeqrf( + const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, + double* a, int lda, double* tau, int a_stride, int tau_stride) const { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDgeqrf_bufferSize( + handle, m, n, a, lda, &lwork)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); + double* workspace_ptr = reinterpret_cast(workspace->ptr()); + auto info = memory::Alloc(dev_ctx, sizeof(int)); + int* info_d = reinterpret_cast(info->ptr()); + + for (int i = 0; i < batch_size; ++i) { + double* a_working_ptr = &a[i * a_stride]; + double* tau_working_ptr = &tau[i * tau_stride]; + // compute geqrf + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDgeqrf( + handle, m, n, a_working_ptr, lda, tau_working_ptr, workspace_ptr, lwork, + info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + memory::Copy(platform::CPUPlace(), &info_h, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + info_d, sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver geqrf is not zero. [%d]", i, info_h)); + } +} + +template <> +void QrGPUKernel::BatchedOrgqr( + const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, + int k, float* a, int lda, float* tau, int a_stride, int tau_stride) const { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSorgqr_bufferSize( + handle, m, n, k, a, lda, tau, &lwork)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float)); + float* workspace_ptr = reinterpret_cast(workspace->ptr()); + auto info = memory::Alloc(dev_ctx, sizeof(int)); + int* info_d = reinterpret_cast(info->ptr()); + + for (int i = 0; i < batch_size; ++i) { + float* a_working_ptr = &a[i * a_stride]; + float* tau_working_ptr = &tau[i * tau_stride]; + // compute orggr + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSorgqr( + handle, m, n, k, a_working_ptr, lda, tau_working_ptr, workspace_ptr, + lwork, info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + memory::Copy(platform::CPUPlace(), &info_h, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + info_d, sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); + } +} + +template <> +void QrGPUKernel::BatchedOrgqr( + const platform::CUDADeviceContext& dev_ctx, int batch_size, int m, int n, + int k, double* a, int lda, double* tau, int a_stride, + int tau_stride) const { + int lwork = 0; + + auto handle = dev_ctx.cusolver_dn_handle(); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDorgqr_bufferSize( + handle, m, n, k, a, lda, tau, &lwork)); + auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double)); + double* workspace_ptr = reinterpret_cast(workspace->ptr()); + auto info = memory::Alloc(dev_ctx, sizeof(int)); + int* info_d = reinterpret_cast(info->ptr()); + + for (int i = 0; i < batch_size; ++i) { + double* a_working_ptr = &a[i * a_stride]; + double* tau_working_ptr = &tau[i * tau_stride]; + // compute orggr + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnDorgqr( + handle, m, n, k, a_working_ptr, lda, tau_working_ptr, workspace_ptr, + lwork, info_d)); + // Do we need synchronized here? + // check the error info + int info_h; + memory::Copy(platform::CPUPlace(), &info_h, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + info_d, sizeof(int), dev_ctx.stream()); + PADDLE_ENFORCE_EQ( + info_h, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: CUSolver QR is not zero. [%d]", i, info_h)); + } +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(qr, ops::QrGPUKernel, ops::QrGPUKernel); +REGISTER_OP_CUDA_KERNEL( + qr_grad, ops::QrGradKernel, + ops::QrGradKernel); + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/qr_op.h b/paddle/fluid/operators/qr_op.h new file mode 100644 index 0000000000000..73ba52f590c0d --- /dev/null +++ b/paddle/fluid/operators/qr_op.h @@ -0,0 +1,135 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +static inline std::tuple _parse_qr_mode(std::string mode) { + bool compute_q; + bool reduced; + if (mode == "reduced") { + compute_q = true; + reduced = true; + } else if (mode == "complete") { + compute_q = true; + reduced = false; + } else if (mode == "r") { + compute_q = false; + reduced = true; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "QR received unrecognized mode '%s'" + " but expected one of 'reduced' (default), 'r', or 'complete'", + mode)); + } + return std::make_tuple(compute_q, reduced); +} + +template +class QrCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + bool compute_q; + bool reduced_mode; + const Tensor& x = *context.Input("X"); + Tensor& q = *context.Output("Q"); + Tensor& r = *context.Output("R"); + std::string mode = context.Attr("mode"); + std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode); + + auto numel = x.numel(); + PADDLE_ENFORCE_GT(numel, 0, platform::errors::PreconditionNotMet( + "The input of QR is empty.")); + auto x_dims = x.dims(); + int x_rank = x_dims.size(); + int m = x_dims[x_rank - 2]; + int n = x_dims[x_rank - 1]; + int min_mn = std::min(m, n); + int k = reduced_mode ? min_mn : m; + int batch_size = numel / (m * n); + int x_stride = m * n; + int q_stride = m * k; + int r_stride = k * n; + + auto* x_data = x.data>(); + T* q_data = nullptr; + if (compute_q) { + q_data = q.mutable_data>( + context.GetPlace(), + size_t(batch_size * m * k * sizeof(math::Real))); + } + auto* r_data = r.mutable_data>( + context.GetPlace(), size_t(batch_size * k * n * sizeof(math::Real))); + + // Implement QR by calling Eigen + for (int i = 0; i < batch_size; ++i) { + const T* x_matrix_ptr = x_data + i * x_stride; + T* r_matrix_ptr = r_data + i * r_stride; + using EigenDynamicMatrix = + Eigen::Matrix; + auto x_matrix = Eigen::Map(x_matrix_ptr, m, n); + Eigen::HouseholderQR qr(x_matrix); + if (reduced_mode) { + auto qr_top_matrix = qr.matrixQR().block(0, 0, min_mn, n); + auto r_matrix_view = + qr_top_matrix.template triangularView(); + auto r_matrix = EigenDynamicMatrix(r_matrix_view); + memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T)); + } else { + auto r_matrix_view = + qr.matrixQR().template triangularView(); + auto r_matrix = EigenDynamicMatrix(r_matrix_view); + memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T)); + } + + if (compute_q) { + T* q_matrix_ptr = q_data + i * q_stride; + if (reduced_mode) { + auto q_matrix = + qr.householderQ() * EigenDynamicMatrix::Identity(m, min_mn); + q_matrix.transposeInPlace(); + memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T)); + } else { + auto q_matrix = + qr.householderQ() * EigenDynamicMatrix::Identity(m, m); + q_matrix.transposeInPlace(); + memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T)); + } + } + } + } +}; + +template +class QrGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + PADDLE_THROW(platform::errors::InvalidArgument( + "QR doesn't have the backward kernel now and will be supported soon.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index 9ba7c9a3062a0..6b2584682277e 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -502,6 +502,19 @@ struct DeviceIndependenceTensorOperations { return ret; } + framework::Tensor TrilTriu(const framework::Tensor& x, int diagonal, + bool lower) { + framework::AttributeMap attrs; + attrs["diagonal"] = diagonal; + attrs["lower"] = lower; + NameInTensorMap inputs({{"X", {&x}}}); + int x_rank = x.dims().size(); + PADDLE_ENFORCE_GE(x_rank, 2, platform::errors::InvalidArgument( + "Rank must be at least 2.")); + std::vector out_shape = framework::vectorize(x.dims()); + return CreateOpRunAndReturnTensor("tril_triu", inputs, attrs, out_shape); + } + Tensor Conj(const Tensor& x) { Tensor out; auto* out_data = out.mutable_data(x.dims(), context.GetPlace()); diff --git a/paddle/fluid/platform/dynload/cusolver.h b/paddle/fluid/platform/dynload/cusolver.h index a8ce1cc9d3a35..4c018908b5945 100644 --- a/paddle/fluid/platform/dynload/cusolver.h +++ b/paddle/fluid/platform/dynload/cusolver.h @@ -65,11 +65,27 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); __macro(cusolverDnSpotrfBatched); \ __macro(cusolverDnDpotrfBatched); \ __macro(cusolverDnSgesvdj_bufferSize); \ + __macro(cusolverDnSgeqrf_bufferSize); \ + __macro(cusolverDnDgeqrf_bufferSize); \ + __macro(cusolverDnCgeqrf_bufferSize); \ + __macro(cusolverDnZgeqrf_bufferSize); \ + __macro(cusolverDnSorgqr_bufferSize); \ + __macro(cusolverDnDorgqr_bufferSize); \ + __macro(cusolverDnCungqr_bufferSize); \ + __macro(cusolverDnZungqr_bufferSize); \ __macro(cusolverDnDestroyGesvdjInfo); \ __macro(cusolverDnCreateGesvdjInfo); \ __macro(cusolverDnDgesvdj_bufferSize); \ __macro(cusolverDnSgesvdj); \ - __macro(cusolverDnDgesvdj); + __macro(cusolverDnDgesvdj); \ + __macro(cusolverDnSgeqrf); \ + __macro(cusolverDnDgeqrf); \ + __macro(cusolverDnCgeqrf); \ + __macro(cusolverDnZgeqrf); \ + __macro(cusolverDnSorgqr); \ + __macro(cusolverDnDorgqr); \ + __macro(cusolverDnCungqr); \ + __macro(cusolverDnZungqr); CUSOLVER_ROUTINE_EACH_R1(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP) #endif diff --git a/python/paddle/fluid/tests/unittests/test_qr_op.py b/python/paddle/fluid/tests/unittests/test_qr_op.py new file mode 100644 index 0000000000000..30cb31f50b1fd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_qr_op.py @@ -0,0 +1,168 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import itertools +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core + + +class TestQrAPI(unittest.TestCase): + def test_dygraph(self): + paddle.disable_static() + + def run_qr_dygraph(shape, mode, dtype): + if dtype == "float32": + np_dtype = np.float32 + elif dtype == "float64": + np_dtype = np.float64 + a = np.random.rand(*shape).astype(np_dtype) + m = a.shape[-2] + n = a.shape[-1] + min_mn = min(m, n) + if mode == "reduced" or mode == "r": + k = min_mn + else: + k = m + np_q_shape = list(a.shape[:-2]) + np_q_shape.extend([m, k]) + np_r_shape = list(a.shape[:-2]) + np_r_shape.extend([k, n]) + np_q = np.zeros(np_q_shape).astype(np_dtype) + np_r = np.zeros(np_r_shape).astype(np_dtype) + batch_size = a.size // (a.shape[-1] * a.shape[-2]) + for i in range(batch_size): + coord = np.unravel_index(i, a.shape[:-2]) + if mode == "r": + tmp_r = np.linalg.qr(a[coord], mode=mode) + np_r[coord] = tmp_r + else: + tmp_q, tmp_r = np.linalg.qr(a[coord], mode=mode) + np_q[coord] = tmp_q + np_r[coord] = tmp_r + + x = paddle.to_tensor(a, dtype=dtype) + if mode == "r": + r = paddle.linalg.qr(x, mode=mode) + self.assertTrue(np.allclose(r, np_r, atol=1e-5)) + else: + q, r = paddle.linalg.qr(x, mode=mode) + self.assertTrue(np.allclose(q, np_q, atol=1e-5)) + self.assertTrue(np.allclose(r, np_r, atol=1e-5)) + + tensor_shapes = [ + (3, 5), + (5, 5), + (5, 3), # 2-dim Tensors + (2, 3, 5), + (3, 5, 5), + (4, 5, 3), # 3-dim Tensors + (2, 5, 3, 5), + (3, 5, 5, 5), + (4, 5, 5, 3) # 4-dim Tensors + ] + modes = ["reduced", "complete", "r"] + dtypes = ["float32", "float64"] + for tensor_shape, mode, dtype in itertools.product(tensor_shapes, modes, + dtypes): + run_qr_dygraph(tensor_shape, mode, dtype) + + def test_static(self): + paddle.enable_static() + + def run_qr_static(shape, mode, dtype): + if dtype == "float32": + np_dtype = np.float32 + elif dtype == "float64": + np_dtype = np.float64 + a = np.random.rand(*shape).astype(np_dtype) + m = a.shape[-2] + n = a.shape[-1] + min_mn = min(m, n) + if mode == "reduced" or mode == "r": + k = min_mn + else: + k = m + np_q_shape = list(a.shape[:-2]) + np_q_shape.extend([m, k]) + np_r_shape = list(a.shape[:-2]) + np_r_shape.extend([k, n]) + np_q = np.zeros(np_q_shape).astype(np_dtype) + np_r = np.zeros(np_r_shape).astype(np_dtype) + places = [] + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + with fluid.program_guard(fluid.Program(), fluid.Program()): + batch_size = a.size // (a.shape[-1] * a.shape[-2]) + for i in range(batch_size): + coord = np.unravel_index(i, a.shape[:-2]) + if mode == "r": + tmp_r = np.linalg.qr(a[coord], mode=mode) + np_r[coord] = tmp_r + else: + tmp_q, tmp_r = np.linalg.qr(a[coord], mode=mode) + np_q[coord] = tmp_q + np_r[coord] = tmp_r + x = paddle.fluid.data( + name="input", shape=shape, dtype=dtype) + if mode == "r": + r = paddle.linalg.qr(x, mode=mode) + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": a}, + fetch_list=[r]) + self.assertTrue( + np.allclose( + fetches[0], np_r, atol=1e-5)) + else: + q, r = paddle.linalg.qr(x, mode=mode) + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": a}, + fetch_list=[q, r]) + self.assertTrue( + np.allclose( + fetches[0], np_q, atol=1e-5)) + self.assertTrue( + np.allclose( + fetches[1], np_r, atol=1e-5)) + + tensor_shapes = [ + (3, 5), + (5, 5), + (5, 3), # 2-dim Tensors + (2, 3, 5), + (3, 5, 5), + (4, 5, 3), # 3-dim Tensors + (2, 5, 3, 5), + (3, 5, 5, 5), + (4, 5, 5, 3) # 4-dim Tensors + ] + modes = ["reduced", "complete", "r"] + dtypes = ["float32", "float64"] + for tensor_shape, mode, dtype in itertools.product(tensor_shapes, modes, + dtypes): + run_qr_static(tensor_shape, mode, dtype) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index 726355379e7b6..06b512150cee8 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -23,6 +23,7 @@ from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import matrix_rank from .tensor.linalg import svd +from .tensor.linalg import qr from .tensor.linalg import eigh # noqa: F401 from .tensor.linalg import det from .tensor.linalg import slogdet @@ -38,6 +39,7 @@ 'multi_dot', 'matrix_rank', 'svd', + 'qr', 'matrix_power', 'det', 'slogdet', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index c8f897c21648f..b898b60fe4712 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -47,6 +47,7 @@ from .linalg import mv # noqa: F401 from .linalg import eig # noqa: F401 from .linalg import matrix_power # noqa: F401 +from .linalg import qr # noqa: F401 from .linalg import eigvals # noqa: F401 from .linalg import multi_dot # noqa: F401 from .linalg import svd # noqa: F401 @@ -237,6 +238,7 @@ 'histogram', 'mv', 'matrix_power', + 'qr', 'eigvals', 'abs', 'acos', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index f112603fbb60f..6853d904adbf6 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1594,6 +1594,70 @@ def matrix_power(x, n, name=None): return out +def qr(x, mode="reduced", name=None): + r""" + Computes the QR decomposition of one matrix or batches of matrice (backward is unsupported now). + + Args: + x (Tensor): The input tensor. Its shape should be `[..., M, N]`, + where ... is zero or more batch dimensions. M and N can be arbitrary + positive number. The data type of x should be float32 or float64. + mode (str, optional): A flag to control the behavior of qr, the default is "reduced". + Suppose x's shape is `[..., M, N]` and denoting `K = min(M, N)`: + If mode = "reduced", qr op will return reduced Q and R matrices, + which means Q's shape is `[..., M, K]` and R's shape is `[..., K, N]`. + If mode = "complete", qr op will return complete Q and R matrices, + which means Q's shape is `[..., M, M]` and R's shape is `[..., M, N]`. + If mode = "r", qr op will only return reduced R matrix, which means + R's shape is `[..., K, N]`. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + If mode = "reduced" or mode = "complete", qr will return a two tensor-tuple, which represents Q and R. + If mode = "r", qr will return a tensor which represents R. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]).astype('float64') + q, r = paddle.linalg.qr(x) + print (q) + print (r) + + # Q = [[-0.16903085, 0.89708523], + # [-0.50709255, 0.27602622], + # [-0.84515425, -0.34503278]]) + + # R = [[-5.91607978, -7.43735744], + # [ 0. , 0.82807867]]) + + # one can verify : X = Q * R ; + """ + if in_dygraph_mode(): + q, r = _C_ops.qr(x, 'mode', mode) + if mode == "r": + return r + else: + return q, r + check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'qr') + check_type(mode, 'mode', str, 'qr') + helper = LayerHelper('qr', **locals()) + q = helper.create_variable_for_type_inference(dtype=x.dtype) + r = helper.create_variable_for_type_inference(dtype=x.dtype) + attrs = dict() + attrs['mode'] = mode + helper.append_op( + type='qr', inputs={'X': [x]}, outputs={'Q': q, + 'R': r}, attrs=attrs) + if mode == "r": + return r + else: + return q, r + + def eig(x, name=None): """ This API performs the eigenvalue decomposition of a square matrix or a batch of square matrices. @@ -1674,7 +1738,7 @@ def eigvals(x, name=None): Its data type should be float32, float64, complex64, or complex128. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. - + Returns: Tensor: A tensor containing the unsorted eigenvalues which has the same batch dimensions with `x`. The eigenvalues are complex-valued even when `x` is real.