Skip to content

Commit 3a110c9

Browse files
guangyeypytorchmergebot
authored andcommitted
Add a new API torch.xpu.is_tf32_supported for Intel GPU (pytorch#163141)
# Motivation Aligned with other backends, this PR introduces a new API `torch.xpu.is_tf32_supported`, which should be used before `torch.backends.mkldnn.allow_tf32=True` or provide hardware capability information to the Triton # Additional Context On Intel Xe architecture and newer, TF32 operations can be accelerated through DPAS (Dot Product Accumulate Systolic) instructions. Therefore, TF32 support can be determined by checking whether the device supports subgroup matrix multiply-accumulate operations. Pull Request resolved: pytorch#163141 Approved by: https://github.com/EikanWang
1 parent 5dbca58 commit 3a110c9

File tree

3 files changed

+17
-0
lines changed

3 files changed

+17
-0
lines changed

docs/source/xpu.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
is_available
2929
is_bf16_supported
3030
is_initialized
31+
is_tf32_supported
3132
set_device
3233
set_stream
3334
stream

test/test_xpu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,10 @@ def test_is_bf16_supported(self):
776776
torch.xpu.is_available(),
777777
)
778778

779+
def test_is_tf32_supported(self):
780+
if not torch.xpu.is_available():
781+
self.assertFalse(torch.xpu.is_tf32_supported())
782+
779783
def test_get_arch_list(self):
780784
if not torch.xpu._is_compiled():
781785
self.assertEqual(len(torch.xpu.get_arch_list()), 0)

torch/xpu/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,17 @@ def is_bf16_supported(including_emulation: bool = True) -> bool:
7878
)
7979

8080

81+
def is_tf32_supported() -> bool:
82+
r"""Return a bool indicating if the current XPU device supports dtype tf32."""
83+
if not is_available():
84+
return False
85+
# On Intel Xe architecture and newer, TF32 operations can be accelerated
86+
# through DPAS (Dot Product Accumulate Systolic) instructions. Therefore,
87+
# TF32 support can be determined by checking whether the device supports
88+
# subgroup matrix multiply-accumulate operations.
89+
return torch.xpu.get_device_properties().has_subgroup_matrix_multiply_accumulate
90+
91+
8192
def is_initialized():
8293
r"""Return whether PyTorch's XPU state has been initialized."""
8394
return _initialized and not _is_in_bad_fork()
@@ -559,6 +570,7 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int:
559570
"is_available",
560571
"is_bf16_supported",
561572
"is_initialized",
573+
"is_tf32_supported",
562574
"manual_seed",
563575
"manual_seed_all",
564576
"max_memory_allocated",

0 commit comments

Comments
 (0)