From c57d1e916ec16e138aad1054db9d75f3b25b5c6d Mon Sep 17 00:00:00 2001 From: Liu-xiandong <85323580+Liu-xiandong@users.noreply.github.com> Date: Mon, 25 Oct 2021 11:04:31 +0800 Subject: [PATCH] Add nn.functional.sparse_attention and some test cases, test=develop (#35757) (#36551) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add paddle.nn.functional.sparse_attention API 本个PR主要将sparse_attention功能在python层进行了一层封装,OP的主体代码见:#PR35676 此外,对于封装的python 接口,增加了相应的单测。 --- paddle/fluid/operators/CMakeLists.txt | 2 +- .../fluid/tests/unittests/CMakeLists.txt | 5 + .../unittests/test_sparse_attention_op.py | 151 +++++++++++++++--- python/paddle/nn/functional/__init__.py | 3 + .../paddle/nn/functional/sparse_attention.py | 144 +++++++++++++++++ 5 files changed, 285 insertions(+), 20 deletions(-) create mode 100644 python/paddle/nn/functional/sparse_attention.py diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index c487313f91c58..b910b4ec73901 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -94,7 +94,7 @@ if (WITH_GPU OR WITH_ROCM) endif() op_library(sync_batch_norm_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n") - if ((NOT WIN32) AND (NOT WITH_ROCM) AND (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 11.2) ) + if ((NOT WIN32) AND (NOT WITH_ROCM) AND (NOT PADDLE_WITH_ARM) AND (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 11.2) ) op_library(sparse_attention_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sparse_attention);\n") endif() diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 2bc62c28fd210..49a23890bbfd0 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -456,6 +456,11 @@ list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_while) # disable this unittest temporarily list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_exception) +# disable sparse_attention which not in suitable env +if ( (NOT WITH_GPU) OR (WIN32) OR (PADDLE_WITH_ARM) OR (WITH_ROCM) ) + list(REMOVE_ITEM TEST_OPS test_sparse_attention_op) +endif() + if (APPLE OR WIN32) list(REMOVE_ITEM TEST_OPS test_dataset) list(REMOVE_ITEM TEST_OPS test_dataset_dataloader) diff --git a/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py b/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py index 48401fb55ef3f..5134b885f3307 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py @@ -16,10 +16,13 @@ import numpy as np from op_test import OpTest import paddle.fluid.core as core +from paddle.static import Program, program_guard import paddle +import paddle.fluid as fluid +import paddle.fluid.framework as framework +import paddle.nn.functional as F import os import re -import platform def get_cuda_version(): @@ -34,22 +37,6 @@ def get_cuda_version(): return -1 -def get_linux_platform(): - if platform.system().lower() == 'windows': - return 0 - elif platform.system().lower() == 'linux': - return 1 - else: - return -1 - - -def get_suitable_env(): - if get_cuda_version() >= 11020 and get_linux_platform() == 1: - return True - else: - return False - - def softmax(x): max = np.max(x, axis=1, keepdims=True) e_x = np.exp(x - max) @@ -141,8 +128,9 @@ def init_csr_format(batch_size, num_heads, rows, blocksize): @unittest.skipIf( - not core.is_compiled_with_cuda() or get_suitable_env() == False, - "core is not compiled with CUDA and cuda version need >= 11.2 in windows") + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.2" +) class TestSparseAttentionOp(OpTest): def config(self): self.shape = (1, 1, 16, 8) @@ -201,5 +189,130 @@ def config(self): self.dtype = "float64" +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.2" +) +class TestSparseAttentionAPI(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (1, 1, 8, 4) + self.blocksize = 2 + self.dtype = 'float64' + + def test_static_graph(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + Q = paddle.static.data(name="Q", shape=self.shape, dtype=self.dtype) + K = paddle.static.data(name="K", shape=self.shape, dtype=self.dtype) + V = paddle.static.data(name="V", shape=self.shape, dtype=self.dtype) + + batch_size, num_heads, rows = self.shape[0], self.shape[ + 1], self.shape[2] + block_num = rows / self.blocksize + block_last = rows % self.blocksize + sparse_nnz_num = block_num * self.blocksize * self.blocksize + block_last * block_last + offset_shape = (batch_size, num_heads, rows + 1) + columns_shape = (batch_size, num_heads, int(sparse_nnz_num)) + + offset = paddle.static.data( + name="Offset", shape=offset_shape, dtype="int32") + columns = paddle.static.data( + name="Columns", shape=columns_shape, dtype="int32") + Out = F.sparse_attention(Q, K, V, offset, columns) + + Q_np = np.random.random(self.shape).astype(self.dtype) + K_np = np.random.random(self.shape).astype(self.dtype) + V_np = np.random.random(self.shape).astype(self.dtype) + offset_np, columns_np = init_csr_format( + self.shape[0], self.shape[1], self.shape[2], self.blocksize) + offset_np = offset_np.astype('int32') + columns_np = columns_np.astype('int32') + + exe = fluid.Executor(self.place) + fetches_result = exe.run(feed={ + "Q": Q_np, + "K": K_np, + "V": V_np, + "Offset": offset_np, + "Columns": columns_np + }, + fetch_list=[Out]) + expected_result, __, __ = ref_batch_sparse_attention( + Q_np, K_np, V_np, offset_np, columns_np) + + self.assertTrue( + np.allclose( + fetches_result, expected_result, atol=1e-5)) + + def test_dygraph(self): + paddle.disable_static() + offset, columns = init_csr_format(self.shape[0], self.shape[1], + self.shape[2], self.blocksize) + offset = offset.astype('int32') + columns = columns.astype('int32') + query = np.random.random(self.shape).astype(self.dtype) + key = np.random.random(self.shape).astype(self.dtype) + value = np.random.random(self.shape).astype(self.dtype) + + paddle_query = paddle.to_tensor(query, place=self.place) + paddle_key = paddle.to_tensor(key, place=self.place) + paddle_value = paddle.to_tensor(value, place=self.place) + paddle_offset = paddle.to_tensor(offset, place=self.place) + paddle_colunmns = paddle.to_tensor(columns, place=self.place) + + paddle_result = F.sparse_attention(paddle_query, paddle_key, + paddle_value, paddle_offset, + paddle_colunmns) + + numpy_result, __, __ = ref_batch_sparse_attention(query, key, value, + offset, columns) + numpy_result = numpy_result.astype(self.dtype) + + self.assertTrue( + np.allclose( + paddle_result.numpy(), numpy_result, atol=1e-5)) + + +class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 2, 8, 4) + self.blocksize = 2 + self.dtype = 'float32' + + +class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 2, 64, 32) + self.blocksize = 2 + self.dtype = 'float64' + + +class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 1, 64, 32) + self.blocksize = 2 + self.dtype = 'float64' + + +class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (4, 4, 128, 32) + self.blocksize = 8 + self.dtype = 'float64' + + +class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (3, 3, 35, 15) + self.blocksize = 3 + self.dtype = 'float64' + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 7965b362b9c55..4151f25b94aff 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -112,6 +112,8 @@ from ...fluid.layers import gather_tree # noqa: F401 from ...fluid.layers import temporal_shift # noqa: F401 +from .sparse_attention import sparse_attention + __all__ = [ #noqa 'conv1d', 'conv1d_transpose', @@ -207,4 +209,5 @@ 'layer_norm', 'instance_norm', 'class_center_sample', + 'sparse_attention', ] diff --git a/python/paddle/nn/functional/sparse_attention.py b/python/paddle/nn/functional/sparse_attention.py new file mode 100644 index 0000000000000..f57669f11457f --- /dev/null +++ b/python/paddle/nn/functional/sparse_attention.py @@ -0,0 +1,144 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +import paddle +from ...fluid.framework import in_dygraph_mode, default_main_program +from paddle.fluid.layer_helper import LayerHelper +from ...fluid.framework import in_dygraph_mode +from paddle import _C_ops + + +def sparse_attention(query, + key, + value, + sparse_csr_offset, + sparse_csr_columns, + name=None): + r""" + This operator sparsify the Attention matrix in Transformer module + to achieve the effect of reducing memory consumption and computation. + The sparse layout is expressed in CSR format and contains two parameters, + ``offset`` and ``columns``. + + .. math:: + + result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V + + where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. + The dimensions of the three parameters are the same. + ``d`` represents the size of the last dimension of the three parameters. + + Parameters: + query(Tensor): The query tensor in the Attention module. + It's a 4-D tensor with a shape of + :math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. + The dtype can be ``float32`` and ``float64``. + key(Tensor): The key tensor in the Attention module. + It's a 4-D tensor with a shape of + :math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. + The dtype can be ``float32`` and ``float64``. + value(Tensor): The value tensor in the Attention module. + It's a 4-D tensor with a shape of + :math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. + The dtype can be ``float32`` and ``float64``. + sparse_csr_offset(Tensor): The sparsity feature in the Attention module + is expressed in the CSR format, and the offset represents + the number of non-zero elements in each row of the matrix. + It's a 3-D tensor with a shape of + :math:`[batch\_size, num\_heads, seq\_len + 1]`. + The dtype should be ``int32``. + sparse_csr_columns(Tensor): The sparsity feature in the Attention module + is expressed in the CSR format, and the columns represent + the column index values of non-zero elements in the matrix. + It's a 3-D tensor with a shape of + :math:`[batch\_size, num\_heads, sparse\_nnz]`. + The dtype should be ``int32``. + name(str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name`. + + Returns: + A Tensor which refers to the result in the Attention module. + It's a 4-D tensor with a shape of + :math:`[batch\_size, num\_heads, seq\_len, head\_dim]`. + The dtype can be ``float32`` and ``float64``. + + Examples: + .. code-block:: python + + # required: skiptest + import paddle + import numpy as np + + query_data = np.array([[[[0, 1,], [2, 3], + [ 0, 1], [2, 3]]]]).astype("float32") + key_data = np.array([[[[0, 1,], [2, 3], + [ 0, 1], [2, 3]]]]).astype("float32") + value_data = np.array([[[[0, 1,], [2, 3], + [ 0, 1], [2, 3]]]]).astype("float32") + sparse_csr_offset_data = np.array([[[0, 2, + 4, 6, 8]]]).astype("int32") + sparse_csr_columns_data = np.array([[[0, 1, + 0, 1, 2, 3, 2, 3]]]).astype("int32") + print(query_data.shape) + # (1, 1, 4, 2) + print(sparse_csr_offset_data.shape) + # (1, 1, 5) + print(sparse_csr_columns_data.shape) + # (1, 1, 8) + paddle.disable_static() + query = paddle.to_tensor(query_data, stop_gradient=False, + place=paddle.CUDAPlace(0)) + key = paddle.to_tensor(key_data, stop_gradient=False, + place=paddle.CUDAPlace(0)) + value = paddle.to_tensor(value_data, stop_gradient=False, + place=paddle.CUDAPlace(0)) + offset = paddle.to_tensor(sparse_csr_offset_data, stop_gradient=False, + place=paddle.CUDAPlace(0)) + columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False, + place=paddle.CUDAPlace(0)) + output = paddle.nn.functional.sparse_attention(query, key, + value, offset, columns) + print(output) + + # [[[[1.60885942, 2.60885954], + # [1.99830270, 2.99830270], + # [1.60885942, 2.60885954], + # [1.99830270, 2.99830270]]]] + """ + if in_dygraph_mode(): + result_attention, result_sdd, result_softmax = _C_ops.sparse_attention( + query, key, value, sparse_csr_offset, sparse_csr_columns) + return result_attention + + helper = LayerHelper('sparse_attention', **locals()) + dtype = helper.input_dtype(input_param_name='Q') + out = helper.create_variable_for_type_inference(dtype) + result_sdd = helper.create_variable_for_type_inference(dtype) + result_softmax = helper.create_variable_for_type_inference(dtype) + inputs = { + 'Q': query, + 'K': key, + 'V': value, + 'Offset': sparse_csr_offset, + 'Columns': sparse_csr_columns + } + outputs = { + 'Out': out, + 'SparseDotSdd': result_sdd, + 'Softmax': result_softmax + } + helper.append_op(type='sparse_attention', inputs=inputs, outputs=outputs) + return out