From 19b6ed267390ce1c01c94c010028956255bbe457 Mon Sep 17 00:00:00 2001 From: wangxiaoning <71813629+wangxn12138@users.noreply.github.com> Date: Wed, 14 Dec 2022 21:10:48 +0800 Subject: [PATCH] support fp16 index sample (#47897) * add index sample fp16 support * remove fluid APIs in distributed_strategy.py and role_maker.py * Revert "remove fluid APIs in distributed_strategy.py and role_maker.py" This reverts commit 223bbee990d3bf69e252fc3c0f19e3873550a264. * fix instantiated more than once * clean codes --- .../tests/unittests/test_index_sample_op.py | 22 +++++++++++++++++++ python/paddle/tensor/search.py | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_index_sample_op.py b/python/paddle/fluid/tests/unittests/test_index_sample_op.py index 84defacc09987..d51474e97990b 100755 --- a/python/paddle/fluid/tests/unittests/test_index_sample_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_sample_op.py @@ -99,6 +99,28 @@ def config(self): self.index_type = "int64" +class TestCase5(TestIndexSampleOp): + def config(self): + """ + For float16 x type + """ + self.x_shape = (10, 128) + self.x_type = "float16" + self.index_shape = (10, 64) + self.index_type = "int32" + + +class TestCase6(TestIndexSampleOp): + def config(self): + """ + For float16 x type + """ + self.x_shape = (10, 128) + self.x_type = "float16" + self.index_shape = (10, 64) + self.index_type = "int64" + + class TestIndexSampleShape(unittest.TestCase): def test_shape(self): paddle.enable_static() diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 2beacff2a081a..76d1ec705eea7 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -792,7 +792,7 @@ def index_sample(x, index): check_variable_and_dtype( x, 'x', - ['float32', 'float64', 'int32', 'int64'], + ['float16', 'float32', 'float64', 'int32', 'int64'], 'paddle.tensor.search.index_sample', ) check_variable_and_dtype(