diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d5a2286565aa..584b12a6f25f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -32,6 +32,11 @@ def is_hip(): triton.runtime.driver.active.get_current_target().backend == "hip" +def is_cpu(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cpu" + + int_dtypes = ['int8', 'int16', 'int32', 'int64'] uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] float_dtypes = ['float16', 'float32', 'float64'] @@ -1590,6 +1595,12 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): if is_hip() and (dtype_z in ("bfloat16", "float8_e4m3fn") or dtype_x == "float8_e4m3fn"): pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.') + # bf16 vector cast is broken in LLVM for large vectors: + # https://github.com/llvm/llvm-project/issues/92471 + # TODO: Remove the change after the bug is fixed. + if is_cpu() and dtype_x == 'bfloat16' and size > 128: + size = 128 + torch.manual_seed(0) # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. if dtype_x.startswith('bfloat'):