Skip to content

[ROCm] Triton in XLA for ROCm - changed ir_emitter_triton_test #11352

Closed
zoranjovanovic-ns wants to merge 1 commit intoopenxla:mainfrom
ROCm:rocm_triton_backend_5
Closed

[ROCm] Triton in XLA for ROCm - changed ir_emitter_triton_test #11352
zoranjovanovic-ns wants to merge 1 commit intoopenxla:mainfrom
ROCm:rocm_triton_backend_5

Conversation

@zoranjovanovic-ns
Copy link
Contributor

@zoranjovanovic-ns zoranjovanovic-ns commented Apr 9, 2024

Modified the test case to pass on rocm and cuda.
There will be at least one more XLA related PR - to switch on Triton usage for ROCm, but only after PR for adding Triton build files for ROCm on openxla/triton.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Apr 9, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Apr 9, 2024
const se::GpuComputeCapability& GpuComputeComp() {
return device_desc().gpu_compute_capability();
}
enum class Switch : uint32_t {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just use a bool directly instead of this Switch enum? I don't think we gain more insights by using this type, especially given that the constructors are called False and True. Am I missing something?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upon looking at all the CudaOrRocmChecks below, it also seems like this could all be deleted. Could you explain what is the motivation behind going through these APIs for testing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that reason for this type of checks is gone with removal of Cuda related checks (e.g. cc.IsAtLeast(e::CudaComputeCapability::AMPERE)).
I will modify my changes and update PR.

}

