Skip to content

Commit

Permalink
support fp16 index sample (PaddlePaddle#47897)
Browse files Browse the repository at this point in the history
* add index sample fp16 support

* remove fluid APIs in distributed_strategy.py and role_maker.py

* Revert "remove fluid APIs in distributed_strategy.py and role_maker.py"

This reverts commit 223bbee.

* fix instantiated more than once

* clean codes
  • Loading branch information
wangxn12138 authored Dec 14, 2022
1 parent 0148968 commit 19b6ed2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
22 changes: 22 additions & 0 deletions python/paddle/fluid/tests/unittests/test_index_sample_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,28 @@ def config(self):
self.index_type = "int64"


class TestCase5(TestIndexSampleOp):
def config(self):
"""
For float16 x type
"""
self.x_shape = (10, 128)
self.x_type = "float16"
self.index_shape = (10, 64)
self.index_type = "int32"


class TestCase6(TestIndexSampleOp):
def config(self):
"""
For float16 x type
"""
self.x_shape = (10, 128)
self.x_type = "float16"
self.index_shape = (10, 64)
self.index_type = "int64"


class TestIndexSampleShape(unittest.TestCase):
def test_shape(self):
paddle.enable_static()
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def index_sample(x, index):
check_variable_and_dtype(
x,
'x',
['float32', 'float64', 'int32', 'int64'],
['float16', 'float32', 'float64', 'int32', 'int64'],
'paddle.tensor.search.index_sample',
)
check_variable_and_dtype(
Expand Down

0 comments on commit 19b6ed2

Please sign in to comment.