Skip to content

Commit

Permalink
add the check for bfloat16 dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
longranger2 committed Jun 29, 2023
1 parent c562c01 commit 8776cde
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4496,9 +4496,9 @@ def lerp(x, y, weight, name=None):
lerp(x, y, weight) = x + weight * (y - x).
Args:
x (Tensor): An N-D Tensor with starting points, the data type is float16, float32, float64.
y (Tensor): An N-D Tensor with ending points, the data type is float16, float32, float64.
weight (float|Tensor): The weight for the interpolation formula. When weight is Tensor, the data type is float16, float32, float64.
x (Tensor): An N-D Tensor with starting points, the data type is bfloat16, float16, float32, float64.
y (Tensor): An N-D Tensor with ending points, the data type is bfloat16, float16, float32, float64.
weight (float|Tensor): The weight for the interpolation formula. When weight is Tensor, the data type is bfloat16, float16, float32, float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Expand All @@ -4523,13 +4523,16 @@ def lerp(x, y, weight, name=None):
return _C_ops.lerp(x, y, weight)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'lerp'
x, 'x', ['uint16', 'float16', 'float32', 'float64'], 'lerp'
)
check_variable_and_dtype(
y, 'y', ['float16', 'float32', 'float64'], 'lerp'
y, 'y', ['uint16', 'float16', 'float32', 'float64'], 'lerp'
)
check_variable_and_dtype(
weight, 'weight', ['float16', 'float32', 'float64'], 'lerp'
weight,
'weight',
['uint16', 'float16', 'float32', 'float64'],
'lerp',
)

helper = LayerHelper('lerp', **locals())
Expand Down

0 comments on commit 8776cde

Please sign in to comment.