TEST_F(TritonGemmTestAny, DoAddConstantToScalarAndBroadcastThat) {
if (CudaOrRocmCheck(Switch::False, Switch::True)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this just be if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {?

}

TEST_F(TritonGemmTest, SingleElementTileIsHandled) {
if (CudaOrRocmCheck(Switch::False, Switch::True)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this just be if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {?


TEST_F(TritonGemmTest, SingleElementTileIsHandled) {
if (CudaOrRocmCheck(Switch::False, Switch::True)) {
GTEST_SKIP() << "Not using autotuner on ROCM yet..";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. -> .

}

TEST_F(TritonGemmTestWithoutTritonGemmAny, SkipF32F32) {
if (CudaOrRocmCheck(Switch::False, Switch::True)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this just be if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {?

}

TEST_F(TritonGemmTestWithoutTritonGemmAny, SkipU8) {
if (CudaOrRocmCheck(Switch::False, Switch::True)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this just be if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {?

}

TEST_F(TritonGemmTest, FailIfTooMuchShmem) {
if (CudaOrRocmCheck(Switch::False, Switch::True)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this just be if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {?

GpuComputeComp());
}
bool SkipBF16Tests() {
return CudaOrRocmCheck(
Copy link
Member

@bchetioui bchetioui Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the check logic just be inlined? Overall, it does seem like we have a lot of CudaOrRocmCheck implementations that don't really need to exist---we can just implement SkipBF16Tests directly and check for alternatives in tests, I believe?

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Apr 12, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Apr 12, 2024
@zoranjovanovic-ns
Copy link
Contributor Author

@xla-rotation modified according to comments.

Copy link
Member

@bchetioui bchetioui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

}

TEST_F(TritonGemmLevel2Test, DoubleBroadcastOfScalarConstantIsHandled) {
if (SkipBF16Tests()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I think that many of these tests could be rewritten to not use bf16, thereby allowing the removal of this filter and allowing more coverage on rocm. Be that as it may, this is fine for now.

const se::GpuComputeCapability& GpuComputeComp() {
return device_desc().gpu_compute_capability();
}
bool SkipBF16Tests() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: there should be one empty line between definition. I'll make the edit on my end before submitting, just wanted to highlight it.

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 15, 2024
…test

Imported from GitHub PR openxla/xla#11352

Modified the test case to pass on rocm and cuda.
There will be at least one more XLA related PR - to switch on Triton usage for ROCm, but only after PR for adding Triton build files for ROCm on openxla/triton.
Copybara import of the project:

--
dd80a6f44295642efd3ae6af6ffd3e2a3302d36e by Zoran Jovanovic <zjovanov@amd.com>:

[ROCm] Triton in XLA for ROCm - changed ir_emitter_triton_test to pass on rocm and cuda

Merging this change closes #11352

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11352 from ROCm:rocm_triton_backend_5 dd80a6f44295642efd3ae6af6ffd3e2a3302d36e
PiperOrigin-RevId: 625033187
copybara-service bot pushed a commit that referenced this pull request Apr 15, 2024
…test

Imported from GitHub PR #11352

Modified the test case to pass on rocm and cuda.
There will be at least one more XLA related PR - to switch on Triton usage for ROCm, but only after PR for adding Triton build files for ROCm on openxla/triton.
Copybara import of the project:

--
dd80a6f by Zoran Jovanovic <zjovanov@amd.com>:

[ROCm] Triton in XLA for ROCm - changed ir_emitter_triton_test to pass on rocm and cuda

Merging this change closes #11352

FUTURE_COPYBARA_INTEGRATE_REVIEW=#11352 from ROCm:rocm_triton_backend_5 dd80a6f
PiperOrigin-RevId: 625043989
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 15, 2024
…test

Imported from GitHub PR openxla/xla#11352

Modified the test case to pass on rocm and cuda.
There will be at least one more XLA related PR - to switch on Triton usage for ROCm, but only after PR for adding Triton build files for ROCm on openxla/triton.
Copybara import of the project:

--
dd80a6f44295642efd3ae6af6ffd3e2a3302d36e by Zoran Jovanovic <zjovanov@amd.com>:

[ROCm] Triton in XLA for ROCm - changed ir_emitter_triton_test to pass on rocm and cuda

Merging this change closes #11352

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11352 from ROCm:rocm_triton_backend_5 dd80a6f44295642efd3ae6af6ffd3e2a3302d36e
PiperOrigin-RevId: 625043989
copybara-service bot pushed a commit that referenced this pull request Apr 16, 2024
…test

Imported from GitHub PR #11352

Modified the test case to pass on rocm and cuda.
There will be at least one more XLA related PR - to switch on Triton usage for ROCm, but only after PR for adding Triton build files for ROCm on openxla/triton.
Copybara import of the project:

--
dd80a6f by Zoran Jovanovic <zjovanov@amd.com>:

[ROCm] Triton in XLA for ROCm - changed ir_emitter_triton_test to pass on rocm and cuda

Merging this change closes #11352

FUTURE_COPYBARA_INTEGRATE_REVIEW=#11352 from ROCm:rocm_triton_backend_5 dd80a6f
PiperOrigin-RevId: 625043989
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 16, 2024
…test

Imported from GitHub PR openxla/xla#11352

Modified the test case to pass on rocm and cuda.
There will be at least one more XLA related PR - to switch on Triton usage for ROCm, but only after PR for adding Triton build files for ROCm on openxla/triton.
Copybara import of the project:

--
dd80a6f44295642efd3ae6af6ffd3e2a3302d36e by Zoran Jovanovic <zjovanov@amd.com>:

[ROCm] Triton in XLA for ROCm - changed ir_emitter_triton_test to pass on rocm and cuda

Merging this change closes #11352

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11352 from ROCm:rocm_triton_backend_5 dd80a6f44295642efd3ae6af6ffd3e2a3302d36e
PiperOrigin-RevId: 625043989
copybara-service bot pushed a commit that referenced this pull request Apr 16, 2024
…test

Imported from GitHub PR #11352

Modified the test case to pass on rocm and cuda.
There will be at least one more XLA related PR - to switch on Triton usage for ROCm, but only after PR for adding Triton build files for ROCm on openxla/triton.
Copybara import of the project:

--
dd80a6f by Zoran Jovanovic <zjovanov@amd.com>:

[ROCm] Triton in XLA for ROCm - changed ir_emitter_triton_test to pass on rocm and cuda

Merging this change closes #11352

FUTURE_COPYBARA_INTEGRATE_REVIEW=#11352 from ROCm:rocm_triton_backend_5 dd80a6f
PiperOrigin-RevId: 625043989
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 16, 2024
…test

Imported from GitHub PR openxla/xla#11352

Modified the test case to pass on rocm and cuda.
There will be at least one more XLA related PR - to switch on Triton usage for ROCm, but only after PR for adding Triton build files for ROCm on openxla/triton.
Copybara import of the project:

--
dd80a6f44295642efd3ae6af6ffd3e2a3302d36e by Zoran Jovanovic <zjovanov@amd.com>:

[ROCm] Triton in XLA for ROCm - changed ir_emitter_triton_test to pass on rocm and cuda

Merging this change closes #11352

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11352 from ROCm:rocm_triton_backend_5 dd80a6f44295642efd3ae6af6ffd3e2a3302d36e
PiperOrigin-RevId: 625043989
copybara-service bot pushed a commit that referenced this pull request Apr 16, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=#11352 from ROCm:rocm_triton_backend_5 dd80a6f
PiperOrigin-RevId: 625365616
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 16, 2024
…test

Imported from GitHub PR openxla/xla#11352

Modified the test case to pass on rocm and cuda.
There will be at least one more XLA related PR - to switch on Triton usage for ROCm, but only after PR for adding Triton build files for ROCm on openxla/triton.
Copybara import of the project:

--
dd80a6f44295642efd3ae6af6ffd3e2a3302d36e by Zoran Jovanovic <zjovanov@amd.com>:

[ROCm] Triton in XLA for ROCm - changed ir_emitter_triton_test to pass on rocm and cuda

Merging this change closes #11352

PiperOrigin-RevId: 625361694
copybara-service bot pushed a commit that referenced this pull request Apr 16, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=#11352 from ROCm:rocm_triton_backend_5 dd80a6f
PiperOrigin-RevId: 625367225
@alekstheod alekstheod deleted the rocm_triton_backend_5 branch July 3, 2025 08:11
@alekstheod alekstheod restored the rocm_triton_backend_5 branch July 15, 2025 08:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants