Skip to content

Commit 853aa90

Browse files
armandouvOrbax Authors
authored andcommitted
Add mapping in tf_data_preprocessor from StableHLO bfloat16 to tf.bfloat16
PiperOrigin-RevId: 817290258
1 parent a0c26cb commit 853aa90

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

export/orbax/export/data_processors/tf_data_processor_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,23 @@ def test_bfloat16_convert_error(self):
179179
),
180180
)
181181

182+
def test_prepare_with_shlo_bf16_inputs(self):
183+
processor = tf_data_processor.TfDataProcessor(lambda x: x)
184+
processor.prepare(
185+
'identity',
186+
input_signature=(
187+
obm.ShloTensorSpec(shape=(1,), dtype=obm.ShloDType.bf16),
188+
),
189+
)
190+
self.assertEqual(
191+
processor.concrete_function.structured_input_signature[0][0].dtype,
192+
tf.bfloat16,
193+
)
194+
self.assertEqual(
195+
processor.input_signature[0][0],
196+
obm.ShloTensorSpec(shape=(1,), dtype=obm.ShloDType.bf16),
197+
)
198+
182199

183200
if __name__ == '__main__':
184201
googletest.main()

0 commit comments

Comments
 (0)