diff --git a/projects/hipblaslt/CMakeLists.txt b/projects/hipblaslt/CMakeLists.txt index 9148d862dd1..11b8fdfe5c3 100644 --- a/projects/hipblaslt/CMakeLists.txt +++ b/projects/hipblaslt/CMakeLists.txt @@ -88,6 +88,10 @@ if(HIPBLASLT_ENABLE_HOST) option(HIPBLASLT_ENABLE_HIPBLAS_DIRECT "Use the hipblas header directly." OFF) endif() +# mxDataGenerator is used by hipblaslt clients and tensilelite client; default OFF on Windows +cmake_dependent_option(HIPBLASLT_ENABLE_MXDATAGENERATOR "Use mxDataGenerator for MX format data generation (clients/tests)." ON "NOT WIN32" OFF) +message(STATUS "Enable mxDataGenerator for MX format data generation: ${HIPBLASLT_ENABLE_MXDATAGENERATOR}") + set(CMAKE_SKIP_BUILD_RPATH FALSE CACHE BOOL "Skip build RPATH") set(CMAKE_BUILD_WITH_INSTALL_RPATH FALSE CACHE BOOL "Build with install RPATH") set(CMAKE_INSTALL_RPATH "$ORIGIN/../lib:$ORIGIN/../llvm/lib" CACHE STRING "Install RPATH") @@ -255,6 +259,48 @@ if(NOT ROCM_LIBS_SUPERBUILD) endif() endif() +# INTERFACE library that owns the public hipblaslt API headers (the in-tree +# `library/include` subtree plus its build-tree counterpart, where +# `hipblaslt-export.h` and `hipblaslt-version.h` are generated). Routing the +# include directories through a target lets consumers pick them up via the +# build graph rather than via `target_include_directories( BEFORE +# PRIVATE .../library/include ...)`. `hip::host` is exposed as an INTERFACE +# link because the in-tree headers `#include `. +add_library(hipblaslt-headers INTERFACE) +add_library(hipblaslt::headers ALIAS hipblaslt-headers) +target_include_directories(hipblaslt-headers + INTERFACE + $ + $ +) +target_link_libraries(hipblaslt-headers INTERFACE hip::host) + +# `hipblaslt-version.h` is consumed by `library/include/hipblaslt/hipblaslt.h` +# (only included when HIPBLASLT_ENABLE_HOST is ON). When HOST is ON, the +# canonical `configure_file` call in `library/include/CMakeLists.txt` +# regenerates the same file from the same template; this generation is +# unconditional only because the file is cheap to produce. +configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/library/include/hipblaslt-version.h.in" + "${CMAKE_CURRENT_BINARY_DIR}/library/include/hipblaslt/hipblaslt-version.h" +) + +# `hipblaslt::mxdatagen` lives in clients/common (next to its source), and is +# only built when at least one consumer needs it. tensilelite/client and +# tensilelite/tests reach for it on the same condition, so we drive the +# subdir from here to keep the dependency wiring close to the gate that +# decides whether to build it at all. +set(_hipblaslt_mxdatagen_enabled FALSE) +if(HIPBLASLT_ENABLE_MXDATAGENERATOR + AND (HIPBLASLT_ENABLE_CLIENT + OR TENSILELITE_ENABLE_CLIENT + OR TENSILELITE_BUILD_TESTING)) + set(_hipblaslt_mxdatagen_enabled TRUE) +endif() +if(_hipblaslt_mxdatagen_enabled) + add_subdirectory(clients/common) +endif() + add_subdirectory(tensilelite) if(HIPBLASLT_ENABLE_HOST) diff --git a/projects/hipblaslt/clients/CMakeLists.txt b/projects/hipblaslt/clients/CMakeLists.txt index e314af282b7..071a560e4d7 100755 --- a/projects/hipblaslt/clients/CMakeLists.txt +++ b/projects/hipblaslt/clients/CMakeLists.txt @@ -29,19 +29,19 @@ target_link_libraries(hipblaslt-clients-common tensilelite::tensilelite-host ) +if(HIPBLASLT_ENABLE_MXDATAGENERATOR) + # PUBLIC because public headers such as testing_matmul.hpp guard + # `` on `HIPBLASLT_ENABLE_MXDATAGENERATOR`, so downstream + # targets (hipblaslt-test, hipblaslt-bench, ...) need to see the macro, + # the include path, and the C++20 floor all the way through. + target_link_libraries(hipblaslt-clients-common PUBLIC hipblaslt::mxdatagen) +endif() + if(HIPBLASLT_ENABLE_ROCROLLER) - if(NOT ROCM_LIBS_SUPERBUILD) - if(HIPBLASLT_ENABLE_THEROCK) - find_package(mxDataGenerator REQUIRED) - else() - add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/../../../shared/mxdatagenerator" "${CMAKE_CURRENT_BINARY_DIR}/mxdatagenerator") - endif() - endif() target_compile_definitions(hipblaslt-clients-common PRIVATE HIPBLASLT_USE_ROCROLLER) - target_link_libraries(hipblaslt-clients-common PRIVATE roc::mxDataGenerator) - target_compile_features(hipblaslt-clients-common PRIVATE cxx_std_20) endif() + if(HIPBLASLT_ENABLE_ASAN) hipblaslt_target_configure_sanitizers(hipblaslt-clients-common PUBLIC) endif() @@ -59,7 +59,13 @@ if(HIPBLASLT_ENABLE_OPENMP) target_link_libraries(hipblaslt-bench PRIVATE OpenMP::OpenMP_CXX) endif() -add_subdirectory(common) +# `clients/common/CMakeLists.txt` itself only declares the mxdatagen helper +# (which the project root already added when needed); the `include` and +# `src` subdirectories populate hipblaslt-clients-common's sources, so they +# are added directly here, avoiding a second `add_subdirectory(common)` that +# would clash with the root-level one. +add_subdirectory(common/include) +add_subdirectory(common/src) add_subdirectory(bench) add_executable(hipblaslt-api-overhead "${CMAKE_CURRENT_SOURCE_DIR}/bench/src/client_api_overhead.cpp") diff --git a/projects/hipblaslt/clients/bench/src/client.cpp b/projects/hipblaslt/clients/bench/src/client.cpp index 1a8bfcf3504..8ca7d4b9c91 100644 --- a/projects/hipblaslt/clients/bench/src/client.cpp +++ b/projects/hipblaslt/clients/bench/src/client.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -969,16 +969,16 @@ try if(isBlockScaling(arg.scaleA)) { if(arg.a_type != HIP_R_8F_E4M3 && arg.a_type != HIP_R_8F_E5M2 - && arg.a_type != HIP_R_4F_E2M1_EXT && arg.a_type != HIP_R_6F_E2M3_EXT - && arg.a_type != HIP_R_6F_E3M2_EXT) + && arg.a_type != HIP_R_4F_E2M1 && arg.a_type != HIP_R_6F_E2M3 + && arg.a_type != HIP_R_6F_E3M2) throw std::invalid_argument("Invalid a_type for block scaling format: "s + hip_datatype_to_string(arg.a_type)); } if(isBlockScaling(arg.scaleB)) { if(arg.b_type != HIP_R_8F_E4M3 && arg.b_type != HIP_R_8F_E5M2 - && arg.b_type != HIP_R_4F_E2M1_EXT && arg.b_type != HIP_R_6F_E2M3_EXT - && arg.b_type != HIP_R_6F_E3M2_EXT) + && arg.b_type != HIP_R_4F_E2M1 && arg.b_type != HIP_R_6F_E2M3 + && arg.b_type != HIP_R_6F_E3M2) throw std::invalid_argument("Invalid b_type for block scaling format: "s + hip_datatype_to_string(arg.b_type)); } diff --git a/projects/hipblaslt/clients/common/CMakeLists.txt b/projects/hipblaslt/clients/common/CMakeLists.txt index 5953fcd289f..5b764a06918 100755 --- a/projects/hipblaslt/clients/common/CMakeLists.txt +++ b/projects/hipblaslt/clients/common/CMakeLists.txt @@ -1,2 +1,78 @@ -add_subdirectory(include) -add_subdirectory(src) +# Copyright Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# Owns the `hipblaslt::mxdatagen` STATIC helper. The source and header live +# under this directory, so the target is defined here rather than at the root. +# +# This file is added by the project root via `add_subdirectory(clients/common)` +# only when: +# * HIPBLASLT_ENABLE_MXDATAGENERATOR is ON, and +# * at least one consumer subtree is enabled (HIPBLASLT_ENABLE_CLIENT, +# TENSILELITE_ENABLE_CLIENT, or TENSILELITE_BUILD_TESTING). +# +# Consumers should `target_link_libraries( PUBLIC hipblaslt::mxdatagen)` +# (PRIVATE if no further consumer needs the macro). Linking propagates: +# * the include path for `` +# * the link to `roc::mxDataGenerator` +# * the C++20 requirement +# * the `HIPBLASLT_ENABLE_MXDATAGENERATOR` macro +# * `hipblaslt::headers`, which carries the in-tree `` API +# headers (the source pulls in `` and +# ``). + +if(NOT ROCM_LIBS_SUPERBUILD) + if(HIPBLASLT_ENABLE_THEROCK) + find_package(mxDataGenerator REQUIRED) + else() + # `${PROJECT_SOURCE_DIR}/../../shared/mxdatagenerator` resolves to the + # sibling `shared/mxdatagenerator` subtree of the project. The binary + # dir is anchored under the hipblaslt project's build tree to match + # the upstream subdirectory's existing layout assumptions. + add_subdirectory( + "${PROJECT_SOURCE_DIR}/../../shared/mxdatagenerator" + "${PROJECT_BINARY_DIR}/mxdatagenerator" + ) + endif() +endif() + +add_library(hipblaslt-mxdatagen STATIC + "${CMAKE_CURRENT_SOURCE_DIR}/src/mxDataGen.cpp" +) +add_library(hipblaslt::mxdatagen ALIAS hipblaslt-mxdatagen) + +target_include_directories(hipblaslt-mxdatagen + PUBLIC + $ +) + +target_link_libraries(hipblaslt-mxdatagen + PUBLIC + hipblaslt::headers + roc::mxDataGenerator + PRIVATE + # `-x hip` (via hip::device's INTERFACE_COMPILE_OPTIONS) makes the + # translation unit compile through hipcc/clang. mxDataGen.cpp itself + # does not launch kernels, but the modern ROCm bf16 header it pulls + # in transitively uses clang-only builtins (e.g. __builtin_elementwise_rint). + hip::device +) + +target_compile_features(hipblaslt-mxdatagen PUBLIC cxx_std_20) +target_compile_definitions(hipblaslt-mxdatagen PUBLIC HIPBLASLT_ENABLE_MXDATAGENERATOR) +set_target_properties(hipblaslt-mxdatagen + PROPERTIES POSITION_INDEPENDENT_CODE ON +) + +# `` is generated by +# `library/include/CMakeLists.txt`'s `generate_export_header(hipblaslt ...)` +# call when HIPBLASLT_ENABLE_HOST=ON. In a tensilelite-only build that target +# does not exist, so produce the same file from this helper instead. Both +# code paths emit the file to the same location so `hipblaslt::headers`'s +# `${CMAKE_BINARY_DIR}/library/include` interface dir resolves it either way. +if(NOT HIPBLASLT_ENABLE_HOST) + include(GenerateExportHeader) + generate_export_header(hipblaslt-mxdatagen + BASE_NAME hipblaslt + EXPORT_FILE_NAME "${PROJECT_BINARY_DIR}/library/include/hipblaslt/hipblaslt-export.h" + ) +endif() diff --git a/projects/hipblaslt/clients/common/include/datatype_interface.hpp b/projects/hipblaslt/clients/common/include/datatype_interface.hpp index d5366ffcafe..e3ceb4156ce 100644 --- a/projects/hipblaslt/clients/common/include/datatype_interface.hpp +++ b/projects/hipblaslt/clients/common/include/datatype_interface.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -86,8 +86,8 @@ inline std::size_t realDataTypeSize(hipDataType dtype) { // These types were not defined in older versions of ROCm, so need to be handled specially here. auto const dtype_int = static_cast(dtype); - if(dtype_int == HIP_R_4F_E2M1_EXT || dtype_int == HIP_R_6F_E2M3_EXT - || dtype_int == HIP_R_6F_E3M2_EXT) + if(dtype_int == HIP_R_4F_E2M1 || dtype_int == HIP_R_6F_E2M3 + || dtype_int == HIP_R_6F_E3M2) { return 1; } diff --git a/projects/hipblaslt/clients/common/include/hipblaslt_init.hpp b/projects/hipblaslt/clients/common/include/hipblaslt_init.hpp index fa28dcec47c..42b31314039 100644 --- a/projects/hipblaslt/clients/common/include/hipblaslt_init.hpp +++ b/projects/hipblaslt/clients/common/include/hipblaslt_init.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -153,13 +153,13 @@ inline void hipblaslt_init(void* A, hipblaslt_init( static_cast(A), M, N, lda, stride, batch_count); break; - case HIP_R_6F_E2M3_EXT: + case HIP_R_6F_E2M3: hipblaslt_cerr << "hipblaslt_init not supports FP6" << std::endl; break; - case HIP_R_6F_E3M2_EXT: + case HIP_R_6F_E3M2: hipblaslt_cerr << "hipblaslt_init not supports BF6" << std::endl; break; - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: hipblaslt_cerr << "hipblaslt_init not supports FP4" << std::endl; break; default: @@ -263,13 +263,13 @@ inline void hipblaslt_init_sin(void* A, hipblaslt_init_sin( static_cast(A), M, N, lda, stride, batch_count); break; - case HIP_R_6F_E2M3_EXT: + case HIP_R_6F_E2M3: hipblaslt_cerr << "hipblaslt_init_sin not supports FP6" << std::endl; break; - case HIP_R_6F_E3M2_EXT: + case HIP_R_6F_E3M2: hipblaslt_cerr << "hipblaslt_init_sin not supports BF6" << std::endl; break; - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: hipblaslt_cerr << "hipblaslt_init_sin not supports FP4" << std::endl; break; default: @@ -353,13 +353,13 @@ inline void hipblaslt_init_alternating_sign(void* A, hipblaslt_init_alternating_sign( static_cast(A), M, N, lda, stride, batch_count); break; - case HIP_R_6F_E2M3_EXT: + case HIP_R_6F_E2M3: hipblaslt_cerr << "hipblaslt_init_alternating_sign not supports FP6" << std::endl; break; - case HIP_R_6F_E3M2_EXT: + case HIP_R_6F_E3M2: hipblaslt_cerr << "hipblaslt_init_alternating_sign not supports BF6" << std::endl; break; - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: hipblaslt_cerr << "hipblaslt_init_alternating_sign not supports FP4" << std::endl; break; default: @@ -440,13 +440,13 @@ inline void hipblaslt_init_hpl_alternating_sign(void* A, hipblaslt_init_hpl_alternating_sign( static_cast(A), M, N, lda, stride, batch_count); break; - case HIP_R_6F_E2M3_EXT: + case HIP_R_6F_E2M3: hipblaslt_cerr << "hipblaslt_init_hpl_alternating_sign not supports FP6" << std::endl; break; - case HIP_R_6F_E3M2_EXT: + case HIP_R_6F_E3M2: hipblaslt_cerr << "hipblaslt_init_hpl_alternating_sign not supports BF6" << std::endl; break; - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: hipblaslt_cerr << "hipblaslt_init_hpl_alternating_sign not supports FP4" << std::endl; break; default: @@ -521,13 +521,13 @@ inline void hipblaslt_init_cos(void* A, hipblaslt_init_cos( static_cast(A), M, N, lda, stride, batch_count); break; - case HIP_R_6F_E2M3_EXT: + case HIP_R_6F_E2M3: hipblaslt_cerr << "hipblaslt_init_cos not supports FP6" << std::endl; break; - case HIP_R_6F_E3M2_EXT: + case HIP_R_6F_E3M2: hipblaslt_cerr << "hipblaslt_init_cos not supports BF6" << std::endl; break; - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: hipblaslt_cerr << "hipblaslt_init_cos not supports FP4" << std::endl; break; default: @@ -608,13 +608,13 @@ inline void hipblaslt_init_hpl(void* A, hipblaslt_init_hpl( static_cast(A), M, N, lda, stride, batch_count); break; - case HIP_R_6F_E2M3_EXT: + case HIP_R_6F_E2M3: hipblaslt_cerr << "hipblaslt_init_hpl not supports FP6" << std::endl; break; - case HIP_R_6F_E3M2_EXT: + case HIP_R_6F_E3M2: hipblaslt_cerr << "hipblaslt_init_hpl not supports BF6" << std::endl; break; - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: hipblaslt_cerr << "hipblaslt_init_hpl not supports FP4" << std::endl; break; default: @@ -678,13 +678,13 @@ inline void hipblaslt_init_nan(void* A, size_t N, hipDataType type) case HIP_R_8I: hipblaslt_init_nan(static_cast(A), N); break; - case HIP_R_6F_E2M3_EXT: + case HIP_R_6F_E2M3: hipblaslt_cerr << "hipblaslt_init_nan not supports FP6" << std::endl; break; - case HIP_R_6F_E3M2_EXT: + case HIP_R_6F_E3M2: hipblaslt_cerr << "hipblaslt_init_nan not supports BF6" << std::endl; break; - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: hipblaslt_cerr << "hipblaslt_init_nan not supports FP4" << std::endl; break; default: @@ -733,13 +733,13 @@ inline void hipblaslt_init_nan(void* A, size_t start_offset, size_t end_offset, case HIP_R_8I: hipblaslt_init_nan(static_cast(A), start_offset, end_offset); break; - case HIP_R_6F_E2M3_EXT: + case HIP_R_6F_E2M3: hipblaslt_cerr << "hipblaslt_init_nan not supports FP6" << std::endl; break; - case HIP_R_6F_E3M2_EXT: + case HIP_R_6F_E3M2: hipblaslt_cerr << "hipblaslt_init_nan not supports BF6" << std::endl; break; - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: hipblaslt_cerr << "hipblaslt_init_nan not supports FP4" << std::endl; break; default: diff --git a/projects/hipblaslt/clients/common/include/mxDataGen.hpp b/projects/hipblaslt/clients/common/include/mxDataGen.hpp index 3df5c6103ab..a766c77df87 100644 --- a/projects/hipblaslt/clients/common/include/mxDataGen.hpp +++ b/projects/hipblaslt/clients/common/include/mxDataGen.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2025 Advanced Micro Devices, Inc. + * Copyright (C) 2025-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,13 +26,21 @@ #pragma once -#include +// All content below is gated so the header is harmless to include when the +// feature is disabled (the entire translation unit collapses to nothing). +// Callers that actually invoke `generateMXInput` must guard their use sites +// on the same macro. +#if HIPBLASLT_ENABLE_MXDATAGENERATOR + +#include +#include +#include +#include #include #include #include -#ifdef HIPBLASLT_USE_ROCROLLER std::vector generateMXInput(hipDataType dataType, hipDataType scaleType, void* data, @@ -49,4 +57,5 @@ std::vector generateMXInput(hipDataType dataType, std::string_view const initMethod = "Bounded", float min_val = -1.0f, float max_val = 1.0f); + #endif diff --git a/projects/hipblaslt/clients/common/include/norm.hpp b/projects/hipblaslt/clients/common/include/norm.hpp index 23687ff935d..65e9329e632 100644 --- a/projects/hipblaslt/clients/common/include/norm.hpp +++ b/projects/hipblaslt/clients/common/include/norm.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -584,9 +584,9 @@ bool norm_check(double norm_error, hipDataType type) case HIP_R_8I: return norm_error < 0.01; // TODO: find a suitable rnom value for f6 and f4 - case HIP_R_6F_E2M3_EXT: - case HIP_R_6F_E3M2_EXT: - case HIP_R_4F_E2M1_EXT: + case HIP_R_6F_E2M3: + case HIP_R_6F_E3M2: + case HIP_R_4F_E2M1: return norm_error < 0.5; default: return false; diff --git a/projects/hipblaslt/clients/common/include/testing_auxiliary.hpp b/projects/hipblaslt/clients/common/include/testing_auxiliary.hpp index bc8ce8e9149..b1dd656de9b 100644 --- a/projects/hipblaslt/clients/common/include/testing_auxiliary.hpp +++ b/projects/hipblaslt/clients/common/include/testing_auxiliary.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -1730,9 +1730,9 @@ void testing_aux_auxiliary_func(const Arguments& arg) ASSERT_TRUE(hip_datatype_to_string(HIP_R_8F_E5M2_FNUZ) == "bf8_fnuz_r"); ASSERT_TRUE(hip_datatype_to_string(HIP_R_8F_E4M3) == "f8_r"); ASSERT_TRUE(hip_datatype_to_string(HIP_R_8F_E5M2) == "bf8_r"); - ASSERT_TRUE(hip_datatype_to_string(static_cast(HIP_R_6F_E2M3_EXT)) == "f6_r"); - ASSERT_TRUE(hip_datatype_to_string(static_cast(HIP_R_6F_E3M2_EXT)) == "bf6_r"); - ASSERT_TRUE(hip_datatype_to_string(static_cast(HIP_R_4F_E2M1_EXT)) == "f4_r"); + ASSERT_TRUE(hip_datatype_to_string(static_cast(HIP_R_6F_E2M3)) == "f6_r"); + ASSERT_TRUE(hip_datatype_to_string(static_cast(HIP_R_6F_E3M2)) == "bf6_r"); + ASSERT_TRUE(hip_datatype_to_string(static_cast(HIP_R_4F_E2M1)) == "f4_r"); // Test hipblas_computetype_to_string hipblas_computetype_to_string(HIPBLAS_COMPUTE_16F); @@ -1756,9 +1756,9 @@ void testing_aux_auxiliary_func(const Arguments& arg) ASSERT_TRUE(string_to_hip_datatype("f16_r") == HIP_R_16F); ASSERT_TRUE(string_to_hip_datatype("bf16_r") == HIP_R_16BF); ASSERT_TRUE(string_to_hip_datatype("i8_r") == HIP_R_8I); - ASSERT_TRUE(string_to_hip_datatype("f6_r") == static_cast(HIP_R_6F_E2M3_EXT)); - ASSERT_TRUE(string_to_hip_datatype("bf6_r") == static_cast(HIP_R_6F_E3M2_EXT)); - ASSERT_TRUE(string_to_hip_datatype("f4_r") == static_cast(HIP_R_4F_E2M1_EXT)); + ASSERT_TRUE(string_to_hip_datatype("f6_r") == static_cast(HIP_R_6F_E2M3)); + ASSERT_TRUE(string_to_hip_datatype("bf6_r") == static_cast(HIP_R_6F_E3M2)); + ASSERT_TRUE(string_to_hip_datatype("f4_r") == static_cast(HIP_R_4F_E2M1)); ASSERT_TRUE(string_to_hip_datatype("i32_r") == HIP_R_32I); ASSERT_TRUE(string_to_hip_datatype("") == HIPBLASLT_DATATYPE_INVALID); @@ -2332,11 +2332,11 @@ void testing_aux_rocblaslt_utility_func(const Arguments& arg) ASSERT_TRUE(std::string_view{hipDataType_to_string(HIP_R_8F_E4M3)} == "R_8F_E4M3"); ASSERT_TRUE(std::string_view{hipDataType_to_string(HIP_R_8F_E5M2)} == "R_8F_E5M2"); ASSERT_TRUE(std::string_view{hipDataType_to_string(HIP_R_8I)} == "R_8I"); - ASSERT_TRUE(std::string_view{hipDataType_to_string(static_cast(HIP_R_6F_E2M3_EXT))} + ASSERT_TRUE(std::string_view{hipDataType_to_string(static_cast(HIP_R_6F_E2M3))} == "R_6F_E2M3"); - ASSERT_TRUE(std::string_view{hipDataType_to_string(static_cast(HIP_R_6F_E3M2_EXT))} + ASSERT_TRUE(std::string_view{hipDataType_to_string(static_cast(HIP_R_6F_E3M2))} == "R_6F_E3M2"); - ASSERT_TRUE(std::string_view{hipDataType_to_string(static_cast(HIP_R_4F_E2M1_EXT))} + ASSERT_TRUE(std::string_view{hipDataType_to_string(static_cast(HIP_R_4F_E2M1))} == "R_4F_E2M1"); ASSERT_TRUE(std::string_view{hipDataType_to_string(static_cast(999))} == "Invalid"); @@ -2353,16 +2353,16 @@ void testing_aux_rocblaslt_utility_func(const Arguments& arg) ASSERT_TRUE(std::string_view{hipDataType_to_bench_string(HIP_R_8F_E5M2)} == "bf8_r"); ASSERT_TRUE(std::string_view{hipDataType_to_bench_string(HIP_R_8F_E5M2_FNUZ)} == "bf8_r"); ASSERT_TRUE( - std::string_view{hipDataType_to_bench_string(static_cast(HIP_R_6F_E2M3_EXT))} + std::string_view{hipDataType_to_bench_string(static_cast(HIP_R_6F_E2M3))} == "f6_r"); ASSERT_TRUE( - std::string_view{hipDataType_to_bench_string(static_cast(HIP_R_6F_E3M2_EXT))} + std::string_view{hipDataType_to_bench_string(static_cast(HIP_R_6F_E3M2))} == "bf6_r"); ASSERT_TRUE( - std::string_view{hipDataType_to_bench_string(static_cast(HIP_R_4F_E2M1_EXT))} + std::string_view{hipDataType_to_bench_string(static_cast(HIP_R_4F_E2M1))} == "f4_r"); ASSERT_TRUE(std::string_view{ - hipDataType_to_bench_string(static_cast(HIP_R_4F_E2M1_EXT + 1))} + hipDataType_to_bench_string(static_cast(HIP_R_4F_E2M1 + 1))} == "invalid"); // Test rocblaslt_compute_type_to_string diff --git a/projects/hipblaslt/clients/common/include/testing_matmul.hpp b/projects/hipblaslt/clients/common/include/testing_matmul.hpp index 4ca734ae9a8..8526952ade3 100644 --- a/projects/hipblaslt/clients/common/include/testing_matmul.hpp +++ b/projects/hipblaslt/clients/common/include/testing_matmul.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,7 +38,9 @@ #include "hipblaslt_random.hpp" #include "hipblaslt_test.hpp" #include "hipblaslt_vector.hpp" +#if HIPBLASLT_ENABLE_MXDATAGENERATOR #include "mxDataGen.hpp" +#endif #include "near.hpp" #include "norm.hpp" #include "unit.hpp" @@ -84,7 +86,7 @@ bool isSwizzleSupported(hipDataType datatype) case HIP_R_16BF: case HIP_R_16F: case HIP_R_8F_E4M3_FNUZ: - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: return true; default: return false; @@ -100,7 +102,7 @@ hipblasLtOrder_t orderForDatatype(hipDataType datatype) return HIPBLASLT_ORDER_COL16_4R8; case HIP_R_8F_E4M3_FNUZ: return HIPBLASLT_ORDER_COL16_4R16; - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: return HIPBLASLT_ORDER_COL16_4R32; default: throw std::runtime_error("unsupported datatype in orderForDatatype"); @@ -141,7 +143,7 @@ void calculateKforSwizzling( MiK = 32; MiKv = 8; break; - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: // For fp4 viewed as uint8: matches shuffle_weight with layout=(16,16) // BK=32 bytes, K=16 bytes, BK/K=2 MiK = 16; // K inner block = 16 bytes @@ -243,7 +245,7 @@ std::vector mx_type_to_f32(hipDataType type, throw std::runtime_error("Error type in mx_type_to_f32()"); } #if defined(HIPBLASLT_USE_FP6) - case HIP_R_6F_E2M3_EXT: + case HIP_R_6F_E2M3: switch(stype) { case HIP_R_8F_UE8M0: @@ -255,7 +257,7 @@ std::vector mx_type_to_f32(hipDataType type, } #endif #if defined(HIPBLASLT_USE_BF6) - case HIP_R_6F_E3M2_EXT: + case HIP_R_6F_E3M2: switch(stype) { case HIP_R_8F_UE8M0: @@ -267,7 +269,7 @@ std::vector mx_type_to_f32(hipDataType type, } #endif #if defined(HIPBLASLT_USE_FP4) - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: switch(stype) { case HIP_R_8F_UE8M0: @@ -396,7 +398,7 @@ void swizzle_tensor_type(HipHostBuffer& dst, swizzle_tensor( dst.as(), src.as(), datatype, arg, b, m_n, k, ld, colMaj); return; - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: // fp4: 2 elements per byte, so k_bytes = k/2, ld_bytes = ld/2 swizzle_tensor( dst.as(), src.as(), datatype, arg, b, m_n, k / 2, ld / 2, colMaj); @@ -466,13 +468,13 @@ Tout cast_from_type(void* in, hipDataType type, size_t index) return static_cast((static_cast(in))[index]); case HIP_R_8I: return static_cast((static_cast(in))[index]); - case HIP_R_6F_E2M3_EXT: + case HIP_R_6F_E2M3: hipblaslt_cerr << "cast_from_type() does not support FP6" << std::endl; return 0; - case HIP_R_6F_E3M2_EXT: + case HIP_R_6F_E3M2: hipblaslt_cerr << "cast_from_type() does not support BF6" << std::endl; return 0; - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: hipblaslt_cerr << "cast_from_type() does not support FP4" << std::endl; return 0; default: @@ -516,13 +518,13 @@ void saturate_cast_to_type(void* dst, Tin src, hipDataType typeD, size_t indexD) case HIP_R_8I: static_cast(dst)[indexD] = saturate_cast(src); return; - case HIP_R_6F_E2M3_EXT: + case HIP_R_6F_E2M3: hipblaslt_cerr << "cast_from_type() does not support FP6!" << std::endl; return; - case HIP_R_6F_E3M2_EXT: + case HIP_R_6F_E3M2: hipblaslt_cerr << "cast_from_type() does not support BF6!" << std::endl; return; - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: hipblaslt_cerr << "cast_from_type() does not support FP4!" << std::endl; return; default: @@ -1369,9 +1371,9 @@ hipDataType derive_unset_bias_type(const Arguments& arg) else //more default cases once support C != D real_bias_type = HIP_R_16F; } - else if((arg.a_type == HIP_R_6F_E2M3_EXT && arg.b_type == HIP_R_6F_E2M3_EXT) - || (arg.a_type == HIP_R_6F_E3M2_EXT && arg.b_type == HIP_R_6F_E3M2_EXT) - || (arg.a_type == HIP_R_4F_E2M1_EXT && arg.b_type == HIP_R_4F_E2M1_EXT)) + else if((arg.a_type == HIP_R_6F_E2M3 && arg.b_type == HIP_R_6F_E2M3) + || (arg.a_type == HIP_R_6F_E3M2 && arg.b_type == HIP_R_6F_E3M2) + || (arg.a_type == HIP_R_4F_E2M1 && arg.b_type == HIP_R_4F_E2M1)) { if(arg.d_type == HIP_R_32F || arg.d_type == HIP_R_16BF) real_bias_type = HIP_R_16BF; @@ -1429,9 +1431,9 @@ std::tuple derive_unset_compute_input_type(const Argum HIP_R_8F_E5M2, HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E5M2_FNUZ, - static_cast(HIP_R_6F_E2M3_EXT), - static_cast(HIP_R_6F_E3M2_EXT), - static_cast(HIP_R_4F_E2M1_EXT), + static_cast(HIP_R_6F_E2M3), + static_cast(HIP_R_6F_E3M2), + static_cast(HIP_R_4F_E2M1), }; hipDataType real_compute_input_typeA = arg.compute_input_typeA; @@ -2481,6 +2483,9 @@ void testing_matmul_with_bias(const Arguments& arg, size_t scaleA_row = ((transA == HIPBLAS_OP_T) ? blockSize(arg.scaleA) : 1); size_t scaleA_col = ((transA == HIPBLAS_OP_T) ? 1 : blockSize(arg.scaleA)); + // TODO: mxDataGenerator can only generate data on CPU. Using + // GPU to generate data might be more efficient and avoid + // unnecessary hipMemCpy when CPU verification is not needed. if(isBlockScaling(arg.scaleA)) { #ifdef HIPBLASLT_USE_ROCROLLER @@ -2500,10 +2505,6 @@ void testing_matmul_with_bias(const Arguments& arg, } // For MX format, use mxDataGenerator to generate input data // (consists of data part and scale part) - // TODO: mxDataGenerator can only generate data on CPU. Using - // GPU to generate data might be more efficient and avoid - // unnecessary hipMemCpy when CPU verification is not needed. - // preTile for A: {tileK, tileM} - swap from preTileSizeForScaleA which returns {tileM, tileK} auto preTileATmp = preTileSizeForScaleA(arg.scaleA); auto preTileA = (preTileATmp.size() == 2) @@ -2606,9 +2607,6 @@ void testing_matmul_with_bias(const Arguments& arg, } // For MX format, use mxDataGenerator to generate // input data (consists of data part and scale part) - // TODO: mxDataGenerator can only generate data on CPU. Using - // GPU to generate data might be more efficient and avoid - // unnecessary hipMemCpy when CPU verification is not needed. // preTile for B: {tileK, tileN} auto preTileB = preTileSizeForScaleB(arg.scaleB); refB.emplace_back(generateMXInput(TiB, @@ -2728,34 +2726,7 @@ void testing_matmul_with_bias(const Arguments& arg, do_swizzle_b, stream)); CHECK_HIP_ERROR(synchronize(hC[i], dC[i], 0, 0, 0, 0, 1, false, stream)); -#ifndef HIPBLASLT_USE_ROCROLLER - if(isBlockScaling(arg.scaleA)) - { - CHECK_HIP_ERROR( - synchronize(hScaleA[i], dScaleA[i], 0, 0, 0, 0, 1, false, stream)); - refA.emplace_back(mx_type_to_f32(TiA, - scaleDataType(arg.scaleA), - hA[i], - hScaleA[i], - A_row[i], - A_col[i], - scaleA_row, - scaleA_col)); - } - if(isBlockScaling(arg.scaleB)) - { - CHECK_HIP_ERROR( - synchronize(hScaleB[i], dScaleB[i], 0, 0, 0, 0, 1, false, stream)); - refB.emplace_back(mx_type_to_f32(TiB, - scaleDataType(arg.scaleB), - hB[i], - hScaleB[i], - B_row[i], - B_col[i], - scaleB_row, - scaleB_col)); - } -#endif + if(arg.dump_matrix) { diff --git a/projects/hipblaslt/clients/common/src/CMakeLists.txt b/projects/hipblaslt/clients/common/src/CMakeLists.txt index 93ba1d972c7..ffad7f3b1bf 100755 --- a/projects/hipblaslt/clients/common/src/CMakeLists.txt +++ b/projects/hipblaslt/clients/common/src/CMakeLists.txt @@ -14,10 +14,6 @@ target_sources(hipblaslt-clients-common "${CMAKE_CURRENT_SOURCE_DIR}/hipblaslt_init_device.cpp" ) -if(HIPBLASLT_ENABLE_ROCROLLER) - target_sources(hipblaslt-clients-common PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/mxDataGen.cpp") -endif() - if(HIPBLASLT_ENABLE_BLIS) target_sources(hipblaslt-clients-common PRIVATE diff --git a/projects/hipblaslt/clients/common/src/cblas_interface.cpp b/projects/hipblaslt/clients/common/src/cblas_interface.cpp index 48a5248b30c..a04d291a1bd 100644 --- a/projects/hipblaslt/clients/common/src/cblas_interface.cpp +++ b/projects/hipblaslt/clients/common/src/cblas_interface.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -520,7 +520,7 @@ void cast_mul(customVector& dst, size); break; #if defined(HIPBLASLT_USE_FP4) - case static_cast(HIP_R_4F_E2M1_EXT): + case static_cast(HIP_R_4F_E2M1): cast_mul(dst, static_cast(src), isScaleAVec, @@ -533,7 +533,7 @@ void cast_mul(customVector& dst, break; #endif #if defined(HIPBLASLT_USE_FP6) - case static_cast(HIP_R_6F_E2M3_EXT): + case static_cast(HIP_R_6F_E2M3): cast_mul(dst, static_cast(src), isScaleAVec, @@ -546,7 +546,7 @@ void cast_mul(customVector& dst, break; #endif #if defined(HIPBLASLT_USE_BF6) - case static_cast(HIP_R_6F_E3M2_EXT): + case static_cast(HIP_R_6F_E3M2): cast_mul(dst, static_cast(src), isScaleAVec, diff --git a/projects/hipblaslt/clients/common/src/hipblaslt_init_device.cpp b/projects/hipblaslt/clients/common/src/hipblaslt_init_device.cpp index f328c79392a..ca998a0a138 100644 --- a/projects/hipblaslt/clients/common/src/hipblaslt_init_device.cpp +++ b/projects/hipblaslt/clients/common/src/hipblaslt_init_device.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2024-2025 Advanced Micro Devices, Inc. + * Copyright (C) 2024-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -343,11 +343,15 @@ template __device__ T trig_float(size_t idx, size_t M, size_t N, size_t lda, size_t stride, Func func) { + // M_PI is not part of ISO C/C++ and is not defined by on Windows (MSVC CRT / clang-cl) + // unless _USE_MATH_DEFINES is set before the first / include. Use a local + // constexpr instead so this translation unit builds on all platforms without macro gymnastics. + constexpr double two_pi = 6.28318530717958647692528676655900576; auto calc = [&](size_t k) { auto b = k / stride; auto j = (k - b * stride) / lda; auto i = (k - b * stride) - j * lda; - return fmod(double(i + j * M + b * M * N), 2 * M_PI); + return fmod(double(i + j * M + b * M * N), two_pi); }; #if defined(HIPBLASLT_USE_FP4) @@ -482,8 +486,8 @@ void hipblaslt_init_device(ABC_dims abc, || std::is_same_v || std::is_same_v) { - hipblaslt_cerr << "No support nan for HIP_R_4F_E2M1_EXT and HIP_R_6F_E2M3_EXT and " - "HIP_R_6F_E3M2_EXT in hipblaslt_init_device" + hipblaslt_cerr << "No support nan for HIP_R_4F_E2M1 and HIP_R_6F_E2M3 and " + "HIP_R_6F_E3M2 in hipblaslt_init_device" << std::endl; } else @@ -755,19 +759,19 @@ void hipblaslt_init_device(ABC_dims abc, abc, init, is_nan, static_cast(A), M, N, lda, stride, batch_count); break; #if defined(HIPBLASLT_USE_FP6) - case static_cast(HIP_R_6F_E2M3_EXT): + case static_cast(HIP_R_6F_E2M3): hipblaslt_init_device( abc, init, is_nan, static_cast(A), M, N, lda, stride, batch_count); break; #endif #if defined(HIPBLASLT_USE_BF6) - case static_cast(HIP_R_6F_E3M2_EXT): + case static_cast(HIP_R_6F_E3M2): hipblaslt_init_device( abc, init, is_nan, static_cast(A), M, N, lda, stride, batch_count); break; #endif #if defined(HIPBLASLT_USE_FP4) - case static_cast(HIP_R_4F_E2M1_EXT): + case static_cast(HIP_R_4F_E2M1): hipblaslt_init_device( abc, init, is_nan, static_cast(A), M, N, lda, stride, batch_count); break; diff --git a/projects/hipblaslt/clients/common/src/mxDataGen.cpp b/projects/hipblaslt/clients/common/src/mxDataGen.cpp index d9ff0c6bd67..fede1dbd701 100644 --- a/projects/hipblaslt/clients/common/src/mxDataGen.cpp +++ b/projects/hipblaslt/clients/common/src/mxDataGen.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2025 Advanced Micro Devices, Inc. + * Copyright (C) 2025-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,7 +27,6 @@ #include "mxDataGen.hpp" #include #include -#include #include #include #include @@ -253,7 +252,6 @@ std::vector generateData(T dgen, std::vector scaleBytes = dgen.getScaleBytes(); -#ifdef HIPBLASLT_USE_ROCROLLER // Apply pre-swizzle to scale data size_t scaleRows = sizes[0] / elementsPerMXBlock; size_t scaleCols = sizes[1]; @@ -263,7 +261,6 @@ std::vector generateData(T dgen, scaleBytes = DGen::preSwizzleScalesGFX950(scaleBytes, {scaleCols, scaleRows}); } -#endif std::memcpy(scale, scaleBytes.data(), scaleBytes.size() * sizeof(uint8_t)); @@ -313,7 +310,6 @@ std::vector generateData(T dgen, } } -#ifdef HIPBLASLT_USE_ROCROLLER /** * @brief Generate random data for OCP (MX) F8/F6/F4 types * @@ -412,7 +408,7 @@ std::vector generateMXInput(hipDataType dataType, preSwizzleTile, preTile); } - else if(static_cast(dataType) == HIP_R_6F_E2M3_EXT) + else if(static_cast(dataType) == HIP_R_6F_E2M3) { DGen::DataGenerator dgen; return generateData(dgen, @@ -428,7 +424,7 @@ std::vector generateMXInput(hipDataType dataType, preSwizzleTile, preTile); } - else if(static_cast(dataType) == HIP_R_6F_E3M2_EXT) + else if(static_cast(dataType) == HIP_R_6F_E3M2) { DGen::DataGenerator dgen; return generateData(dgen, @@ -444,7 +440,7 @@ std::vector generateMXInput(hipDataType dataType, preSwizzleTile, preTile); } - else if(static_cast(dataType) == HIP_R_4F_E2M1_EXT) + else if(static_cast(dataType) == HIP_R_4F_E2M1) { if(scaleType == HIP_R_8F_E4M3) { @@ -500,4 +496,3 @@ std::vector generateMXInput(hipDataType dataType, throw std::runtime_error("Unsupported data types in MX data generation!"); } } -#endif diff --git a/projects/hipblaslt/clients/tests/data/matmul_gtest.yaml b/projects/hipblaslt/clients/tests/data/matmul_gtest.yaml index 8b9bf80c617..0bba4cfe96c 100755 --- a/projects/hipblaslt/clients/tests/data/matmul_gtest.yaml +++ b/projects/hipblaslt/clients/tests/data/matmul_gtest.yaml @@ -2681,4 +2681,23 @@ Tests: beta: [ 0.0, 1.0 ] gpu_arch: '90a' +# This is for testing MX FP4 kernel using Tensile +- name: matmul_tensile_fp4 + category: quick + function: + matmul: + - { a_type: f4_r, b_type: f4_r, c_type: f32_r, d_type: f32_r, compute_type: c_f32_r, scaleA: 3, scaleB: 3, scale_type: f32_r} + M: [2048] + N: [2048] + K: [4096] + transA: T + transB: N + alpha: 1.0 + beta: 0.0 + initialization: hpl + unit_check: 0 + norm_check: 1 + requested_solution_num: 1 + gpu_arch: '950' + ... diff --git a/projects/hipblaslt/library/include/hipblaslt/hipblaslt-types.h b/projects/hipblaslt/library/include/hipblaslt/hipblaslt-types.h index b1c7f5b0796..d17651bcf6e 100644 --- a/projects/hipblaslt/library/include/hipblaslt/hipblaslt-types.h +++ b/projects/hipblaslt/library/include/hipblaslt/hipblaslt-types.h @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -75,8 +75,5 @@ typedef int32_t hipblasLtInt32; } #endif -int const HIP_R_6F_E2M3_EXT = 31; -int const HIP_R_6F_E3M2_EXT = 32; -int const HIP_R_4F_E2M1_EXT = 33; int const HIP_R_8F_E5M3_EXT = 34; #endif /* _HIPBLASLT_TYPES_H_ */ diff --git a/projects/hipblaslt/library/src/amd_detail/include/auxiliary.hpp b/projects/hipblaslt/library/src/amd_detail/include/auxiliary.hpp index ab290e57899..a07b283d237 100644 --- a/projects/hipblaslt/library/src/amd_detail/include/auxiliary.hpp +++ b/projects/hipblaslt/library/src/amd_detail/include/auxiliary.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -137,11 +137,11 @@ constexpr const char* hip_datatype_to_string(hipDataType type) return "bf8_r"; case HIP_R_8F_UE8M0: return "e8_r"; - case HIP_R_6F_E2M3_EXT: + case HIP_R_6F_E2M3: return "f6_r"; - case HIP_R_6F_E3M2_EXT: + case HIP_R_6F_E3M2: return "bf6_r"; - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: return "f4_r"; case static_cast(HIP_R_8F_E5M3_EXT): return "e5m3_r"; @@ -193,9 +193,9 @@ constexpr hipDataType string_to_hip_datatype(const std::string& value) value == "f16_r" || value == "h" ? HIP_R_16F : value == "bf16_r" ? HIP_R_16BF : value == "i8_r" || value == "i8" ? HIP_R_8I : - value == "f6_r" ? static_cast(HIP_R_6F_E2M3_EXT) : - value == "bf6_r" ? static_cast(HIP_R_6F_E3M2_EXT) : - value == "f4_r" ? static_cast(HIP_R_4F_E2M1_EXT) : + value == "f6_r" ? static_cast(HIP_R_6F_E2M3) : + value == "bf6_r" ? static_cast(HIP_R_6F_E3M2) : + value == "f4_r" ? static_cast(HIP_R_4F_E2M1) : value == "i32_r" || value == "i" ? HIP_R_32I : value == "f8_fnuz_r" ? HIP_R_8F_E4M3_FNUZ : value == "bf8_fnuz_r" ? HIP_R_8F_E5M2_FNUZ : diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/include/tensile_host.hpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/include/tensile_host.hpp index 217d3b4a56e..ba2f6f6a663 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/include/tensile_host.hpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/include/tensile_host.hpp @@ -3,7 +3,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -241,12 +241,17 @@ inline rocisa::DataType hipDataType_to_tensile_type(hipDataType type) return rocisa::DataType::Int8; case HIP_R_32I: return rocisa::DataType::Int32; - case HIP_R_4F_E2M1_EXT: - return rocisa::DataType::Float4; - case HIP_R_6F_E2M3_EXT: + // MX 6/4 data types + case HIP_R_6F_E2M3: return rocisa::DataType::Float6; - case HIP_R_6F_E3M2_EXT: + case HIP_R_6F_E3M2: return rocisa::DataType::BFloat6; + case HIP_R_4F_E2M1: + return rocisa::DataType::Float4; + case HIP_C_32F: + return rocisa::DataType::ComplexFloat; + case HIP_C_64F: + return rocisa::DataType::ComplexDouble; default: assert(!"hipDataType_to_tensile_type: non-supported type"); return rocisa::DataType::None; diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/rocroller_host.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/rocroller_host.cpp index e503ad9ec6b..1ce368167ff 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/rocroller_host.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/rocroller_host.cpp @@ -303,15 +303,15 @@ rocRoller::DataType hipDataType_to_rocRoller_type(hipDataType type) { // Older versions of ROCm do not have these types defined, // so they need to be handled specially. - if(static_cast(type) == HIP_R_6F_E2M3_EXT) + if(static_cast(type) == HIP_R_6F_E2M3) { return rocRoller::DataType::FP6; } - if(static_cast(type) == HIP_R_6F_E3M2_EXT) + if(static_cast(type) == HIP_R_6F_E3M2) { return rocRoller::DataType::BF6; } - if(static_cast(type) == HIP_R_4F_E2M1_EXT) + if(static_cast(type) == HIP_R_4F_E2M1) { return rocRoller::DataType::FP4; } diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/tensile_host.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/tensile_host.cpp index 88302ed0f71..8925ecdc769 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/tensile_host.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/tensile_host.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -390,12 +390,16 @@ namespace return rocisa::DataType::Int8; case HIP_R_32I: return rocisa::DataType::Int32; - case HIP_R_6F_E2M3_EXT: + case HIP_R_6F_E2M3: return rocisa::DataType::Float6; - case HIP_R_6F_E3M2_EXT: + case HIP_R_6F_E3M2: return rocisa::DataType::BFloat6; - case HIP_R_4F_E2M1_EXT: + case HIP_R_4F_E2M1: return rocisa::DataType::Float4; + case HIP_C_32F: + return rocisa::DataType::ComplexFloat; + case HIP_C_64F: + return rocisa::DataType::ComplexDouble; default: throw std::runtime_error("Unsupported type."); } @@ -427,11 +431,15 @@ namespace case rocisa::DataType::Int32: return HIP_R_32I; case rocisa::DataType::Float6: - return static_cast(HIP_R_6F_E2M3_EXT); + return static_cast(HIP_R_6F_E2M3); case rocisa::DataType::BFloat6: - return static_cast(HIP_R_6F_E3M2_EXT); + return static_cast(HIP_R_6F_E3M2); case rocisa::DataType::Float4: - return static_cast(HIP_R_4F_E2M1_EXT); + return static_cast(HIP_R_4F_E2M1); + case rocisa::DataType::ComplexFloat: + return HIP_C_32F; + case rocisa::DataType::ComplexDouble: + return HIP_C_64F; default: throw std::runtime_error("Unsupported type."); } @@ -497,17 +505,37 @@ namespace default:; } + if(typeA == rocisa::DataType::Float8_fnuz && typeB == rocisa::DataType::BFloat8_fnuz) + { + return rocisa::DataType::Float8_fnuz; + } + else if(typeA == rocisa::DataType::BFloat8_fnuz && typeB == rocisa::DataType::Float8_fnuz) + { + return rocisa::DataType::BFloat8_fnuz; + } + + if(typeA == rocisa::DataType::Float8 && typeB == rocisa::DataType::BFloat8) + { + return rocisa::DataType::Float8; + } + else if(typeA == rocisa::DataType::BFloat8 && typeB == rocisa::DataType::Float8) + { + return rocisa::DataType::BFloat8; + } + if(typeA == rocisa::DataType::Float8 || typeA == rocisa::DataType::BFloat8 || typeA == rocisa::DataType::Float8_fnuz || typeA == rocisa::DataType::BFloat8_fnuz || typeA == rocisa::DataType::Float6 || typeA == rocisa::DataType::BFloat6 || typeA == rocisa::DataType::Float4) return typeA; + + return TensileLite::DataTypeInfo::Get(typeA).elementSize <= TensileLite::DataTypeInfo::Get(typeB).elementSize ? typeA : typeB; } - - inline const rocisa::DataType - roc2TensileComputeInputTypeB(const rocisa::DataType& typeA, - const rocisa::DataType& typeB, - const rocblaslt_compute_type& typeCompute) + + inline const rocisa::DataType + roc2TensileComputeInputTypeB(const rocisa::DataType& typeA, + const rocisa::DataType& typeB, + const rocblaslt_compute_type& typeCompute) { switch(typeCompute) { @@ -534,7 +562,27 @@ namespace default:; } + if(typeA == rocisa::DataType::Float8_fnuz && typeB == rocisa::DataType::BFloat8_fnuz) + { + return rocisa::DataType::BFloat8_fnuz; + } + else if(typeA == rocisa::DataType::BFloat8_fnuz && typeB == rocisa::DataType::Float8_fnuz) + { + return rocisa::DataType::Float8_fnuz; + } + + if(typeA == rocisa::DataType::Float8 && typeB == rocisa::DataType::BFloat8) + { + return rocisa::DataType::BFloat8; + } + else if(typeA == rocisa::DataType::BFloat8 && typeB == rocisa::DataType::Float8) + { + return rocisa::DataType::Float8; + } + if(typeB == rocisa::DataType::Float8 || typeB == rocisa::DataType::BFloat8 || typeB == rocisa::DataType::Float8_fnuz || typeB == rocisa::DataType::BFloat8_fnuz ||typeB == rocisa::DataType::Float6 || typeB == rocisa::DataType::BFloat6 || typeB == rocisa::DataType::Float4) return typeB; + + return TensileLite::DataTypeInfo::Get(typeA).elementSize <= TensileLite::DataTypeInfo::Get(typeB).elementSize ? typeA @@ -572,6 +620,20 @@ namespace auto typeBTensile = hip2TensileType(typeB); std::vector biasDataTypeWhiteList; // dummy std::vector biasSrcWhiteList; // dummy + + TensileLite::TensorOps aOps, bOps, cOps, dOps; + + if(opA == HIPBLAS_OP_C) + aOps = {TensileLite::TensorOp::ComplexConjugate()}; + + if(opB == HIPBLAS_OP_C) + bOps = {TensileLite::TensorOp::ComplexConjugate()}; + + bool isComplexInput = (typeATensile == rocisa::DataType::ComplexFloat + || typeATensile == rocisa::DataType::ComplexDouble); + + auto alphaBetaType = isComplexInput ? typeATensile : roc2TensileType(typeCompute); + return TensileLite::ContractionProblemGemm::createDefaultProblem( (opA != HIPBLAS_OP_N), (opB != HIPBLAS_OP_N), @@ -579,8 +641,8 @@ namespace typeBTensile, hip2TensileType(typeC), hip2TensileType(typeD), - roc2TensileType(typeCompute), - roc2TensileType(typeCompute), + alphaBetaType, + alphaBetaType, roc2TensileComputeInputTypeA(typeATensile, typeBTensile, typeCompute), roc2TensileComputeInputTypeB(typeATensile, typeBTensile, typeCompute), roc2TensileType(typeCompute), @@ -1621,8 +1683,14 @@ namespace roc2TensileComputeInputTypeA(a_type, b_type, prob.compute_type)); tensileProblem.setComputeInputTypeB( roc2TensileComputeInputTypeB(a_type, b_type, prob.compute_type)); - tensileProblem.setAlphaType(compute_type); - tensileProblem.setBetaType(compute_type); + + bool isComplexInput = (a_type == rocisa::DataType::ComplexFloat + || a_type == rocisa::DataType::ComplexDouble); + + auto alphaBetaType = isComplexInput ? a_type : compute_type; + + tensileProblem.setAlphaType(alphaBetaType); + tensileProblem.setBetaType(alphaBetaType); // HPA is active iff sizeof(compute type) > sizeof(input type) tensileProblem.setHighPrecisionAccumulate( @@ -1678,6 +1746,8 @@ namespace case RocblasltContractionProblem::ScalingFormat::Vector: break; case RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0: + case RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0_32_8_EXT: + // Block_32_UE8M0_32_8_EXT (commit fe9a04d) is pre-swizzled scale data in `32x8` tile tensileProblem.setMXScaleA(rocisa::DataType::E8, 32); break; case RocblasltContractionProblem::ScalingFormat::Block_16_UE8M0: @@ -1704,6 +1774,8 @@ namespace case RocblasltContractionProblem::ScalingFormat::Vector: break; case RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0: + case RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0_32_8_EXT: + // Block_32_UE8M0_32_8_EXT (commit fe9a04d) is pre-swizzled scale data in `32x8` tile tensileProblem.setMXScaleB(rocisa::DataType::E8, 32); break; case RocblasltContractionProblem::ScalingFormat::Block_16_UE8M0: @@ -1861,8 +1933,14 @@ namespace roc2TensileComputeInputTypeA(a_type, b_type, prob.compute_type)); tensileProblem.setComputeInputTypeB( roc2TensileComputeInputTypeB(a_type, b_type, prob.compute_type)); - tensileProblem.setAlphaType(compute_type); - tensileProblem.setBetaType(compute_type); + + bool isComplexInput = (a_type == rocisa::DataType::ComplexFloat + || a_type == rocisa::DataType::ComplexDouble); + + auto alphaBetaType = isComplexInput ? a_type : compute_type; + + tensileProblem.setAlphaType(alphaBetaType); + tensileProblem.setBetaType(alphaBetaType); // HPA is active iff sizeof(compute type) > sizeof(input type) tensileProblem.setHighPrecisionAccumulate( @@ -1911,6 +1989,7 @@ namespace case RocblasltContractionProblem::ScalingFormat::Vector: break; case RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0: + case RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0_32_8_EXT: tensileProblem.setMXScaleA(rocisa::DataType::E8, 32); break; case RocblasltContractionProblem::ScalingFormat::Block_16_UE8M0: @@ -1937,6 +2016,7 @@ namespace case RocblasltContractionProblem::ScalingFormat::Vector: break; case RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0: + case RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0_32_8_EXT: tensileProblem.setMXScaleB(rocisa::DataType::E8, 32); break; case RocblasltContractionProblem::ScalingFormat::Block_16_UE8M0: @@ -2008,6 +2088,13 @@ namespace tensileProblem.setSwizzleTensorA(prob.swizzleA); tensileProblem.setSwizzleTensorB(prob.swizzleB); + + if(prob.scaleAType == RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0 or + prob.scaleAType == RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0_32_8_EXT) + tensileProblem.setMXScaleA(rocisa::DataType::E8, 32); + if(prob.scaleBType == RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0 or + prob.scaleBType == RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0_32_8_EXT) + tensileProblem.setMXScaleB(rocisa::DataType::E8, 32); } /*************************************************************** @@ -2045,6 +2132,7 @@ namespace inputs.bias = nullptr; if(prob.scaleAType == RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0 + || prob.scaleAType == RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0_32_8_EXT || prob.scaleAType == RocblasltContractionProblem::ScalingFormat::Block_16_UE4M3 || prob.scaleAType == RocblasltContractionProblem::ScalingFormat::Block_16_UE8M0 || prob.scaleAType == RocblasltContractionProblem::ScalingFormat::Block_32_UE4M3 @@ -2061,6 +2149,7 @@ namespace } if(prob.scaleBType == RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0 + || prob.scaleBType == RocblasltContractionProblem::ScalingFormat::Block_32_UE8M0_32_8_EXT || prob.scaleBType == RocblasltContractionProblem::ScalingFormat::Block_16_UE4M3 || prob.scaleBType == RocblasltContractionProblem::ScalingFormat::Block_16_UE8M0 || prob.scaleBType == RocblasltContractionProblem::ScalingFormat::Block_32_UE4M3 @@ -2098,13 +2187,22 @@ namespace // push 2 activation arguments std::visit( [&inputs, &prob](auto val) { - inputs.activationArgs.push_back(static_cast(prob.act0)); - inputs.activationArgs.push_back(static_cast(prob.act1)); + using ValType = decltype(val); + if constexpr (std::is_constructible_v) + { + inputs.activationArgs.push_back(static_cast(prob.act0)); + inputs.activationArgs.push_back(static_cast(prob.act1)); + } + else + { + inputs.activationArgs.push_back(prob.act0); + inputs.activationArgs.push_back(prob.act1); + } if(prob.k) - inputs.alpha = *(decltype(val)*)(prob.alpha); + inputs.alpha = *(ValType*)(prob.alpha); else inputs.alpha = val; - inputs.beta = *(decltype(val)*)(prob.beta); + inputs.beta = *(ValType*)(prob.beta); }, argument_vals.at(compute_type)); @@ -2638,7 +2736,6 @@ TensileLite::ProblemOverride TensileDataGemm2ProblemOverride(std::shared_ptr data = std::static_pointer_cast(gemmData); rocisa::DataType computeType = rocisa::DataType::None; - // TODO: Check correctness of computeInputTypeA() rocisa::DataType computeInputType = data->problem.computeInputTypeA(); if(data->problem.f32XdlMathOp() == rocisa::DataType::XFloat32) diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/utility.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/utility.cpp index 372d98d18f0..f68fca52d5f 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/utility.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/utility.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -92,11 +92,11 @@ const char* hipDataType_to_string(hipDataType type) return "R_8F_E5M2"; case HIP_R_8I: return "R_8I"; - case static_cast(HIP_R_6F_E2M3_EXT): + case static_cast(HIP_R_6F_E2M3): return "R_6F_E2M3"; - case static_cast(HIP_R_6F_E3M2_EXT): + case static_cast(HIP_R_6F_E3M2): return "R_6F_E3M2"; - case static_cast(HIP_R_4F_E2M1_EXT): + case static_cast(HIP_R_4F_E2M1): return "R_4F_E2M1"; default: return "Invalid"; @@ -135,11 +135,11 @@ const char* hipDataType_to_bench_string(hipDataType type) return "f8_r"; case HIP_R_8F_E5M2: return "bf8_r"; - case static_cast(HIP_R_6F_E2M3_EXT): + case static_cast(HIP_R_6F_E2M3): return "f6_r"; - case static_cast(HIP_R_6F_E3M2_EXT): + case static_cast(HIP_R_6F_E3M2): return "bf6_r"; - case static_cast(HIP_R_4F_E2M1_EXT): + case static_cast(HIP_R_4F_E2M1): return "f4_r"; default: return "invalid"; diff --git a/projects/hipblaslt/tensilelite/CMakeLists.txt b/projects/hipblaslt/tensilelite/CMakeLists.txt index 8d631455a3d..c7c699163fe 100644 --- a/projects/hipblaslt/tensilelite/CMakeLists.txt +++ b/projects/hipblaslt/tensilelite/CMakeLists.txt @@ -56,6 +56,9 @@ if(TENSILELITE_ENABLE_HOST) add_subdirectory(include) if(TENSILELITE_ENABLE_CLIENT) + # mxDataGenerator bootstrap and `hipblaslt::mxdatagen` helper target + # live in projects/hipblaslt/CMakeLists.txt and are guaranteed to be + # available before this subdirectory is processed. add_subdirectory(client) endif() diff --git a/projects/hipblaslt/tensilelite/LayerNormGenerator.py b/projects/hipblaslt/tensilelite/LayerNormGenerator.py index 7eaeddcd4a0..dcca1471956 100644 --- a/projects/hipblaslt/tensilelite/LayerNormGenerator.py +++ b/projects/hipblaslt/tensilelite/LayerNormGenerator.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2023-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2023-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -121,7 +121,7 @@ def __init__(self, arch: str, isa: IsaVersion): self.io_type = io_type - self.bpe = io_type.numBytes() + self.bpe = int(io_type.numBytes()) self.num_workitems = num_workitems self.num_load_count = num_load_count self.num_load_size = num_load_size diff --git a/projects/hipblaslt/tensilelite/SoftmaxGenerator.py b/projects/hipblaslt/tensilelite/SoftmaxGenerator.py index 1ffaaa2e864..861bc9761a7 100644 --- a/projects/hipblaslt/tensilelite/SoftmaxGenerator.py +++ b/projects/hipblaslt/tensilelite/SoftmaxGenerator.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2023-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2023-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -241,18 +241,6 @@ def shiftSrd(self, srdIdx): self.sgpr_pool.checkIn(stmp) return module - def shiftSrd(self, srdIdx): - module = Module() - if self.isa[0] == 12 and self.isa[1] == 5: - stmp = self.sgpr_pool.checkOutAligned(1, 1) - module.add(ri.SAndB32(sgpr(stmp), sgpr(srdIdx+2), 0x7F)) - module.add(ri.SLShiftLeftB32(sgpr(stmp), 25, sgpr(stmp))) - module.add(ri.SAndB32(sgpr(srdIdx+1), sgpr(srdIdx+1), 0x1FFFFFF)) - module.add(ri.SOrB32(sgpr(srdIdx+1), sgpr(srdIdx+1), sgpr(stmp))) - module.add(ri.SLShiftRightB32(sgpr(srdIdx+2), 7, sgpr(srdIdx+2))) - self.sgpr_pool.checkIn(stmp) - return module - def load_kernel_args(self): kernel_args_addr = 0 kernel_args_addr_size = 2 diff --git a/projects/hipblaslt/tensilelite/Tensile/Activation.py b/projects/hipblaslt/tensilelite/Tensile/Activation.py index 82e8ec4874b..77bd4474007 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Activation.py +++ b/projects/hipblaslt/tensilelite/Tensile/Activation.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -1350,14 +1350,14 @@ def generateInlineAssemblyBody(self, spaces, activationType): kStr += (padSpacesStr + "%s.s = %s.s & 0x7fff;\n"%(unionName, unionName)) kStr += (padSpacesStr + "value = %s.f;\n"%unionName) elif (self.dataType.isSingle() or self.dataType.isDouble() or self.dataType.isInt32()): - kStr += (padSpacesStr + "value = abs(value);\n") + kStr += (padSpacesStr + "value = (value < decltype(value)(0)) ? -value : value;\n") else: raise RuntimeError("Unrecognized data type %s."%self.dataType) elif (activationType == 'clippedrelu'): if (self.dataType.isSingle() or self.dataType.isHalf() or self.dataType.isDouble()): - kStr += (padSpacesStr + "value = (value > alpha) ? min(value, beta) : min(0.0, beta);\n") + kStr += (padSpacesStr + "value = (value > alpha) ? std::min(value, beta) : std::min(decltype(beta)(0), beta);\n") elif self.dataType.isInt32(): - kStr += (padSpacesStr + "value = (value > alpha) ? min(value, beta) : min(0, beta);\n") + kStr += (padSpacesStr + "value = (value > alpha) ? std::min(value, beta) : std::min(0, beta);\n") elif (activationType == 'exp'): kStr += (asm + " // Exp\n") module = activation.getExpModule(self.dataType, 0, 0) @@ -1385,9 +1385,9 @@ def generateInlineAssemblyBody(self, spaces, activationType): raise RuntimeError("Unsupported data type %s."%ptrStr) elif (activationType == 'relu'): if (self.dataType.isSingle() or self.dataType.isHalf() or self.dataType.isDouble()): - kStr += (padSpacesStr + "value = max(0.0, value);\n") + kStr += (padSpacesStr + "value = std::max(decltype(value)(0), value);\n") elif self.dataType.isInt32(): - kStr += (padSpacesStr + "value = max(0, value);\n") + kStr += (padSpacesStr + "value = std::max(0, value);\n") else: raise RuntimeError("Unsupported data type %s."%ptrStr) elif (activationType == 'sigmoid'): @@ -1427,7 +1427,7 @@ def generateInlineAssemblyBody(self, spaces, activationType): kStr += addSpace(asm, ": \"+v\"(value) : \"s\"(alpha)\n") kStr += self.getRequiredRegStr(asm, activation.vgprCounter, activation.sgprCounter) elif (activationType == 'clamp'): - kStr += (padSpacesStr + "value = max(alpha, min(value, beta));\n") + kStr += (padSpacesStr + "value = std::max(alpha, std::min(value, beta));\n") # clamp else: if (activationType != 'none'): raise RuntimeError("Unrecognized type %s."%activationType) diff --git a/projects/hipblaslt/tensilelite/Tensile/ClientWriter.py b/projects/hipblaslt/tensilelite/Tensile/ClientWriter.py index 52b36529edc..48a151f19d6 100644 --- a/projects/hipblaslt/tensilelite/Tensile/ClientWriter.py +++ b/projects/hipblaslt/tensilelite/Tensile/ClientWriter.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -213,6 +213,10 @@ def runNewClient(scriptPath, clientParametersPath, cxxCompiler: str, cCompiler: iniFile = "--config-file={}".format(clientParametersPath) args = [clientExe, iniFile] + # Add MX scale format if set + if globalParameters["MXScaleFormat"]: + args.append("--mx-scale-format={}".format(globalParameters["MXScaleFormat"])) + try: subprocess.run(args, check=True) except (subprocess.CalledProcessError, OSError) as e: @@ -328,8 +332,9 @@ def writeRunScript(path, forBenchmark, enableTileSelection, cxxCompiler: str, cC clientExe = getClientExecutablePath() timingFlag = " --timing-instrumentation" if globalParameters["TimingInstrumentation"] else "" + mxScaleFormatFlag = " --mx-scale-format={}".format(globalParameters["MXScaleFormat"]) if globalParameters["MXScaleFormat"] else "" for configFile in configPaths: - runScriptFile.write("{} --config-file {}{}\n".format(clientExe, configFile, timingFlag)) + runScriptFile.write("{} --config-file {}{}{}\n".format(clientExe, configFile, timingFlag, mxScaleFormatFlag)) runScriptFile.write("ERR2=$?\n\n") runScriptFile.write(""" @@ -351,8 +356,9 @@ def writeRunScript(path, forBenchmark, enableTileSelection, cxxCompiler: str, cC runScriptFile.write("%s -d 0 --resetclocks\n" % globalParameters["ROCmSMIPath"]) runScriptFile.write("%s -d 0 --setfan 50\n" % globalParameters["ROCmSMIPath"]) else: + mxScaleFormatFlag = " --mx-scale-format={}".format(globalParameters["MXScaleFormat"]) if globalParameters["MXScaleFormat"] else "" for configFile in configPaths: - runScriptFile.write("{} --config-file {} --best-solution 1\n".format(getClientExecutablePath(), configFile)) + runScriptFile.write("{} --config-file {} --best-solution 1{}\n".format(getClientExecutablePath(), configFile, mxScaleFormatFlag)) if os.name != "nt": runScriptFile.write("exit $ERR\n") diff --git a/projects/hipblaslt/tensilelite/Tensile/Common/GlobalParameters.py b/projects/hipblaslt/tensilelite/Tensile/Common/GlobalParameters.py index 5ca02af0853..3844831d904 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Common/GlobalParameters.py +++ b/projects/hipblaslt/tensilelite/Tensile/Common/GlobalParameters.py @@ -137,6 +137,7 @@ ) globalParameters["LogicFormat"] = "yaml" # set library backend (yaml, or json) globalParameters["LibraryFormat"] = "yaml" # set library backend (yaml, or msgpack) +globalParameters["MXScaleFormat"] = 0 # MX scale data format (0=none, 1=pre-swizzle for GPU kernel layout) # True/False: CSV will/won't export WinnerGFlops, WinnerTimeUS, WinnerIdx, WinnerName. # TODO - if no side-effect, we can set default to True. This can make analyzing "LibraryLogic" (AddFromCSV) faster @@ -177,9 +178,9 @@ globalParameters["DataInitTypeScaleC"] = 2 globalParameters["DataInitTypeScaleD"] = 2 globalParameters["DataInitTypeScaleAlphaVec"] = 3 -globalParameters["DataInitValueActivationArgs"] = [2.0, 2.0] globalParameters["DataInitTypeMXSA"] = 1 globalParameters["DataInitTypeMXSB"] = 1 +globalParameters["DataInitValueActivationArgs"] = [2.0, 2.0] globalParameters["CEqualD"] = ( False # Set to true if testing for the case where the pointer to C is the same as D. ) @@ -453,6 +454,7 @@ {"GlobalSplitUCoalesced": [False]}, {"GlobalSplitUWorkGroupMappingRoundRobin": [False]}, {"Use64bShadowLimit": [True]}, + {"Use64bShadowLimitMX": [False]}, # Disable Use64bShadowLimit for MXSA/B by default {"NumLoadsCoalescedA": [1]}, {"NumLoadsCoalescedB": [1]}, {"WorkGroup": [[16, 16, 1]]}, diff --git a/projects/hipblaslt/tensilelite/Tensile/Common/RequiredParameters.py b/projects/hipblaslt/tensilelite/Tensile/Common/RequiredParameters.py index 491dc79b094..b73211524c7 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Common/RequiredParameters.py +++ b/projects/hipblaslt/tensilelite/Tensile/Common/RequiredParameters.py @@ -102,6 +102,7 @@ def getRequiredParametersMin() -> set: 'TDMInst', 'UnrollLoopSwapGlobalReadOrder', 'Use64bShadowLimit', + 'Use64bShadowLimitMX', 'UseInstOffsetForGRO', 'UseSgprForGRO', 'VectorStore', diff --git a/projects/hipblaslt/tensilelite/Tensile/Common/ValidParameters.py b/projects/hipblaslt/tensilelite/Tensile/Common/ValidParameters.py index 87f9205af8a..a65649ca2b9 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Common/ValidParameters.py +++ b/projects/hipblaslt/tensilelite/Tensile/Common/ValidParameters.py @@ -91,10 +91,10 @@ def makeValidSWMMAC(): @lru_cache def makeValidMFMA(): validMFMA = {} - validMFMA["H"] = [[32, 32, 16, 1], [32, 32, 4, 2], [32, 32, 8, 1], [16, 16, 32, 1], [16, 16, 4, 4], [16, 16, 16, 1], [4, 4, 4, 16]] - validMFMA["S"] = [[32, 32, 1, 2], [32, 32, 2, 1], [16, 16, 1, 4], [16, 16, 4, 1], [4, 4, 1, 16], [16, 16, 32, 1], [32, 32, 16, 1]] - validMFMA["B"] = [[32, 32, 16, 1], [32, 32, 2, 2], [32, 32, 4, 1], [16, 16, 32, 1], [16, 16, 2, 4], [16, 16, 8, 1], [4, 4, 2, 16]] - validMFMA["4xi8"] = [ + validMFMA["HH"] = [[32, 32, 16, 1], [32, 32, 4, 2], [32, 32, 8, 1], [16, 16, 32, 1], [16, 16, 4, 4], [16, 16, 16, 1], [4, 4, 4, 16]] + validMFMA["SS"] = [[32, 32, 1, 2], [32, 32, 2, 1], [16, 16, 1, 4], [16, 16, 4, 1], [4, 4, 1, 16], [16, 16, 32, 1], [32, 32, 16, 1]] + validMFMA["BB"] = [[32, 32, 16, 1], [32, 32, 2, 2], [32, 32, 4, 1], [16, 16, 32, 1], [16, 16, 2, 4], [16, 16, 8, 1], [4, 4, 2, 16]] + validMFMA["4xi84xi8"] = [ [32, 32, 4, 2], [32, 32, 8, 1], [16, 16, 4, 4], @@ -103,35 +103,48 @@ def makeValidMFMA(): [32, 32, 16, 1], [16, 16, 32, 1], ] - validMFMA["D"] = [[16, 16, 4, 1], [4, 4, 4, 4]] + validMFMA["DD"] = [[16, 16, 4, 1], [4, 4, 4, 4]] validMFMA["B1k"] = [[32, 32, 4, 2], [32, 32, 8, 1], [16, 16, 4, 4], [16, 16, 16, 1], [4, 4, 4, 16]] - validMFMA["C"] = validMFMA["S"] - validMFMA["Z"] = validMFMA["D"] - validMFMA["I8"] = [ + validMFMA["CC"] = validMFMA["SS"] + validMFMA["ZZ"] = validMFMA["DD"] + validMFMA["I8I8"] = [ [32, 32, 4, 2], [32, 32, 8, 1], [16, 16, 4, 4], [16, 16, 16, 1], [4, 4, 4, 16], ] + [[32, 32, 16, 1], [16, 16, 32, 1],] + [[16, 16, 64, 1],[32, 32, 32, 1]] - validMFMA["X"] = [[32, 32, 4, 1], [16, 16, 8, 1], [16, 16, 16, 1], [16, 16, 32, 1], [32, 32, 16, 1]] - validMFMA["F8"] = [[32, 32, 16, 1], [16, 16, 32, 1], [32, 32, 64, 1], [16, 16, 128, 1]] - validMFMA["B8"] = validMFMA["F8"] - validMFMA["F8B8"] = validMFMA["F8"] - validMFMA["B8F8"] = validMFMA["F8"] - validMFMA["F8N"] = [[32, 32, 16, 1], [16, 16, 32, 1]] - validMFMA["B8N"] = validMFMA["F8N"] - validMFMA["F8B8N"] = validMFMA["F8N"] - validMFMA["B8F8N"] = validMFMA["F8N"] + validMFMA["XX"] = [[32, 32, 4, 1], [16, 16, 8, 1], [16, 16, 16, 1], [16, 16, 32, 1], [32, 32, 16, 1]] + validMFMA["F8F8"] = [[32, 32, 16, 1], [16, 16, 32, 1], [32, 32, 64, 1], [16, 16, 128, 1]] + validMFMA["B8B8"] = validMFMA["F8F8"] + validMFMA["F8B8"] = validMFMA["F8F8"] + validMFMA["B8F8"] = validMFMA["F8F8"] + validMFMA["F8NF8N"] = [[32, 32, 16, 1], [16, 16, 32, 1]] + validMFMA["B8NB8N"] = validMFMA["F8NF8N"] + validMFMA["F8NB8N"] = validMFMA["F8NF8N"] + validMFMA["B8NF8N"] = validMFMA["F8NF8N"] + validMFMA["F6F6"] = [[16,16,128,1], [32,32,64,1]] + validMFMA["B6B6"] = [[16,16,128,1], [32,32,64,1]] + validMFMA["F6B6"] = [[16,16,128,1], [32,32,64,1]] + validMFMA["B6F6"] = [[16,16,128,1], [32,32,64,1]] + validMFMA["F8F6"] = [[16,16,128,1], [32,32,64,1]] + validMFMA["F6F8"] = [[16,16,128,1], [32,32,64,1]] + validMFMA["F8F4"] = [[16,16,128,1], [32,32,64,1]] + validMFMA["F4F8"] = [[16,16,128,1], [32,32,64,1]] + validMFMA["F6F4"] = [[16,16,128,1], [32,32,64,1]] + validMFMA["F4F6"] = [[16,16,128,1], [32,32,64,1]] + validMFMA["B6F4"] = [[16,16,128,1], [32,32,64,1]] + validMFMA["F4B6"] = [[16,16,128,1], [32,32,64,1]] + validMFMA["F4F4"] = [[16,16,128,1], [32,32,64,1]] validMFMA["_format9"] = [] for MFMA in [ - validMFMA["H"], - validMFMA["S"], - validMFMA["B"], - validMFMA["D"], - validMFMA["X"], - validMFMA["F8N"], + validMFMA["HH"], + validMFMA["SS"], + validMFMA["BB"], + validMFMA["DD"], + validMFMA["XX"], + validMFMA["F8F8"], makeValidWMMA(), ]: for MI in MFMA: @@ -148,20 +161,20 @@ def makeValidMFMA(): @lru_cache def makeValidSMFMA(): validSMFMA = {} - validSMFMA["H"] = [[32, 32, 16, 1], [16, 16, 32, 1], [16, 16, 64, 1], [32, 32, 32, 1]] - validSMFMA["B"] = [[32, 32, 16, 1], [16, 16, 32, 1], [16, 16, 64, 1], [32, 32, 32, 1]] - validSMFMA["4xi8"] = [[32, 32, 32, 1], [16, 16, 64, 1], [16, 16, 128, 1], [32, 32, 64, 1]] - validSMFMA["I8"] = validSMFMA["4xi8"] - validSMFMA["F8"] = [[32, 32, 32, 1], [16, 16, 64, 1], [16, 16, 128, 1], [32, 32, 64, 1]] - validSMFMA["B8"] = validSMFMA["F8"] - validSMFMA["F8B8"] = validSMFMA["F8"] - validSMFMA["B8F8"] = validSMFMA["F8"] - validSMFMA["F8N"] = validSMFMA["F8"] - validSMFMA["B8N"] = validSMFMA["F8"] - validSMFMA["F8B8N"] = validSMFMA["F8N"] - validSMFMA["B8F8N"] = validSMFMA["F8N"] + validSMFMA["HH"] = [[32, 32, 16, 1], [16, 16, 32, 1], [16, 16, 64, 1], [32, 32, 32, 1]] + validSMFMA["BB"] = [[32, 32, 16, 1], [16, 16, 32, 1], [16, 16, 64, 1], [32, 32, 32, 1]] + validSMFMA["4xi84xi8"] = [[32, 32, 32, 1], [16, 16, 64, 1], [16, 16, 128, 1], [32, 32, 64, 1]] + validSMFMA["I8I8"] = validSMFMA["4xi84xi8"] + validSMFMA["F8F8"] = [[32, 32, 32, 1], [16, 16, 64, 1], [16, 16, 128, 1], [32, 32, 64, 1]] + validSMFMA["B8B8"] = validSMFMA["F8F8"] + validSMFMA["F8B8"] = validSMFMA["F8F8"] + validSMFMA["B8F8"] = validSMFMA["F8F8"] + validSMFMA["F8NF8N"] = validSMFMA["F8F8"] + validSMFMA["B8NB8N"] = validSMFMA["F8F8"] + validSMFMA["F8NB8N"] = validSMFMA["F8NF8N"] + validSMFMA["B8NF8N"] = validSMFMA["F8NF8N"] validSMFMA["_format9"] = [] - for SMFMA in [validSMFMA["H"], validSMFMA["B"], validSMFMA["4xi8"], validSMFMA["F8N"], makeValidSWMMAC()]: + for SMFMA in [validSMFMA["HH"], validSMFMA["BB"], validSMFMA["4xi84xi8"], validSMFMA["F8NF8N"], makeValidSWMMAC()]: for MI in SMFMA: for bm in range(int(math.log(MI[3], 2)) + 1): for tt0 in range(1, validTT + 1): @@ -180,15 +193,15 @@ def makeValidMatrixInstructions(): wmma = makeValidWMMA() validMatrixInstructions = ( [[], [-1]] - + mfma["H"] - + mfma["S"] - + mfma["B"] - + mfma["D"] + + mfma["HH"] + + mfma["SS"] + + mfma["BB"] + + mfma["DD"] + mfma["B1k"] - + mfma["X"] - + smfma["H"] - + smfma["B"] - + smfma["4xi8"] + + mfma["XX"] + + smfma["HH"] + + smfma["BB"] + + smfma["4xi84xi8"] + wmma ) return validMatrixInstructions + mfma["_format9"] + smfma["_format9"] @@ -423,6 +436,8 @@ def makeValidMatrixInstructions(): "UseSgprForGRO": [-1, 0, 1], # Use a 64-bit shadow limit register to allow buffers larger than 2^32 bytes "Use64bShadowLimit": [True, False], + # Use a 64-bit shadow limit register for MXSA/B to allow buffers larger than 2^32 bytes + "Use64bShadowLimitMX": [True, False], # Assertion properties # These provide information or assertions that the problem size meets certain requirements # for sizes or alignments. The kernel generator can use this information to produce diff --git a/projects/hipblaslt/tensilelite/Tensile/Component.py b/projects/hipblaslt/tensilelite/Tensile/Component.py index b41d134a390..d86898344be 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Component.py +++ b/projects/hipblaslt/tensilelite/Tensile/Component.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -34,7 +34,8 @@ class FMA_NonPacked(MAC): asmCaps = {"v_fma_f16": True, "v_pk_fma_f16": False} #archCaps = {} - kernel = {"ProblemType": {"DataType": DataType(DataTypeEnum.Half), + kernel = {"ProblemType": {"MacDataTypeA": DataType(DataTypeEnum.Half), + "MacDataTypeB": DataType(DataTypeEnum.Half), "HighPrecisionAccumulate": False}} ``` @@ -44,7 +45,8 @@ class FMA_NonPacked(MAC): class FMA_HPA_MAD_MIX(MAC): asmCaps = {"v_mad_mix_f32": True} #archCaps = {} - kernel = {"ProblemType": {"DataType": DataType(DataTypeEnum.Half), + kernel = {"ProblemType": {"MacDataTypeA": DataType(DataTypeEnum.Half), + "MacDataTypeB": DataType(DataTypeEnum.Half), "HighPrecisionAccumulate": True}, } ``` diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/GSU.py b/projects/hipblaslt/tensilelite/Tensile/Components/GSU.py index 5f5983c03f1..c877b73d2c7 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/GSU.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/GSU.py @@ -21,7 +21,7 @@ ################################################################################ from rocisa import countInstruction -from rocisa.code import Module, Label, RegSet +from rocisa.code import Module, Label, RegSet, ValueSet from rocisa.container import ContinuousRegister, SMEMModifiers, vgpr, sgpr, replaceHolder from rocisa.instruction import SAddCU32, SAddU32, SAndB32, SLoadB32, SStoreB32, SBranch, \ SCBranchSCC0, SCBranchSCC1, SCMovB32, SCSelectB32, SCmpEQU32, SCmpLgU32, SCmpLtU32, SCmpGtI32, \ @@ -65,8 +65,22 @@ def graIncrementsCommon(self, writer, loopIdx, tc, stride, m): # multiply by stride, optimizing if unit stride if writer.isConstUnitStride(stride): - module.add(SMovB32(dst=sgpr("GlobalReadIncs%s+%u"%(tc, loopIdx)), src=m, \ - comment="incr%s (unrollIdx)"%(tc) )) + if tc == "A": + abinfo = writer.states.a + elif tc == "B": + abinfo = writer.states.b + elif tc == "MXSA": + abinfo = writer.states.mxsa + elif tc == "MXSB": + abinfo = writer.states.mxsb + else: + abinfo = None + if abinfo != None and abinfo.useConstSgprGlobalReadIncs: + # useConstSgprGlobalReadIncs case, define value set for GlobalReadIncs here instead of initializing sgpr + module.add(ValueSet("GlobalReadIncs%s"%tc, m)) + else: + module.add(SMovB32(dst=sgpr("GlobalReadIncs%s+%u"%(tc, loopIdx)), src=m, \ + comment="incr%s (unrollIdx)"%(tc) )) else: module.add(SMulI32(dst=sgpr("GlobalReadIncs%s+%u"%(tc, loopIdx)), \ src0=m, src1=stride, \ @@ -198,7 +212,6 @@ def graIncrements(self, writer, kernel, loopIdx, tP): module = Module("GSU Off graIncrements") tc = tP["tensorChar"] - tcGR = tc if tc == "Metadata" else (tc + "GR") dimIdx = kernel["ProblemType"]["IndicesSummation"][loopIdx] # dimension index loopChar = writer.states.indexChars[dimIdx] stride = writer.strideRef(tc, dimIdx) @@ -344,8 +357,10 @@ def computeLoadSrd(self, writer, kernel, tP, stmp, tileStart): module = Module("GSU On computeLoadSrd") tc = tP["tensorChar"] + isgfx950 = kernel["ISA"][:2] == (9, 5) + isgfx950mx = isgfx950 and ("MXS" in tc) depthU = kernel["DepthU"] - depthUDiv = kernel["DepthU"] + depthUDiv = kernel["_DepthU%s"%tc] if isgfx950mx else kernel["_DepthU"] # swizzle if (tP["isSwizzled"] and tc == 'A'): depthUDiv = kernel["DepthU"] * kernel["MatrixInstM"] @@ -915,7 +930,7 @@ def reductionProcedure(self, writer, kernel, elements, alpha, beta, edges, atomi # are likely to be low-performing so likely not worth optimizing. print("WARNING: half requires at least two elements per batch") self.overflowedResources = 3 - # elif kernel["ProblemType"]["DataType"].is8bitFloat(): + # elif kernel["ProblemType"]["MacDataTypeA"].is8bitFloat(): # if numElementsPerBatch > 1: # numElementsPerBatch = int(numElementsPerBatch/4)*4 diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py b/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py index 39a1133dff7..e83ec332579 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py @@ -63,12 +63,12 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): # dot2: currently only support unroll major LDS if kernel["UseDotInstruction"]: numVectorsPerTile = kernel["ThreadTile%u"%tile01] - numReadsPerVector = ceil((writer.states.lrvwUnrollA * tP["bpe"]) // (blockWidth*4)) # bytes/register + numReadsPerVector = ceil((writer.states.lrvwUnrollA * tP["bpe"]) / (blockWidth*4)) # bytes/register LdsPad = kernel["LdsPad%s"%tc] if kernel["LdsBlockSizePerPad%s"%tc] == 0 else 0 tileStride = kernel["_DepthU%s"%tc] + LdsPad if kernel["UnrollMajorLDS%s" % tP["tensorChar"]] else 1 else: numVectorsPerTile = (kernel["ThreadTile%u"%tile01]//kernel["VectorWidthA"]) - numReadsPerVector = ceil((kernel["VectorWidthA"] * tP["bpe"]) // (blockWidth*4)) # bytes/register + numReadsPerVector = ceil((kernel["VectorWidthA"] * tP["bpe"]) / (blockWidth*4)) # bytes/register for vIdx in range(0, numVectorsPerTile): for rIdx in range(0, int(numReadsPerVector)): @@ -562,7 +562,7 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): tc = tP["tensorChar"] if tc == "A": - writer.states.localReadDoCntA += 1 + writer.states.localReadDoCntA += 1 elif tc == "MXSA": writer.states.localReadDoCntMXSA += 1 elif tc == "Metadata": @@ -574,6 +574,8 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): else: raise Exception(f"unsupport tc %s{tc}") + isgfx950 = kernel["ISA"][:2] == (9, 5) + isgfx950mx = isgfx950 and ("MXS" in tc) MacDataType = f"MacDataType{tc}" if(tc=="A" or tc=="B") else "DataType" tile01 = tP["tile01Idx"] instruction = tP["localReadInstruction"] @@ -643,6 +645,10 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): needPack = blockWidth == 0.25 needPack |= (kernel["ConvertAfterDS"] and (tP["bpe"] != tP["bpeDS"])) needPack |= kernel["UseF32XEmulation"] + if isgfx950mx: + # Keep the gfx950 workaround, but allow normal MX packing on gfx1250. + needPack = False + pack = Module("pack%s_I%s"%(tc,iui)) packPre = Module("pack%s_I%s Pre"%(tc,iui)) @@ -655,7 +661,6 @@ def __call__(self, writer, kernel, bufferIdx, iui, epsi, tP): useDirect32XEmulation = writer.states.a.useDirect32XEmulationThis if tc == "A" else writer.states.b.useDirect32XEmulationThis indexTranpose = lrvwTile > 1 and (not useTransposeCode) - # split Metadata when localread width > mi input numSplitMetadata = max(ceil((blockWidth * 4) // tP["bpeDS"]) - 1, 0) if tP["isM"] else 0 # caculate SMFMA layout @@ -1395,6 +1400,10 @@ def applyPad(offset_val): if kernel["ConvertAfterDS"] and kernel["UnrollMajorLDS%s"%tc]: valufIdx += blockWidth * (tP["bpe"] // tP["bpeDS"]) if (not tP["isM"]) else 1 + # workaround for gfx950 MX + # need to increment valufIdx by 1 + elif isgfx950mx: + valufIdx += 1 else: valufIdx += blockWidth if (not tP["isM"]) else (numVgpr if writer.states.asmCaps["HasSWMMAC_gfx1250"] else 1) @@ -1429,7 +1438,7 @@ def calcGfx1250LdsOffset(): else: offset_val = offset_val + (blockOffsetSMFMA * blockId) * UnrollStride offset_val = int((rIdx * numElementPerRead * UnrollStride + offset_val + tP["localReadOffset"]) * tP["bpeDS"]) - elif writer.states.asmCaps["HasMFMA_f8f6f4"] and kernel["ProblemType"][MacDataType].is8bitFloat() and kernel["MatrixInstK"] > 32: + elif writer.states.asmCaps["HasMFMA_f8f6f4"] and kernel["ProblemType"][MacDataType].is8bitFloat() and kernel["MatrixInstK"] > 32 and not isgfx950mx: incOffset = 0 midIdx = numReadsPerUnroll // 2 if rIdx >= midIdx: @@ -1484,7 +1493,7 @@ def calcGfx1250LdsOffset(): offset_val = offset_val + tP["localReadSwapByteOffset"] # TODO: Add NLC>1 offset calcs here? if (kernel["DirectToLds%s" % tc] and \ - kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeDS"] > 4) and not kernel["UseGeneralizedNLCOne%s"%tc]: + kernel["GlobalReadVectorWidth%s"%tc] * tP["bpeDS"] > 4) and not kernel["UseGeneralizedNLCOne%s"%tc]: # another address conversion for DirectToLds + NumLoadsCoalesced > 1 dummy, offset_val = writer.lraOffsetConversionForDTLandNLC(kernel, tP, offset_val) diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/LraTileAssignment.py b/projects/hipblaslt/tensilelite/Tensile/Components/LraTileAssignment.py index e4290b39a25..b3c19570c19 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/LraTileAssignment.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/LraTileAssignment.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -763,14 +763,25 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi if enableLDSTr: sReg = writer.vgprPool.checkOut(1,"sReg") # remainder - noUnrollOffset = writer.states.asmCaps["HasWMMA_V1"] or ("MXS" in tP["tensorChar"]) - # get constant parameter tc = tP["tensorChar"] tile01 = tP["tile01Idx"] waveWidth = writer.states.kernel["WavefrontSize"] + + noUnrollOffset = writer.states.asmCaps["HasWMMA_V1"] or ("MXS" in tc) + isgfx950 = kernel["ISA"][:2] == (9, 5) + isgfx950mx = isgfx950 and ("MXS" in tc) + # workaround for gfx950 + # force noUnrollOffset=False for MX + if isgfx950: + noUnrollOffset = False + lrvw = kernel["LocalReadVectorWidthMXS"] if ("MXS" in tc) else kernel[f"LocalReadVectorWidth{tc}"] inputPerThread = lrvw if not writer.states.inTailLoop else kernel["MIInputPerThread%s"%tc] + # workaround for gfx950 + fp4 + # use MIInputPerThread if lrvw < MIInputPerThread + if isgfx950 and tP["bpeDS"] == 0.5 and lrvw < kernel["MIInputPerThread%s"%tc]: + inputPerThread = kernel["MIInputPerThread%s"%tc] if kernel["ProblemType"]["Sparse"]: if (kernel["ProblemType"]["Sparse"] == 2 and tP["isB"]) or (kernel["ProblemType"]["Sparse"] == 1 and tP["isA"]): inputPerThread = inputPerThread // 2 @@ -806,6 +817,7 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi dividedForWaveId = dividedForWaveId, \ vectorWidth=vectorWidth, \ maxKId=maxKId) + abmatrixinfo = writer.states.a if tc == 'A' else writer.states.b perpStride = abmatrixinfo.gNLCPerpStride permBlock = abmatrixinfo.gNLCPermBlock @@ -835,9 +847,9 @@ def LraTileAssignmentCode(self, writer, kernel, tP, tReg, kReg, tmpVgprRes, divi # GFX1250 Sparse if writer.states.asmCaps["HasSWMMAC"] and writer.states.asmCaps["HasSWMMAC_gfx1250"] and (not isSparseTrack or tP["isM"]): strideK *= 2 - + # special case for new F8 MFMA, need to exclude wmma_v3 - elif kernel["ProblemType"]["DataType"].is8bitFloat() and kernel["MatrixInstK"] > 32 and (not writer.states.asmCaps["HasWMMA_V3"]): + elif kernel["ProblemType"]["DataType"].is8bitFloat() and kernel["MatrixInstK"] > 32 and (not writer.states.asmCaps["HasWMMA_V3"]) and (not isgfx950mx): if umlds: strideK = 16 else: diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_BF16_HPA.py b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_BF16_HPA.py index 6d3ae1d9c12..85d577af793 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_BF16_HPA.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_BF16_HPA.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,7 +31,8 @@ class FMA_BF16_HPA(MAC): asmCaps = {"v_fma_f32": True} - kernel = {"ProblemType": {"DataType": DataType(DataTypeEnum.BFloat16), + kernel = {"ProblemType": {"MacDataTypeA": DataType(DataTypeEnum.BFloat16), + "MacDataTypeB": DataType(DataTypeEnum.BFloat16), "HighPrecisionAccumulate": True}, "UseDotInstruction": False, } @@ -118,7 +119,8 @@ def __call__(self, writer, m, innerUnroll): class FMA_BF16_HPA_DOT2(MAC): asmCaps = lambda caps: caps['v_dot2_f32_bf16'] or caps['v_dot2c_f32_bf16'] #archCaps = {} - kernel = {"ProblemType": {"DataType": DataType(DataTypeEnum.BFloat16), + kernel = {"ProblemType": {"MacDataTypeA": DataType(DataTypeEnum.BFloat16), + "MacDataTypeB": DataType(DataTypeEnum.BFloat16), "HighPrecisionAccumulate": True}, "UseDotInstruction": True, } diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F16.py b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F16.py index 23fe421ca74..2bc1d8a997b 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F16.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F16.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -35,7 +35,8 @@ class MAC_F16_Plain(MAC): "v_pk_fma_f16": False, "v_fma_f16": False} #archCaps = {} - kernel = {"ProblemType": {"DataType": DataType(DataTypeEnum.Half), + kernel = {"ProblemType": {"MacDataTypeA": DataType(DataTypeEnum.Half), + "MacDataTypeB": DataType(DataTypeEnum.Half), "HighPrecisionAccumulate": False}} def __call__(self, writer, m, innerUnroll): @@ -80,7 +81,8 @@ class FMA_F16_NonPacked(MAC): asmCaps = {"v_fma_f16": True, "v_pk_fma_f16": False} #archCaps = {} - kernel = {"ProblemType": {"DataType": DataType(DataTypeEnum.Half), + kernel = {"ProblemType": {"MacDataTypeA": DataType(DataTypeEnum.Half), + "MacDataTypeB": DataType(DataTypeEnum.Half), "HighPrecisionAccumulate": False}} def __call__(self, writer, m, innerUnroll): @@ -131,7 +133,8 @@ def __call__(self, writer, m, innerUnroll): class FMA_F16_Packed(MAC): asmCaps = {"v_pk_fma_f16": True} #archCaps = {} - kernel = {"ProblemType": {"DataType": DataType(DataTypeEnum.Half), + kernel = {"ProblemType": {"MacDataTypeA": DataType(DataTypeEnum.Half), + "MacDataTypeB": DataType(DataTypeEnum.Half), "HighPrecisionAccumulate": False}} def __call__(self, writer, m, innerUnroll): diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F16_HPA.py b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F16_HPA.py index d17293446f6..f8189607e36 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F16_HPA.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F16_HPA.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -34,7 +34,8 @@ class FMA_F16_HPA_DOT2(MAC): asmCaps = lambda caps: caps['v_dot2_f32_f16'] or caps['v_dot2c_f32_f16'] #archCaps = {} - kernel = {"ProblemType": {"DataType": DataType(DataTypeEnum.Half), + kernel = {"ProblemType": {"MacDataTypeA": DataType(DataTypeEnum.Half), + "MacDataTypeB": DataType(DataTypeEnum.Half), "HighPrecisionAccumulate": True}, "UseDotInstruction": True, } @@ -89,7 +90,8 @@ def __call__(self, writer, tPA, tPB, m, innerUnroll): class FMA_F16_HPA_MAD_MIX(MAC): asmCaps = lambda caps: caps['v_mad_mix_f32'] or caps['v_fma_mix_f32'] #archCaps = {} - kernel = {"ProblemType": {"DataType": DataType(DataTypeEnum.Half), + kernel = {"ProblemType": {"MacDataTypeA": DataType(DataTypeEnum.Half), + "MacDataTypeB": DataType(DataTypeEnum.Half), "HighPrecisionAccumulate": True}, "UseDotInstruction": False, } diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F32.py b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F32.py index 685c227b06e..82020603830 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F32.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F32.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -37,7 +37,8 @@ class MAC_F32_Plain(MAC): def asmCaps(caps): return caps["v_mac_f32"] or caps["v_fma_f32"] - kernel = {"ProblemType": {"DataType": DataType(DataTypeEnum.Float)}} + kernel = {"ProblemType": {"MacDataTypeA": DataType(DataTypeEnum.Float), + "MacDataTypeB": DataType(DataTypeEnum.Float),}} def __call__(self, writer, tPA, tPB, m, innerUnroll): kernel = writer.states.kernel diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F32C.py b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F32C.py index 8e1d89bfa2e..20f7e61570d 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F32C.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F32C.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,8 @@ from ..Component import Component, MAC class MAC_F32C_Plain(MAC): - kernel = {"ProblemType": {"DataType": DataType(DataTypeEnum.ComplexFloat)}} + kernel = {"ProblemType": {"MacDataTypeA": DataType(DataTypeEnum.ComplexFloat), + "MacDataTypeB": DataType(DataTypeEnum.ComplexFloat)}} def __call__(self, writer, m, innerUnroll): kernel = writer.states.kernel diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F64.py b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F64.py index c263ac33394..ecfed2b78ec 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F64.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F64.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -34,7 +34,8 @@ class FMA_F64_Plain(MAC): Plain MAC instruction implementation """ asmCaps = {"v_fma_f64": True} - kernel = {"ProblemType": {"DataType": DataType(DataTypeEnum.Double)}} + kernel = {"ProblemType": {"MacDataTypeA": DataType(DataTypeEnum.Double), + "MacDataTypeB": DataType(DataTypeEnum.Double)}} def __call__(self, writer, tPA, tPB, m, innerUnroll): kernel = writer.states.kernel diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F64C.py b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F64C.py index 59c8b99b239..b5472b0839a 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F64C.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_F64C.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -29,7 +29,8 @@ class FMA_F64C_Plain(MAC): asmCaps = {"v_fma_f64": True} - kernel = {"ProblemType": {"DataType": DataType(DataTypeEnum.ComplexDouble)}} + kernel = {"ProblemType": {"MacDataTypeA": DataType(DataTypeEnum.ComplexDouble), + "MacDataTypeB": DataType(DataTypeEnum.ComplexDouble)}} def __call__(self, writer, m, innerUnroll): kernel = writer.states.kernel diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_I8_HPA.py b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_I8_HPA.py index 21184431725..2ab1c3fa6b4 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/MAC_I8_HPA.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/MAC_I8_HPA.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -33,13 +33,15 @@ def asmCaps(caps): return True kernel = { - "ProblemType": {"DataType": DataType(DataTypeEnum.Int8), "HighPrecisionAccumulate": True}, + "ProblemType": {"MacDataTypeA": DataType(DataTypeEnum.Int8), + "MacDataTypeB": DataType(DataTypeEnum.Int8), + "HighPrecisionAccumulate": True}, } def __call__(self, writer, m, innerUnroll): kernel = writer.states.kernel priority = Component.Priority.find(writer) - spacePerReg = writer.states.bpr // writer.states.bpeAB + spacePerReg = writer.states.bpr elemPerReg = min(kernel['VectorWidth'], spacePerReg) module = Module("FMA_I8_HPA") diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py b/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py index 54d135e9881..b5dcf2ed08b 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/SIA.py @@ -245,10 +245,10 @@ def calculateLatencyLeft(numReads, localReadBlockWidth, localReadLatency): # ds_read[MXSA][0] if kernel["ProblemType"]["MXBlockA"]: calculateLatencyLeft(writer.states.numReadsPerUnrollMXSA, tensorParametersA["MX"]["localReadInstruction"].blockWidth, tensorParametersA["MX"]["localReadInstruction"].issueLatency) - # ds_read[M][0] # ds_read[MXSB][0] if kernel["ProblemType"]["MXBlockB"]: calculateLatencyLeft(writer.states.numReadsPerUnrollMXSB, tensorParametersB["MX"]["localReadInstruction"].blockWidth, tensorParametersB["MX"]["localReadInstruction"].issueLatency) + # ds_read[M][0] if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: calculateLatencyLeft(writer.states.numReadsPerUnrollMetadata, tPM["localReadInstruction"].blockWidth, tPM["localReadInstruction"].issueLatency) # ds_read[B][0] @@ -258,19 +258,20 @@ def calculateLatencyLeft(numReads, localReadBlockWidth, localReadLatency): calculateLatencyLeft((writer.states.numReadsPerIterA//kernel["InnerUnroll"] - writer.states.numReadsPerUnrollA), tensorParametersA["localReadInstruction"].blockWidth, tensorParametersA["localReadInstruction"].issueLatency) # ds_read[MXSA][1:] if kernel["ProblemType"]["MXBlockA"]: - calculateLatencyLeft((writer.states.numReadsPerIterMXSA//kernel["InnerUnroll"] - writer.states.numReadsPerUnrollMXSA), tensorParametersA["MX"]["localReadInstruction"].blockWidth, tensorParametersA["MX"]["localReadInstruction"].issueLatency) + calculateLatencyLeft(writer.states.numReadsPerIterMXSA//kernel["InnerUnroll"] - writer.states.numReadsPerUnrollMXSA, tensorParametersA["MX"]["localReadInstruction"].blockWidth, tensorParametersA["MX"]["localReadInstruction"].issueLatency) # ds_read[M][1:] if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: calculateLatencyLeft((writer.states.numReadsPerIterMetadata//kernel["InnerUnroll"] - writer.states.numReadsPerUnrollMetadata), tPM["localReadInstruction"].blockWidth, tPM["localReadInstruction"].issueLatency) # ds_read[MXSB][1:] if kernel["ProblemType"]["MXBlockB"]: - calculateLatencyLeft((writer.states.numReadsPerIterMXSB//kernel["InnerUnroll"] - writer.states.numReadsPerUnrollMXSB), tensorParametersB["MX"]["localReadInstruction"].blockWidth, tensorParametersB["MX"]["localReadInstruction"].issueLatency) + calculateLatencyLeft((writer.states.numReadsPerIterMXSB//kernel["InnerUnroll"] - writer.states.numReadsPerUnrollMXSB), tensorParametersB["MX"]["localReadInstruction"].blockWidth, tensorParametersB["MX"]["localReadInstruction"].issueLatency) # get the latency before B[:1] if isLocalReadsOpt: tmpLatencyLeft = latencyLeft tmpNumMfmaForLR = writer.states.numMfmaForLR # ds_read[B][1:] calculateLatencyLeft((writer.states.numReadsPerIterB//kernel["InnerUnroll"] - writer.states.numReadsPerUnrollB), tensorParametersB["localReadInstruction"].blockWidth, tensorParametersB["localReadInstruction"].issueLatency) + # to calculate number of mfma we need to wait before data arrive from lds to vgpr. # latency: 40 quad-cycle for 4 word, 20 quad-cycle for 2 word, 10 quad-cycle for 1 word / half word writer.states.numMfmaForNextLoopLR = writer.states.numMfmaForLR @@ -317,6 +318,9 @@ def getLocalWriteMFMAStart(writer, kernel, tensorParametersA, tensorParametersB, # TODO: replace here for real number of globalReadIncInst # numGRIncInst = 18 # Always on. Original logic: 12 if not kernel["StaggerU"] else 18 numGRIncInst = 12 if not writer.states.staggerUCode else 18 + if kernel["ProblemType"]["MXBlockA"] and kernel["ProblemType"]["MXBlockB"]: + # MXSA+MXSB case, double the number + numGRIncInst *= 2 numInstPerMfma = max(roundUp(writer.states.miLatencyLeft/2),1) numMfmaToSched = roundUp(numGRIncInst/numInstPerMfma) lwStartMfmaIndex = 1 + numMfmaToSched @@ -334,9 +338,13 @@ def getLocalWriteMFMAStart(writer, kernel, tensorParametersA, tensorParametersB, lookRange = 1 if subIter else kernel["LoopIters"] - writer.states.numItersPLR for u in range(lookRange): doReadA = ((u < kernel["LoopIters"] // writer.states.numIterPerCoalescedReadA - writer.states.numItersPLR) and not kernel["DirectToVgprA"]) or subIter - doReadMXSA = (u < kernel["LoopIters"] // writer.states.numIterPerCoalescedReadMXSA - writer.states.numItersPLR) + doReadMXSA = False + if kernel["ProblemType"]["MXBlockA"] > 0: + doReadMXSA = ((u < kernel["LoopIters"] // writer.states.numIterPerCoalescedReadMXSA - writer.states.numItersPLR) and not kernel["DirectToVgprMXSA"]) or subIter doReadB = ((u < kernel["LoopIters"] // writer.states.numIterPerCoalescedReadB - writer.states.numItersPLR) and not kernel["DirectToVgprB"]) or subIter - doReadMXSB = (u < kernel["LoopIters"] // writer.states.numIterPerCoalescedReadMXSB - writer.states.numItersPLR) + doReadMXSB = False + if kernel["ProblemType"]["MXBlockB"] > 0: + doReadMXSB = ((u < kernel["LoopIters"] // writer.states.numIterPerCoalescedReadMXSB - writer.states.numItersPLR) and not kernel["DirectToVgprMXSB"]) or subIter doReadM = ((u < kernel["LoopIters"] // writer.states.numIterPerCoalescedReadMetadata - writer.states.numItersPLR)) doReadMXSA = doReadMXSA and kernel["ProblemType"]["MXBlockA"] and (not kernel["DirectToVgprMXSA"]) doReadMXSB = doReadMXSB and kernel["ProblemType"]["MXBlockB"] and (not kernel["DirectToVgprMXSB"]) @@ -404,7 +412,11 @@ def getNumLocalWritePerMfma(writer, kernel, lwStartMfmaIndex): ######### numMfmaCanSched = writer.states.lwEndMfmaIndex - lwStartMfmaIndex + 1 numLoadsA = kernel["DepthU"]*kernel["MacroTileA"]//kernel["GlobalReadVectorWidthA"]//kernel["NumThreads"] + if kernel["ProblemType"]["MXBlockA"]: + numLoadsA += kernel["_DepthUMXSA"]*kernel["MacroTileA"]//kernel["GlobalReadVectorWidthMXSA"]//kernel["NumThreads"] numLoadsB = kernel["DepthU"]*kernel["MacroTileB"]//kernel["GlobalReadVectorWidthB"]//kernel["NumThreads"] + if kernel["ProblemType"]["MXBlockB"]: + numLoadsB += kernel["_DepthUMXSB"]*kernel["MacroTileB"]//kernel["GlobalReadVectorWidthMXSB"]//kernel["NumThreads"] if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: macroTile = kernel["MacroTileB"] if kernel["ProblemType"]["Sparse"] == 2 else kernel["MacroTileA"] numLoadsM = kernel["DepthU"]*macroTile//kernel["GlobalReadVectorWidthMetadata"]//kernel["NumThreads"] diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/ShiftVectorComponents.py b/projects/hipblaslt/tensilelite/Tensile/Components/ShiftVectorComponents.py index 9a6ac3a4696..36a090a2542 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/ShiftVectorComponents.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/ShiftVectorComponents.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -306,7 +306,8 @@ def ShiftVectorComponentsMFMAPartialThread(self, writer, kernel, tP): numOutputsPrep = (matrixInstCoal * matrixInstPrep // numThreadInWave) if conThInProcDim else 1 numOutputsPrep = numOutputsPrep * matrixInstBPrep * miWaveTitlePrep - complexMultiplier = 2 if kernel["ProblemType"]["DataType"].isComplex() else 1 + datatyp = kernel["ProblemType"]["MacDataTypeA"] if tP["isA"] else kernel["ProblemType"]["MacDataTypeB"] + complexMultiplier = 2 if datatyp.isComplex() else 1 # unify process for dimension M/N regStrideCoal = 1 if tP["isA"] else numOutputsPrep diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/StreamK.py b/projects/hipblaslt/tensilelite/Tensile/Components/StreamK.py index d7e759d9f7c..a90e326b19f 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/StreamK.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/StreamK.py @@ -2222,9 +2222,9 @@ def graWorkGroup(self, writer, kernel, tPA, tPB): return module - def computeLoadSrd(self, writer, kernel, tc, sTmp): + def computeLoadSrd(self, writer, kernel, tP, sTmp): module = Module("StreamK TwoTileOriginal computeLoadSrd") - module.add(self.computeLoadSrdCommon(writer, kernel, tc, sTmp)) + module.add(self.computeLoadSrdCommon(writer, kernel, tP, sTmp)) return module def computeStoreSrdStart(self, writer, kernel): diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/SumUnroll.py b/projects/hipblaslt/tensilelite/Tensile/Components/SumUnroll.py index acf7e130392..3f77abfdbfc 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/SumUnroll.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/SumUnroll.py @@ -73,7 +73,7 @@ def loopSum(self, writer, kernel, tP, u, innerUnroll): m = (u) % (writer.states.numVgprBuffer) # local to use for MACs # calculate constant - numRegistersIn = kernel["ProblemType"]["DataType"].numRegisters() + numRegistersIn = kernel["ProblemType"]["MacDataType%s"%tc].numRegisters() numMIInput = kernel["MIInputPerThread%s"%tc] vgprPerInput = int(numMIInput * numRegistersIn) @@ -96,7 +96,7 @@ def loopSum(self, writer, kernel, tP, u, innerUnroll): vgprBuffer_new = (m//numIterPerCoalescedRead)*numIterPerCoalescedRead vgprBuffer_new_offset = m%numIterPerCoalescedRead*kernel["InnerUnroll"]*vgprPerInput - if kernel["ProblemType"]["DataType"].isBFloat16(): + if kernel["ProblemType"]["MacDataType%s"%tc].isBFloat16(): hiBitsMaskVgpr = writer.vgprPool.checkOut(1) imod.add(VMovB32(dst=vgpr(hiBitsMaskVgpr), src=hex(0xffff0000), comment="mask 0xffff0000 for pack two bfloat16 element to 32bit")) @@ -130,7 +130,7 @@ def loopSum(self, writer, kernel, tP, u, innerUnroll): while inputIdx < vgprPerInput: imod.add(VAddF32(dst=vgpr(valuSumStr), src0=vgpr("%s+%s"%(valuStr, iui_new_offset + inputIdx)), src1=vgpr(valuSumStr), comment="sum K")) inputIdx += 1 - elif kernel["ProblemType"]["DataType"].isBFloat16(): + elif kernel["ProblemType"]["MacDataType%s"%tc].isBFloat16(): # BF16 BiasSrcA,B tmpVgpr = writer.vgprPool.checkOutAligned(2,2) if vgprPerInput > 1 and (vgprPerInput % 2 == 0): @@ -145,8 +145,7 @@ def loopSum(self, writer, kernel, tP, u, innerUnroll): else: printExit("Currently unsupported vgprPerInput %u"%vgprPerInput) writer.vgprPool.checkIn(tmpVgpr) - elif (kernel["ProblemType"]["DataType"].isAnyFloat8A() and tc == "A") or \ - (kernel["ProblemType"]["DataType"].isAnyFloat8B() and tc == "B"): + elif kernel["ProblemType"]["MacDataType%s"%tc].isAnyFloat8(): #FP8 tmpVgpr = writer.vgprPool.checkOutAligned(4,2) if vgprPerInput > 1 and (vgprPerInput % 2 == 0): @@ -161,8 +160,7 @@ def loopSum(self, writer, kernel, tP, u, innerUnroll): else: printExit("Currently unsupported vgprPerInput %u"%vgprPerInput) writer.vgprPool.checkIn(tmpVgpr) - elif (kernel["ProblemType"]["DataType"].isAnyBFloat8A() and tc == "A") or \ - (kernel["ProblemType"]["DataType"].isAnyBFloat8B() and tc == "B"): + elif kernel["ProblemType"]["MacDataType%s"%tc].isAnyBFloat8(): #BF8 tmpVgpr = writer.vgprPool.checkOutAligned(4,2) if vgprPerInput > 1 and (vgprPerInput % 2 == 0): @@ -180,7 +178,7 @@ def loopSum(self, writer, kernel, tP, u, innerUnroll): else: printExit("Currently unsupported data type") - if kernel["ProblemType"]["DataType"].isBFloat16(): + if kernel["ProblemType"]["MacDataType%s"%tc].isBFloat16(): writer.vgprPool.checkIn(hiBitsMaskVgpr) return imod diff --git a/projects/hipblaslt/tensilelite/Tensile/Contractions.py b/projects/hipblaslt/tensilelite/Tensile/Contractions.py index f5c7f1ac2dd..98f316e4ae7 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Contractions.py +++ b/projects/hipblaslt/tensilelite/Tensile/Contractions.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -133,8 +133,8 @@ def FromOriginalState(cls, d): rv.transA = bool(d['TransposeA']) rv.transB = bool(d['TransposeB']) - # it will either be set as d['MacDataType'] or a specified input - if 'MacDataTypeA' in d: + + if 'MacDataTypeA' in d: #it will either be set as d['MacDataType'] or a specified input rv.computeInputTypeA = DataType(d['MacDataTypeA']) else: rv.computeInputTypeA = srcType @@ -414,7 +414,6 @@ def predicates(self, includeBatch=False, includeOperation=False, includeType=Fal predicates.append(ProblemPredicate("MXBlockB", value=self.mxBlockB)) if self.mxBlockB: predicates.append(ProblemPredicate("DataTypeMXSB", value=self.mxTypeB)) - return predicates def extractDimPredicate(cls, key, value, predicateName): diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py index 42a99c8dd72..238533eb0c5 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py @@ -107,6 +107,7 @@ class ABMatrixInfo(MatrixInfo): numVgprLocalWriteSwapAddr: int = -1 startVgprLocalWriteSwapAddr: int= -1 numSgprGlobalReadIncs: int = -1 + useConstSgprGlobalReadIncs: bool = False numPackCvt: int = 0 useTransposeCodeThis: bool = False useTransposeCodeNext: bool = False @@ -227,6 +228,7 @@ class StateValues: globalReadIncsUseVgpr: bool = False groOffsetInMacroTile: int = 0 use64bShadowLimit: bool = True + use64bShadowLimitMX: bool = False preventVgprOverflowDuringNewTile: int = -1 interleaveStoreVmcnt: bool = False srdShiftLeft:dict = field(init=False) @@ -306,9 +308,9 @@ class StateValues: localReadDoCntMXSB: int = 0 localReadDoCntMetadata: int = 0 savedLocalReadDoCntA: int = 0 + savedLocalReadDoCntB: int = 0 savedLocalReadDoCntMXSA: int = 0 savedLocalReadDoCntMXSB: int = 0 - savedLocalReadDoCntB: int = 0 savedLocalReadDoCntMetadata: int = 0 dtvKIntervalA: int = 1 @@ -346,6 +348,7 @@ class StateValues: doPackPreSchedulingNextLoop: bool = False doFullPackCodePrefetch: bool = False lockLdsReadTokenSwap: bool = False + useCommonSgprSwap: bool = False # Epilogue states preloadScaleA = False @@ -518,9 +521,9 @@ def __init__( # Only works if the problem uses full tiles (no edges) # Mismatches will assert (generate GPUVM fault) self.db["CheckValue1A"] = self.debugConfig.enableDebugA + self.db["CheckValue1B"] = self.debugConfig.enableDebugB self.db["CheckValue1MXSA"] = False self.db["CheckValue1MXSB"] = False - self.db["CheckValue1B"] = self.debugConfig.enableDebugB self.db["CheckValue1Metadata"] = False # Check value in C matrix. # Caveats: @@ -1079,6 +1082,12 @@ def _makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, point numPackedB = 0 numPackedM = 0 pointerLWCodeInserted = False + PointerLRCodeItems = pointerLRCode.flatitems() + lenPointerLRCode = 0 + # count number of PointerLRCode (excluding comment) + for item in PointerLRCodeItems: + if isinstance(item, Instruction): + lenPointerLRCode += 1 ##### # Prepare localReadCode #### @@ -1755,7 +1764,19 @@ def calculateRangeAndUpdateCounter(itemCounter, writeCounters, length): (not pointerLWCodeInserted)): iterCode.add(pointerLWCode) pointerLWCodeInserted = True - if i == numMfmaPerIter - 1: + # schedule PointerLRCode into multiple MFMA (only v operation case (means excludes IncLdsBufSwitch)) + if lenPointerLRCode > 2 and not self.states.IncLdsBufSwitch: + if len(localReadItemsThisLoop) == 0: + if i >= numMfmaPerIter - lenPointerLRCode: + while PointerLRCodeItems: + item = PointerLRCodeItems.pop(0) + iterCode.add(item) + if isinstance(item, Instruction): + break + if i == numMfmaPerIter - 1: + while PointerLRCodeItems: + iterCode.add(PointerLRCodeItems.pop(0)) + elif i == numMfmaPerIter - 1: iterCode.add(pointerLRCode) #### @@ -2817,11 +2838,11 @@ def releaseTensorTmpGprs(tP): module.add(replaceHolder(moduleTmp, 0)) module.add(self.globalReadDo(kernel, 0, tensorParameters1st)) if "MX" in tensorParameters1st: - moduleTmp = self.directToLdsM0Update(kernel, 0, tensorParameters1st["MX"]) + moduleTmp = self.directToLdsM0Update(kernel, 0, tensorParameters1st["MX"], skipWait=True) module.add(replaceHolder(moduleTmp, 0)) module.add(self.globalReadDo(kernel, 0, tensorParameters1st["MX"])) if "MX" in tensorParameters2nd: - moduleTmp = self.directToLdsM0Update(kernel, 0, tensorParameters2nd["MX"]) + moduleTmp = self.directToLdsM0Update(kernel, 0, tensorParameters2nd["MX"], skipWait=True) module.add(replaceHolder(moduleTmp, 0)) module.add(self.globalReadDo(kernel, 0, tensorParameters2nd["MX"])) skip2ndWaitForDtl = kernel["DirectToLds%s"%tensorParameters1st["tensorChar"]] @@ -3080,8 +3101,6 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, pa doReadMXSB = doReadMXSB and iui*self.states.numReadsIterCoalescedMXSB < kernel["InnerUnroll"] doReadB = doReadB and iui*self.states.numReadsIterCoalescedB < kernel["InnerUnroll"] doReadM = doReadM and iui*self.states.numReadsIterCoalescedMetadata < kernel["InnerUnroll"] - if (doReadA or doReadB or doReadM or doReadMXSA or doReadMXSB) and kernel["ForceUnrollSubIter"] and not self.states.doFullPackCodePrefetch: - pack[1] = Module() if doReadA: localReads.addComment1("local read a") bufferIdx = plrIdx*self.states.numIterPerCoalescedReadA @@ -3103,20 +3122,17 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, pa packPre[packStoreIdx*self.states.numIterPerCoalescedReadA].add(packCodeA) else: pack[packStoreIdx*self.states.numIterPerCoalescedReadA].add(packCodeA) - if kernel["ForceUnrollSubIter"] and not self.states.doFullPackCodePrefetch: - pack[1].add(packCodeA) if doReadMXSA: localReads.addComment1("local read mxsa") bufferIdx = plrIdx*self.states.numIterPerCoalescedReadMXSA if self.states.packDTVA or self.states.convDTVA: # DTV + pack or input conversion case, offset bufferIdx for local read packing instructions bufferIdx = plrIdxDTV*self.states.numIterPerCoalescedReadMXSA + vregSetIdxLR * kernel["LoopIters"] - localReadCodeMXSA, packCodeMXSA, _ = self.localReadDo(kernel, bufferIdx, iui*self.states.numReadsIterCoalescedMXSA, 0, tensorParametersA["MX"]) + localReadCodeMXSA, packCodeMXSA, packPreMXSA = self.localReadDo(kernel, bufferIdx, iui*self.states.numReadsIterCoalescedMXSA, 0, tensorParametersA["MX"]) if needNextBufLR: localReads.add(localReadCodeMXSA) - pack[plrIdx*self.states.numIterPerCoalescedReadMXSA].add(packCodeMXSA) - if kernel["ForceUnrollSubIter"]: - pack[1].add(packCodeMXSA) + pack[packStoreIdx*self.states.numIterPerCoalescedReadMXSA].add(packPreMXSA) + pack[packStoreIdx*self.states.numIterPerCoalescedReadMXSA].add(packCodeMXSA) if doReadM: localReads.addComment1("local read metadata") localReadCodeM, packCodeM, packPreM = self.localReadDo(kernel, plrIdx*self.states.numIterPerCoalescedReadMetadata, iui*self.states.numReadsIterCoalescedMetadata, 0, tPM) @@ -3131,12 +3147,11 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, pa if self.states.packDTVB or self.states.convDTVB: # DTV + pack or input conversion case, offset bufferIdx for local read packing instructions bufferIdx = plrIdxDTV*self.states.numIterPerCoalescedReadMXSB + vregSetIdxLR * kernel["LoopIters"] - localReadCodeMXSB, packCodeMXSB, _ = self.localReadDo(kernel, bufferIdx, iui*self.states.numReadsIterCoalescedMXSB, 0, tensorParametersB["MX"]) + localReadCodeMXSB, packCodeMXSB, packPreMXSB = self.localReadDo(kernel, bufferIdx, iui*self.states.numReadsIterCoalescedMXSB, 0, tensorParametersB["MX"]) if needNextBufLR: localReads.add(localReadCodeMXSB) - pack[plrIdx*self.states.numIterPerCoalescedReadMXSB].add(packCodeMXSB) - if kernel["ForceUnrollSubIter"]: - pack[1].add(packCodeMXSB) + pack[packStoreIdx*self.states.numIterPerCoalescedReadMXSB].add(packPreMXSB) + pack[packStoreIdx*self.states.numIterPerCoalescedReadMXSB].add(packCodeMXSB) if doReadB: localReads.addComment1("local read b") bufferIdx = plrIdx*self.states.numIterPerCoalescedReadB @@ -3237,10 +3252,10 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, pa if not kernel["ForceUnrollSubIter"] or (doReadA and (u 1): - unitA = 4 // (bpeGRA * asem) + unitA = int(4 / (bpeGRA * asem)) if ((not tluB) and (bpeGRB * asem < 4) and grvwb > 1): - unitB = 4 // (bpeGRB * asem) - self.states.tailloopInNllmaxUnit = int(max(unitA, unitB)) + unitB = int(4 / (bpeGRB * asem)) + self.states.tailloopInNllmaxUnit = max(unitA, unitB) # Only assembly supports scheduling if kernel["KernelLanguage"] == "Assembly": @@ -5588,7 +5625,7 @@ def _initKernel(self, kernel, tensorParametersA, tensorParametersB): """ if kernel["EnableMatrixInstruction"] and kernel["LocalReadVectorWidthA"] >= kernel["MIInputPerThread"]: - WLR = max(max(kernel["LocalReadVectorWidthA"]//kernel["MIInputPerThread"], 1), 1) + WLR = int(max(kernel["LocalReadVectorWidthA"]//kernel["MIInputPerThread"], 1)) self.states.numItersPLR = kernel["PrefetchLocalRead"]%(kernel["LoopIters"]//WLR) else: self.states.numItersPLR = kernel["PrefetchLocalRead"]%(kernel["LoopIters"]) @@ -5638,29 +5675,31 @@ def _initKernel(self, kernel, tensorParametersA, tensorParametersB): # TODO: implement extra logic to swap vgprs after local read to suport lrvwTile > 1 for numBytes >= 4 + MIInputPerThread > 1 isCMS = kernel["UseCustomMainLoopSchedule"] isMfmaXf32 = kernel["UseMFMAF32XEmulation"] - forceLrvwTile1 = kernel["ProblemType"]["MacDataTypeA"].numBytes() >= 4 and \ + forceLrvwTile1A = kernel["ProblemType"]["MacDataTypeA"].numBytes() >= 4 and \ (kernel["EnableMatrixInstruction"] and kernel["MIInputPerThread"] > 1) and \ not (kernel["UseF32XEmulation"] and (isMfmaXf32 or isCMS)) - if not kernel["UnrollMajorLDSA"] and not forceLrvwTile1: + if not kernel["UnrollMajorLDSA"] and not forceLrvwTile1A: self.states.lrvwTileA = kernel["VectorWidthA"] else: self.states.lrvwTileA = 1 - forceLrvwTile1 = kernel["ProblemType"]["MacDataTypeB"].numBytes() >= 4 and \ + forceLrvwTile1B = kernel["ProblemType"]["MacDataTypeB"].numBytes() >= 4 and \ (kernel["EnableMatrixInstruction"] and kernel["MIInputPerThreadB"] > 1) and \ not (kernel["UseF32XEmulation"] and (isMfmaXf32 or isCMS)) - if not kernel["UnrollMajorLDSB"] and not forceLrvwTile1: + if not kernel["UnrollMajorLDSB"] and not forceLrvwTile1B: self.states.lrvwTileB = kernel["VectorWidthB"] else: + if kernel["ProblemType"]["MXBlockB"]: + self.states.lrvwTileMXSB = 1 self.states.lrvwTileB = 1 if kernel["ProblemType"]["MXBlockA"]: - if not kernel["UnrollMajorLDSMXSA"] and not forceLrvwTile1: + if not kernel["UnrollMajorLDSMXSA"] and not forceLrvwTile1A: self.states.lrvwTileMXSA = kernel["VectorWidthA"] else: self.states.lrvwTileMXSA = 1 if kernel["ProblemType"]["MXBlockB"]: - if not kernel["UnrollMajorLDSMXSB"] and not forceLrvwTile1: + if not kernel["UnrollMajorLDSMXSB"] and not forceLrvwTile1B: self.states.lrvwTileMXSB = kernel["VectorWidthB"] else: self.states.lrvwTileMXSB = 1 @@ -5694,16 +5733,20 @@ def _initKernel(self, kernel, tensorParametersA, tensorParametersB): self.states.numVgprBufferPackMXSB = 1 if kernel["UnrollMajorLDSA"]: - divider = 2 if (kernel["ProblemType"]["Sparse"] == 1) and (kernel["MIInputPerThread"] * kernel["ProblemType"]["DataType"].numBytes() <= 16) else 1 + divider = 2 if (kernel["ProblemType"]["Sparse"] == 1) and (kernel["MIInputPerThread"] * kernel["ProblemType"]["MacDataTypeA"].numBytes() <= 16) else 1 self.states.lrvwUnrollA = kernel["LocalReadVectorWidthA"] // divider else: self.states.lrvwUnrollA = 1 + if kernel["ProblemType"]["MXBlockA"]: + self.states.lrvwUnrollMXSA = 1 if kernel["UnrollMajorLDSB"]: - divider = 2 if (kernel["ProblemType"]["Sparse"] == 2) and (kernel["MIInputPerThread"] * kernel["ProblemType"]["DataType"].numBytes() <= 16) else 1 + divider = 2 if (kernel["ProblemType"]["Sparse"] == 2) and (kernel["MIInputPerThread"] * kernel["ProblemType"]["MacDataTypeB"].numBytes() <= 16) else 1 self.states.lrvwUnrollB = kernel["LocalReadVectorWidthB"] // divider else: self.states.lrvwUnrollB = 1 + if kernel["ProblemType"]["MXBlockB"]: + self.states.lrvwUnrollMXSB = 1 if kernel["ProblemType"]["MXBlockA"]: if kernel["UnrollMajorLDSMXSA"]: @@ -5764,6 +5807,15 @@ def _initKernel(self, kernel, tensorParametersA, tensorParametersB): self.states.oneBufferScheduling = (kernel["1LDSBuffer"]) or \ ((kernel["DirectToLdsA"] or kernel["DirectToLdsB"]) and \ self.states.numLDSBlk == kernel["PrefetchGlobalRead"]) + # common sgprSwap + # enable it for the following case. + # DTLA+B + numLDSBlk==2 + (not EPS) + (not sparse) + (MX or (StoreSwapAddr and (not CMS)) + self.states.useCommonSgprSwap = False + if kernel["DirectToLds"] == 1 and self.states.numLDSBlk == 2 and \ + (not kernel["ExpandPointerSwap"]) and (not kernel["ProblemType"]["Sparse"]) and \ + ((kernel["ProblemType"]["MXBlockA"] or kernel["ProblemType"]["MXBlockB"]) or \ + kernel["StoreSwapAddr"] and not kernel["UseCustomMainLoopSchedule"]): + self.states.useCommonSgprSwap = True # set memory token by LDS buffer setting self.states.memTokenLdsBufferMeta = 4 if kernel["1LDSBuffer"]: @@ -5901,6 +5953,12 @@ def readWriteVectors(mat, vw, kernel): # use 64-bit buffer limit shadow register # but not implemented or tested self.states.use64bShadowLimit = kernel["Use64bShadowLimit"] and kernel["BufferLoad"] + # Keep gfx950's dedicated MX shadow-limit switch, but match legacy gfx1250 + # behavior where MX tensors share the same shadow-limit mode as A/B. + if isgfx950: + self.states.use64bShadowLimitMX = kernel["Use64bShadowLimitMX"] and kernel["BufferLoad"] + else: + self.states.use64bShadowLimitMX = self.states.use64bShadowLimit # Check if the address setup code for LWA and GRO causes register growth. # This is not an error condition but bears further investigation. @@ -5919,9 +5977,11 @@ def readWriteVectors(mat, vw, kernel): self.states.srdShiftLeft["A"] = kernel["GlobalReadVectorWidthA"] self.states.srdShiftLeft["B"] = kernel["GlobalReadVectorWidthB"] if kernel["ProblemType"]["MXBlockA"]: - self.states.srdShiftLeft["MXSA"] = kernel["GlobalReadVectorWidthA"] + # use MXS version for gfx950 only + self.states.srdShiftLeft["MXSA"] = kernel["GlobalReadVectorWidthMXSA"] if isgfx950 else kernel["GlobalReadVectorWidthA"] if kernel["ProblemType"]["MXBlockB"]: - self.states.srdShiftLeft["MXSB"] = kernel["GlobalReadVectorWidthB"] + # use MXS version for gfx950 only + self.states.srdShiftLeft["MXSB"] = kernel["GlobalReadVectorWidthMXSB"] if isgfx950 else kernel["GlobalReadVectorWidthB"] if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: self.states.srdShiftLeft["Metadata"] = kernel["GlobalReadVectorWidthMetadata"] @@ -6204,16 +6264,32 @@ def readWriteVectors(mat, vw, kernel): if kernel["ProblemType"]["MXBlockA"]: self.states.mxsa.numVgprValuPerBlock = kernel["MIWaveTileMXSA"] * kernel["MIInputPerThreadMXSA"] // self.states.bpr + # workaround for gfx950 + # need to allocate same amount of MIWaveTile + if isgfx950: + self.states.mxsa.numVgprValuPerBlock = kernel["MIWaveTileMXSA"] if kernel["DirectToVgprMXSA"] and not (self.states.packDTVA or self.states.convDTVA): self.states.mxsa.numVgprValuPerBlock = 0 + # MX scale registers are consumed by local-read and wmma paths; avoid + # emitting unresolved vgprValuMXSA symbols when integer division rounds to 0. + elif self.states.mxsa.numVgprValuPerBlock == 0: + self.states.mxsa.numVgprValuPerBlock = kernel["MIWaveTileMXSA"] self.states.mxsa.numVgprValu = self.states.mxsa.numVgprValuPerBlock * valuBlocksA if self.states.lrvwTileMXSA > 1: self.states.mxsa.numVgprValu = self.states.mxsa.numVgprValuPerBlock * kernel["InnerUnroll"] if kernel["ProblemType"]["MXBlockB"]: self.states.mxsb.numVgprValuPerBlock = kernel["MIWaveTileMXSB"] * kernel["MIInputPerThreadMXSB"] // self.states.bpr + # workaround for gfx950 + # need to allocate same amount of MIWaveTile + if isgfx950: + self.states.mxsb.numVgprValuPerBlock = kernel["MIWaveTileMXSB"] if kernel["DirectToVgprMXSB"] and not (self.states.packDTVB or self.states.convDTVB): self.states.mxsb.numVgprValuPerBlock = 0 + # MX scale registers are consumed by local-read and wmma paths; avoid + # emitting unresolved vgprValuMXSB symbols when integer division rounds to 0. + elif self.states.mxsb.numVgprValuPerBlock == 0: + self.states.mxsb.numVgprValuPerBlock = kernel["MIWaveTileMXSB"] self.states.mxsb.numVgprValu = self.states.mxsb.numVgprValuPerBlock * valuBlocksB if self.states.lrvwTileMXSB > 1: self.states.mxsb.numVgprValu = self.states.mxsb.numVgprValuPerBlock * kernel["InnerUnroll"] @@ -6484,6 +6560,8 @@ def readWriteVectors(mat, vw, kernel): numVgprMultiplierA = 1 numVgprMultiplierB = 1 + numVgprMultiplierMXSA = 1 + numVgprMultiplierMXSB = 1 numVgprMultiplierMetadata = 1 if kernel["ProblemType"]["MXBlockA"]: numVgprMultiplierMXSA = 1 @@ -6538,9 +6616,14 @@ def readWriteVectors(mat, vw, kernel): # do not allocate local read address register if DirectToVgpr is enabled if kernel["DirectToVgprA"]: self.states.a.numVgprLocalReadAddr = 0 + if kernel["ProblemType"]["MXBlockA"] and kernel["DirectToVgprMXSA"]: + self.states.mxsa.numVgprLocalReadAddr = 0 + self.states.mxsa.numVgprLocalWriteAddr = 0 if kernel["DirectToVgprB"]: self.states.b.numVgprLocalReadAddr = 0 - + if kernel["ProblemType"]["MXBlockB"] and kernel["DirectToVgprMXSB"]: + self.states.mxsb.numVgprLocalReadAddr = 0 + self.states.mxsb.numVgprLocalWriteAddr = 0 if not (kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]): self.states.m.numVgprLocalWriteAddr = 0 @@ -6579,9 +6662,9 @@ def readWriteVectors(mat, vw, kernel): self.states.b.numVgprLocalWriteSwapAddr = 1 if kernel["ProblemType"]["Sparse"] and (not kernel["LocalWriteUseSgprMetadata"]) and (self.states.m.numVgprLocalWriteAddr > 0): self.states.m.numVgprLocalWriteSwapAddr = 1 - if kernel["ProblemType"]["MXBlockA"] and (not kernel["LocalWriteUseSgprA"]) and (self.states.mxsa.numVgprLocalWriteAddr > 0): + if kernel["ProblemType"]["MXBlockA"] and (not kernel["LocalWriteUseSgprMXSA"]) and (self.states.mxsa.numVgprLocalWriteAddr > 0): self.states.mxsa.numVgprLocalWriteSwapAddr = 1 - if kernel["ProblemType"]["MXBlockB"] and (not kernel["LocalWriteUseSgprB"]) and (self.states.mxsb.numVgprLocalWriteAddr > 0): + if kernel["ProblemType"]["MXBlockB"] and (not kernel["LocalWriteUseSgprMXSB"]) and (self.states.mxsb.numVgprLocalWriteAddr > 0): self.states.mxsb.numVgprLocalWriteSwapAddr = 1 #################################### @@ -6983,8 +7066,14 @@ def GNLCOInit(tc): # Need backup for the first LocalReadAddr only (others will be calculated from the first one) self.states.a.startVgprLocalReadAddrOrig = vgprIdx vgprIdx += 1 if self.states.a.numVgprLocalReadAddr > 0 else 0 + if kernel["ProblemType"]["MXBlockA"]: + self.states.mxsa.startVgprLocalReadAddrOrig = vgprIdx + vgprIdx += 1 if self.states.mxsa.numVgprLocalReadAddr > 0 else 0 self.states.b.startVgprLocalReadAddrOrig = vgprIdx vgprIdx += 1 if self.states.b.numVgprLocalReadAddr > 0 else 0 + if kernel["ProblemType"]["MXBlockB"]: + self.states.mxsb.startVgprLocalReadAddrOrig = vgprIdx + vgprIdx += 1 if self.states.mxsb.numVgprLocalReadAddr > 0 else 0 # ---------------------------- # TODO: alignment hack, figure out a better solution @@ -7069,8 +7158,6 @@ def GNLCOInit(tc): vgprIdx = self.states.b.startVgprValu \ + max(self.states.b.numVgprValu + numVgprValuPackB, self.states.b.numVgprG2LAllocated) - - if ((tensorParametersA["bpe"] < 4 and not kernel["UnrollMajorLDSA"]) or \ (tensorParametersB["bpe"] < 4 and not kernel["UnrollMajorLDSB"]) or \ (kernel["ProblemType"]["Sparse"] and not kernel["UnrollMajorLDSMetadata"] and (kernel["MIInputPerThreadMetadata"] == 4))) \ @@ -7216,6 +7303,14 @@ def GNLCOInit(tc): if self.states.b.numVgprLocalWriteSwapAddr > 0: self.states.b.startVgprLocalWriteSwapAddr = vgprIdx vgprIdx += 1 + if kernel["ProblemType"]["MXBlockA"]: + if self.states.mxsa.numVgprLocalWriteSwapAddr > 0: + self.states.mxsa.startVgprLocalWriteSwapAddr = vgprIdx + vgprIdx += 1 + if kernel["ProblemType"]["MXBlockB"]: + if self.states.mxsb.numVgprLocalWriteSwapAddr > 0: + self.states.mxsb.startVgprLocalWriteSwapAddr = vgprIdx + vgprIdx += 1 # X32F Emulation initializations # meaning of variables @@ -7489,6 +7584,37 @@ def checkVregOverflowTF32Emu(vgprIdx, numV): if kernel["ProblemType"]["MXBlockB"]: self.states.mxsb.numSgprGlobalReadIncs = kernel["ProblemType"]["NumIndicesSummation"] * self.states.rpgo self.states.m.numSgprGlobalReadIncs = kernel["ProblemType"]["NumIndicesSummation"] * self.states.rpgo + # check for constSgprGlobalReadInc + # use const version of GlobalReadInc for the following case + # - StreamK (not GSU) + # - no staggerUCode + # - TLUA false for A, TLUB false for B + # - numSgprGlobalReadIncs is 1 + if kernel["StreamK"] and (not self.states.staggerUCode): + if kernel["ProblemType"]["TLUA"] == False: + if self.states.a.numSgprGlobalReadIncs == 1: + # use const GR Inc + self.states.a.useConstSgprGlobalReadIncs = True + # do not allocate GRInc sgpr + self.states.a.numSgprGlobalReadIncs = 0 + if kernel["ProblemType"]["MXBlockA"]: + if self.states.mxsa.numSgprGlobalReadIncs == 1: + # use const GR Inc + self.states.mxsa.useConstSgprGlobalReadIncs = True + # do not allocate GRInc sgpr + self.states.mxsa.numSgprGlobalReadIncs = 0 + if kernel["ProblemType"]["TLUB"] == False: + if self.states.b.numSgprGlobalReadIncs == 1: + # use const GR Inc + self.states.b.useConstSgprGlobalReadIncs = True + # do not allocate GRInc sgpr + self.states.b.numSgprGlobalReadIncs = 0 + if kernel["ProblemType"]["MXBlockB"]: + if self.states.mxsb.numSgprGlobalReadIncs == 1: + # use const GR Inc + self.states.mxsb.useConstSgprGlobalReadIncs = True + # do not allocate GRInc sgpr + self.states.mxsb.numSgprGlobalReadIncs = 0 ######################################## # SGPR Assignment according to AMDGPU-ABI @@ -7657,13 +7783,23 @@ def checkVregOverflowTF32Emu(vgprIdx, numV): self.defineSgpr("LocalWriteAddrA", 1) if kernel["LocalWriteUseSgprB"]: self.defineSgpr("LocalWriteAddrB", 1) + if kernel["ProblemType"]["MXBlockA"] and kernel["LocalWriteUseSgprMXSA"]: + self.defineSgpr("LocalWriteAddrMXSA", 1) + if kernel["ProblemType"]["MXBlockB"] and kernel["LocalWriteUseSgprMXSB"]: + self.defineSgpr("LocalWriteAddrMXSB", 1) # Allocate registers to swap between lds buffers - if kernel["StoreSwapAddr"]: + if self.states.useCommonSgprSwap: + self.defineSgpr("SwapCommon", 1) + elif kernel["StoreSwapAddr"]: if kernel["LocalWriteUseSgprA"]: self.defineSgpr("SwapA", 1) if kernel["LocalWriteUseSgprB"]: self.defineSgpr("SwapB", 1) + if kernel["ProblemType"]["MXBlockA"] and kernel["LocalWriteUseSgprMXSA"]: + self.defineSgpr("SwapMXSA", 1) + if kernel["ProblemType"]["MXBlockB"] and kernel["LocalWriteUseSgprMXSB"]: + self.defineSgpr("SwapMXSB", 1) if kernel["ProblemType"]["Sparse"] and kernel["LocalWriteUseSgprMetadata"]: self.defineSgpr("SwapMetadata", 1) @@ -7782,7 +7918,7 @@ def checkVregOverflowTF32Emu(vgprIdx, numV): self.states.numReadsPerUnrollA = ceil(tensorParametersA["bpe"] * kernel["MIInputPerThreadA"] / int(tensorParametersA["localReadInstruction"].blockWidth * 4)) else: self.states.numReadsPerUnrollA = kernel["MIInputPerThreadA"] - if kernel["ForceUnrollSubIter"]: + if kernel["ForceUnrollSubIter"] and self.states.numReadsPerUnrollA > 1: self.states.numReadsPerUnrollA //= kernel["numSubTiles"] factorSubIterA = 1 numA = kernel["InnerUnroll"]*(kernel["MIWaveTile"][0] * self.states.numReadsPerUnrollA) // tensorParametersA["localReadInstruction"].numOffsets @@ -7790,6 +7926,11 @@ def checkVregOverflowTF32Emu(vgprIdx, numV): numA = numA // kernel["VectorWidthA"] if kernel["ForceUnrollSubIter"]: numA = numA // factorSubIterA + if kernel["ProblemType"]["MXBlockA"]: + self.states.numReadsPerUnrollMXSA = 1 + numMXSA = kernel["InnerUnroll"] * kernel["MIWaveTile"][0] // tensorParametersMXSA["localReadInstruction"].numOffsets + if self.states.lrvwTileMXSA > 1: + numMXSA = numMXSA // kernel["VectorWidthA"] if kernel["ProblemType"]["MXBlockA"]: if kernel["UnrollMajorLDSMXSA"]: @@ -7815,7 +7956,7 @@ def checkVregOverflowTF32Emu(vgprIdx, numV): self.states.numReadsPerUnrollB = ceil(tensorParametersB["bpe"] * kernel["MIInputPerThreadB"] / int(tensorParametersB["localReadInstruction"].blockWidth * 4)) else: self.states.numReadsPerUnrollB = kernel["MIInputPerThreadB"] - if kernel["ForceUnrollSubIter"]: + if kernel["ForceUnrollSubIter"] and self.states.numReadsPerUnrollB > 1: self.states.numReadsPerUnrollB //= kernel["numSubTiles"] factorSubIterB = 1 numB = kernel["InnerUnroll"]*(kernel["MIWaveTile"][1] * self.states.numReadsPerUnrollB) // tensorParametersB["localReadInstruction"].numOffsets @@ -7823,6 +7964,11 @@ def checkVregOverflowTF32Emu(vgprIdx, numV): numB = numB // kernel["VectorWidthB"] if kernel["ForceUnrollSubIter"]: numB = numB // factorSubIterB + if kernel["ProblemType"]["MXBlockB"]: + self.states.numReadsPerUnrollMXSB = 1 + numMXSB = kernel["InnerUnroll"] * kernel["MIWaveTile"][1] // tensorParametersMXSB["localReadInstruction"].numOffsets + if self.states.lrvwTileMXSB > 1: + numMXSB = numMXSB // kernel["VectorWidthB"] if kernel["ProblemType"]["MXBlockB"]: if kernel["UnrollMajorLDSMXSB"]: @@ -8113,6 +8259,12 @@ def checkResources(self, kernel, mkb) -> None: def graWorkGroup(self, kernel): return "" + def tpBpe(self, kernel, key, cM): + if cM in ("Metadata", "MXSA", "MXSB"): + return 1 + else: + return kernel["ProblemType"][key].numBytes() + ############################################################################## # Get Input Tensor Bpe ############################################################################## diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py index 8b298d82a16..83bfc7ec360 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py @@ -576,6 +576,7 @@ def removeGRSrdVariableSgprsFromPool(self, kernel): if self.states.use64bShadowLimit: self.removeSgprVarFromPool("ShadowLimitA") self.removeSgprVarFromPool("ShadowLimitB") + if self.states.use64bShadowLimitMX: if kernel["ProblemType"]["MXBlockA"]: self.removeSgprVarFromPool("ShadowLimitMXSA") if kernel["ProblemType"]["MXBlockB"]: @@ -595,6 +596,11 @@ def removeGROffsetsVariableSgprsFromPool(self, kernel): self.removeSgprVarFromPool("GlobalReadIncsA") self.removeSgprVarFromPool("GlobalReadIncsB") + if kernel["ProblemType"]["MXBlockA"]: + self.removeSgprVarFromPool("GlobalReadIncsMXSA") + if kernel["ProblemType"]["MXBlockB"]: + self.removeSgprVarFromPool("GlobalReadIncsMXSB") + return module def defineVariableSgprs(self, kernel): @@ -657,14 +663,15 @@ def defineVariableSgprs(self, kernel): if not kernel["enableTDMB"]: module.add(self.defineSgpr("ShadowLimitB", 2, 2)) self.addSgprVarToPool("ShadowLimitB") + if kernel["ProblemType"]["Sparse"]: + module.add(self.defineSgpr("ShadowLimitMetadata", 2, 2)) + if self.states.use64bShadowLimitMX: if kernel["ProblemType"]["MXBlockA"] and not kernel["enableTDMA"]: module.add(self.defineSgpr("ShadowLimitMXSA", 2, 2)) self.addSgprVarToPool("ShadowLimitMXSA") if kernel["ProblemType"]["MXBlockB"] and not kernel["enableTDMB"]: module.add(self.defineSgpr("ShadowLimitMXSB", 2, 2)) self.addSgprVarToPool("ShadowLimitMXSB") - if kernel["ProblemType"]["Sparse"]: - module.add(self.defineSgpr("ShadowLimitMetadata", 2, 2)) if self.states.staggerUCode: module.add(self.defineSgpr("StaggerUIter", 1)) # stagger loop iterations, used for various iter counts in the code @@ -686,16 +693,20 @@ def defineVariableSgprs(self, kernel): self.addSgprVarToPool("WrapUA") self.addSgprVarToPool("WrapUB") - module.add(self.defineSgpr("GlobalReadIncsA", self.states.a.numSgprGlobalReadIncs)) - module.add(self.defineSgpr("GlobalReadIncsB", self.states.b.numSgprGlobalReadIncs)) - if kernel["ProblemType"]["MXBlockA"]: + if self.states.a.numSgprGlobalReadIncs > 0: + module.add(self.defineSgpr("GlobalReadIncsA", self.states.a.numSgprGlobalReadIncs)) + self.addSgprVarToPool("GlobalReadIncsA") + if kernel["ProblemType"]["MXBlockA"] and self.states.mxsa.numSgprGlobalReadIncs > 0: module.add(self.defineSgpr("GlobalReadIncsMXSA", self.states.mxsa.numSgprGlobalReadIncs)) - if kernel["ProblemType"]["MXBlockB"]: + self.addSgprVarToPool("GlobalReadIncsMXSA") + if self.states.b.numSgprGlobalReadIncs > 0: + module.add(self.defineSgpr("GlobalReadIncsB", self.states.b.numSgprGlobalReadIncs)) + self.addSgprVarToPool("GlobalReadIncsB") + if kernel["ProblemType"]["MXBlockB"] and self.states.mxsb.numSgprGlobalReadIncs > 0: module.add(self.defineSgpr("GlobalReadIncsMXSB", self.states.mxsb.numSgprGlobalReadIncs)) + self.addSgprVarToPool("GlobalReadIncsMXSB") if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: module.add(self.defineSgpr("GlobalReadIncsMetadata", self.states.m.numSgprGlobalReadIncs)) - self.addSgprVarToPool("GlobalReadIncsA") - self.addSgprVarToPool("GlobalReadIncsB") if self.states.IncLdsBufSwitch: module.add(self.defineSgpr("LDSBufferReadInc", 1)) module.add(self.defineSgpr("LDSBufferWriteInc", 1)) @@ -1054,7 +1065,6 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module: moduleVgprMacroValuAPack.add(RegSet("v", "vgprValuA_X%u_I%u_D%u"%(bi,iui,data),"vgprValuA_X0_I0_D0_PACK", ri)) ri += self.states.a.numVgprValuPerBlock - ri = 0 if self.states.b.numVgprValu > 0: # Do not generate vgprValuB if numVgprValuB is 0 numBiFactor = numBi @@ -1155,7 +1165,7 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module: module.add(RegSet("v", "vgprLocalWriteAddrOverhangA", \ self.states.a.startVgprLocalWriteAddr+1)) if kernel["ProblemType"]["MXBlockA"]: - if not kernel["LocalWriteUseSgprA"] and self.states.mxsa.numVgprLocalWriteAddr > 0: + if not kernel["LocalWriteUseSgprMXSA"] and self.states.mxsa.numVgprLocalWriteAddr > 0: module.add(RegSet("v", "vgprLocalWriteAddrMXSA", \ self.states.mxsa.startVgprLocalWriteAddr)) if self.states.mxsa.numVgprLocalWriteAddr > 1: @@ -1168,7 +1178,7 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module: module.add(RegSet("v", "vgprLocalWriteAddrOverhangB", \ self.states.b.startVgprLocalWriteAddr+1)) if kernel["ProblemType"]["MXBlockB"]: - if not kernel["LocalWriteUseSgprB"] and self.states.mxsb.numVgprLocalWriteAddr > 0: + if not kernel["LocalWriteUseSgprMXSB"] and self.states.mxsb.numVgprLocalWriteAddr > 0: module.add(RegSet("v", "vgprLocalWriteAddrMXSB", \ self.states.mxsb.startVgprLocalWriteAddr)) if self.states.mxsb.numVgprLocalWriteAddr > 1: @@ -1275,9 +1285,17 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module: if self.states.a.numVgprLocalReadAddr > 0: module.add(RegSet("v", "vgprLocalReadAddrOrigA", \ self.states.a.startVgprLocalReadAddrOrig)) + if kernel["ProblemType"]["MXBlockA"]: + if self.states.mxsa.numVgprLocalReadAddr > 0: + module.add(RegSet("v", "vgprLocalReadAddrOrigMXSA", \ + self.states.mxsa.startVgprLocalReadAddrOrig)) if self.states.b.numVgprLocalReadAddr > 0: module.add(RegSet("v", "vgprLocalReadAddrOrigB", \ self.states.b.startVgprLocalReadAddrOrig)) + if kernel["ProblemType"]["MXBlockB"]: + if self.states.mxsb.numVgprLocalReadAddr > 0: + module.add(RegSet("v", "vgprLocalReadAddrOrigMXSB", \ + self.states.mxsb.startVgprLocalReadAddrOrig)) if self.states.m.numVgprLocalReadAddr > 0: module.add(RegSet("v", "vgprLocalReadAddrMetadata", \ self.states.m.startVgprLocalReadAddr)) @@ -1296,6 +1314,14 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module: if self.states.m.numVgprLocalReadSwapAddr > 0: module.add(RegSet("v", "vgprLocalReadSwapAddrMetadata", \ self.states.m.startVgprLocalReadSwapAddr)) + if kernel["ProblemType"]["MXBlockA"]: + if self.states.mxsa.numVgprLocalReadSwapAddr > 0: + module.add(RegSet("v", "vgprLocalReadSwapAddrMXSA", \ + self.states.mxsa.startVgprLocalReadSwapAddr)) + if kernel["ProblemType"]["MXBlockB"]: + if self.states.mxsb.numVgprLocalReadSwapAddr > 0: + module.add(RegSet("v", "vgprLocalReadSwapAddrMXSB", \ + self.states.mxsb.startVgprLocalReadSwapAddr)) if self.states.a.numVgprLocalWriteSwapAddr > 0: module.add(RegSet("v", "vgprLocalWriteSwapAddrA", \ self.states.a.startVgprLocalWriteSwapAddr)) @@ -1305,6 +1331,14 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module: if self.states.m.numVgprLocalWriteSwapAddr > 0: module.add(RegSet("v", "vgprLocalWriteSwapAddrMetadata", \ self.states.m.startVgprLocalWriteSwapAddr)) + if kernel["ProblemType"]["MXBlockA"]: + if self.states.mxsa.numVgprLocalWriteSwapAddr > 0: + module.add(RegSet("v", "vgprLocalWriteSwapAddrMXSA", \ + self.states.mxsa.startVgprLocalWriteSwapAddr)) + if kernel["ProblemType"]["MXBlockB"]: + if self.states.mxsb.numVgprLocalWriteSwapAddr > 0: + module.add(RegSet("v", "vgprLocalWriteSwapAddrMXSB", \ + self.states.mxsb.startVgprLocalWriteSwapAddr)) if kernel["ProblemType"]["OutputAmaxD"]: module.add(RegSet("v", "vgprAmaxOut", self.startVgprAmaxOut)) @@ -1445,7 +1479,7 @@ def macroAndSet(self, kernel, tPA, tPB) -> Module: module.add(ValueSet("MTOffset", reductionOffsetLow32, format=1)) module.add(ValueSet("MTOffsetH32", reductionOffsetHigh32, format=1)) - if self.states.IncLdsBufSwitch: + if self.states.IncLdsBufSwitch or self.states.useCommonSgprSwap: module.addComment0("%d LDS Blocks for PGR %d"%(self.states.numLDSBlk, kernel["PrefetchGlobalRead"])) module.add(ValueSet("LdsOneBlockSize", kernel["LdsOffsetA_Blk"])) module.add(ValueSet("LdsBlockEndSize", kernel["LdsOffsetA_Blk"] * self.states.numLDSBlk)) @@ -1894,19 +1928,19 @@ def localReadAddresses(self, kernel, tPA, tPB, tPM): module.add(self.lraDeclareAddresses(kernel, tPB["MX"])) module.addComment1("local read addresses: declare addresses b") module.add(self.lraDeclareAddresses(kernel, tPB)) - if self.states.a.numVgprLocalReadAddr > 0: module.add(self.lraSwapAddressesForDTLPad(kernel, tPA)) module.add(self.lraAddressesInitFor3LDSBlk(kernel, tPA, False, True)) - if self.states.mxsa.numVgprLocalReadAddr > 0: module.add(self.lraSwapAddressesForDTLPad(kernel, tPA["MX"])) + module.add(self.lraAddressesInitFor3LDSBlk(kernel, tPA["MX"], False, True)) if self.states.b.numVgprLocalReadAddr > 0: module.add(self.lraSwapAddressesForDTLPad(kernel, tPB)) module.add(self.lraAddressesInitFor3LDSBlk(kernel, tPB, False, True)) if self.states.mxsb.numVgprLocalReadAddr > 0: module.add(self.lraSwapAddressesForDTLPad(kernel, tPB["MX"])) + module.add(self.lraAddressesInitFor3LDSBlk(kernel, tPB["MX"], False, True)) if self.states.m.numVgprLocalReadAddr > 0: module.add(self.lraSwapAddressesForDTLPad(kernel, tPM)) @@ -3337,8 +3371,8 @@ def graUnrollOffsets(self, kernel, tP): v = tP["gpr"]["unrollOffsets"] strideIdx = (tP["lsp"] if tP["tlu"] else tP["lsc"]) stride = kernel[strideIdx] - # DTV + LSU case, we need to divide stride by LSU if (tc in ("A", "B", "MXSA", "MXSB")) and kernel["DirectToVgpr%s"%tc] and kernel["LocalSplitU"] > 1: + # DTV + LSU case, we need to divide stride by LSU stride = stride // kernel["LocalSplitU"] prevStride = 0 totalStride = 0 @@ -3797,7 +3831,7 @@ def graFinalOffsetsSingleLoopGNLC(self, kernel, tP, tc, margin = -1): self.vgprPool.checkIn(tmpv4) stride = "Strides%s"%(tc) - module.add(VLShiftLeftB32(dst=vgpr(grov), shiftHex=log2(kernel["GlobalReadVectorWidth%c"%tc]), src=vgpr(tmpv2))) + module.add(VLShiftLeftB32(dst=vgpr(grov), shiftHex=log2(kernel["GlobalReadVectorWidth%s"%tc]), src=vgpr(tmpv2))) module.add(VMulLOU32(dst=vgpr(tmpv), src0=sgpr(stride), src1=vgpr(tmpv))) if kernel["EdgeType"] == "ShiftPtr" and kernel["ProblemType"]["TLU%s"%tc] == 1: @@ -3805,8 +3839,9 @@ def graFinalOffsetsSingleLoopGNLC(self, kernel, tP, tc, margin = -1): module.add(VAddU32(dst=vgpr(grov), src0=vgpr(tmpv), src1=vgpr(grov), \ comment="final" )) - module.add(VLShiftLeftB32(dst=vgpr(grov), shiftHex=log2(tP["bpeGR"]), src=vgpr(grov))) - module.add(VAddU32(dst=vgpr(grov), src0=int(self.states.srdShiftLeft[tc] * tP["bpeGR"]) , src1=vgpr(grov), \ + module.add(vectorMultiplyBpe(grov, grov, tP["bpeGR"])) + ptrshift = int(self.states.srdShiftLeft[tc] * tP["bpeGR"]) + module.add(VAddU32(dst=vgpr(grov), src0=ptrshift , src1=vgpr(grov), \ comment="ptr-shift" )) self.vgprPool.checkIn(tmpv) @@ -3898,11 +3933,11 @@ def graFinalOffsetsSingleLoop(self, kernel, tP, tc, tmp, graIdx, perp, sPerp, pa module.add(SMovB32(dst=sgpr(tmpSgpr), src=self.buff_load_inst_offset_max)) module.add(VAddU32(dst=vgpr(groVgpr), src0=vgpr(groVgpr), src1=sgpr(tmpSgpr), comment="shift for UseInstOffsetForGRO")) - ldsInc = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeGR"] + ldsInc = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%s"%tc] * tP["bpeGR"] if kernel["LdsBlockSizePerPad%s"%tc] != 0: ldsInc += (ldsInc // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeGR"] else: - padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * self.states.bpr + padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * self.states.bpr ldsInc += (ldsInc // padInterval) * kernel["LdsPad%s"%tc] * tP["bpeGR"] # buffer_load only support 12 bit instruction offset @@ -3952,17 +3987,17 @@ def computeScalarGroImpl(scalarGro): # Using offsets so GRO holds a byte offset not an element offset # So scale here before comparison: - module.add(scalarMultiplyBpe(scalarGro, scalarGro, tP["bpeGR"])) + module.add(scalarMultiplyBpe(scalarGro, scalarGro, float(tP["bpeGR"]))) if kernel["DirectToLds%s"%tc] and kernel["UseInstOffsetForGRO"]: # add room for instruction offset module.add(SAddU32(dst=sgpr(scalarGro), src0=sgpr(scalarGro), src1=self.buff_load_inst_offset_max, comment="shift for UseInstOffsetForGRO")) - ldsInc = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%c"%tc] * tP["bpeGR"] + ldsInc = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * kernel["GlobalReadVectorWidth%s"%tc] * tP["bpeGR"] if kernel["LdsBlockSizePerPad%s"%tc] != 0: ldsInc += (ldsInc // kernel["LdsBlockSizePerPad%s"%tc]) * kernel["LdsPad%s"%tc] * tP["bpeGR"] else: - padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%c"%tc] else kernel["NumThreads"]) * self.states.bpr + padInterval = (self.states.kernel["WavefrontSize"] if kernel["WaveSeparateGlobalRead%s"%tc] else kernel["NumThreads"]) * self.states.bpr ldsInc += (ldsInc // padInterval) * kernel["LdsPad%s"%tc] * tP["bpeGR"] # buffer_load only support 12 bit instruction offset @@ -4103,11 +4138,12 @@ def computeMetaDataSrd(self, kernel, tP, tc, indices): def computeLoadSrd(self, kernel, tP, tc, indices, bpe): module = Module("computeLoadSrd") moduleLoadGeneralBatch = Module("computeLoadSrd-GeneralBatch") - moduleLoadStridedBatch = Module("computeLoadSrd-StridedBatch") - with self.allocTmpSgpr(2 + 2 + (0 if self.states.use64bShadowLimit else 2)) as tmpSgprInfo: + moduleLoadStridedBatch = Module("computeLoadSrd-StridedBatch") + use64bShadowLimit = self.states.use64bShadowLimitMX if tc in ["MXSA", "MXSB"] else self.states.use64bShadowLimit + with self.allocTmpSgpr(2 + 2 + (0 if use64bShadowLimit else 2)) as tmpSgprInfo: stmp = tmpSgprInfo.idx tileStart = stmp+2 - if self.states.use64bShadowLimit: + if use64bShadowLimit: tensor2dSize0 = "ShadowLimit%s+0"%tc tensor2dSize1 = "ShadowLimit%s+1"%tc else: @@ -4195,7 +4231,7 @@ def computeLoadSrd(self, kernel, tP, tc, indices, bpe): module.add(SMovB64(dst=sgpr(tileStart, 2), src=0, comment="set default tileStart")) #Calculate tensor 2d size - if self.states.use64bShadowLimit or ((not self.states.use64bShadowLimit) and tensor2dSize0 % 2 == 0): + if use64bShadowLimit or ((not use64bShadowLimit) and tensor2dSize0 % 2 == 0): module.add(SMovB64(dst=sgpr(tensor2dSize0, 2), src=0x1, comment="Init tensor size")) else: module.add(SMovB32(dst=sgpr(tensor2dSize0), src=0x1, comment="Init tensor size")) @@ -4219,14 +4255,18 @@ def computeLoadSrd(self, kernel, tP, tc, indices, bpe): divider = 8 if tP["isM"] else 2 module.add(SLShiftRightB32(dst=sgpr(stmp), src=size, shiftHex=hex(int(log(divider,2))), comment="(size/%u)"%divider)) module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=0x1, comment="(size/%u-1)"%divider)) - elif tc == "MXSA": - mxBlock = kernel["ProblemType"]["MXBlockA"] - module.add(SLShiftRightB32(dst=sgpr(stmp), src=size, shiftHex=log2(mxBlock), comment="(size/%d-1)" %mxBlock)) - module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=0x1, comment="(size/%d-1)" %mxBlock)) - elif tc == "MXSB": - mxBlock = kernel["ProblemType"]["MXBlockB"] - module.add(SLShiftRightB32(dst=sgpr(stmp), src=size, shiftHex=log2(mxBlock), comment="(size/%d-1)" %mxBlock)) - module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=0x1, comment="(size/%d-1)" %mxBlock)) + elif tc in ("MXSA", "MXSB"): + mxBlock = kernel["ProblemType"]["MXBlockA"] if tc == "MXSA" else kernel["ProblemType"]["MXBlockB"] + if mxBlock > 0: + if kernel["AssertSummationElementMultiple"] % mxBlock != 0: + module.add(SAddU32(dst=sgpr(stmp), src0=size, src1=(mxBlock-1), comment="(size/%d-1)" %mxBlock)) + src0 = sgpr(stmp) + else: + src0 = size + module.add(SLShiftRightB32(dst=sgpr(stmp), src=src0, shiftHex=log2(mxBlock), comment="(size/%d-1)" %mxBlock)) + module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=0x1, comment="(size/%d-1)" %mxBlock)) + else: + module.add(SSubU32(dst=sgpr(stmp), src0=size, src1=0x1, comment="(size-1)")) elif tP["isSwizzled"]: module.addModuleAsFlatItems(self.alignTo(stmp, "SizeL", tP["swizzleK"])) module.add(SSubU32(dst=sgpr(stmp), src0=sgpr(stmp), src1=1, comment="SWZ-%s align: (sizeL-1)"%tc)) @@ -4249,7 +4289,7 @@ def computeLoadSrd(self, kernel, tP, tc, indices, bpe): module.add(SAddU32(dst=sgpr(tensor2dSize0), src0=sgpr(tensor2dSize0), src1=sgpr(stmp+0), comment="sum tensor size")) module.add(SAddCU32(dst=sgpr(tensor2dSize1), src0=sgpr(tensor2dSize1), src1=sgpr(stmp+1), comment="sum tensor size")) - if self.states.use64bShadowLimit: + if use64bShadowLimit: limitTmp0 = "ShadowLimit%s+0"%tc limitTmp1 = "ShadowLimit%s+1"%tc else: @@ -4259,7 +4299,7 @@ def computeLoadSrd(self, kernel, tP, tc, indices, bpe): module.add(SSubU32(dst=sgpr(limitTmp0), src0=sgpr(tensor2dSize0), src1=sgpr(tileStart+0), comment="sub tileStart")) module.add(SSubBU32(dst=sgpr(limitTmp1), src0=sgpr(tensor2dSize1), src1=sgpr(tileStart+1), comment="sub tileStart")) - if self.states.use64bShadowLimit: + if use64bShadowLimit: # Set initial buffer limit # if the limit is >64bit, incrementSrd decrements the shadow as the SRD increments, # and when we get within 32-bit we start to step down the SRD @@ -4279,7 +4319,7 @@ def computeLoadSrd(self, kernel, tP, tc, indices, bpe): module.add(self.shiftSrd(tc)) else: # put limit directly into SRD: - module.add(scalarMultiplyBpe("Srd%s+2"%tc, stmp, tP["bpeGR"], comment="Set limit to use bytes")) + module.add(scalarMultiplyBpe("Srd%s+2"%tc, stmp, float(tP["bpeGR"]), comment="Set limit to use bytes")) module.add(SAddU32(dst=sgpr("Srd%s+2"%tc), src0=sgpr("Srd%s+2"%tc), src1=prePad, comment="extend limit for pre-pad")) # Apply any high-order address components to the tileStart and eventually the SRD - batch idx for batched gemm @@ -4665,7 +4705,7 @@ def lwaTileAssignment(self, kernel, tP): if lrvwOther >= 2 and (not tluOther) and tP["tlu"]: # DirectToVgpr + LocalReadVectorWidth>=2 case, multiply qReg by lrvwOther - dtvKInterval = lrvwOther + dtvKInterval = int(lrvwOther) if tluOther and tP["tlu"]: # DirectToVgpr + both TLU case, multiply qReg by kernel["MIInputPerThread"] dtvKInterval = kernel["MIInputPerThread"] @@ -4782,7 +4822,7 @@ def lwaFirstOffset(self, kernel, tP): "padding %u per block %u" % (int(kernel["LdsPad%s"%tc] * tP["bpeDS"]), kernel["LdsBlockSizePerPad%s"%tc]))) with self.allocTmpSgpr(1) as tmpSgprInfo: module.add(vectorStaticMultiplyAdd(vgpr(destVgpr), vgpr(tmpVgpr), int(kernel["LdsPad%s"%tc] * tP["bpeDS"]), vgpr(destVgpr), tmpSgprInfo, \ - "padding %u per block %u" % (kernel["LdsPad%s"%tc] * tP["bpeDS"], kernel["LdsBlockSizePerPad%s"%tc]))) + "padding %u per block %u" % (int(kernel["LdsPad%s"%tc] * tP["bpeDS"]), kernel["LdsBlockSizePerPad%s"%tc]))) self.vgprPool.checkIn(tmpVgpr) if tc in ("B", "MXSA", "MXSB", "Metadata"): @@ -4827,7 +4867,7 @@ def lwaFirstOffset(self, kernel, tP): else: validBytesPerLoad *= (kernel["MacroTile%s"%tc] // kernel["NumLoadsPerpendicular%s"%tc] // (kernel["NumThreads"] // kernel["WavefrontSize"])) - isDTVAB = ((tP["isA"] or tP["isB"]) and kernel["DirectToVgpr%s"%tc]) + isDTVAB = ((tP["isA"] or tP["isB"] or tP["isMXSA"] or tP["isMXSB"]) and kernel["DirectToVgpr%s"%tc]) # For GNLC we don't need to check these asserts since num threads coalesced may not divide num threads assert (validBytesPerLoad <= maxBytesPerLoad) or isDTVAB or kernel["UseGeneralizedNLCOne%s"%tc] assert (kernel[tP["lsc"]] * kernel[tP["lsp"]] % tP["glvw"] == 0) or isDTVAB or kernel["UseGeneralizedNLCOne%s"%tc] @@ -4870,20 +4910,25 @@ def lwaFirstOffset(self, kernel, tP): src=vgpr(tmpv), \ comment="Copy lds write address VGPR to SGPR")) module.add(SMulI32(dst=sgpr("LocalWriteAddr%s"%tc), src0=sgpr("LocalWriteAddr%s"%tc), \ - src1=int((kernel["WavefrontSize"] * kernel["GlobalReadVectorWidth%c"%tc]+kernel["LdsPad%s"%tc]) * tP["bpeGR"]) )) + src1=int((kernel["WavefrontSize"] * kernel["GlobalReadVectorWidth%s"%tc]+kernel["LdsPad%s"%tc]) * tP["bpeGR"]) )) if tc == 'B': module.add(SAddU32(dst=sgpr("LocalWriteAddr%s"%tc), src0=sgpr("LocalWriteAddr%s"%tc), \ src1=kernel["LdsOffsetB"] )) self.vgprPool.checkIn(tmpv) self.vgprPool.checkIn(destVgpr) - if kernel["StoreSwapAddr"]: + if kernel["StoreSwapAddr"] or self.states.useCommonSgprSwap: if kernel["LocalWriteUseSgpr%s"%tc]: - # needed for the VReadfirstlaneB32 in the prior code block - if self.states.archCaps["CrosslaneWait"]: - module.add(SNop(waitState=0, comment="1 wait states")) - module.add(SAddU32(dst=sgpr("Swap%s"%tc), src0=sgpr("LocalWriteAddr%s"%tc), src1=kernel["LdsOffsetA_Blk"], comment="Calculate starting lds addr of second buffer")) - module.add(SXorB32(dst=sgpr("Swap%s"%tc), src0=sgpr("Swap%s"%tc), src1=sgpr("LocalWriteAddr%s"%tc), comment="xor both lds buffer offsets to enable swapping")) + if self.states.useCommonSgprSwap: + # Need only once. Generate the code for "A" only + if tc == "A": + module.add(SMovB32(dst=sgpr("SwapCommon"), src=0, comment="Initialize SwapCommon")) + else: + # needed for the VReadfirstlaneB32 in the prior code block + if self.states.archCaps["CrosslaneWait"]: + module.add(SNop(waitState=0, comment="1 wait states")) + module.add(SAddU32(dst=sgpr("Swap%s"%tc), src0=sgpr("LocalWriteAddr%s"%tc), src1=kernel["LdsOffsetA_Blk"], comment="Calculate starting lds addr of second buffer")) + module.add(SXorB32(dst=sgpr("Swap%s"%tc), src0=sgpr("Swap%s"%tc), src1=sgpr("LocalWriteAddr%s"%tc), comment="xor both lds buffer offsets to enable swapping")) else: module.add(VAddU32(dst=vgpr("LocalWriteSwapAddr%s"%tc), src0=kernel["LdsOffsetA_Blk"], src1=vgpr("LocalWriteAddr%s"%tc), \ comment="starting lds addr of second buffer" )) @@ -4986,7 +5031,7 @@ def lraFinalOffset(self, kernel, tP): "Final Offset: padding %u per block %u" % (int(kernel["LdsPad%s"%tc] * tP["bpeDS"]), kernel["LdsBlockSizePerPad%s"%tc]))) with self.allocTmpSgpr(1) as tmpSgprInfo: module.add(vectorStaticMultiplyAdd(vgpr("LocalReadAddr%s"%tc), vgpr(rReg), int(kernel["LdsPad%s"%tc] * tP["bpeDS"]), vgpr("LocalReadAddr%s"%tc), tmpSgprInfo, \ - "Final Offset: padding %u per block %u" % (kernel["LdsPad%s"%tc] * tP["bpeDS"], kernel["LdsBlockSizePerPad%s"%tc]))) + "Final Offset: padding %u per block %u" % (int(kernel["LdsPad%s"%tc] * tP["bpeDS"]), kernel["LdsBlockSizePerPad%s"%tc]))) # release resources self.vgprPool.checkIn(tmpVgpr) @@ -5045,7 +5090,7 @@ def lraFinalOffset(self, kernel, tP): module.add(vectorStaticDivide(rReg, "LocalReadAddr%s"%tc, kernel["LdsBlockSizePerPad%s"%tc], tmpSgpr, \ "Final Offset: padding %u per block %u" % (kernel["LdsPad%s"%tc], kernel["LdsBlockSizePerPad%s"%tc]))) module.add(vectorStaticMultiplyAdd(vgpr("LocalReadAddr%s"%tc), vgpr(rReg), int(kernel["LdsPad%s"%tc] * tP["bpe"]), vgpr("LocalReadAddr%s"%tc), tmpSgprInfo, \ - "Final Offset: padding %u per block %u" % (kernel["LdsPad%s"%tc] * tP["bpeDS"], kernel["LdsBlockSizePerPad%s"%tc]))) + "Final Offset: padding %u per block %u" % (int(kernel["LdsPad%s"%tc] * tP["bpeDS"]), kernel["LdsBlockSizePerPad%s"%tc]))) self.vgprPool.checkIn(rReg) return module @@ -5174,17 +5219,21 @@ def lwaInitAddressesForDTLTailLoop(self, kernel, tP): tc = tP["tensorChar"] waveSize = kernel["WavefrontSize"] - if kernel["DirectToLds%c"%tc] and kernel["NonDTLTailLoop%s"%tc]: - module.addComment0("Set local write offsets for %c to be same as DTL %uB load"%(tc, kernel["GlobalReadVectorWidth%c"%tc] * tP["bpe"])) - module.add(VAndB32(dst=vgpr("LocalWriteAddr%c"%tc), src0=(waveSize - 1), src1=vgpr("Serial"), comment="Serial % wavesize")) - module.add(VLShiftLeftB32(dst=vgpr("LocalWriteAddr%c"%tc), shiftHex=hex(log2(kernel["GlobalReadVectorWidth%c"%tc] * tP["bpe"])), src=vgpr("LocalWriteAddr%c"%tc), comment="")) - module.add(VAddU32(dst=vgpr("LocalWriteAddr%s"%tc), src0=sgpr("LocalWriteAddr%c"%tc), src1=vgpr("LocalWriteAddr%s"%tc), \ + if kernel["DirectToLds%s"%tc] and kernel["NonDTLTailLoop%s"%tc]: + module.addComment0("Set local write offsets for %s to be same as DTL %uB load"%(tc, kernel["GlobalReadVectorWidth%s"%tc] * tP["bpe"])) + module.add(VAndB32(dst=vgpr("LocalWriteAddr%s"%tc), src0=(waveSize - 1), src1=vgpr("Serial"), comment="Serial % wavesize")) + module.add(VLShiftLeftB32(dst=vgpr("LocalWriteAddr%s"%tc), shiftHex=hex(log2(kernel["GlobalReadVectorWidth%s"%tc] * tP["bpe"])), src=vgpr("LocalWriteAddr%s"%tc), comment="")) + module.add(VAddU32(dst=vgpr("LocalWriteAddr%s"%tc), src0=sgpr("LocalWriteAddr%s"%tc), src1=vgpr("LocalWriteAddr%s"%tc), \ comment="" )) numLw = 0 if tP["isA"]: numLw = self.states.a.numVgprLocalWriteAddrTailLoop elif tP["isB"]: numLw = self.states.b.numVgprLocalWriteAddrTailLoop + elif tc == "MXSA": + numLw = self.states.mxsa.numVgprLocalWriteAddrTailLoop + elif tc == "MXSB": + numLw = self.states.mxsb.numVgprLocalWriteAddrTailLoop for i in range(1,numLw): module.add(VAddU32(dst=vgpr("LocalWriteAddr%s + %s"%(tc, i)), src0=(i * 0x10000), \ src1=vgpr("LocalWriteAddr%s"%tc), comment="" )) @@ -5721,7 +5770,7 @@ def tailLoopAllocValuVgpr(self, kernel, tensorParametersA, tensorParametersB, te if numVgprCvtTemp > 0: imodMisc.add(RegSet("v", "vgprCvtTemp", vgprBaseMisc + numVgprPackTemp)) - return ([vgprBaseA, imodA],[vgprBaseB, imodB], [vgprBaseM, imodM],[vgprBaseMisc, imodMisc]) + return ([vgprBaseA, imodA],[vgprBaseB, imodB],[vgprBaseM, imodM],[vgprBaseMisc, imodMisc]) def tailLoopAllocG2LVgpr(self, kernel): imod = Module("tailLoopAllocG2LVgpr") @@ -5744,15 +5793,14 @@ def tailLoopAllocG2LVgpr(self, kernel): numG2LB = self.states.b.numVgprG2LTailloopAllocated if kernel["ProblemType"]["MXBlockA"] and (not kernel["DirectToVgprMXSA"]): if ("ULSGRODoubleG2L" in kernel) and kernel["ULSGRODoubleG2L"] == 1: - numG2LMXSA = self.states.mxsa.numVgprG2LAllocated*2 + numG2LMXSA = self.states.mxsa.numVgprG2LTailloopAllocated*2 else: - numG2LMXSA = self.states.mxsa.numVgprG2LAllocated + numG2LMXSA = self.states.mxsa.numVgprG2LTailloopAllocated if kernel["ProblemType"]["MXBlockB"] and (not kernel["DirectToVgprMXSB"]): if ("ULSGRODoubleG2L" in kernel) and kernel["ULSGRODoubleG2L"] == 1: - numG2LMXSB = self.states.mxsb.numVgprG2LAllocated*2 + numG2LMXSB = self.states.mxsb.numVgprG2LTailloopAllocated*2 else: - numG2LMXSB = self.states.mxsb.numVgprG2LAllocated - + numG2LMXSB = self.states.mxsb.numVgprG2LTailloopAllocated if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: numG2LMetadata = self.states.m.numVgprG2LAllocated @@ -5764,8 +5812,7 @@ def tailLoopAllocG2LVgpr(self, kernel): if numG2LA + numG2LB + numG2LMXSA + numG2LMXSB + numG2LMetadata > 0: vgprBase = self.vgprPool.checkOutAligned(numG2LA + numG2LB + numG2LMXSA + numG2LMXSB + numG2LMetadata, 2) - imod.addComment0("Check out VGPR (numG2LA,numG2LB,numG2LMetadata) = (%d,%d,%d)"%(numG2LA,numG2LB,numG2LMetadata)) - + imod.addComment0("Check out VGPR (numG2LA,numG2LB,numG2LMXSA,numG2LMXSB,numG2LMetadata) = (%d,%d,%d,%d,%d)"%(numG2LA,numG2LB,numG2LMXSA,numG2LMXSB,numG2LMetadata)) if numG2LA > 0: imod.add(RegSet("v", "vgprG2LA_BASE", vgprBase)) if kernel["DirectToLdsA"] and kernel["NonDTLTailLoopA"]: @@ -5782,11 +5829,17 @@ def tailLoopAllocG2LVgpr(self, kernel): if numG2LMXSA > 0: imod.add(RegSet("v", "vgprG2LMXSA_BASE", vgprBase + numG2LA + numG2LB)) - imod.add(self.moduleVgprMacroG2LMXSA) + if kernel["DirectToLdsMXSA"] and kernel["NonDTLTailLoopMXSA"]: + imod.add(RegSet("v", "vgprG2LMXSA", "vgprG2LMXSA_BASE", 0)) + else: + imod.add(self.moduleVgprMacroG2LMXSA) if numG2LMXSB > 0: imod.add(RegSet("v", "vgprG2LMXSB_BASE", vgprBase + numG2LA + numG2LB + numG2LMXSA)) - imod.add(self.moduleVgprMacroG2LMXSB) + if kernel["DirectToLdsMXSB"] and kernel["NonDTLTailLoopMXSB"]: + imod.add(RegSet("v", "vgprG2LMXSB", "vgprG2LMXSB_BASE", 0)) + else: + imod.add(self.moduleVgprMacroG2LMXSB) if numG2LMetadata > 0: imod.add(RegSet("v", "vgprG2LMetadata", vgprBase + numG2LA + numG2LB + numG2LMXSA + numG2LMXSB)) @@ -5798,14 +5851,20 @@ def tailLoopAllocDTLLWVgpr(self, kernel): vgprBase = -1 numLWA = self.states.a.numVgprLocalWriteAddrTailLoop numLWB = self.states.b.numVgprLocalWriteAddrTailLoop + numLWMXSA = self.states.mxsa.numVgprLocalWriteAddrTailLoop if kernel["ProblemType"]["MXBlockA"] else 0 + numLWMXSB = self.states.mxsb.numVgprLocalWriteAddrTailLoop if kernel["ProblemType"]["MXBlockB"] else 0 - if numLWA + numLWB > 0: - vgprBase = self.vgprPool.checkOutAligned(numLWA + numLWB, 2) - imod.addComment0("Check out VGPR (numLWA,numLWB) = (%d,%d)"%(numLWA,numLWB)) + if numLWA + numLWB + numLWMXSA + numLWMXSB > 0: + vgprBase = self.vgprPool.checkOutAligned(numLWA + numLWB + numLWMXSA + numLWMXSB, 2) + imod.addComment0("Check out VGPR (numLWA,numLWB,numLWMXSA,numLWMXSB) = (%d,%d,%d,%d)"%(numLWA,numLWB,numLWMXSA,numLWMXSB)) if numLWA > 0 and (kernel["DirectToLdsA"] and kernel["NonDTLTailLoopA"]): imod.add(RegSet("v", "vgprLocalWriteAddrA", vgprBase)) + if numLWMXSA > 0 and (kernel["DirectToLdsMXSA"] and kernel["NonDTLTailLoopMXSA"]): + imod.add(RegSet("v", "vgprLocalWriteAddrMXSA", vgprBase + numLWA)) if numLWB > 0 and (kernel["DirectToLdsB"] and kernel["NonDTLTailLoopB"]): - imod.add(RegSet("v", "vgprLocalWriteAddrB", vgprBase + numLWA)) + imod.add(RegSet("v", "vgprLocalWriteAddrB", vgprBase + numLWA + numLWMXSA)) + if numLWMXSB > 0 and (kernel["DirectToLdsMXSB"] and kernel["NonDTLTailLoopMXSB"]): + imod.add(RegSet("v", "vgprLocalWriteAddrMXSB", vgprBase + numLWA + numLWMXSA + numLWB)) return imod, vgprBase @@ -5907,9 +5966,9 @@ def tailLoopGlobalRead(self, kernel, tPA, tPB, doA, doB): imod = Module("tailLoopGlobalRead") glvwWorkaround = 8 * kernel["ProblemType"]["DataType"].numRegisters() - dataTypeA = kernel["ProblemType"]["DataType"] if tPA["glvw"] < glvwWorkaround else \ + dataTypeA = kernel["ProblemType"]["MacDataTypeA"] if tPA["glvw"] < glvwWorkaround else \ kernel["ProblemType"]["DataTypeA"] - dataTypeB = kernel["ProblemType"]["DataType"] if tPB["glvw"] < glvwWorkaround else \ + dataTypeB = kernel["ProblemType"]["MacDataTypeB"] if tPB["glvw"] < glvwWorkaround else \ kernel["ProblemType"]["DataTypeB"] numElementsPerLoadA = -1 numElementsPerLoadB = -1 @@ -7296,13 +7355,15 @@ def fixPreloadOffset(offset, sgpxIdxVec, numStoreSgprToLoad): module.add(RegSet("s", "sgpr"+name, self.sgprs[name])) argOffset = self.argLoader.getOffset() # Backup offset - if kernel["ProblemType"]["UseScaleAB"] == "Scalar" and (kernel["ProblemType"]["DataTypeA"].numRegisters() > kernel["ProblemType"]["MacDataTypeA"].numRegisters() or kernel["ProblemType"]["DataTypeB"].numRegisters() > kernel["ProblemType"]["MacDataTypeB"].numRegisters()): + if kernel["ProblemType"]["UseScaleAB"] == "Scalar" and (kernel["ProblemType"]["DataTypeA"].numRegisters() > kernel["ProblemType"]["MacDataTypeA"].numRegisters() \ + or kernel["ProblemType"]["DataTypeB"].numRegisters() > kernel["ProblemType"]["MacDataTypeB"].numRegisters()): self.argLoader.setOffset(argOffset + (self.states.numStoreSgprToLoad)*4 + (self.states.rpga * self.states.bpr)) # Restore offset else: self.argLoader.setOffset(argOffset + (self.states.numStoreSgprToLoad)*4) # Restore offset numStoreSgprToLoad = self.states.numStoreSgprToLoad2 - if kernel["ProblemType"]["UseScaleAB"] == "Scalar" and (kernel["ProblemType"]["DataTypeA"].numRegisters() > kernel["ProblemType"]["MacDataTypeA"].numRegisters() or kernel["ProblemType"]["DataTypeB"].numRegisters() > kernel["ProblemType"]["MacDataTypeB"].numRegisters()): + if kernel["ProblemType"]["UseScaleAB"] == "Scalar" and (kernel["ProblemType"]["DataTypeA"].numRegisters() > kernel["ProblemType"]["MacDataTypeA"].numRegisters() \ + or kernel["ProblemType"]["DataTypeB"].numRegisters() > kernel["ProblemType"]["MacDataTypeB"].numRegisters()): argOffsettmp = (argOffset + (self.states.numStoreSgprToLoad)*4 + (self.states.rpga * self.states.bpr)) # Restore offset else: argOffsettmp = (argOffset + (self.states.numStoreSgprToLoad)*4) # Restore offset @@ -7388,8 +7449,7 @@ def fixPreloadOffset(offset, sgpxIdxVec, numStoreSgprToLoad): def mfmaIter_waitCount(self, kernel): if self.states.version in [(9,4,2), (9,5,0)]: - # TODO: check if this is correct for all MI versions - dataType = kernel["ProblemType"]["MacDataTypeA"] + dataType = kernel["ProblemType"]["DataType"] miM = kernel["MatrixInstM"] miN = kernel["MatrixInstN"] if dataType.isSingle() or dataType.isHalf() or dataType.isBFloat16(): @@ -7486,7 +7546,7 @@ def generateSrcStrForMFMA(self, kernel, tP, innerUnroll, vregSetIdx, vgprPerInpu def macIter(self, kernel, tPA, tPB, bufferIdx, iuiCount, useMacro, isTail=False): imod = Module("macIter_X%u_I%u"%(bufferIdx, iuiCount)) - # if kernel["ProblemType"]["DataType"].isHalf(): + # if kernel["ProblemType"]["MacDataTypeA"].isHalf(): # # imod.addText(".align32 8, 0xbf800001", "align v_pk_fma") # Align v_pk_fma instructions used in MAC_ blocks # imod.addText(".align32 8, 0xbf800001\n") # Align v_pk_fma instructions used in MAC_ blocks @@ -7728,6 +7788,7 @@ def dataTypeNameAbbrevToInstType(abbrev: str, sourceSwap: bool = False) -> InstT assert("Unsupported data type.") return InstType.INST_NOTYPE + isgfx950 = kernel["ISA"][:2] == (9, 5) miInputTypeA = kernel["ProblemType"]["F32XdlMathOp"] if kernel["EnableF32XdlMathOp"] else kernel["ProblemType"]["MacDataTypeA"] miInputTypeB = kernel["ProblemType"]["F32XdlMathOp"] if kernel["EnableF32XdlMathOp"] else kernel["ProblemType"]["MacDataTypeB"] # calculate constant @@ -7952,6 +8013,14 @@ def findSparseOffset(isA:bool): bk = se + group * vgprPerSet0Group + ti * vgprPerInUnrollA aStr = vgpr(self.generateSrcStrForMFMAshiftK(kernel, tPA, innerUnroll, vregSetIdx, vgprPerInputA, m, u, iui, a, bk=bk), 1) shiftK.add(VCndMaskB32(dst=aStr, src0=aStr, src1=0, src2=sgpr(tmpSgprX2, self.states.laneSGPRCount), comment="set 0 if K_idx >= sizeL")) + # separate code for gfx950 mx + if group == 0 and kernel["ProblemType"]["MXBlockA"] and isgfx950: + for mxsa in range(0, kernel["MIWaveTileMXSA"]): + for iui in range(0, innerUnroll): + for bk in range(0, vgprPerInputMXSA): + mxsaStr_base = self.generateSrcStrForMFMAshiftK(kernel, tPA["MX"], innerUnroll, vregSetIdx, vgprPerInputMXSA, m, u, iui, mxsa, bk=bk) + mxsaStr = vgpr(mxsaStr_base, 1) + shiftK.add(VCndMaskB32(dst=mxsaStr, src0=mxsaStr, src1=0, src2=sgpr(tmpSgprX2, self.states.laneSGPRCount), comment="")) if kernel["ProblemType"]["Sparse"] == 2 and numMIInUnroll//8 >= 1: shiftK.add(vectorStaticRemainder(dummy, kReg_first, "Serial", kernel["WavefrontSize"], tmpVgpr, tmpSgprInfo)) @@ -8040,6 +8109,14 @@ def findSparseOffset(isA:bool): bk = se + group * vgprPerSet0Group + ti * vgprPerInUnrollB bStr = vgpr(self.generateSrcStrForMFMAshiftK(kernel, tPB, innerUnroll, vregSetIdx, vgprPerInputB, m, u, iui, b, bk=bk), 1) shiftK.add(VCndMaskB32(dst=bStr, src0=bStr, src1=0, src2=sgpr(tmpSgprX2, self.states.laneSGPRCount), comment="set 0 if K_idx >= sizeL")) + # separate code for gfx950 mx + if group == 0 and kernel["ProblemType"]["MXBlockB"] and isgfx950: + for mxsb in range(0, kernel["MIWaveTileMXSB"]): + for iui in range(0, innerUnroll): + for bk in range(0, vgprPerInputMXSB): + mxsbStr_base = self.generateSrcStrForMFMAshiftK(kernel, tPB["MX"], innerUnroll, vregSetIdx, vgprPerInputMXSB, m, u, iui, mxsb, bk=bk) + mxsbStr = vgpr(mxsbStr_base, 1) + shiftK.add(VCndMaskB32(dst=mxsbStr, src0=mxsbStr, src1=0, src2=sgpr(tmpSgprX2, self.states.laneSGPRCount), comment="")) # replace elements with 0 for same thread, this conducting shift and mask between numElementsPerRead # Skip when numRegistersIn >= 1 on wmma_v3 with XF32 emulation: each element fills a full register, @@ -8276,7 +8353,7 @@ def findSparseOffset(isA:bool): if vgprPerInUnrollB == 4 and is_wmma_v2: if tmpVgpr2 is not None: self.vgprPool.checkIn(tmpVgpr2) - if kernel["ProblemType"]["MXBlockA"] or kernel["ProblemType"]["MXBlockB"]: + if (kernel["ProblemType"]["MXBlockA"] or kernel["ProblemType"]["MXBlockB"]) and not isgfx950: mxK = kernel["MatrixInstK"] mxBlock = max(kernel["ProblemType"]["MXBlockA"], kernel["ProblemType"]["MXBlockB"]) vgprPerInUnroll = max(vgprPerInputMXSA, vgprPerInputMXSB) @@ -8295,31 +8372,31 @@ def findSparseOffset(isA:bool): else: raise Exception(f"unsupport vgprPerInUnroll {vgprPerInUnroll}") - if kernel["ProblemType"]["MXBlockA"]: + if kernel["ProblemType"]["MXBlockA"] and not isgfx950: for mxsa in range(0, kernel["MIWaveTileMXSA"]): for iui in range(0, innerUnroll): mxsaStr_base = self.generateSrcStrForMFMA(kernel, tPA["MX"], innerUnroll, vregSetIdx, vgprPerInputMXSA, m, u, iui, mxsa) mxsaStr = vgpr(mxsaStr_base, vgprPerInputMXSA) shiftK.add(VShiftLeft(dst=vgpr(mxReg, vgprPerInUnroll), shiftHex=sgpr(tmpSgprX1), src=mxsaStr, comment="")) shiftK.add(VShiftRight(dst=vgpr(mxReg, vgprPerInUnroll), shiftHex=sgpr(tmpSgprX1), src=vgpr(mxReg, vgprPerInUnroll), comment="")) - shiftK.add(VCmpGTI32(dst=sgpr(tmpSgprX2), src0=mxK, src1=sgpr(loopCntSgpr), comment="check K index >= Size L")) + shiftK.add(VCmpGTI32(dst=sgpr(tmpSgprX2, self.states.laneSGPRCount), src0=mxK, src1=sgpr(loopCntSgpr), comment="check K index >= Size L")) for bk in range(0, vgprPerInUnroll): mxsaStr_base = self.generateSrcStrForMFMA(kernel, tPA["MX"], innerUnroll, vregSetIdx, vgprPerInputMXSA, m, u, iui, mxsa, bk=bk) mxsaStr = vgpr(mxsaStr_base, 1) - shiftK.add(VCndMaskB32(dst=mxsaStr, src0=mxsaStr, src1=vgpr(mxReg+bk), src2=sgpr(tmpSgprX2), comment="")) + shiftK.add(VCndMaskB32(dst=mxsaStr, src0=mxsaStr, src1=vgpr(mxReg+bk), src2=sgpr(tmpSgprX2, self.states.laneSGPRCount), comment="")) - if kernel["ProblemType"]["MXBlockB"]: + if kernel["ProblemType"]["MXBlockB"] and not isgfx950: for mxsb in range(0, kernel["MIWaveTileMXSB"]): for iui in range(0, innerUnroll): mxsbStr_base = self.generateSrcStrForMFMA(kernel, tPB["MX"], innerUnroll, vregSetIdx, vgprPerInputMXSB, m, u, iui, mxsb) mxsbStr = vgpr(mxsbStr_base, vgprPerInputMXSB) shiftK.add(VShiftLeft(dst=vgpr(mxReg, vgprPerInUnroll), shiftHex=sgpr(tmpSgprX1), src=mxsbStr, comment="")) shiftK.add(VShiftRight(dst=vgpr(mxReg, vgprPerInUnroll), shiftHex=sgpr(tmpSgprX1), src=vgpr(mxReg, vgprPerInUnroll), comment="")) - shiftK.add(VCmpGTI32(dst=sgpr(tmpSgprX2), src0=mxK, src1=sgpr(loopCntSgpr), comment="check K index >= Size L")) + shiftK.add(VCmpGTI32(dst=sgpr(tmpSgprX2, self.states.laneSGPRCount), src0=mxK, src1=sgpr(loopCntSgpr), comment="check K index >= Size L")) for bk in range(0, vgprPerInUnroll): mxsbStr_base = self.generateSrcStrForMFMA(kernel, tPB["MX"], innerUnroll, vregSetIdx, vgprPerInputMXSB, m, u, iui, mxsb, bk=bk) mxsbStr = vgpr(mxsbStr_base, 1) - shiftK.add(VCndMaskB32(dst=mxsbStr, src0=mxsbStr, src1=vgpr(mxReg+bk), src2=sgpr(tmpSgprX2), comment="")) + shiftK.add(VCndMaskB32(dst=mxsbStr, src0=mxsbStr, src1=vgpr(mxReg+bk), src2=sgpr(tmpSgprX2, self.states.laneSGPRCount), comment="")) if kernel["LocalSplitU"] > 1: self.sgprPool.checkIn(loopCntSgpr) @@ -8907,6 +8984,7 @@ def swapBackGprPool(gprPool, savedGprPool, isNotLast): def incrementSrd(self, tP, incLower, incUpper): imod = Module("incrementSrd") tc = tP["tensorChar"] + use64bShadowLimit = self.states.use64bShadowLimitMX if tc in ["MXSA", "MXSB"] else self.states.use64bShadowLimit imod.add(SAddU32(dst=sgpr("Srd%s+0"%(tc)), \ src0=sgpr("Srd%s+0"%(tc)), \ @@ -8919,7 +8997,7 @@ def incrementSrd(self, tP, incLower, incUpper): # also have to move the boundary since we change the base # so less buffers to the edge: - if self.states.use64bShadowLimit: + if use64bShadowLimit: imod.add(SSubU32(dst=sgpr("ShadowLimit%s+0"%tc), \ src0=sgpr("ShadowLimit%s+0"%tc), \ src1=incLower, \ @@ -9002,7 +9080,8 @@ def setTailSrd(self, tP, incLower): comment="gra SRD -= inc(upper)" )) # using Shadow limit here which only works with 64-bit PBC: - assert(self.states.use64bShadowLimit) + use64bShadowLimit = self.states.use64bShadowLimitMX if tc in ["MXSA", "MXSB"] else self.states.use64bShadowLimit + assert(use64bShadowLimit) module.add(SAddU32(dst=sgpr("ShadowLimit%s+0"%tc), \ src0=sgpr("ShadowLimit%s+0"%tc), src1=incLower, \ @@ -9057,6 +9136,22 @@ def globalReadIncrement(self, kernel, imod, loopIdx, tP, prefetchIndex): comment="incUpper <- ?")) imod.addModuleAsFlatItems(self.incrementSrd(tP, sgpr(incLower), sgpr(incUpper))) + if "MX" in tP: + # TODO: DirectToVgpr + tc = tP["MX"]["tensorChar"] + imod.addComment1("global read inc %s loop%s"%(tc, loopChar)) + if prefetchIndex: + imod.add(SAddU32(dst=sgpr(tmpS), src0=self.loopCounter(kernel, self.states.unrollIdx), src1=prefetchIndex, comment="remove pf(%u)"%prefetchIndex)) + imod.add(SCmpEQU32(src0=sgpr(suStr), src1=sgpr(tmpS), comment="Is this wrapIter? (pf)")) + else: + imod.add(SCmpEQU32(src0=self.loopCounter(kernel, self.states.unrollIdx), \ + src1=sgpr(suStr), comment="Is this the wrapIter?")) + imod.add(SCSelectB32(dst=sgpr(incLower), src0=sgpr("WrapU%s+0"%tc), src1=sgpr("GlobalReadIncs%s+%u"%(tc,self.states.unrollIdx)), \ + comment="incLower <- ?")) + imod.add(SCSelectB32(dst=sgpr(incUpper), src0=sgpr("WrapU%s+1"%tc), src1=0, + comment="incUpper <- ?")) + imod.addModuleAsFlatItems(self.incrementSrd(tP["MX"], sgpr(incLower), sgpr(incUpper))) + if kernel["ProblemType"]["Sparse"]: if (kernel["ProblemType"]["Sparse"] == 2 and tP["isB"]) or (kernel["ProblemType"]["Sparse"] == 1 and tP["isA"]) : tc = "Metadata" @@ -9088,7 +9183,29 @@ def globalReadIncrement(self, kernel, imod, loopIdx, tP, prefetchIndex): imod.addModuleAsFlatItems(self.incrementSrd(tP, sgpr("GlobalReadIncs%s+%u"%(tc,loopIdx)), sgpr(incUpper))) else: incUpper = 0 # GRO is positive for loop unroll - imod.addModuleAsFlatItems(self.incrementSrd(tP, sgpr("GlobalReadIncs%s+%u"%(tc,loopIdx)), hex(incUpper))) + srcGRInc = sgpr("GlobalReadIncs%s+%u"%(tc,loopIdx)) + if tc in ('A', 'B'): + useConstSgprGlobalReadIncs = self.states.a.useConstSgprGlobalReadIncs if tc == 'A' else self.states.b.useConstSgprGlobalReadIncs + if useConstSgprGlobalReadIncs: + srcGRInc = "GlobalReadIncs%s"%tc + imod.addModuleAsFlatItems(self.incrementSrd(tP, srcGRInc, hex(incUpper))) + + if "MX" in tP: + tc = tP["MX"]["tensorChar"] + imod.addComment1("global read inc %s loop%s"%(tc, loopChar)) + if loopIdx != self.states.unrollIdx or (tc in ('MXSA', 'MXSB') and kernel["ProblemType"]["IndicesSummation"][self.states.unrollIdx] in kernel["ProblemType"]["MirrorDims%s"%tc]): + with self.allocTmpSgpr(1) as tmpSgprInfo: + incUpper = tmpSgprInfo.idx + # GRO may be negative for other summation if stride-other < stride-unroll or if mirror dim. + imod.add(SAShiftRightI32(dst=sgpr(incUpper), shiftHex=31, src=sgpr("GlobalReadIncs%s+%u"%(tc,loopIdx)), comment="sign-extend")) + imod.addModuleAsFlatItems(self.incrementSrd(tP["MX"], sgpr("GlobalReadIncs%s+%u"%(tc,loopIdx)), sgpr(incUpper))) + else: + incUpper = 0 # GRO is positive for loop unroll + srcGRInc = sgpr("GlobalReadIncs%s+%u"%(tc,loopIdx)) + useConstSgprGlobalReadIncs = self.states.mxsa.useConstSgprGlobalReadIncs if tc == 'MXSA' else self.states.mxsb.useConstSgprGlobalReadIncs + if useConstSgprGlobalReadIncs: + srcGRInc = "GlobalReadIncs%s"%tc + imod.addModuleAsFlatItems(self.incrementSrd(tP["MX"], srcGRInc, hex(incUpper))) if kernel["ProblemType"]["Sparse"]: if (kernel["ProblemType"]["Sparse"] == 2 and tP["isB"]) or (kernel["ProblemType"]["Sparse"] == 1 and tP["isA"]) : @@ -9174,16 +9291,12 @@ def globalReadIncrementAB(self, kernel, tPA, tPB, loopIdx, prefetchIndex): self.globalReadIncrement(kernel, incCodeA, loopIdx, tPA, prefetchIndex) else: incCodeA.add(self.tdmIncrementAB(kernel, tPA)) - if "MX" in tPA and not tdmA: - self.globalReadIncrement(kernel, incCodeA, loopIdx, tPA["MX"], prefetchIndex) incCodeB = imod.add(Module("globalReadIncrementB")) if tPB != None: if not tdmB: self.globalReadIncrement(kernel, incCodeB, loopIdx, tPB, prefetchIndex) else: incCodeB.add(self.tdmIncrementAB(kernel, tPB)) - if "MX" in tPB and not tdmB: - self.globalReadIncrement(kernel, incCodeB, loopIdx, tPB["MX"], prefetchIndex) return imod ############################################################################## @@ -9411,7 +9524,7 @@ def globalReadGuardKBody(tP, optParams = None): r = 0 numLoadVectorComp = int(loadWidth*self.states.bpr//tP["bpeGR"]) - if (kernel["ProblemType"]["DataType%s"%tcDataType].isDouble() or kernel["ProblemType"]["DataType%s"%tcDataType].isSingle())\ + if (dataType.isDouble() or dataType.isSingle())\ and kernel["BufferLoad"]: # adjustment for {d,s}gemm + BufferLoad # use same buffer_load instruction for tail loop as out of tail loop @@ -9910,8 +10023,8 @@ def doTailLoopOpt(self, kernel, tP): loadWidth = tP["globalReadInstruction"].totalWidth oriNumLoadVectorComp = (int(loadWidth*self.states.bpr//tP["bpeGR"])) numElementsPerLoad = 1 - glvwWorkaround = 8 * kernel["ProblemType"]["DataType"].numRegisters() - dataType = kernel["ProblemType"]["DataType"] if tP["glvw"] < glvwWorkaround \ + glvwWorkaround = 8 * kernel["ProblemType"]["MacDataType%s"%tc].numRegisters() + dataType = kernel["ProblemType"]["MacDataType%s"%tc] if tP["glvw"] < glvwWorkaround \ else kernel["ProblemType"]["DataType%s"%tc] if dataType.isHalf() or dataType.isBFloat16(): @@ -10053,6 +10166,9 @@ def directToLdsM0Update(self, kernel, mode, tP, skipWait = False): if self.states.IncLdsBufSwitch: DtldsModule.add(SAddU32(dst=mgpr(0), src0=sgpr("LocalWriteAddr%s"%tc), \ src1=sgpr("LDSBufferWriteInc"), comment="m0 <- LDS write address (base + inc)")) + elif self.states.useCommonSgprSwap: + DtldsModule.add(SAddU32(dst=mgpr(0), src0=sgpr("LocalWriteAddr%s"%tc), \ + src1=sgpr("SwapCommon"), comment="m0 <- LDS write address (base + inc)")) elif kernel["ExpandPointerSwap"]: DtldsModule.add(SAddU32(dst=mgpr(0), src0=sgpr("LocalWriteAddr%s"%tc), \ src1=tP["localWriteSwapByteOffset"], comment="m0 <- LDS write address")) @@ -10564,6 +10680,27 @@ def localWriteAddRound(tc): module.add(SCmpEQU32(src0=sgpr("LDSBufferWriteInc"), src1="LdsBlockEndSize", comment="LDSBufferWriteInc == End ?")) module.add(SCMovB32(dst=sgpr("LDSBufferWriteInc"), src=0, comment="LDSBufferWriteInc loop back to 0")) + def localWriteSwapCommon(tc): + is1st = tc == "A" # so far, A is always first + if is1st: + module.add(SXorB32( + dst=sgpr("SwapCommon"), \ + src0="LdsOneBlockSize", \ + src1=sgpr("SwapCommon"), \ + comment="xor LDS block size")) + + def getSrc0Val(tc): + src0Val = None + if kernel["StoreSwapAddr"]: + if kernel["LocalWriteUseSgpr%s"%tc]: + src0Val = sgpr("Swap%s"%tc) + else: + src0Val = vgpr("LocalWriteSwapAddr%s"%tc) + else: + # Using inlined constants + src0Val = hex(kernel["LdsOffsetA_Blk"]) + return src0Val + if needSwap: #fixme-iui need to use wrapping increment for double or triple buffering: if internalPointerSwap and not kernel["StoreSwapAddr"]: @@ -10573,32 +10710,15 @@ def localWriteAddRound(tc): # 3 or more LDS block case, we do not use xor. Instead, use add and max check for round back # (numLDSBlk>=3 is for DTL (and LocalWriteUseSgpr) only) localWriteAddRound(tc) + elif self.states.useCommonSgprSwap: + # commonSwap case, need only 1 swap + # (generate at "A" only) + localWriteSwapCommon(tc) else: - src0Val = None - if kernel["StoreSwapAddr"]: - if kernel["LocalWriteUseSgpr%s"%tc]: - src0Val = sgpr("Swap%s"%tc) - else: - src0Val = vgpr("LocalWriteSwapAddr%s"%tc) - else: - # Using inlined constants - src0Val = hex(kernel["LdsOffsetA_Blk"]) - - numLwa = 0 - if tP["isA"]: - numLwa = self.states.a.numVgprLocalWriteAddr - elif tP["isB"]: - numLwa = self.states.b.numVgprLocalWriteAddr - elif tP["isMXSA"]: - numLwa = self.states.mxsa.numVgprLocalWriteAddr - elif tP["isMXSB"]: - numLwa = self.states.mxsb.numVgprLocalWriteAddr - else: - raise Exception(f"unsupport tc %s{tc}") - + src0Val = getSrc0Val(tc) numLwa = self.states.a.numVgprLocalWriteAddr if tP["isA"] else self.states.b.numVgprLocalWriteAddr localWriteSwapXOR(tc, src0Val, numLwa) - + # This used to control where to store the metadata if needMetaSwap: if kernel["DirectToVgprSparseMetadata"]: @@ -10611,14 +10731,7 @@ def localWriteAddRound(tc): tPM["localWriteSwapByteOffset"] = 0 if tPM["localWriteSwapByteOffset"] else kernel["LdsOffsetA_Blk"] module.addComment1("(EPS=1) local write swap internal offset -> %u" % tPM["localWriteSwapByteOffset"]) else: - if kernel["StoreSwapAddr"]: - if kernel["LocalWriteUseSgpr%s"%tc]: - src0Val = sgpr("Swap%s"%tc) - else: - src0Val = vgpr("LocalWriteSwapAddr%s"%tc) - else: - # Using inlined constants - src0Val = hex(kernel["LdsOffsetA_Blk"]) + src0Val = getSrc0Val(tc) numLwa = self.states.m.numVgprLocalWriteAddr localWriteSwapXOR(tc, src0Val, numLwa) @@ -10646,7 +10759,8 @@ def localWriteResetOffsets(self, kernel, internalPointerSwap, tP): module = Module("localWriteResetOffsets") if needReset: resetMask = hex(kernel["LdsOffsetA_Blk"]-1 | self.consts.ldsOOB) - if internalPointerSwap or kernel["StoreSwapAddr"]: + useSwapAddr = (internalPointerSwap or (kernel["StoreSwapAddr"] and not self.states.useCommonSgprSwap)) + if useSwapAddr: if internalPointerSwap: tP["localWriteSwapByteOffset"] = 0 else: @@ -10680,6 +10794,14 @@ def localWriteResetOffsets(self, kernel, internalPointerSwap, tP): dst=sgpr("LDSBufferWriteInc"), \ src=0, \ comment="reset incSgpr")) + elif self.states.useCommonSgprSwap: + # commonSwap case, back to 0 + # (generate at "A" only) + if tc == "A": + module.add(SMovB32( + dst=sgpr("SwapCommon"), \ + src=0, \ + comment="reset swapCommon")) else: if kernel["LocalWriteUseSgpr%s"%tc]: module.add(SAndB32( @@ -10742,6 +10864,10 @@ def localWriteResetOffsets(self, kernel, internalPointerSwap, tP): numLwa = self.states.mxsb.numVgprLocalWriteAddr elif tP["isB"]: numLwa = self.states.b.numVgprLocalWriteAddr + elif tP["isMXSA"]: + numLwa = self.states.mxsa.numVgprLocalWriteAddr + elif tP["isMXSB"]: + numLwa = self.states.mxsb.numVgprLocalWriteAddr elif tP["isM"]: numLwa = self.states.m.numVgprLocalWriteAddr else: @@ -11280,9 +11406,8 @@ def localWriteBody(tP): elif (kernel["ProblemType"]["DataType%s"%tc].isHalf() and kernel["ProblemType"]["MacDataType%s"%tc if (tc=='A' or tc=='B') else "DataType"].is8bitFloat()): #HH_F8/B8/F8B8/B8F8_ toF8 = False - if (kernel["ProblemType"]["DataType"].isAnyFloat8() or - (tc == "A" and kernel["ProblemType"]["DataType"].isAnyFloat8BFloat8()) or - (tc == "B" and kernel["ProblemType"]["DataType"].isAnyBFloat8Float8())): + if ((tc == "A" and kernel["ProblemType"]["MacDataTypeA"].isAnyFloat8()) or + (tc == "B" and kernel["ProblemType"]["MacDataTypeB"].isAnyFloat8())): toF8 = True newBlockWidth = (tP["bpeGR"] / tP["bpe"]) * blockWidth @@ -11300,7 +11425,7 @@ def localWriteBody(tP): localWriteCVTCode.add(ECvtF16toF32(dst=vgpr(vgprTmp), src=paramList[0], sel=HighBitSel.HIGH if isHigh16Bits else HighBitSel.LOW, comment="convert to F32")) # ScaleA/B - if kernel["ProblemType"]["UseScaleAB"] == "Scalar" and kernel["ProblemType"]["DataType%s"%tc].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): + if kernel["ProblemType"]["UseScaleAB"] == "Scalar" and kernel["ProblemType"]["DataType%s"%tc].numRegisters() > kernel["ProblemType"]["MacDataType%s"%tc].numRegisters(): localWriteCVTCode.add(VMulF32(dst=vgpr(vgprTmp), src0=vgpr(vgprTmp), src1=sgpr("Scale%s"%tc), comment="Input *= scale %s"%tc)) if kernel["ProblemType"]["StochasticRounding"]: @@ -11349,7 +11474,7 @@ def localWriteBody(tP): if kernel["ProblemType"]["StochasticRounding"]: # ScaleA/B, sgpr upper is dummy. - if kernel["ProblemType"]["UseScaleAB"] == "Scalar" and kernel["ProblemType"]["DataType%s"%tc].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): + if kernel["ProblemType"]["UseScaleAB"] == "Scalar" and kernel["ProblemType"]["DataType%s"%tc].numRegisters() > kernel["ProblemType"]["MacDataType%s"%tc].numRegisters(): localWriteCVTCode.add(VMulPKF32S(dst=vgpr(vgprTmp, 2), src0=vgpr(vgprTmp, 2), src1=sgpr("Scale%s"%tc, 2), vop3=VOP3PModifiers(op_sel_hi=[1,0,1]), comment="Input *= scale %s"%tc)) vRand = vgprTmp+2 if self.states.asmCaps["v_prng_b32"]: @@ -11383,7 +11508,7 @@ def localWriteBody(tP): localWriteCVTCode.add(VCvtSRF32toBF8(dst=vgpr(destVgprPrefix + "+%u+%u"%(g2lIdx, vi//2)), src0=vgpr(vgprTmp2), src1=vgpr(vRand), vop3=VOP3PModifiers(op_sel=[0,0,1,sel]), comment="Convert to BF8")) else: # ScaleA/B, sgpr upper is dummy. - if kernel["ProblemType"]["UseScaleAB"] == "Scalar" and kernel["ProblemType"]["DataType%s"%tc].numRegisters() > kernel["ProblemType"]["DataType"].numRegisters(): + if kernel["ProblemType"]["UseScaleAB"] == "Scalar" and kernel["ProblemType"]["DataType%s"%tc].numRegisters() > kernel["ProblemType"]["MacDataType%s"%tc].numRegisters(): localWriteCVTCode.add(VMulPKF32S(dst=vgpr(vgprTmp, 2), src0=vgpr(vgprTmp, 2), src1=sgpr("Scale%s"%tc, 2), vop3=VOP3PModifiers(op_sel_hi=[1,0,1]), comment="Input *= scale %s"%tc)) if (toF8): @@ -11392,7 +11517,7 @@ def localWriteBody(tP): localWriteCVTCode.add(VCvtPkF32toBF8(dst=vgpr(destVgprPrefix + "+%u+%u"%(g2lIdx, vi//2)), src0=vgpr(vgprTmp), src1=vgpr(vgprTmp2), vop3=VOP3PModifiers(op_sel=[0,0,sel]), comment="Convert to BF8")) self.vgprPool.checkIn(vgprTmp) - elif (kernel["ProblemType"]["DataType%s"%tc].isAnyFloat8() and kernel["ProblemType"]["DataType"].isHalf()): + elif (kernel["ProblemType"]["DataType%s"%tc].isAnyFloat8() and kernel["ProblemType"]["MacDataType%s"%tc].isHalf()): noSDWA = ti.getArchCaps()["NoSDWA"] newBlockWidth = tP["globalReadInstruction"].blockWidth if newBlockWidth == 0.25: @@ -11531,7 +11656,7 @@ def localWriteBody(tP): if vi%2 == 0: localWriteCVTCode.add(VCvtPkF32toBF16(dst=vgpr(destVgprPrefix + "+%u+%u"%(g2lIdx, vi//2)), src0=vgpr(destVgprPrefix + "+%u+%u"%(g2lIdx, vi)), src1=vgpr(destVgprPrefix + "+%u+%u"%(g2lIdx, vi+1)), comment="convert/pack fp32 to bf16")) else: - printExit("Unsupported combination DataType%s (%s) -> DataType (%s)"%(tc, kernel["ProblemType"]["DataType%s"%tc].toChar(), kernel["ProblemType"]["DataType"].toChar())) + printExit("Unsupported combination DataType%s (%s) -> DataType (%s)"%(tc, kernel["ProblemType"]["DataType%s"%tc].toChar(), kernel["ProblemType"]["MacDataType%s"%tc].toChar())) LocalWriteX = tP["localWriteInstruction"].getInst(isHigh16Bits) localWriteMemToken = [self.states.ldsWriteTokenIdx] @@ -11589,8 +11714,7 @@ def localWriteBody(tP): # Enable local write if not DTL or using nonDTL loads in tail loop if (not kernel["DirectToLds%s"%tc]) or \ - (((tP["isA"] and kernel["NonDTLTailLoopA"]) or \ - (tP["isB"] and kernel["NonDTLTailLoopB"])) and self.states.inTailLoop): + ((tc in ("A","B","MXSA","MXSB") and kernel["NonDTLTailLoop%s"%tc]) and self.states.inTailLoop): # Skip local write if DTVA or DTVB if not ((tP["isA"] or tP["isB"]) and kernel["DirectToVgpr%s"%tc]): localWriteBody(tP) @@ -11772,6 +11896,10 @@ def localReadInitPointers(self, kernel, tPA, tP): numLra = self.states.a.numVgprLocalReadAddr elif tP["isB"]: numLra = self.states.b.numVgprLocalReadAddr + elif tP["isMXSA"]: + numLra = self.states.mxsa.numVgprLocalReadAddr + elif tP["isMXSB"]: + numLra = self.states.mxsb.numVgprLocalReadAddr elif tP["isM"]: numLra = self.states.m.numVgprLocalReadAddr module.add(VAndB32( @@ -11811,17 +11939,17 @@ def localReadInc(self, kernel, iui, tP): if self.states.inTailLoop: if kernel["UseDotInstruction"]: # dot2 - inc = self.states.lrvwUnrollA * kernel["NumWaveSplitK"] * tP["bpeDS"] + inc = int(self.states.lrvwUnrollA * kernel["NumWaveSplitK"] * tP["bpeDS"]) comment = "(LocalReadVectorWidth*NumWaveSplitK*bpeDS)" else: - inc = (kernel["MacroTile%s" % tP["tensorChar"]] + LdsPad) * tP["bpeDS"] + inc = int((kernel["MacroTile%s" % tP["tensorChar"]] + LdsPad) * tP["bpeDS"]) comment = " ((MT+PAD)*bpeDS)" if kernel["EnableMatrixInstruction"]: if kernel["UnrollMajorLDS%s" % tc]: if tc in ("MXSA", "MXSB"): - inc = tP["bpeDS"] * max(self.states.numReadsIterCoalescedMXSA,self.states.numReadsIterCoalescedMXSB) + inc = tP["bpeDS"] * max(self.states.numReadsIterCoalescedMXSA, self.states.numReadsIterCoalescedMXSB) else: - inc = tP["bpeDS"] * max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB) + inc = tP["bpeDS"] * max(self.states.numReadsIterCoalescedA, self.states.numReadsIterCoalescedB) comment = " (bpeDS)" inc *= matrixInstK if kernel["ProblemType"]["Sparse"]: @@ -11847,7 +11975,7 @@ def localReadInc(self, kernel, iui, tP): with self.allocTmpSgpr(1) as tmpSgprInfo: tmpSgpr = tmpSgprInfo.idx - module.add(SMovB32(dst=sgpr(tmpSgpr), src=int(inc+padd), comment="inc")) + module.add(SMovB32(dst=sgpr(tmpSgpr), src=int(inc + padd), comment="inc")) numLra = 0 if tP["isA"]: numLra = self.states.a.numVgprLocalReadAddr @@ -11884,8 +12012,7 @@ def localReadInc(self, kernel, iui, tP): if tc == "A": sparseA = kernel["ProblemType"]["Sparse"] == 1 lrvw = kernel["LocalReadVectorWidth%s"%tc] // (2 if sparseA else 1) - wlr = lrvw//kernel["MIInputPerThreadA"] - wlr = max(wlr, 1) + wlr = max(lrvw//kernel["MIInputPerThreadA"], 1) if kernel["ProblemType"]["Sparse"] and lrvw < kernel["MIInputPerThreadA"]: if not sparseA: offsetInc = (kernel["MacroTile%s"%tP["tensorChar"]] + LdsPad) * (kernel["MatrixInstK"]) @@ -11896,31 +12023,27 @@ def localReadInc(self, kernel, iui, tP): elif (self.states.localReadDoCntA)%(wlr): offsetInc = (kernel["MacroTile%s"%tP["tensorChar"]] + LdsPad) * kernel["MIInputPerThreadA"] else: - isTF32EmuGfx950 = kernel["UseF32XEmulation"] and (kernel["MatrixInstK"] >= 32 or (kernel["MatrixInstM"] == 32 and kernel["MatrixInstK"] == 16)) - if kernel["UnrollMajorLDSA"] == False and kernel["MatrixInstK"] >= 64 or isTF32EmuGfx950: # Special handling for new f8 mfmas and tf32 - offsetInc = (kernel["MacroTile%s"%tP["tensorChar"]] + LdsPad) * (kernel["MatrixInstK"]) - else: - offsetInc = (kernel["MacroTile%s"%tP["tensorChar"]] + LdsPad) * (kernel["MatrixInstK"]*lrvw//kernel["MIInputPerThreadA"]-kernel["MIInputPerThreadA"]*(lrvw//kernel["MIInputPerThreadA"]-1)) + offsetInc = (kernel["MacroTile%s"%tP["tensorChar"]] + LdsPad) * ((matrixInstK * wlr) - (kernel["MIInputPerThreadA"] * (wlr - 1))) if sparseA: offsetInc //= 2 elif tc in ("MXSA", "MXSB"): offsetInc = (kernel["MacroTile%s"%tP["tensorChar"]] + LdsPad) * (matrixInstK) elif tc == "Metadata": lrvw = kernel["LocalReadVectorWidth%s"%tc] // 8 + wlr = max(lrvw//kernel["MIInputPerThreadMetadata"], 1) if lrvw < kernel["MIInputPerThreadMetadata"]: offsetInc = (kernel["MacroTile%s"%tP["tensorChar"]] + LdsPad) * (kernel["MatrixInstK"]*lrvw//kernel["MIInputPerThreadMetadata"]) // 4 if kernel["WavefrontSize"] == 32: offsetInc *= 2 - elif (self.states.localReadDoCntMetadata)%(lrvw//kernel["MIInputPerThreadMetadata"]): + elif (self.states.localReadDoCntMetadata)%(wlr): offsetInc = (kernel["MacroTile%s"%tP["tensorChar"]] + LdsPad) * kernel["MIInputPerThreadMetadata"] else: - offsetInc = (kernel["MacroTile%s"%tP["tensorChar"]] + LdsPad) * (kernel["MatrixInstK"]*lrvw//kernel["MIInputPerThreadMetadata"]-kernel["MIInputPerThreadMetadata"]*(lrvw//kernel["MIInputPerThreadMetadata"]-1)) + offsetInc = (kernel["MacroTile%s"%tP["tensorChar"]] + LdsPad) * (kernel["MatrixInstK"] * wlr - (kernel["MIInputPerThreadMetadata"] * (wlr - 1))) offsetInc //= 8 elif tc == "B": sparseB = kernel["ProblemType"]["Sparse"] == 2 lrvw = kernel["LocalReadVectorWidth%s"%tc] // (2 if sparseB else 1) - wlr = lrvw//kernel["MIInputPerThreadB"] - wlr = max(wlr, 1) + wlr = max(lrvw//kernel["MIInputPerThreadB"], 1) if kernel["ProblemType"]["Sparse"] and lrvw < kernel["MIInputPerThreadB"]: if not sparseB: offsetInc = (kernel["MacroTile%s"%tP["tensorChar"]] + LdsPad) * (kernel["MatrixInstK"]) @@ -11930,14 +12053,8 @@ def localReadInc(self, kernel, iui, tP): offsetInc *= 2 elif (self.states.localReadDoCntB)%(wlr): offsetInc = (kernel["MacroTile%s"%tP["tensorChar"]] + LdsPad) * kernel["MIInputPerThreadB"] - else: - isTF32EmuGfx950 = kernel["UseF32XEmulation"] and (kernel["MatrixInstK"] >= 32 or (kernel["MatrixInstM"] == 32 and kernel["MatrixInstK"] == 16)) - if kernel["UnrollMajorLDSB"] == False and kernel["MatrixInstK"] >= 64 or isTF32EmuGfx950: # Special handling for new f8 mfmas and tf32 - offsetInc = (kernel["MacroTile%s"%tP["tensorChar"]] + LdsPad) * (kernel["MatrixInstK"]) - else: - offsetInc = (kernel["MacroTile%s"%tP["tensorChar"]] + LdsPad) * (kernel["MatrixInstK"]*lrvw//kernel["MIInputPerThreadB"]-kernel["MIInputPerThreadB"]*(lrvw//kernel["MIInputPerThreadB"]-1)) - + offsetInc = (kernel["MacroTile%s"%tP["tensorChar"]] + LdsPad) * (kernel["MatrixInstK"] * wlr - (kernel["MIInputPerThreadB"] * (wlr - 1))) if sparseB: offsetInc //= 2 else: @@ -14579,7 +14696,7 @@ def chooseGlobalRead(self, useBuffer, bpl, destVgpr, \ if useBuffer: rv = Module("Global Read") - mubuf = MUBUFModifiers(offen=True, offset12=offset, glc=glc, slc=slc, nt=nt, lds=lds) + mubuf = MUBUFModifiers(offen=True, offset12=int(offset), glc=glc, slc=slc, nt=nt, lds=lds) # Nested buffer load implementation function for easy branching for soffset def bufferLoadImpl(soffset): diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Problem.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Problem.py index d4c956656ad..f266140cefc 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Problem.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Problem.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -36,7 +36,6 @@ - class ProblemSizeRange: def __init__(self, problemType, config): @@ -775,7 +774,7 @@ def __init__(self, config, printIndexAssignmentInfo: bool): else: self["DataTypeB"] = self["MacDataTypeB"] self["DataTypeB"] = getRealDataTypeB(self["DataTypeB"]) - + if "DestDataType" in config: self["DestDataType"] = DataType(config["DestDataType"]) else: @@ -963,9 +962,6 @@ def _checkIfSupportedGEMMType(self): computeType = self["ComputeDataType"] gemmType = ( inType.toChar(), outType.toChar(), computeType.toChar() ) - if gemmType not in _validGEMMTypes: - raise Exception("This typed-GEMM (Ti, To, Tc) = (%s, %s, %s) is not supported yet."%(gemmType[0], gemmType[1], gemmType[2])) - if self["MXBlockA"] or self["MXBlockB"]: if gemmType not in _validMXGEMMTypes: raise Exception("This typed-MX-GEMM (Ti, To, Tc) = (%s, %s, %s) is not supported yet." % (gemmType[0], gemmType[1], gemmType[2])) @@ -979,6 +975,8 @@ def _checkIfSupportedGEMMType(self): raise Exception("MXShape is not supported") elif (self["MXBlockA"] not in _validMXGEMMBlock): raise Exception("MXShape is not supported") + elif gemmType not in _validGEMMTypes: + raise Exception("This typed-GEMM (Ti, To, Tc) = (%s, %s, %s) is not supported yet."%(gemmType[0], gemmType[1], gemmType[2])) ######################################## def initGEMM(self): @@ -1148,9 +1146,9 @@ def assignDerivedParameters(state, printIndexAssignmentInfo: bool=False): if printIndexAssignmentInfo: print("TLUA: %s (stridePosA(%d) int: + isMX = state["ProblemType"].get("MXBlockA", 0) != 0 or state["ProblemType"].get("MXBlockB", 0) != 0 + numBytesA = state["ProblemType"]["MacDataTypeA"].numBytes() + numBytesB = state["ProblemType"]["MacDataTypeB"].numBytes() lrvwA = state["LocalReadVectorWidthA"] lrvwB = state["LocalReadVectorWidthB"] ldsPadA = state["LdsPadA"] @@ -2216,63 +2252,69 @@ def calcLdsPad(isaInfoMap: Dict[str, IsaInfo]) -> int: if readRegsB == 4 or readRegsB == 1: optPadB *= 2 if ldsPadA == -1: - if not state["UnrollMajorLDSA"]: - if state["EnableMatrixInstruction"]: - ldsPadA = 0 - if state["MatrixInstB"] == 1 and state["MatrixInstM"] == 16: - ldsPadA = int(((16 * state["VectorWidthA"] * state["ProblemType"]["MacDataTypeA"].numBytes() + state["MacroTile0"] * state["ProblemType"]["MacDataTypeA"].numBytes() * state["LocalReadVectorWidthA"]) % 128) // state["ProblemType"]["MacDataTypeA"].numBytes()) - if state["GlobalReadVectorWidthA"] * state["ProblemType"]["MacDataTypeA"].numBytes() == 32 and ldsPadA == 0: - ldsPadA = int(16 // state["ProblemType"]["MacDataTypeA"].numBytes()) - if state["DirectToLdsA"]: - # TODO: Check if there are cases which benefit from padding, currently set to zero by default - ldsPadA = state["MatrixInstM"] if state["enableLDSTrA"] else 0 - else: # mac instruction - if state["ProblemType"]["TLUA"]: - ldsPadA = 0 - else: - ldsPadA = state["VectorWidthA"] + if isMX and state["ProblemType"]["DataTypeA"].is6bitFloat(): + ldsPadA = 0 else: - if state["DirectToLdsA"]: - if not state["ProblemType"]["TLUA"]: - bpeA = state["ProblemType"]["DataTypeA"].numBytes() - LdsStride = state["VectorWidthA"] * bpeA * state["DepthU"] - MinLdsBlockSizePerPadA = (state[f"GlobalReadVectorWidthA"] * bpeA) * state["WavefrontSize"] - isM0PadEnough = LdsStride >= MinLdsBlockSizePerPadA - ldsPadA = state["MatrixInstK"] if bpeA == 2 and not isM0PadEnough else 2 * lrvwA - else: + if not state["UnrollMajorLDSA"]: + if state["EnableMatrixInstruction"]: ldsPadA = 0 + if state["MatrixInstB"] == 1 and state["MatrixInstM"] == 16: + ldsPadA = int(((16 * state["VectorWidthA"] * numBytesA + state["MacroTile0"] * numBytesA * lrvwA) % 128) // numBytesA) + if state["GlobalReadVectorWidthA"] * numBytesA == 32 and ldsPadA == 0: + ldsPadA = int(16 // numBytesA) + if state["DirectToLdsA"]: + # TODO: Check if there are cases which benefit from padding, currently set to zero by default + ldsPadA = state["MatrixInstM"] if state["enableLDSTrA"] else 0 + else: # mac instruction + if state["ProblemType"]["TLUA"]: + ldsPadA = 0 + else: + ldsPadA = state["VectorWidthA"] else: - ldsPadA = max(state["GlobalReadVectorWidthA"],optPadA) + if state["DirectToLdsA"]: + if not state["ProblemType"]["TLUA"]: + bpeA = state["ProblemType"]["DataTypeA"].numBytes() + LdsStride = state["VectorWidthA"] * bpeA * state["DepthU"] + MinLdsBlockSizePerPadA = (state[f"GlobalReadVectorWidthA"] * bpeA) * state["WavefrontSize"] + isM0PadEnough = LdsStride >= MinLdsBlockSizePerPadA + ldsPadA = state["MatrixInstK"] if bpeA == 2 and not isM0PadEnough else 2 * lrvwA + else: + ldsPadA = 0 + else: + ldsPadA = max(state["GlobalReadVectorWidthA"],optPadA) assert(ldsPadA >= 0) if ldsPadB == -1: - if not state["UnrollMajorLDSB"]: - if state["EnableMatrixInstruction"]: - ldsPadB = 0 - if state["MatrixInstB"] == 1 and state["MatrixInstM"] == 16: - ldsPadB = int(((16 * state["VectorWidthB"] * state["ProblemType"]["MacDataTypeB"].numBytes() + state["MacroTile1"] * state["ProblemType"]["MacDataTypeB"].numBytes() * state["LocalReadVectorWidthB"]) % 128) // state["ProblemType"]["MacDataTypeB"].numBytes()) - if state["GlobalReadVectorWidthB"] * state["ProblemType"]["MacDataTypeB"].numBytes() == 32 and ldsPadB == 0: - ldsPadB = int(16 // state["ProblemType"]["MacDataTypeB"].numBytes()) - if state["DirectToLdsB"]: - # TODO: Check if there are cases which benefit from padding, currently set to zero by default - ldsPadB = state["MatrixInstM"] if state["enableLDSTrB"] else 0 - else: - if state["ProblemType"]["TLUB"]: - ldsPadB = 0 - else: - ldsPadB = state["VectorWidthB"] + if isMX and state["ProblemType"]["DataTypeB"].is6bitFloat(): + ldsPadB = 0 else: - if state["DirectToLdsB"]: - if not state["ProblemType"]["TLUB"]: - bpeB = state["ProblemType"]["DataTypeB"].numBytes() - LdsStride = state["VectorWidthB"] * bpeB * state["DepthU"] - MinLdsBlockSizePerPadB = (state[f"GlobalReadVectorWidthB"] * bpeB) * state["WavefrontSize"] - isM0PadEnough = LdsStride >= MinLdsBlockSizePerPadB - ldsPadB = state["MatrixInstK"] if bpeB == 2 and not isM0PadEnough else 2 * lrvwB - else: + if not state["UnrollMajorLDSB"]: + if state["EnableMatrixInstruction"]: ldsPadB = 0 + if state["MatrixInstB"] == 1 and state["MatrixInstM"] == 16: + ldsPadB = int(((16 * state["VectorWidthB"] * numBytesB + state["MacroTile1"] * numBytesB * lrvwB) % 128) // numBytesB) + if state["GlobalReadVectorWidthB"] * numBytesB == 32 and ldsPadB == 0: + ldsPadB = int(16 // numBytesB) + if state["DirectToLdsB"]: + # TODO: Check if there are cases which benefit from padding, currently set to zero by default + ldsPadB = state["MatrixInstM"] if state["enableLDSTrB"] else 0 + else: # mac instruction + if state["ProblemType"]["TLUB"]: + ldsPadB = 0 + else: + ldsPadB = state["VectorWidthB"] else: - ldsPadB = max(state["GlobalReadVectorWidthB"],optPadB) + if state["DirectToLdsB"]: + if not state["ProblemType"]["TLUB"]: + bpeB = state["ProblemType"]["DataTypeB"].numBytes() + LdsStride = state["VectorWidthB"] * bpeB * state["DepthU"] + MinLdsBlockSizePerPadB = (state[f"GlobalReadVectorWidthB"] * bpeB) * state["WavefrontSize"] + isM0PadEnough = LdsStride >= MinLdsBlockSizePerPadB + ldsPadB = state["MatrixInstK"] if bpeB == 2 and not isM0PadEnough else 2 * lrvwB + else: + ldsPadB = 0 + else: + ldsPadB = max(state["GlobalReadVectorWidthB"],optPadB) assert(ldsPadB >= 0) if state["ProblemType"]["Sparse"] and not state["DirectToVgprSparseMetadata"]: @@ -2439,145 +2481,231 @@ def calcLdsNumBytesM(): if state["LocalReadVectorWidthB"] == -1: state["LocalReadVectorWidthB"] = state["LocalReadVectorWidth"] - # Default LocalReadVectorWidth - if state["EnableMatrixInstruction"]: + def calLRVW(): # Default LocalReadVectorWidth - autoLRVWA = False - maxLRVWA = int(Solution.MAX_NUM_DS_LOAD_BYTES // state["ProblemType"]["MacDataTypeA"].numBytes()) - # Set maxLRVW to 32 for 6 bits float: use two load instructions b128(4 vgpr) and b64(2 vgpr) to mimic b192 - if isaInfoMap[isa].asmCaps["HasWMMA_f8f6f4"] and state["ProblemType"]["MacDataTypeA"].numBytes() == 0.75: - maxLRVWA = 32 - if state["LocalReadVectorWidthA"] == -1: - autoLRVWA = True - if state["TransposeLDS"] or (state["MIInputPerThread"] * state["ProblemType"]["MacDataTypeA"].numBytes() > Solution.MAX_NUM_DS_LOAD_BYTES): - state["LocalReadVectorWidthA"] = maxLRVWA + if state["EnableMatrixInstruction"]: + # Default LocalReadVectorWidth + autoLRVWA = False + maxLRVWA = int(Solution.MAX_NUM_DS_LOAD_BYTES // state["ProblemType"]["MacDataTypeA"].numBytes()) + # Set maxLRVW to 32 for 6 bits float: use two load instructions b128(4 vgpr) and b64(2 vgpr) to mimic b192 + if isaInfoMap[isa].asmCaps["HasWMMA_f8f6f4"] and state["ProblemType"]["MacDataTypeA"].numBytes() == 0.75: + maxLRVWA = 32 + if state["LocalReadVectorWidthA"] == -1: + autoLRVWA = True + if state["TransposeLDS"] or (state["MIInputPerThread"] * state["ProblemType"]["MacDataTypeA"].numBytes() > Solution.MAX_NUM_DS_LOAD_BYTES): + state["LocalReadVectorWidthA"] = maxLRVWA + else: + state["LocalReadVectorWidthA"] = min(state["MIInputPerThread"], maxLRVWA) + if state["LocalReadVectorWidthA"] > maxLRVWA: + raise RuntimeError("LocalReadVectorWidthA (%d) exceeds max %d (# bytes of lrvw > 32)" \ + % (state["LocalReadVectorWidthA"], maxLRVWA)) else: - state["LocalReadVectorWidthA"] = min(state["MIInputPerThread"], maxLRVWA) - assert state["LocalReadVectorWidthA"] <= maxLRVWA, "# bytes of lrvw > 32" - else: - if isaInfoMap[isa].asmCaps["HasWMMA_V3"]: - if state["LocalReadVectorWidthA"] != maxLRVWA and state["TransposeLDS"]: - reject(state, printRejectionReason, f"gfx1250 requires lrvwA == {maxLRVWA} for datatype {state['ProblemType']['MacDataTypeA']}, actual value: {state['LocalReadVectorWidthA']}") - if state["ProblemType"]["Sparse"]: - if isaInfoMap[isa].asmCaps["HasSMFMA"]: # gfx942, gfx950 - if state["MIInputPerThread"] * state["ProblemType"]["MacDataTypeA"].numBytes() > Solution.MAX_NUM_DS_LOAD_BYTES: - if state["LocalReadVectorWidthA"] < state["MIInputPerThread"] // 2: - reject(state, printRejectionReason, "LocalReadVectorWidthA < %u" %(state["MIInputPerThread"] // 2)) - elif isaInfoMap[isa].asmCaps["HasSWMMAC_gfx1250"]: # gfx1250 - if state["ProblemType"]["Sparse"] and state["LocalReadVectorWidthA"] * state["ProblemType"]["MacDataTypeA"].numBytes() > Solution.MAX_NUM_DS_LOAD_BYTES: - reject(state, printRejectionReason, "LocalReadVectorWidthA * BytePerMacDataTypeA(%s) > %d bytes." % (state["ProblemType"]["MacDataTypeA"].numBytes(), Solution.MAX_NUM_DS_LOAD_BYTES)) - elif not state["ProblemType"]["Sparse"] and not state["UseF32XEmulation"] and not(state["ProblemType"]["MacDataTypeA"].is8bitFloat() and (state["MatrixInstK"] in [64, 128,])): - if state["LocalReadVectorWidthA"] < state["MIInputPerThread"] and not state["LDSTrInst"] and not isaInfoMap[isa].asmCaps["HasWMMA_V3"]: - reject(state, printRejectionReason, "LocalReadVectorWidthA < %u" %(state["MIInputPerThread"])) - if state["LocalReadVectorWidthA"] > state["MIInputPerThread"] and not state["TransposeLDS"]: - reject(state, printRejectionReason, "LocalReadVectorWidth require Transpose LDS") - - if autoLRVWA: - if state["LocalReadVectorWidthA"] // state["MIInputPerThread"] > 1: - if (state["DepthU"] // state["MatrixInstK"] <= state["LocalReadVectorWidthA"] // state["MIInputPerThread"]): - # if only have 1 iteration with wider local read, reduce LRVW to have better scheduling (at least 2 iterations) - state["LocalReadVectorWidthA"] //= 2 + if isaInfoMap[isa].asmCaps["HasWMMA_V3"]: + if state["LocalReadVectorWidthA"] != maxLRVWA and state["TransposeLDS"]: + reject(state, printRejectionReason, f"gfx1250 requires lrvwA == {maxLRVWA} for datatype {state['ProblemType']['MacDataTypeA']}, actual value: {state['LocalReadVectorWidthA']}") + if state["ProblemType"]["Sparse"]: + if isaInfoMap[isa].asmCaps["HasSMFMA"]: # gfx942, gfx950 + if state["MIInputPerThread"] * state["ProblemType"]["MacDataTypeA"].numBytes() > Solution.MAX_NUM_DS_LOAD_BYTES: + if state["LocalReadVectorWidthA"] < state["MIInputPerThread"] // 2: + reject(state, printRejectionReason, "LocalReadVectorWidthA < %u" %(state["MIInputPerThread"] // 2)) + elif isaInfoMap[isa].asmCaps["HasSWMMAC_gfx1250"]: # gfx1250 + if state["ProblemType"]["Sparse"] and state["LocalReadVectorWidthA"] * state["ProblemType"]["MacDataTypeA"].numBytes() > Solution.MAX_NUM_DS_LOAD_BYTES: + reject(state, printRejectionReason, "LocalReadVectorWidthA * BytePerMacDataTypeA(%s) > %d bytes." % (state["ProblemType"]["MacDataTypeA"].numBytes(), Solution.MAX_NUM_DS_LOAD_BYTES)) + elif not state["ProblemType"]["Sparse"] and not state["UseF32XEmulation"] and not(state["ProblemType"]["MacDataTypeA"].is8bitFloat() and (state["MatrixInstK"] in [64, 128,])): + if state["LocalReadVectorWidthA"] < state["MIInputPerThread"] and not state["LDSTrInst"] and not isaInfoMap[isa].asmCaps["HasWMMA_V3"]: + reject(state, printRejectionReason, "LocalReadVectorWidthA < %u" %(state["MIInputPerThread"])) # << Rejected here + if state["LocalReadVectorWidthA"] > state["MIInputPerThread"] and not state["TransposeLDS"]: + reject(state, printRejectionReason, "LocalReadVectorWidth require Transpose LDS") + + if autoLRVWA: + if state["LocalReadVectorWidthA"] // state["MIInputPerThread"] > 1: + if (state["DepthU"] // state["MatrixInstK"] <= state["LocalReadVectorWidthA"] // state["MIInputPerThread"]): + # if only have 1 iteration with wider local read, reduce LRVW to have better scheduling (at least 2 iterations) + state["LocalReadVectorWidthA"] //= 2 - # Default LocalReadVectorWidth - autoLRVWB = False - maxLRVWB = int(Solution.MAX_NUM_DS_LOAD_BYTES // state["ProblemType"]["MacDataTypeB"].numBytes()) - # Set maxLRVW to 32 for 6 bits float: use two load instructions b128(4 vgpr) and b64(2 vgpr) to mimic b192 - if isaInfoMap[isa].asmCaps["HasWMMA_f8f6f4"] and state["ProblemType"]["MacDataTypeB"].numBytes() == 0.75: - maxLRVWB = 32 - if state["LocalReadVectorWidthB"] == -1: - autoLRVWB = True - if state["TransposeLDS"] or (state["MIInputPerThread"] * state["ProblemType"]["MacDataTypeB"].numBytes() > Solution.MAX_NUM_DS_LOAD_BYTES): - state["LocalReadVectorWidthB"] = maxLRVWB + # Default LocalReadVectorWidth + autoLRVWB = False + maxLRVWB = int(Solution.MAX_NUM_DS_LOAD_BYTES // state["ProblemType"]["MacDataTypeB"].numBytes()) + # Set maxLRVW to 32 for 6 bits float: use two load instructions b128(4 vgpr) and b64(2 vgpr) to mimic b192 + if isaInfoMap[isa].asmCaps["HasWMMA_f8f6f4"] and state["ProblemType"]["MacDataTypeB"].numBytes() == 0.75: + maxLRVWB = 32 + if state["LocalReadVectorWidthB"] == -1: + autoLRVWB = True + if state["TransposeLDS"] or (state["MIInputPerThread"] * state["ProblemType"]["MacDataTypeB"].numBytes() > Solution.MAX_NUM_DS_LOAD_BYTES): + state["LocalReadVectorWidthB"] = maxLRVWB + else: + state["LocalReadVectorWidthB"] = min(state["MIInputPerThread"], maxLRVWB) + if state["LocalReadVectorWidthB"] > maxLRVWB: + raise RuntimeError("LocalReadVectorWidthB (%d) exceeds max %d (# bytes of lrvw > 32)" \ + % (state["LocalReadVectorWidthB"], maxLRVWB)) else: - state["LocalReadVectorWidthB"] = min(state["MIInputPerThread"], maxLRVWB) - assert state["LocalReadVectorWidthB"] <= maxLRVWB, "# bytes of lrvwB > 32" - else: - if isaInfoMap[isa].asmCaps["HasWMMA_V3"]: - if state["LocalReadVectorWidthB"] != maxLRVWB and state["TransposeLDS"]: - reject(state, printRejectionReason, f"gfx1250 requires lrvwB == {maxLRVWB} for MacDataTypeB {state['ProblemType']['MacDataTypeB']}, actual value: {state['LocalReadVectorWidthB']}") - # TODO: Find better conditons to filter gfx1250 solutions - if state["ProblemType"]["Sparse"]: - if isaInfoMap[isa].asmCaps["HasSMFMA"]: # gfx942, gfx950 - if state["MIInputPerThread"] * state["ProblemType"]["MacDataTypeB"].numBytes() > Solution.MAX_NUM_DS_LOAD_BYTES: - if state["LocalReadVectorWidthB"] < state["MIInputPerThread"] // 2: - reject(state, printRejectionReason, "LocalReadVectorWidthB < %u" %(state["MIInputPerThread"] // 2)) - elif isaInfoMap[isa].asmCaps["HasSWMMAC_gfx1250"]: # gfx1250 - if state["ProblemType"]["Sparse"] and state["LocalReadVectorWidthB"] * state["ProblemType"]["MacDataTypeB"].numBytes() > Solution.MAX_NUM_DS_LOAD_BYTES: - reject(state, printRejectionReason, "LocalReadVectorWidthB * BytePerMacDataTypeB(%s) > %d bytes." % (state["ProblemType"]["MacDataTypeB"].numBytes(), Solution.MAX_NUM_DS_LOAD_BYTES)) - elif not state["ProblemType"]["Sparse"] and not state["UseF32XEmulation"] and not(state["ProblemType"]["MacDataTypeB"].is8bitFloat() and (state["MatrixInstK"] in [64, 128,])): - if state["LocalReadVectorWidthB"] < state["MIInputPerThread"] and not state["LDSTrInst"] and not isaInfoMap[isa].asmCaps["HasWMMA_V3"]: - reject(state, printRejectionReason, "LocalReadVectorWidthB < %u" %(state["MIInputPerThread"])) - if state["LocalReadVectorWidthB"] > state["MIInputPerThread"] and not state["TransposeLDS"]: - reject(state, printRejectionReason, "LocalReadVectorWidthB require Transpose LDS") - - if autoLRVWB: - if state["LocalReadVectorWidthB"] // state["MIInputPerThread"] > 1: - if (state["DepthU"] // state["MatrixInstK"] <= state["LocalReadVectorWidthB"] // state["MIInputPerThread"]): - # if only have 1 iteration with wider local read, reduce LRVW to have better scheduling (at least 2 iterations) - state["LocalReadVectorWidthB"] //= 2 - - if autoLRVWA or autoLRVWB: - wlrA = max(state["LocalReadVectorWidthA"] // state["MIInputPerThread"], 1) - wlrB = max(state["LocalReadVectorWidthB"] // state["MIInputPerThread"], 1) - if (wlrA > 1) or (wlrB > 1): - padA, padB, padM, padMXSA, padMXSB = calcLdsPad(isaInfoMap) - ldsBlockSizePerPadA = calcLdsBlockSizePerPad("A", state["LocalReadVectorWidthA"]) - ldsBlockSizePerPadB = calcLdsBlockSizePerPad("B", state["LocalReadVectorWidthB"]) - ldsBlockSizePerPadMXSA = calcLdsBlockSizePerPad("MXSA", state["LocalReadVectorWidthMXS"]) if state["ProblemType"]["MXBlockA"] else 0 - ldsBlockSizePerPadMXSB = calcLdsBlockSizePerPad("MXSB", state["LocalReadVectorWidthMXS"]) if state["ProblemType"]["MXBlockB"] else 0 - ldsBlockSizePerPadA = 0 if padA == 0 else ldsBlockSizePerPadA - ldsBlockSizePerPadB = 0 if padB == 0 else ldsBlockSizePerPadB - ldsBlockSizePerPadMXSA = 0 if padMXSA == 0 else ldsBlockSizePerPadMXSA - ldsBlockSizePerPadMXSB = 0 if padMXSB == 0 else ldsBlockSizePerPadMXSB - checkLdsBlockSizePerPadForTDM(ldsBlockSizePerPadA, ldsBlockSizePerPadB, ldsBlockSizePerPadMXSA, ldsBlockSizePerPadMXSB) - (ldsNumBytesA, ldsNumBytesAlignedA) = calcLdsNumBytesAB("A", padA, ldsBlockSizePerPadA) - (ldsNumBytesB, ldsNumBytesAlignedB) = calcLdsNumBytesAB("B", padB, ldsBlockSizePerPadB) - (ldsNumBytesMXSA, ldsNumBytesAlignedMXSA) = calcLdsNumBytesAB("MXSA", padMXSA, ldsBlockSizePerPadMXSA) if state["ProblemType"]["MXBlockA"] else 0, 0 - (ldsNumBytesMXSB, ldsNumBytesAlignedMXSB) = calcLdsNumBytesAB("MXSB", padMXSB, ldsBlockSizePerPadMXSB) if state["ProblemType"]["MXBlockB"] else 0, 0 - (ldsNumBytesMetadata, ldsNumBytesAlignedMetadata) = calcLdsNumBytesM() - ldsNumBytesMetadata, ldsNumBytesAlignedMetadata = calcLdsNumBytesM() - if (ldsNumBytesAlignedA + ldsNumBytesAlignedB) > state["MaxLDS"]: - if wlrA > 1: - state["LocalReadVectorWidthA"] //= 2 - if wlrB > 1: + if isaInfoMap[isa].asmCaps["HasWMMA_V3"]: + if state["LocalReadVectorWidthB"] != maxLRVWB and state["TransposeLDS"]: + reject(state, printRejectionReason, f"gfx1250 requires lrvwB == {maxLRVWB} for MacDataTypeB {state['ProblemType']['MacDataTypeB']}, actual value: {state['LocalReadVectorWidthB']}") + # TODO: Find better conditons to filter gfx1250 solutions + if state["ProblemType"]["Sparse"]: + if isaInfoMap[isa].asmCaps["HasSMFMA"]: # gfx942, gfx950 + if state["MIInputPerThread"] * state["ProblemType"]["MacDataTypeB"].numBytes() > Solution.MAX_NUM_DS_LOAD_BYTES: + if state["LocalReadVectorWidthB"] < state["MIInputPerThread"] // 2: + reject(state, printRejectionReason, "LocalReadVectorWidthB < %u" %(state["MIInputPerThread"] // 2)) + elif isaInfoMap[isa].asmCaps["HasSWMMAC_gfx1250"]: # gfx1250 + if state["ProblemType"]["Sparse"] and state["LocalReadVectorWidthB"] * state["ProblemType"]["MacDataTypeB"].numBytes() > Solution.MAX_NUM_DS_LOAD_BYTES: + reject(state, printRejectionReason, "LocalReadVectorWidthB * BytePerMacDataTypeB(%s) > %d bytes." % (state["ProblemType"]["MacDataTypeB"].numBytes(), Solution.MAX_NUM_DS_LOAD_BYTES)) + elif not state["ProblemType"]["Sparse"] and not state["UseF32XEmulation"] and not(state["ProblemType"]["MacDataTypeB"].is8bitFloat() and (state["MatrixInstK"] in [64, 128,])): + if state["LocalReadVectorWidthB"] < state["MIInputPerThread"] and not state["LDSTrInst"] and not isaInfoMap[isa].asmCaps["HasWMMA_V3"]: + reject(state, printRejectionReason, "LocalReadVectorWidthB < %u" %(state["MIInputPerThread"])) + if state["LocalReadVectorWidthB"] > state["MIInputPerThread"] and not state["TransposeLDS"]: + reject(state, printRejectionReason, "LocalReadVectorWidthB require Transpose LDS") + + if autoLRVWB: + if state["LocalReadVectorWidthB"] // state["MIInputPerThread"] > 1: + if (state["DepthU"] // state["MatrixInstK"] <= state["LocalReadVectorWidthB"] // state["MIInputPerThread"]): + # if only have 1 iteration with wider local read, reduce LRVW to have better scheduling (at least 2 iterations) state["LocalReadVectorWidthB"] //= 2 - wlrA = max(state["LocalReadVectorWidthA"] // state["MIInputPerThread"], 1) - wlrB = max(state["LocalReadVectorWidthB"] // state["MIInputPerThread"], 1) - if wlrA > wlrB: - state["LocalReadVectorWidthA"] = wlrB * state["MIInputPerThread"] - if wlrA < wlrB: - state["LocalReadVectorWidthB"] = wlrA * state["MIInputPerThread"] - - if state["ProblemType"]["Sparse"] == 1: - state["LocalReadVectorWidthMetadata"] = state["LocalReadVectorWidthA"] - elif state["ProblemType"]["Sparse"] == 2: - state["LocalReadVectorWidthMetadata"] = state["LocalReadVectorWidthB"] - - if state["enableGLTrA"] or state["enableGLTrB"]: - state["LocalReadVectorWidth"] = 8 - state["LocalReadVectorWidthA"] = 8 - state["LocalReadVectorWidthB"] = 8 - else: - if state["UseDotInstruction"]: - # dot2: LRVW should be equal to NumDotElements * InnerUnroll - if state["LocalReadVectorWidth"] not in [-1, state["NumDotElements"] * state["InnerUnroll"]]: - reject(state, printRejectionReason, "dot kernel requires LocalReadVectorWidth = NumDotElements(%u) * InnerUnroll(%u)" \ - % (state["NumDotElements"], state["InnerUnroll"])) - return - state["LocalReadVectorWidth"] = state["NumDotElements"] * state["InnerUnroll"] - state["LocalReadVectorWidthA"] = state["NumDotElements"] * state["InnerUnroll"] - state["LocalReadVectorWidthB"] = state["NumDotElements"] * state["InnerUnroll"] - else: # mac - if state["LocalReadVectorWidth"] == -1: - state["LocalReadVectorWidth"] = state["VectorWidthA"] - if state["LocalReadVectorWidth"] != state["VectorWidthA"] or \ - state["LocalReadVectorWidth"] != state["VectorWidthB"]: - reject(state, printRejectionReason, "LocalReadVectorWidth must equal VectorWidthA/B for MAC kernels") - if state["LocalReadVectorWidthA"] == -1: - state["LocalReadVectorWidthA"] = state["VectorWidthA"] - if state["LocalReadVectorWidthB"] == -1: - state["LocalReadVectorWidthB"] = state["VectorWidthB"] + if autoLRVWA or autoLRVWB: + wlrA = max(state["LocalReadVectorWidthA"] // state["MIInputPerThread"], 1) + wlrB = max(state["LocalReadVectorWidthB"] // state["MIInputPerThread"], 1) + if (wlrA > 1) or (wlrB > 1): + padA, padB, padM, padMXSA, padMXSB = calcLdsPad(isaInfoMap) + ldsBlockSizePerPadA = calcLdsBlockSizePerPad("A", state["LocalReadVectorWidthA"]) + ldsBlockSizePerPadB = calcLdsBlockSizePerPad("B", state["LocalReadVectorWidthB"]) + ldsBlockSizePerPadMXSA = calcLdsBlockSizePerPad("MXSA", state["LocalReadVectorWidthMXS"]) if state["ProblemType"]["MXBlockA"] else 0 + ldsBlockSizePerPadMXSB = calcLdsBlockSizePerPad("MXSB", state["LocalReadVectorWidthMXS"]) if state["ProblemType"]["MXBlockB"] else 0 + ldsBlockSizePerPadA = 0 if padA == 0 else ldsBlockSizePerPadA + ldsBlockSizePerPadB = 0 if padB == 0 else ldsBlockSizePerPadB + ldsBlockSizePerPadMXSA = 0 if padMXSA == 0 else ldsBlockSizePerPadMXSA + ldsBlockSizePerPadMXSB = 0 if padMXSB == 0 else ldsBlockSizePerPadMXSB + checkLdsBlockSizePerPadForTDM(ldsBlockSizePerPadA, ldsBlockSizePerPadB, ldsBlockSizePerPadMXSA, ldsBlockSizePerPadMXSB) + (ldsNumBytesA, ldsNumBytesAlignedA) = calcLdsNumBytesAB("A", padA, ldsBlockSizePerPadA) + (ldsNumBytesB, ldsNumBytesAlignedB) = calcLdsNumBytesAB("B", padB, ldsBlockSizePerPadB) + (ldsNumBytesMXSA, ldsNumBytesAlignedMXSA) = calcLdsNumBytesAB("MXSA", padMXSA, ldsBlockSizePerPadMXSA) if state["ProblemType"]["MXBlockA"] else (0, 0) + (ldsNumBytesMXSB, ldsNumBytesAlignedMXSB) = calcLdsNumBytesAB("MXSB", padMXSB, ldsBlockSizePerPadMXSB) if state["ProblemType"]["MXBlockB"] else (0, 0) + (ldsNumBytesMetadata, ldsNumBytesAlignedMetadata) = calcLdsNumBytesM() + if (ldsNumBytesAlignedA + ldsNumBytesAlignedB) > state["MaxLDS"]: + if wlrA > 1: + state["LocalReadVectorWidthA"] //= 2 + if wlrB > 1: + state["LocalReadVectorWidthB"] //= 2 + wlrA = max(state["LocalReadVectorWidthA"] // state["MIInputPerThread"], 1) + wlrB = max(state["LocalReadVectorWidthB"] // state["MIInputPerThread"], 1) + if wlrA > wlrB: + state["LocalReadVectorWidthA"] = wlrB * state["MIInputPerThread"] + if wlrA < wlrB: + state["LocalReadVectorWidthB"] = wlrA * state["MIInputPerThread"] + + if state["ProblemType"]["Sparse"] == 1: + state["LocalReadVectorWidthMetadata"] = state["LocalReadVectorWidthA"] + elif state["ProblemType"]["Sparse"] == 2: + state["LocalReadVectorWidthMetadata"] = state["LocalReadVectorWidthB"] + + if state["enableGLTrA"] or state["enableGLTrB"]: + state["LocalReadVectorWidth"] = 8 + state["LocalReadVectorWidthA"] = 8 + state["LocalReadVectorWidthB"] = 8 + else: + if state["UseDotInstruction"]: + # dot2: LRVW should be equal to NumDotElements * InnerUnroll + if state["LocalReadVectorWidth"] not in [-1, state["NumDotElements"] * state["InnerUnroll"]]: + reject(state, printRejectionReason, "dot kernel requires LocalReadVectorWidth = NumDotElements(%u) * InnerUnroll(%u)" \ + % (state["NumDotElements"], state["InnerUnroll"])) + return + state["LocalReadVectorWidth"] = state["NumDotElements"] * state["InnerUnroll"] + state["LocalReadVectorWidthA"] = state["NumDotElements"] * state["InnerUnroll"] + state["LocalReadVectorWidthB"] = state["NumDotElements"] * state["InnerUnroll"] + else: # mac + if state["LocalReadVectorWidth"] == -1: + state["LocalReadVectorWidth"] = state["VectorWidthA"] + if state["LocalReadVectorWidth"] != state["VectorWidthA"] or \ + state["LocalReadVectorWidth"] != state["VectorWidthB"]: + reject(state, printRejectionReason, "LocalReadVectorWidth must equal VectorWidthA/B for MAC kernels") + if state["LocalReadVectorWidthA"] == -1: + state["LocalReadVectorWidthA"] = state["VectorWidthA"] + if state["LocalReadVectorWidthB"] == -1: + state["LocalReadVectorWidthB"] = state["VectorWidthB"] + + + def calLRVWFor950MX(): + # Determine if we need to infer LRVW. Returns True if, + # - state["LocalReadVectorWidth{tc}"] is -1, and, + # - state["ProblemType"]["MacDataType{tc}"] is not 6-bit float + # If the LRVW is set by the user, validate the configuration and rejects if, + # - state["LocalReadVectorWidth{tc}"] * state["ProblemType"]["MacDataType{tc}"].numRegisters() < 1 if not sparse + # - state["LocalReadVectorWidth{tc}"] // 2 * state["ProblemType"]["MacDataType{tc}"].numRegisters() < 1 is sparse + # - state["LocalReadVectorWidth{tc}"] > state["MIInputPerThread"] and LDS is not transposed + def isAutoLRVW(tc) -> bool: + autoLRVW = False + if state[f"LocalReadVectorWidth{tc}"] != -1: + tmplrvw = (state[f"LocalReadVectorWidth{tc}"] // 2) if state["ProblemType"]["Sparse"] else state[f"LocalReadVectorWidth{tc}"] + if tmplrvw * state["ProblemType"][f"MacDataType{tc}"].numRegisters() < 1: + reject(state, "LocalReadVectorWidth * dataRegister < 1") + if state[f"LocalReadVectorWidth{tc}"] > state["MIInputPerThread"] and not state["TransposeLDS"]: + reject(state, "LocalReadVectorWidth require Transpose LDS") + else: + if state["ProblemType"][f"MacDataType{tc}"].is6bitFloat(): + state[f"LocalReadVectorWidth{tc}"] = 32 if state[f"UnrollMajorLDS{tc}"] else 16 + else: + autoLRVW = True + if state["TransposeLDS"] and (not state[f"DirectToLds{tc}"]): + state[f"LocalReadVectorWidth{tc}"] = int(16 // state["ProblemType"][f"MacDataType{tc}"].numBytes()) + else: + if state["ProblemType"]["Sparse"] and state["MIInputPerThread"] * state["ProblemType"][f"MacDataType{tc}"].numBytes() > 16: + state[f"LocalReadVectorWidth{tc}"] = int(16 // state["ProblemType"][f"MacDataType{tc}"].numBytes()) + else: + state[f"LocalReadVectorWidth{tc}"] = state["MIInputPerThread"] + if state[f"LocalReadVectorWidth{tc}"] // state["MIInputPerThread"] > 1: + if (state["DepthU"] // state["MatrixInstK"] <= state[f"LocalReadVectorWidth{tc}"] // state["MIInputPerThread"]): + # if only have 1 iteration with wider local read, reduce LRVW to have better scheduling (at least 2 iterations) + state[f"LocalReadVectorWidth{tc}"] //= 2 + return autoLRVW + + if state["EnableMatrixInstruction"]: + autoLRVWA = isAutoLRVW("A") + autoLRVWB = isAutoLRVW("B") + if autoLRVWA or autoLRVWB: + wlrA = max(state["LocalReadVectorWidthA"] // state["MIInputPerThread"], 1) + wlrB = max(state["LocalReadVectorWidthB"] // state["MIInputPerThread"], 1) + + if (wlrA > 1) or (wlrB > 1): + padA, padB, padM = calcLdsPad(state["LocalReadVectorWidth"], isaInfoMap) + ldsBlockSizePerPadA, ldsBlockSizePerPadB = calcLdsBlockSizePerPad(state["LocalReadVectorWidth"]) + ldsNumBytesA, ldsNumBytesAlignedA, ldsNumBytesB, ldsNumBytesAlignedB, ldsNumBytesMetadata, ldsNumBytesAlignedMetadata, \ + ldsNumBytesMXSA, ldsNumBytesAlignedMXSA, ldsNumBytesMXSB, ldsNumBytesAlignedMXSB \ + = calcLdsNumBytes(padA, ldsBlockSizePerPadA, padB, ldsBlockSizePerPadB) + ldsNumBytes = ldsNumBytesAlignedA + ldsNumBytesAlignedB + \ + ldsNumBytesAlignedMXSA + ldsNumBytesAlignedMXSB + \ + ldsNumBytesAlignedMetadata + if ldsNumBytes > state["MaxLDS"]: + if wlrA > 1: + state["LocalReadVectorWidthA"] //= 2 + if wlrB > 1: + state["LocalReadVectorWidthB"] //= 2 + + wlrA = max(state["LocalReadVectorWidthA"] // state["MIInputPerThread"], 1) + wlrB = max(state["LocalReadVectorWidthB"] // state["MIInputPerThread"], 1) + if wlrA > wlrB: + state["LocalReadVectorWidthA"] = wlrB * state["MIInputPerThread"] + if wlrA < wlrB: + state["LocalReadVectorWidthB"] = wlrA * state["MIInputPerThread"] + + if state["ProblemType"]["Sparse"] == 1: + state["LocalReadVectorWidthMetadata"] = state["LocalReadVectorWidthA"] + elif state["ProblemType"]["Sparse"] == 2: + state["LocalReadVectorWidthMetadata"] = state["LocalReadVectorWidthB"] + + if state["ProblemType"]["MXBlockA"]: + state["LocalReadVectorWidthMXSA"] = 1 # TODO: check if need to fomulization + if state["ProblemType"]["MXBlockB"]: + state["LocalReadVectorWidthMXSB"] = 1 # TODO: check if need to fomulization + else: + reject(state, printRejectionReason, "expecting MFMA for MX datatypes") + return + + # Default LocalReadVectorWidth + if state["ISA"] == IsaVersion(9,5,0) and (state["ProblemType"]["MXBlockA"] or state["ProblemType"]["MXBlockB"]): + calLRVWFor950MX() + else: + calLRVW() def calcOptGRVW(lrvw: int, unrollMajorLDS: bool, datatype: DataType) -> int: # with UnrollMajorLDS, GRVW need to less or equal than LRVW to have conflict free LDS read with padding. @@ -2603,6 +2731,8 @@ def calSwizzlePackK(state, tc): elif state["GlobalReadVectorWidthA"] == -1: if state["ProblemType"]["SwizzleTensorA"]: state["GlobalReadVectorWidthA"] = state["MIInputPerThreadA"] * calSwizzlePackK(state, "A") + elif state["ProblemType"]["DataTypeA"].is6bitFloat(): + state["GlobalReadVectorWidthA"] = 32 elif state["enableGLTrA"]: state["GlobalReadVectorWidthA"] = 8 else: @@ -2613,6 +2743,12 @@ def calSwizzlePackK(state, tc): if (state["MacroTile0"]*state["_DepthUA"]//state["NumThreads"]) % curGRVW == 0: state["GlobalReadVectorWidthA"] = int(curGRVW) curGRVW *= 2 + if state["ProblemType"]["MXBlockA"]: + state["GlobalReadVectorWidthMXSA"] = max(state["MacroTile0"] * state["_DepthUMXSA"] // state["NumThreads"], 1) + # workaround for DTL + # use 32bit load for DTL if GlobalReadVectorWidthMXS is 8 + if state["DirectToLdsMXSA"] and state["GlobalReadVectorWidthMXSA"] == 8: + state["GlobalReadVectorWidthMXSA"] = 4 else: # dot2 if state["UseDotInstruction"]: @@ -2642,6 +2778,8 @@ def calSwizzlePackK(state, tc): elif state["GlobalReadVectorWidthB"] == -1: if state["ProblemType"]["SwizzleTensorB"]: state["GlobalReadVectorWidthB"] = state["MIInputPerThreadB"] * calSwizzlePackK(state, "B") + elif state["ProblemType"]["DataTypeB"].is6bitFloat(): + state["GlobalReadVectorWidthB"] = 32 elif state["enableGLTrB"]: state["GlobalReadVectorWidthB"] = 8 else: @@ -2652,6 +2790,12 @@ def calSwizzlePackK(state, tc): if (state["MacroTile1"]*state["_DepthUB"]//state["NumThreads"]) % curGRVW == 0: state["GlobalReadVectorWidthB"] = int(curGRVW) curGRVW *= 2 + if state["ProblemType"]["MXBlockB"]: + state["GlobalReadVectorWidthMXSB"] = max(state["MacroTile1"] * state["_DepthUMXSB"] // state["NumThreads"], 1) + # workaround for DTL + # use 32bit load for DTL if GlobalReadVectorWidthMXS is 8 + if state["DirectToLdsMXSB"] and state["GlobalReadVectorWidthMXSB"] == 8: + state["GlobalReadVectorWidthMXSB"] = 4 else: # dot2 if state["UseDotInstruction"]: @@ -2727,9 +2871,9 @@ def calSwizzlePackK(state, tc): reject(state, printRejectionReason, "VWB * MacDataTypeB.numBytes() > 16") # reject - GRVW too big - if (state["GlobalReadVectorWidthA"] * state["ProblemType"]["DataTypeA"].numBytes()) > 24 and not state["UseF32XEmulation"]: + if (state["GlobalReadVectorWidthA"] * state["ProblemType"]["DataTypeA"].numBytes()) > 24: reject(state, printRejectionReason, "GRVWA * DataTypeA.numBytes() > 24") - if (state["GlobalReadVectorWidthB"] * state["ProblemType"]["DataTypeB"].numBytes()) > 24 and not state["UseF32XEmulation"]: + if (state["GlobalReadVectorWidthB"] * state["ProblemType"]["DataTypeB"].numBytes()) > 24: reject(state, printRejectionReason, "GRVWB * DataTypeB.numBytes() > 24") disableGNLC = False # Set to true to disable GNLC if needed @@ -2749,6 +2893,9 @@ def calSwizzlePackK(state, tc): sameLayout = state["ProblemType"]["TLU%s"%tc] != state["UnrollMajorLDS%s"%tc] state["UseGeneralizedNLCOne%s"%tc] = grwidth != 8 and sameLayout \ and state["WaveSeparateGlobalRead%s"%tc] == 0 and not state["DirectToVgpr%s"%tc] + # workaround: disable UseGeneralizedNLCOne for MXBlock (TODO: enable it for MXBlock) + if state["ProblemType"]["MXBlock%s"%tc]: + state["UseGeneralizedNLCOne%s"%tc] = False else: state["UseGeneralizedNLCOne%s"%tc] = False state["UseGeneralizedNLCOneMetadata"] = False @@ -2851,6 +2998,8 @@ def calSwizzlePackK(state, tc): reject(state, printRejectionReason, f"GRVWB({state['GlobalReadVectorWidthB']}) > Coaleased({totalElementsCoalescedB})") totalElementsB = totalElementsCoalescedB * totalElementsPerpB + + # handle global read vector width M if state["ProblemType"]["Sparse"] and not state["DirectToVgprSparseMetadata"]: if state["ProblemType"]["TLUMetadata"]: totalElementsCoalescedM = state["MacroTileMetadata"] @@ -3167,6 +3316,11 @@ def calSwizzlePackK(state, tc): totalVectorsCoalescedMXSB, totalElementsPerpMXSB, state["_DepthUMXSB"], printRejectionReason): return + if state["ProblemType"]["MXBlockB"]: + if not Solution.setGlobalLoadTileDimClassic(state, "MXSB", state["NumLoadsMXSB"], \ + totalVectorsCoalescedMXSB, totalElementsPerpMXSB, state["_DepthUMXSB"], printRejectionReason): + return + if state["ProblemType"]["Sparse"] and not state["DirectToVgprSparseMetadata"]: if state["ProblemType"]["TLUMetadata"]: totalElementsCoalescedM = state["MacroTileMetadata"] @@ -3237,6 +3391,10 @@ def calSwizzlePackK(state, tc): state["LVCB"] = roundupRatio(state["LSCB"] , state["GlobalReadVectorWidthB"]) state["LVPB"] = roundupRatio(state["LSPB"] , state["GlobalReadVectorWidthB"]) + if state["ProblemType"]["MXBlockB"]: + state["LVCMXSB"] = roundupRatio(state["LSCMXSB"] , state["GlobalReadVectorWidthMXSB"]) + state["LVPMXSB"] = roundupRatio(state["LSPMXSB"] , state["GlobalReadVectorWidthMXSB"]) + # KRingShift wrap handling exists only in the tail loop. # If (k + KRingShift) would wrap inside the main loop, the kernel will be incorrect (no main-loop wrap fix). # Enforce a host-side runtime predicate which guarantees any KRS wrap happens only in tail. @@ -3261,10 +3419,6 @@ def calSwizzlePackK(state, tc): f"bpeB={bpeB})") return - if state["ProblemType"]["MXBlockB"]: - state["LVCMXSB"] = roundupRatio(state["LSCMXSB"] , state["GlobalReadVectorWidthMXSB"]) - state["LVPMXSB"] = roundupRatio(state["LSPMXSB"] , state["GlobalReadVectorWidthMXSB"]) - if state["ProblemType"]["Sparse"] and not state["DirectToVgprSparseMetadata"]: state["LVCMetadata"] = roundupRatio(state["LSCMetadata"] , state["GlobalReadVectorWidthMetadata"]) state["LVPMetadata"] = roundupRatio(state["LSPMetadata"] , state["GlobalReadVectorWidthMetadata"]) @@ -3300,22 +3454,30 @@ def calSwizzlePackK(state, tc): # - NoTailLoop # - DepthU is not power of 2 # - LocalSplitU > 1 + # - MX case if ((not state["EnableMatrixInstruction"]) or isaInfoMap[isa].asmCaps["HasWMMA"]) or \ (state["PrefetchGlobalRead"] == 0) or \ state["NoTailLoop"] or \ (state["DepthU"] <=1 or (state["DepthU"] & (state["DepthU"] - 1) != 0)) or \ - state["LocalSplitU"] > 1: + state["LocalSplitU"] > 1 or \ + state["ProblemType"]["MXBlockA"] or state["ProblemType"]["MXBlockB"]: state["TailloopInNll"] = False # need restrictions for TailloopInNll if state["TailloopInNll"]: + # need to disable SuppressNoLoadLoop + state["SuppressNoLoadLoop"] = False + + # disable StaggerUcode + # - TailloopInNll + # - MX + StreamK (not enough sgpr) + if state["TailloopInNll"] or \ + (state["StreamK"] and (state["ProblemType"]["MXBlockA"] or state["ProblemType"]["MXBlockB"])): # need to disable StaggerU state["StaggerU"] = 0 state["StaggerUMapping"] = 0 state["StaggerUStride"] = 0 - # need to disable SuppressNoLoadLoop - state["SuppressNoLoadLoop"] = False - state["InternalSupportParams"]["SupportCustomStaggerU"] = False # Disable CustomStaggerU for TailloopInNll + state["InternalSupportParams"]["SupportCustomStaggerU"] = False # Disable CustomStaggerU for no StagggerU code # Determine if we can load directly-to-Vgpr # need to check after state["LocalReadVectorWidth"] = -1 is resolved @@ -3354,14 +3516,22 @@ def calSwizzlePackK(state, tc): # LDS (load size coalesced) * LSPA must load some multiple of 256 bytes. # No longer support loadX2/loadx4 . for tc in ['A', 'B']: + tcmx = "MXS%s"%tc if state["DirectToLds%s"%tc]: isDtlDoable = Solution.isDirectToLdsDoable(state, tc, isaInfoMap, printRejectionReason) if (not state["DirectToVgpr%s"%tc]) and isDtlDoable: state["DirectToLds%s"%tc] = True state["LocalWriteUseSgpr%s"%tc] = True + # MX case + if state["ProblemType"]["MXBlock%s"%tc]: + isDtlMxDoable = Solution.isDirectToLdsDoable(state, tcmx, isaInfoMap, printRejectionReason) + state["DirectToLds%s"%tcmx] = isDtlMxDoable else: state["DirectToLds%s"%tc] = False state["LocalWriteUseSgpr%s"%tc] = False + # MX case + if state["ProblemType"]["MXBlock%s"%tc]: + state["DirectToLds%s"%tcmx] = False if not isDtlDoable: if state["UseGeneralizedNLCOne%s"%tc]: reject(state, printRejectionReason, "DirectToLds%s not doable, but GNLC%s enabled, rejecting"%(tc, tc)) @@ -3378,6 +3548,12 @@ def calSwizzlePackK(state, tc): if state["1LDSBuffer"] == -1 and state["DirectToLds"]: #1LDS buffer must be 0 for DirectToLdsA state["1LDSBuffer"] = 0 + # MX case + if state["ProblemType"]["MXBlockA"] or state["ProblemType"]["MXBlockB"]: + if state["DirectToLdsA"] != state["DirectToLdsMXSA"] or state["DirectToLdsB"] != state["DirectToLdsMXSB"]: + reject(state, printRejectionReason, "DirectToLdsA/B and DirectToLdsMXSA/B should match") + if state["DirectToLdsA"] != state["DirectToLdsB"]: + reject(state, printRejectionReason, "DirectToLdsA and DirectToLdsB should match") # does not work with UnrollLoopSwapGlobalReadOrder if (state["DirectToLds"] == 2 or state["DirectToLds"] == 3) and state["UnrollLoopSwapGlobalReadOrder"]: @@ -3532,7 +3708,7 @@ def findValidWriteBlockWidth(nwcv, bpe, bpr): def subCheckLdsBlockSizePerPad(tc, idx): lbspp = state["LdsBlockSizePerPad%s"%tc] - bpe = state["ProblemType"]["MacDataTypeB"].numBytes() + bpe = state["ProblemType"]["MacDataType%s"%tc].numBytes() bpr = 4 vw = state["GlobalReadVectorWidth%s"%tc] tlu = state["ProblemType"]["TLU%s"%tc] @@ -3598,9 +3774,20 @@ def subCheckLdsBlockSizePerPad(tc, idx): state["NoLdsWriteCode"] = False if (state["DirectToVgprA"] or state["DirectToLdsA"]) and (state["DirectToVgprB"] or state["DirectToLdsB"]): state["NoLdsWriteCode"] = True + # MX case + if state["ProblemType"]["MXBlockA"]: + if not (state["DirectToVgprMXSA"] or state["DirectToLdsMXSA"]): + state["NoLdsWriteCode"] = False + if state["ProblemType"]["MXBlockB"]: + if not (state["DirectToVgprMXSB"] or state["DirectToLdsMXSB"]): + state["NoLdsWriteCode"] = False elif state["enableTDMA"] and state["enableTDMB"]: state["NoLdsWriteCode"] = True + # Use64bShadowLimitMX for MX case only + if not (state["ProblemType"]["MXBlockA"] or state["ProblemType"]["MXBlockB"]): + state["Use64bShadowLimitMX"] = False + # enable scheduling GR (in LWcode for PGR2) over barrier sync if not state["EnableMatrixInstruction"]: # not valid for non MFMA case @@ -3643,14 +3830,15 @@ def subCheckLdsBlockSizePerPad(tc, idx): # number of minimum GR inc inst per MFMA # default 1 # Set at least 2 for gfx950 + MI16 + smaller MT case + # Set 3 for MX (number of GRInc is doubled) if state["MinGRIncPerMfma"] == -1: state["MinGRIncPerMfma"] = 1 if isa == (9, 5, 0): - if state["EnableMatrixInstruction"] and state["MatrixInstM"] == 16 and state["MatrixInstK"] == 1: - if numMFMA<=32: - state["MinGRIncPerMfma"] = 2 - if numMFMA<=16: + if state["EnableMatrixInstruction"] and state["MatrixInstM"] == 16 and state["MatrixInstB"] == 1: + if numMFMA<=16 or (state["ProblemType"]["MXBlockA"] or state["ProblemType"]["MXBlockB"]): state["MinGRIncPerMfma"] = 3 + elif numMFMA<=32: + state["MinGRIncPerMfma"] = 2 # calculate ldsPad state["LdsPadA"], state["LdsPadB"], state["LdsPadMetadata"], state["LdsPadMXSA"], state["LdsPadMXSB"] = calcLdsPad(isaInfoMap) @@ -3670,6 +3858,14 @@ def subCheckLdsBlockSizePerPad(tc, idx): if state["LdsPad%s"%tc] == 0: state["LdsBlockSizePerPad%s"%tc] = 0 + # Normalize lds block-size-per-pad fields to native Python int. + assert(int(state["LdsBlockSizePerPadA"]) == state["LdsBlockSizePerPadA"]) + assert(int(state["LdsBlockSizePerPadB"]) == state["LdsBlockSizePerPadB"]) + assert(int(state["LdsBlockSizePerPadMetadata"]) == state["LdsBlockSizePerPadMetadata"]) + state["LdsBlockSizePerPadA"] = int(state["LdsBlockSizePerPadA"]) + state["LdsBlockSizePerPadB"] = int(state["LdsBlockSizePerPadB"]) + state["LdsBlockSizePerPadMetadata"] = int(state["LdsBlockSizePerPadMetadata"]) + if (state["UnrollMajorLDSA"] or state["UnrollMajorLDSB"]) and (not state["EnableMatrixInstruction"]) and (not state["UseDotInstruction"]): reject(state, printRejectionReason, "UnrollMajorLDS Supports only in EnableMatrixInstruction=1 or dot2 kernel") @@ -3687,11 +3883,11 @@ def subCheckLdsBlockSizePerPad(tc, idx): # todo, can the alignment be a power of 2? state["LdsOffsetA"] = 0 - state["LdsNumElementsAlignedA"] = ldsNumBytesAlignedA - state["LdsNumElementsAlignedMXSA"] = ldsNumBytesAlignedMXSA - state["LdsNumElementsAlignedB"] = ldsNumBytesAlignedB - state["LdsNumElementsAlignedMXSB"] = ldsNumBytesAlignedMXSB - state["LdsNumElementsAlignedMetadata"] = ldsNumBytesAlignedMetadata + state["LdsNumElementsAlignedA"] = int(ldsNumBytesAlignedA) + state["LdsNumElementsAlignedMXSA"] = int(ldsNumBytesAlignedMXSA) + state["LdsNumElementsAlignedB"] = int(ldsNumBytesAlignedB) + state["LdsNumElementsAlignedMXSB"] = int(ldsNumBytesAlignedMXSB) + state["LdsNumElementsAlignedMetadata"] = int(ldsNumBytesAlignedMetadata) # check for auto DtlPlusLdsBuf if state["DtlPlusLdsBuf"] == -1: if state["PrefetchGlobalRead"] > 2: @@ -4259,7 +4455,6 @@ def calcEpilogueTurns(factorDims: List) -> int: not state["InnerUnroll"] >= state["LocalReadVectorWidth"] // state["MIInputPerThread"]: reject(state, printRejectionReason, "wider localRead only support ClusterLocalRead or (InnerUnroll > WiderLocalReadxN)") - if state["GlobalReadPerMfma"] > 1 and state["PrefetchGlobalRead"] >= 2: reject(state, printRejectionReason, "GlobalReadPerMfma need to be 1 if PGR>=2") @@ -4268,19 +4463,19 @@ def calcEpilogueTurns(factorDims: List) -> int: state["ULSGRODoubleG2L"] = 0 if state["UnrollLoopSwapGlobalReadOrder"] == 1: - bpeAB = state["ProblemType"]["DataType"].numBytes() + bpeA = state["ProblemType"]["DataType"].numBytes() bpr = 4 numVgprG2LA = roundUp((state["NumLoadsCoalescedA"] * state["NumLoadsPerpendicularA"] * \ - state["GlobalReadVectorWidthA"] * bpeAB) / (float)(bpr)) + state["GlobalReadVectorWidthA"] * bpeA) / (float)(bpr)) numVgprG2LB = roundUp((state["NumLoadsCoalescedB"] * state["NumLoadsPerpendicularB"] * \ - state["GlobalReadVectorWidthB"] * bpeAB) / (float)(bpr)) + state["GlobalReadVectorWidthB"] * bpeA) / (float)(bpr)) if numVgprG2LA % 2 == 1 or numVgprG2LB % 2 == 1: reject(state, printRejectionReason, "G2LA/B vgpr has bubble inside. Cannot use UnrollLoopSwapGlobalReadOrder=1.") if state["GlobalReadVectorWidthA"] != state["GlobalReadVectorWidthB"]: # TODO: Add a configuration to schedule better. state["ULSGRODoubleG2L"] = 1 minGRVW = min(state["GlobalReadVectorWidthA"], state["GlobalReadVectorWidthB"]) - if minGRVW * bpeAB < 4: + if minGRVW * bpeA < 4: # G2LA/B vgpr index will jump. state["ULSGRODoubleG2L"] = 1 if state["ExpandPointerSwap"] == 1: @@ -4289,6 +4484,8 @@ def calcEpilogueTurns(factorDims: List) -> int: reject(state, printRejectionReason, "PrefetchGlobalRead need to be >=2 if UnrollLoopSwapGlobalReadOrder") if state["ProblemType"]["DataTypeA"].numBytes() != state["ProblemType"]["DataTypeB"].numBytes(): reject(state, printRejectionReason, "UnrollLoopSwapGlobalReadOrder doesn't support mixed precision.") + if state["ProblemType"]["MXBlockA"] and (not state["DirectToLdsMXSA"]) or state["ProblemType"]["MXBlockB"] and (not state["DirectToLdsMXSB"]): + reject(state, printRejectionReason, "UnrollLoopSwapGlobalReadOrder doesn't support MX + non DTL") if state["ExpandPointerSwap"] == 1 and state["LDSTrInst"]: reject(state, printRejectionReason, "LDSTrInst + ExpandPointerSwap not supported") diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Utilities.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Utilities.py index 4cda6a12a46..12d824e3f55 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Utilities.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Utilities.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -25,13 +25,16 @@ import sys import math +from Tensile.Common.DataType import DataType +from rocisa.enum import DataTypeEnum + def getMiInputType(kernel: dict): """Select the effective MI operand type for MFMA latency lookup. Handles three cases based on kernel flags: 1. EnableF32XdlMathOp + UseF32XEmulation → BFloat16 (BF16 emulation) 2. EnableF32XdlMathOp only → F32XdlMathOp (native XF32) - 3. Neither → DataType (plain) + 3. Neither → MacDataTypeA (plain) UseF32XEmulation implies EnableF32XdlMathOp (guaranteed by Solution.__init__). @@ -40,8 +43,6 @@ def getMiInputType(kernel: dict): """ if kernel["EnableF32XdlMathOp"]: if kernel["UseF32XEmulation"]: - from rocisa.enum import DataTypeEnum - from Tensile.Common.DataType import DataType return DataType(DataTypeEnum.BFloat16) return kernel["ProblemType"]["F32XdlMathOp"] return kernel["ProblemType"]["DataType"] @@ -86,3 +87,27 @@ def pvar(state, field): def roundupRatio(dividend, divisor): return int(math.ceil(float(dividend) / float(divisor))) + +def getRealDataTypeA(dataType): + if dataType.value == DataTypeEnum.Float8BFloat8.value: + return DataType(DataTypeEnum.Float8) + elif dataType.value == DataTypeEnum.BFloat8Float8.value: + return DataType(DataTypeEnum.BFloat8) + elif dataType.value == DataTypeEnum.Float8BFloat8_fnuz.value: + return DataType(DataTypeEnum.Float8_fnuz) + elif dataType.value == DataTypeEnum.BFloat8Float8_fnuz.value: + return DataType(DataTypeEnum.BFloat8_fnuz) + else: + return dataType + +def getRealDataTypeB(dataType): + if dataType.value == DataTypeEnum.Float8BFloat8.value: + return DataType(DataTypeEnum.BFloat8) + elif dataType.value == DataTypeEnum.BFloat8Float8.value: + return DataType(DataTypeEnum.Float8) + elif dataType.value == DataTypeEnum.Float8BFloat8_fnuz.value: + return DataType(DataTypeEnum.BFloat8_fnuz) + elif dataType.value == DataTypeEnum.BFloat8Float8_fnuz.value: + return DataType(DataTypeEnum.Float8_fnuz) + else: + return dataType diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Validators/MatrixInstruction.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Validators/MatrixInstruction.py index cc0fcdf76fa..adc52f21f1c 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Validators/MatrixInstruction.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Validators/MatrixInstruction.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -62,6 +62,8 @@ def matrixInstructionToMIParameters( result = {} result["ISA"] = isa + isgfx950 = isa[:2] == (9, 5) + # Enable F32 XDL math operation only when the input type is f32. enableF32xdl = ( "F32XdlMathOp" in problemType @@ -136,11 +138,13 @@ def matrixInstructionToMIParameters( sparseB = False if not isSparse else True if isSparse == 2 else False result['MIInputPerThreadA'] = result['MIInputPerThreadA'] if not sparseA else result['MIInputPerThreadA'] // 2 if ("MXBlockA" in problemType) and problemType["MXBlockA"]: - duplicateFactor = 32 // result["MatrixInstM"] + # work around for gfx950. Use duplicateFactor = 1 + duplicateFactor = 32 // result["MatrixInstM"] if not isgfx950 else 1 result['MIInputPerThreadMXSA'] = result['MIInputPerThreadA'] // problemType["MXBlockA"] * duplicateFactor result['MIInputPerThreadB'] = result['MIInputPerThreadB'] if not sparseB else result['MIInputPerThreadB'] // 2 if ("MXBlockB" in problemType) and problemType["MXBlockB"]: - duplicateFactor = 32 // result["MatrixInstN"] + # work around for gfx950. Use duplicateFactor = 1 + duplicateFactor = 32 // result["MatrixInstN"] if not isgfx950 else 1 result['MIInputPerThreadMXSB'] = result['MIInputPerThreadB'] // problemType["MXBlockB"] * duplicateFactor result['MIInputPerThreadMetadata'] = result['MIInputPerThread'] if not isSparse else result['MIInputPerThread'] // 8 @@ -194,12 +198,23 @@ def validateMIParameters( ptype = solution["ProblemType"] isSparse = ptype.get("Sparse", 0) + isX = solution.get("EnableF32XdlMathOp", False) miDataType = DataType( ptype["DataType"] - if not solution.get("EnableF32XdlMathOp", False) + if not isX else ptype["F32XdlMathOp"] ) + # For MFMA validation, determine the key based on MAC data types + macDataTypeA = ptype.get("MacDataTypeA", miDataType) if not isX else miDataType + macDataTypeB = ptype.get("MacDataTypeB", miDataType) if not isX else miDataType + + # If both MAC types are the same, use doubled character key; otherwise use combined key + if macDataTypeA == macDataTypeB: + miDataTypeKey = macDataTypeA.toChar() + macDataTypeA.toChar() + else: + miDataTypeKey = macDataTypeA.toChar() + macDataTypeB.toChar() + mi4 = solution[MI_KEY] miEnabled = solution[MI_ENABLED_KEY] assert len(mi4) == 4 or len(mi4) == 0, elineno() + ": MI length not 4 or 0" @@ -244,7 +259,7 @@ def validateMIParameters( if not isSparse: if hasMFMA: - if not (miDataType.toChar() in validMFMA and mi4 in validMFMA[miDataType.toChar()]): + if not (miDataTypeKey in validMFMA and mi4 in validMFMA[miDataTypeKey]): if miDataType.isBFloat16() and mi4 in validMFMA["B1k"]: # but is valid bf16 MFMA assert solution["MFMA_BF16_1K"], elineno() else: @@ -258,7 +273,7 @@ def validateMIParameters( solution, printSolutionRejectionReason, f"Invalid WMMA configuration: {solution}" ) else: - if hasSMFMA and not (miDataType.toChar() in validSMFMA and mi4 in validSMFMA[miDataType.toChar()]): + if hasSMFMA and not (miDataTypeKey in validSMFMA and mi4 in validSMFMA[miDataTypeKey]): return not reject( solution, printSolutionRejectionReason, f"Invalid SMFMA configuration: {solution}" ) diff --git a/projects/hipblaslt/tensilelite/Tensile/Tensile.py b/projects/hipblaslt/tensilelite/Tensile/Tensile.py index f68ab260864..000b1d47d66 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Tensile.py +++ b/projects/hipblaslt/tensilelite/Tensile/Tensile.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -218,10 +218,11 @@ def splitExtraParameters(par): argParser.add_argument("--library-format", dest="LibraryFormat", choices=["yaml", "msgpack"], \ action="store", default="yaml", help="select which library format to use") argParser.add_argument("--client-lock", default=None) - argParser.add_argument("--prebuilt-client", default=str(TENSILE_CLIENT_PATH), + argParser.add_argument("--prebuilt-client", default=str(TENSILE_CLIENT_PATH), \ type=os.path.abspath, help="Specify the full path to a pre-built tensilelite-client executable") + argParser.add_argument("--mx-scale-format", dest="MXScaleFormat", type=int, default=0, \ + help="MX scale data format (0=none, 1=pre-swizzle for GPU kernel layout)") argParser.add_argument("--rocm-agent-enumerator", default=None, action="store", dest="rocm_agent_enumerator") - argParser.add_argument("--global-parameters", nargs="+", type=splitExtraParameters, default=[]) @@ -247,6 +248,9 @@ def argUpdatedGlobalParameters(args): rv["ClientExecutionLockPath"] = args.client_lock if args.prebuilt_client: rv["PrebuiltClient"] = args.prebuilt_client + if args.MXScaleFormat: + print1("# Command-line override: MXScaleFormat") + rv["MXScaleFormat"] = args.MXScaleFormat for key, value in args.global_parameters: rv[key] = value diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/fp8_mxfp4_bf16_tn_act.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/fp8_mxfp4_bf16_tn_act.yaml new file mode 100644 index 00000000000..211830a63f4 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/fp8_mxfp4_bf16_tn_act.yaml @@ -0,0 +1,128 @@ +TestParameters: + marks: [xfail-gfx950, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201, skip-gfx1250, skip-gfx940, skip-gfx941] # Only for gfx950 + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + PrintSolutionRejectionReason: True + Device: 0 + CMakeBuildType: Release + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 21 + DataInitTypeB: 21 + DataInitTypeC: 21 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 1 + BoundsCheck: 2 + +BenchmarkProblems: + ######################################## + # FP8 A + MXFP4 B -> BF16 C/D (TN (TransposeA=True, TransposeB=False)) + # Mixed FP8+MXFP4, BF16 output, All tile sizes (64x64, 128x128, 256x256), with Activation + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F4 + DataTypeA: F8 + DestDataType: B + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: True # TN configuration + TransposeB: False # TN configuration + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all + UseBias: 1 + BiasDataTypeList: [s] + - # BenchmarkProblemSizeGroup + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 1, 1, 1, 1] + - [16, 16, 128, 1, 1, 4, 2, 2, 2] + - [16, 16, 128, 1, 1, 2, 4, 2, 2] + - [16, 16, 128, 1, 1, 8, 8, 2, 2] + - [32, 32, 64, 1, 1, 4, 4, 2, 2] + - [32, 32, 64, 1, 1, 8, 8, 2, 2] + - DepthU: [32, 64, 128] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DirectToLds: [0] + - GlobalReadVectorWidthA: [16] + - GlobalReadVectorWidthB: [32] + - LocalReadVectorWidth: [32] + - VectorWidthA: [1] + - VectorWidthB: [1] + - ClusterLocalRead: [1] + - 1LDSBuffer: [0] + - GlobalSplitU: [1] + - GlobalSplitUAlgorithm: ["MultipleBuffer"] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + - StoreVectorWidth: [4] + - InnerUnroll: [1] + - ScheduleIterAlg: [3] + - LdsPadA: [4] + - LdsPadB: [4] + - WorkGroupMapping: [64] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + ######################################## + # 1. Small, Power-of-2 - to test optimized code path + ######################################## + - Exact: [256, 256, 1, 256] + + ######################################## + # 2. Odd M - to test edge M + ######################################## + - Exact: [63, 64, 1, 64] + - Exact: [255, 256, 1, 256] + + ######################################## + # 3. Odd N - to test edge N + ######################################## + - Exact: [64, 63, 1, 64] + - Exact: [256, 255, 1, 256] + + ######################################## + # 4. Odd M and N - to test both edge dimensions + ######################################## + - Exact: [63, 63, 1, 64] + - Exact: [255, 255, 1, 256] + + ######################################## + # 5. Odd K - to test tail loop + ######################################## + - Exact: [64, 64, 1, 63] + - Exact: [256, 256, 1, 255] + + ######################################## + # 6. Small size with batch > 1 + ######################################## + - Exact: [32, 32, 8, 32] + + ######################################## + # 7. Medium size that doesn't divide evenly on CUs (stream-k) + ######################################## + - Exact: [1000, 1000, 1, 256] + - Exact: [1500, 1500, 1, 512] + - Exact: [1024, 1024, 1, 333] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] + - [Enum: gelu] + - [Enum: relu] + - [Enum: sigmoid] + - [Enum: Silu] + - [Enum: Clamp] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/fp8_mxfp4_bf16_tn_act_groupgemm.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/fp8_mxfp4_bf16_tn_act_groupgemm.yaml new file mode 100644 index 00000000000..de344424ac4 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/fp8_mxfp4_bf16_tn_act_groupgemm.yaml @@ -0,0 +1,129 @@ +TestParameters: + marks: [xfail-gfx950, xfail-gfx1250, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201, skip-gfx940, skip-gfx941] # Only for gfx950 + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + PrintSolutionRejectionReason: True + Device: 0 + CMakeBuildType: Release + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 21 + DataInitTypeB: 21 + DataInitTypeC: 21 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 1 + BoundsCheck: 2 + +BenchmarkProblems: + ######################################## + # FP8 A + MXFP4 B -> BF16 C/D (TN (TransposeA=True, TransposeB=False)) + # Mixed FP8+MXFP4, BF16 output, All tile sizes (64x64, 128x128, 256x256), with Activation, GroupedGemm + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F4 + DataTypeA: F8 + DestDataType: B + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: True # TN configuration + TransposeB: False # TN configuration + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all + UseBias: 1 + BiasDataTypeList: [s] + GroupedGemm: True + - # BenchmarkProblemSizeGroup + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 1, 1, 1, 1] + - [16, 16, 128, 1, 1, 4, 2, 2, 2] + - [16, 16, 128, 1, 1, 2, 4, 2, 2] + - [16, 16, 128, 1, 1, 8, 8, 2, 2] + - [32, 32, 64, 1, 1, 4, 4, 2, 2] + - [32, 32, 64, 1, 1, 8, 8, 2, 2] + - DepthU: [32, 64, 128] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DirectToLds: [0] + - GlobalReadVectorWidthA: [16] + - GlobalReadVectorWidthB: [32] + - LocalReadVectorWidth: [32] + - VectorWidthA: [1] + - VectorWidthB: [1] + - ClusterLocalRead: [1] + - 1LDSBuffer: [0] + - GlobalSplitU: [1] + - GlobalSplitUAlgorithm: ["MultipleBuffer"] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + - StoreVectorWidth: [4] + - InnerUnroll: [1] + - ScheduleIterAlg: [3] + - LdsPadA: [4] + - LdsPadB: [4] + - WorkGroupMapping: [64] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + ######################################## + # 1. Small, Power-of-2 - to test optimized code path + ######################################## + - Exact: [256, 256, 1, 256] + + ######################################## + # 2. Odd M - to test edge M + ######################################## + - Exact: [63, 64, 1, 64] + - Exact: [255, 256, 1, 256] + + ######################################## + # 3. Odd N - to test edge N + ######################################## + - Exact: [64, 63, 1, 64] + - Exact: [256, 255, 1, 256] + + ######################################## + # 4. Odd M and N - to test both edge dimensions + ######################################## + - Exact: [63, 63, 1, 64] + - Exact: [255, 255, 1, 256] + + ######################################## + # 5. Odd K - to test tail loop + ######################################## + - Exact: [64, 64, 1, 63] + - Exact: [256, 256, 1, 255] + + ######################################## + # 6. Small size with batch > 1 + ######################################## + - Exact: [32, 32, 8, 32] + + ######################################## + # 7. Medium size that doesn't divide evenly on CUs (stream-k) + ######################################## + - Exact: [1000, 1000, 1, 256] + - Exact: [1500, 1500, 1, 512] + - Exact: [1024, 1024, 1, 333] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] + - [Enum: gelu] + - [Enum: relu] + - [Enum: sigmoid] + - [Enum: Silu] + - [Enum: Clamp] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/fp8_mxfp4_fp32_tn_act.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/fp8_mxfp4_fp32_tn_act.yaml new file mode 100644 index 00000000000..048e28e2cc5 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/fp8_mxfp4_fp32_tn_act.yaml @@ -0,0 +1,128 @@ +TestParameters: + marks: [xfail-gfx950, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201, skip-gfx1250, skip-gfx940, skip-gfx941] # Only for gfx950 + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + PrintSolutionRejectionReason: True + Device: 0 + CMakeBuildType: Release + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 21 + DataInitTypeB: 21 + DataInitTypeC: 21 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 1 + BoundsCheck: 2 + +BenchmarkProblems: + ######################################## + # FP8 A + MXFP4 B -> FP32 C/D (TN (TransposeA=True, TransposeB=False)) + # Mixed FP8+MXFP4, FP32 output, All tile sizes (64x64, 128x128, 256x256), with Activation + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F4 + DataTypeA: F8 + DestDataType: s + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: True # TN configuration + TransposeB: False # TN configuration + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all + UseBias: 1 + BiasDataTypeList: [s] + - # BenchmarkProblemSizeGroup + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 1, 1, 1, 1] + - [16, 16, 128, 1, 1, 4, 2, 2, 2] + - [16, 16, 128, 1, 1, 2, 4, 2, 2] + - [16, 16, 128, 1, 1, 8, 8, 2, 2] + - [32, 32, 64, 1, 1, 4, 4, 2, 2] + - [32, 32, 64, 1, 1, 8, 8, 2, 2] + - DepthU: [32, 64, 128] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DirectToLds: [0] + - GlobalReadVectorWidthA: [16] + - GlobalReadVectorWidthB: [32] + - LocalReadVectorWidth: [32] + - VectorWidthA: [1] + - VectorWidthB: [1] + - ClusterLocalRead: [1] + - 1LDSBuffer: [0] + - GlobalSplitU: [1] + - GlobalSplitUAlgorithm: ["MultipleBuffer"] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + - StoreVectorWidth: [4] + - InnerUnroll: [1] + - ScheduleIterAlg: [3] + - LdsPadA: [4] + - LdsPadB: [4] + - WorkGroupMapping: [64] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + ######################################## + # 1. Small, Power-of-2 - to test optimized code path + ######################################## + - Exact: [256, 256, 1, 256] + + ######################################## + # 2. Odd M - to test edge M + ######################################## + - Exact: [63, 64, 1, 64] + - Exact: [255, 256, 1, 256] + + ######################################## + # 3. Odd N - to test edge N + ######################################## + - Exact: [64, 63, 1, 64] + - Exact: [256, 255, 1, 256] + + ######################################## + # 4. Odd M and N - to test both edge dimensions + ######################################## + - Exact: [63, 63, 1, 64] + - Exact: [255, 255, 1, 256] + + ######################################## + # 5. Odd K - to test tail loop + ######################################## + - Exact: [64, 64, 1, 63] + - Exact: [256, 256, 1, 255] + + ######################################## + # 6. Small size with batch > 1 + ######################################## + - Exact: [32, 32, 8, 32] + + ######################################## + # 7. Medium size that doesn't divide evenly on CUs (stream-k) + ######################################## + - Exact: [1000, 1000, 1, 256] + - Exact: [1500, 1500, 1, 512] + - Exact: [1024, 1024, 1, 333] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] + - [Enum: gelu] + - [Enum: relu] + - [Enum: sigmoid] + - [Enum: Silu] + - [Enum: Clamp] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/fp8_mxfp4_fp32_tn_act_groupgemm.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/fp8_mxfp4_fp32_tn_act_groupgemm.yaml new file mode 100644 index 00000000000..7a6ca30a1b4 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/fp8_mxfp4_fp32_tn_act_groupgemm.yaml @@ -0,0 +1,129 @@ +TestParameters: + marks: [xfail-gfx950, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201, skip-gfx1250, skip-gfx940, skip-gfx941] # Only for gfx950 + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + PrintSolutionRejectionReason: True + Device: 0 + CMakeBuildType: Release + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 21 + DataInitTypeB: 21 + DataInitTypeC: 21 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 1 + BoundsCheck: 2 + +BenchmarkProblems: + ######################################## + # FP8 A + MXFP4 B -> FP32 C/D (TN (TransposeA=True, TransposeB=False)) + # Mixed FP8+MXFP4, FP32 output, All tile sizes (64x64, 128x128, 256x256), with Activation, GroupedGemm + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F4 + DataTypeA: F8 + DestDataType: s + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: True # TN configuration + TransposeB: False # TN configuration + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all + UseBias: 1 + BiasDataTypeList: [s] + GroupedGemm: True + - # BenchmarkProblemSizeGroup + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 1, 1, 1, 1] + - [16, 16, 128, 1, 1, 4, 2, 2, 2] + - [16, 16, 128, 1, 1, 2, 4, 2, 2] + - [16, 16, 128, 1, 1, 8, 8, 2, 2] + - [32, 32, 64, 1, 1, 4, 4, 2, 2] + - [32, 32, 64, 1, 1, 8, 8, 2, 2] + - DepthU: [32, 64, 128] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DirectToLds: [0] + - GlobalReadVectorWidthA: [16] + - GlobalReadVectorWidthB: [32] + - LocalReadVectorWidth: [32] + - VectorWidthA: [1] + - VectorWidthB: [1] + - ClusterLocalRead: [1] + - 1LDSBuffer: [0] + - GlobalSplitU: [1] + - GlobalSplitUAlgorithm: ["MultipleBuffer"] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + - StoreVectorWidth: [4] + - InnerUnroll: [1] + - ScheduleIterAlg: [3] + - LdsPadA: [4] + - LdsPadB: [4] + - WorkGroupMapping: [64] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + ######################################## + # 1. Small, Power-of-2 - to test optimized code path + ######################################## + - Exact: [256, 256, 1, 256] + + ######################################## + # 2. Odd M - to test edge M + ######################################## + - Exact: [63, 64, 1, 64] + - Exact: [255, 256, 1, 256] + + ######################################## + # 3. Odd N - to test edge N + ######################################## + - Exact: [64, 63, 1, 64] + - Exact: [256, 255, 1, 256] + + ######################################## + # 4. Odd M and N - to test both edge dimensions + ######################################## + - Exact: [63, 63, 1, 64] + - Exact: [255, 255, 1, 256] + + ######################################## + # 5. Odd K - to test tail loop + ######################################## + - Exact: [64, 64, 1, 63] + - Exact: [256, 256, 1, 255] + + ######################################## + # 6. Small size with batch > 1 + ######################################## + - Exact: [32, 32, 8, 32] + + ######################################## + # 7. Medium size that doesn't divide evenly on CUs (stream-k) + ######################################## + - Exact: [1000, 1000, 1, 256] + - Exact: [1500, 1500, 1, 512] + - Exact: [1024, 1024, 1, 333] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] + - [Enum: gelu] + - [Enum: relu] + - [Enum: sigmoid] + - [Enum: Silu] + - [Enum: Clamp] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mx32f4_tn.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mx32f4_tn.yaml new file mode 100644 index 00000000000..a0dc563ef3a --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mx32f4_tn.yaml @@ -0,0 +1,202 @@ +TestParameters: + marks: [skip-gfx1250, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx940, skip-gfx941, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201] + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + #PrintSolutionRejectionReason: True + Device: 0 + CMakeBuildType: Release +# MergeFiles: False + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 3 + DataInitTypeB: 3 + DataInitTypeC: 0 + DataInitTypeMXSA: 3 + DataInitTypeMXSB: 3 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 0 + NumElementsToValidate: -1 + BoundsCheck: 0 + MaxFileName: 128 + DeviceLDS: 163840 + MaxLDS: 163840 +# PrintTensorA: True +# PrintTensorB: True +# PrintTensorC: True +# PrintTensorD: True +# PrintTensorRef: True + +BenchmarkProblems: + ######################################## + # F4SS TN + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F4 + DestDataType: S + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: 1 + TransposeB: 0 + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all +# UseScaleAB: "Scalar" +# UseScaleCD: True + UseScaleAlphaVec: 1 + UseBias: 1 + BiasDataTypeList: [s] + + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 1,1, 1,1] # 64x64 + - [16, 16, 128, 1, 1, 4,2, 2,2] # 64x64 + - [16, 16, 128, 1, 1, 2,4, 2,2] # 64x64 +# - [16, 16, 128, 1, 1, 8,8, 2,2] # 64x64 + #- [32, 32, 64, 1, 1, 1,1, 1,1] # 64x64 + #- [32, 32, 64, 1, 1, 2,2, 2,2] # 64x64 + #- [32, 32, 64, 1, 1, 4,2, 2,2] # 64x64 +# - [32, 32, 64, 1, 1, 2,4, 2,2] # 64x64 + #- ForceDisableShadowInit: [True, False] + - UseSgprForGRO: [0,1] + - DepthU: [128, 256] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - AssertSummationElementMultiple: [32] + - LocalReadVectorWidth: [32] + - PrefetchGlobalRead: [0,2] + - PrefetchLocalRead: [0,1] + #- PreloadKernArgs: [0,1] + - ClusterLocalRead: [0,1] + - VectorWidthA: [1] + - VectorWidthB: [1] + - GlobalReadVectorWidthA: [32] + - GlobalReadVectorWidthB: [32] + - ScheduleIterAlg: [3] + - InnerUnroll: [1] + - ExpandPointerSwap: [0,1] +# - LdsBlockSizePerPadA: [0] +# - LdsBlockSizePerPadB: [0] + - TransposeLDS: [1] +# - LdsPadA: [0] +# - LdsPadB: [0] + - WaveSeparateGlobalReadA: [0] + - WaveSeparateGlobalReadB: [0] + - 1LDSBuffer: [0] + - GlobalSplitU: [1,2] +# - GlobalSplitUAlgorithm: ["MultipleBuffer", "MultipleBufferSingleKernel"] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + - StoreVectorWidth: [-1] + - SourceSwap: [1] +# - NumElementsPerBatchStore: [0] +# - StorePriorityOpt: [0] +# - StaggerU: [0] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + # - Exact: [16, 16, 1, 128] + - Exact: [32, 16, 1, 1024] + - Exact: [256, 256, 1, 128] + - Exact: [1025, 513, 1, 2048] +# - Exact: [1026, 514, 1, 2048] +# - Exact: [1027, 515, 1, 2048] +# - Exact: [1028, 516, 1, 2048] +# - Exact: [1029, 517, 1, 2048] +# - Exact: [1030, 518, 1, 2048] +# - Exact: [1031, 519, 1, 2048] +# - Exact: [1032, 520, 1, 2048] +# - Exact: [1033, 521, 1, 2048] +# - Exact: [1034, 522, 1, 2048] +# - Exact: [1035, 523, 1, 2048] +# - Exact: [1036, 524, 1, 2048] +# - Exact: [1037, 525, 1, 2048] +# - Exact: [1038, 526, 1, 2048] +# - Exact: [1039, 527, 1, 2048] +# - Exact: [1040, 528, 1, 2048] +# - Exact: [1041, 529, 1, 2048] +# - Exact: [1042, 530, 1, 2048] +# - Exact: [1043, 531, 1, 2048] +# - Exact: [1044, 532, 1, 2048] # dropped: [1025, 513, 1, 2048] above covers the same odd-edge case + - Exact: [127, 127, 1, 640] #special cleanup case + - Exact: [128, 128, 1, 128] + - Exact: [129, 129, 1, 640] + # - Exact: [128, 128, 1, 127] + # - Exact: [128, 128, 1, 129] + # - Exact: [111, 111, 1, 111] + # - Exact: [777, 777, 1, 777] + - Exact: [128, 128, 1, 512] + #- Range: [[128], [128], [1], [1024,32,1280]] + - Range: [[128], [128], [1], [1024,128,1280]] #make K value is multiple of DepthU + - Range: [[128,8,256], [128], [1], [1024]] + - Range: [[128], [128,8,256], [1], [1024]] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] + - # BenchmarkProblemSizeGroup - DTL + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 8,8, 2,2] # 256x256 + - [16, 16, 128, 1, 1, 8,2, 2,2] # 256x64 + - [16, 16, 128, 1, 1, 2,8, 2,2] # 64x256 + - [32, 32, 64, 1, 1, 4,4, 2,2] # 256x256 + - [32, 32, 64, 1, 1, 2,2, 2,2] # 128x128 + - DepthU: [64,128,256] + - AssertSummationElementMultiple: [32] + - LocalReadVectorWidth: [32] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - ScheduleIterAlg: [3] + - 1LDSBuffer: [0] + - DirectToLds: [0,1,2,3] + - Use64bShadowLimitMX: [0,1] + - StreamK: [0,3] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Exact: [32, 16, 1, 1024] + - Exact: [256, 256, 1, 128] + - Exact: [1025, 513, 1, 2048] + - Exact: [127, 127, 1, 640] #special cleanup case + - Exact: [129, 129, 1, 640] + - Exact: [128, 128, 1, 512] + - Exact: [128, 128, 1, 32] + - Exact: [128, 128, 1, 64] + - Exact: [128, 128, 1, 96] + - Exact: [128, 128, 1, 160] + - Exact: [128, 128, 1, 192] + - Exact: [128, 128, 1, 224] + # mutiple of 16 (not working for now) + #- Exact: [128, 128, 1, 16] + #- Exact: [128, 128, 1, 48] + #- Exact: [128, 128, 1, 80] + #- Exact: [128, 128, 1, 112] + #- Exact: [128, 128, 1, 144] + #- Exact: [128, 128, 1, 176] + #- Exact: [128, 128, 1, 208] + #- Exact: [128, 128, 1, 240] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] + +# LibraryLogic: +# ScheduleName: "aquavanjaram" +# DeviceNames: ["Device 0049"] +# ArchitectureName: "gfx950" +# LibraryType: "GridBased" +# +# LibraryClient: diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mx32f8_tn.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mx32f8_tn.yaml new file mode 100644 index 00000000000..d246f36cdb9 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mx32f8_tn.yaml @@ -0,0 +1,160 @@ +TestParameters: + marks: [skip-gfx1250, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx940, skip-gfx941, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201] + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + # PrintSolutionRejectionReason: True + Device: 0 + CMakeBuildType: Debug +# MergeFiles: False + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 3 + DataInitTypeB: 3 + DataInitTypeC: 0 + DataInitTypeMXSA: 3 + DataInitTypeMXSB: 3 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 0 + NumElementsToValidate: -1 + BoundsCheck: 0 + KeepBuildTmp: True + MaxFileName: 128 + DeviceLDS: 163840 + MaxLDS: 163840 +# PrintTensorA: True +# PrintTensorB: True +# PrintTensorC: True +# PrintTensorD: True +# PrintTensorRef: True + +BenchmarkProblems: + ######################################## + # FP8HS with ScaleABCD + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F8 + DestDataType: S + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: 1 + TransposeB: 0 + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all +# UseScaleAB: "Scalar" +# UseScaleCD: True + UseScaleAlphaVec: 1 + UseBias: 1 + BiasDataTypeList: [s] + + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 1,1, 1,1] # 64x64 + - [16, 16, 128, 1, 1, 4,2, 2,2] # 64x64 + - [16, 16, 128, 1, 1, 2,4, 2,2] # 64x64 +# - [16, 16, 128, 1, 1, 8,8, 2,2] # 64x64 + - [32, 32, 64, 1, 1, 1,1, 1,1] # 64x64 + - [32, 32, 64, 1, 1, 2,2, 2,2] # 64x64 + - [32, 32, 64, 1, 1, 4,2, 2,2] # 64x64 +# - [32, 32, 64, 1, 1, 2,4, 2,2] # 64x64 + #- ForceDisableShadowInit: [True, False] + - UseSgprForGRO: [0,1] + - DepthU: [64, 128, 256] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - AssertSummationElementMultiple: [64] + - LocalReadVectorWidth: [16] + - PrefetchGlobalRead: [0,1,2] + - PrefetchLocalRead: [0,1] + #- PreloadKernArgs: [0,1] + - ClusterLocalRead: [0,1] + - VectorWidthA: [1] + - VectorWidthB: [1] + - GlobalReadVectorWidthA: [16] + - GlobalReadVectorWidthB: [16] + - ScheduleIterAlg: [1,2,3] + - InnerUnroll: [1] + - ExpandPointerSwap: [0,1] +# - LdsBlockSizePerPadA: [0] +# - LdsBlockSizePerPadB: [0] + - TransposeLDS: [1] +# - LdsPadA: [0] +# - LdsPadB: [0] + - WaveSeparateGlobalReadA: [0] + - WaveSeparateGlobalReadB: [0] + - 1LDSBuffer: [0] + - GlobalSplitU: [1,2] +# - GlobalSplitUAlgorithm: ["MultipleBuffer", "MultipleBufferSingleKernel"] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + - StoreVectorWidth: [-1] + - SourceSwap: [1] + - StreamK: [0,3] +# - NumElementsPerBatchStore: [0] +# - StorePriorityOpt: [0] +# - StaggerU: [0] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + # - Exact: [16, 16, 1, 128] + - Exact: [32, 16, 1, 1024] + - Exact: [256, 256, 1, 32] + - Exact: [256, 256, 1, 64] + - Exact: [256, 256, 1, 96] + - Exact: [256, 256, 1, 160] + - Exact: [256, 256, 1, 192] + - Exact: [256, 256, 1, 224] + - Exact: [1025, 513, 1, 2048] +# - Exact: [1026, 514, 1, 2048] +# - Exact: [1027, 515, 1, 2048] +# - Exact: [1028, 516, 1, 2048] +# - Exact: [1029, 517, 1, 2048] +# - Exact: [1030, 518, 1, 2048] +# - Exact: [1031, 519, 1, 2048] +# - Exact: [1032, 520, 1, 2048] +# - Exact: [1033, 521, 1, 2048] +# - Exact: [1034, 522, 1, 2048] +# - Exact: [1035, 523, 1, 2048] +# - Exact: [1036, 524, 1, 2048] +# - Exact: [1037, 525, 1, 2048] +# - Exact: [1038, 526, 1, 2048] +# - Exact: [1039, 527, 1, 2048] +# - Exact: [1040, 528, 1, 2048] +# - Exact: [1041, 529, 1, 2048] +# - Exact: [1042, 530, 1, 2048] +# - Exact: [1043, 531, 1, 2048] + - Exact: [1044, 532, 1, 2048] + - Exact: [127, 127, 1, 640] #special cleanup case + - Exact: [128, 128, 1, 128] + - Exact: [129, 129, 1, 640] + # - Exact: [128, 128, 1, 127] + # - Exact: [128, 128, 1, 129] + # - Exact: [111, 111, 1, 111] + # - Exact: [777, 777, 1, 777] + - Exact: [128, 128, 1, 512] + - Range: [[128], [128], [1], [1024,32,1280]] + - Range: [[128,1,256], [128], [1], [1024]] + - Range: [[128], [128,1,256], [1], [1024]] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] + +# LibraryLogic: +# ScheduleName: "aquavanjaram" +# DeviceNames: ["Device 0049"] +# ArchitectureName: "gfx950" +# LibraryType: "GridBased" +# +# LibraryClient: diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_fp8_bf16_tn_act.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_fp8_bf16_tn_act.yaml new file mode 100644 index 00000000000..afe643a7f74 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_fp8_bf16_tn_act.yaml @@ -0,0 +1,128 @@ +TestParameters: + marks: [xfail-gfx950, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201, skip-gfx1250, skip-gfx940, skip-gfx941] # Only for gfx950 + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + PrintSolutionRejectionReason: True + Device: 0 + CMakeBuildType: Release + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 21 + DataInitTypeB: 21 + DataInitTypeC: 21 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 1 + BoundsCheck: 2 + +BenchmarkProblems: + ######################################## + # MXFP4 A + FP8 B -> BF16 C/D (TN (TransposeA=True, TransposeB=False)) + # Mixed MXFP4+FP8, BF16 output, All tile sizes (64x64, 128x128, 256x256), with Activation + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F4 + DataTypeB: F8 + DestDataType: B + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: True # TN configuration + TransposeB: False # TN configuration + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all + UseBias: 1 + BiasDataTypeList: [s] + - # BenchmarkProblemSizeGroup + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 1, 1, 1, 1] + - [16, 16, 128, 1, 1, 4, 2, 2, 2] + - [16, 16, 128, 1, 1, 2, 4, 2, 2] + - [16, 16, 128, 1, 1, 8, 8, 2, 2] + - [32, 32, 64, 1, 1, 4, 4, 2, 2] + - [32, 32, 64, 1, 1, 8, 8, 2, 2] + - DepthU: [32, 64, 128] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DirectToLds: [0] + - GlobalReadVectorWidthA: [32] + - GlobalReadVectorWidthB: [16] + - LocalReadVectorWidth: [32] + - VectorWidthA: [1] + - VectorWidthB: [1] + - ClusterLocalRead: [1] + - 1LDSBuffer: [0] + - GlobalSplitU: [1] + - GlobalSplitUAlgorithm: ["MultipleBuffer"] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + - StoreVectorWidth: [4] + - InnerUnroll: [1] + - ScheduleIterAlg: [3] + - LdsPadA: [4] + - LdsPadB: [4] + - WorkGroupMapping: [64] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + ######################################## + # 1. Small, Power-of-2 - to test optimized code path + ######################################## + - Exact: [256, 256, 1, 256] + + ######################################## + # 2. Odd M - to test edge M + ######################################## + - Exact: [63, 64, 1, 64] + - Exact: [255, 256, 1, 256] + + ######################################## + # 3. Odd N - to test edge N + ######################################## + - Exact: [64, 63, 1, 64] + - Exact: [256, 255, 1, 256] + + ######################################## + # 4. Odd M and N - to test both edge dimensions + ######################################## + - Exact: [63, 63, 1, 64] + - Exact: [255, 255, 1, 256] + + ######################################## + # 5. Odd K - to test tail loop + ######################################## + - Exact: [64, 64, 1, 63] + - Exact: [256, 256, 1, 255] + + ######################################## + # 6. Small size with batch > 1 + ######################################## + - Exact: [32, 32, 8, 32] + + ######################################## + # 7. Medium size that doesn't divide evenly on CUs (stream-k) + ######################################## + - Exact: [1000, 1000, 1, 256] + - Exact: [1500, 1500, 1, 512] + - Exact: [1024, 1024, 1, 333] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] + - [Enum: gelu] + - [Enum: relu] + - [Enum: sigmoid] + - [Enum: Silu] + - [Enum: Clamp] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_fp8_bf16_tn_act_groupgemm.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_fp8_bf16_tn_act_groupgemm.yaml new file mode 100644 index 00000000000..158960fec55 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_fp8_bf16_tn_act_groupgemm.yaml @@ -0,0 +1,129 @@ +TestParameters: + marks: [xfail-gfx950, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201, skip-gfx1250,skip-gfx940, skip-gfx941] # Only for gfx950 + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + PrintSolutionRejectionReason: True + Device: 0 + CMakeBuildType: Release + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 21 + DataInitTypeB: 21 + DataInitTypeC: 21 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 1 + BoundsCheck: 2 + +BenchmarkProblems: + ######################################## + # MXFP4 A + FP8 B -> BF16 C/D (TN (TransposeA=True, TransposeB=False)) + # Mixed MXFP4+FP8, BF16 output, All tile sizes (64x64, 128x128, 256x256), with Activation, GroupedGemm + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F4 + DataTypeB: F8 + DestDataType: B + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: True # TN configuration + TransposeB: False # TN configuration + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all + UseBias: 1 + BiasDataTypeList: [s] + GroupedGemm: True + - # BenchmarkProblemSizeGroup + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 1, 1, 1, 1] + - [16, 16, 128, 1, 1, 4, 2, 2, 2] + - [16, 16, 128, 1, 1, 2, 4, 2, 2] + - [16, 16, 128, 1, 1, 8, 8, 2, 2] + - [32, 32, 64, 1, 1, 4, 4, 2, 2] + - [32, 32, 64, 1, 1, 8, 8, 2, 2] + - DepthU: [32, 64, 128] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DirectToLds: [0] + - GlobalReadVectorWidthA: [32] + - GlobalReadVectorWidthB: [16] + - LocalReadVectorWidth: [32] + - VectorWidthA: [1] + - VectorWidthB: [1] + - ClusterLocalRead: [1] + - 1LDSBuffer: [0] + - GlobalSplitU: [1] + - GlobalSplitUAlgorithm: ["MultipleBuffer"] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + - StoreVectorWidth: [4] + - InnerUnroll: [1] + - ScheduleIterAlg: [3] + - LdsPadA: [4] + - LdsPadB: [4] + - WorkGroupMapping: [64] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + ######################################## + # 1. Small, Power-of-2 - to test optimized code path + ######################################## + - Exact: [256, 256, 1, 256] + + ######################################## + # 2. Odd M - to test edge M + ######################################## + - Exact: [63, 64, 1, 64] + - Exact: [255, 256, 1, 256] + + ######################################## + # 3. Odd N - to test edge N + ######################################## + - Exact: [64, 63, 1, 64] + - Exact: [256, 255, 1, 256] + + ######################################## + # 4. Odd M and N - to test both edge dimensions + ######################################## + - Exact: [63, 63, 1, 64] + - Exact: [255, 255, 1, 256] + + ######################################## + # 5. Odd K - to test tail loop + ######################################## + - Exact: [64, 64, 1, 63] + - Exact: [256, 256, 1, 255] + + ######################################## + # 6. Small size with batch > 1 + ######################################## + - Exact: [32, 32, 8, 32] + + ######################################## + # 7. Medium size that doesn't divide evenly on CUs (stream-k) + ######################################## + - Exact: [1000, 1000, 1, 256] + - Exact: [1500, 1500, 1, 512] + - Exact: [1024, 1024, 1, 333] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] + - [Enum: gelu] + - [Enum: relu] + - [Enum: sigmoid] + - [Enum: Silu] + - [Enum: Clamp] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_fp8_fp32_tn_act.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_fp8_fp32_tn_act.yaml new file mode 100644 index 00000000000..93d6a3467a9 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_fp8_fp32_tn_act.yaml @@ -0,0 +1,128 @@ +TestParameters: + marks: [xfail-gfx950, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201, skip-gfx1250, skip-gfx940, skip-gfx941] # Only for gfx950 + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + PrintSolutionRejectionReason: True + Device: 0 + CMakeBuildType: Release + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 21 + DataInitTypeB: 21 + DataInitTypeC: 21 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 1 + BoundsCheck: 2 + +BenchmarkProblems: + ######################################## + # MXFP4 A + FP8 B -> FP32 C/D (TN (TransposeA=True, TransposeB=False)) + # Mixed MXFP4+FP8, FP32 output, All tile sizes (64x64, 128x128, 256x256), with Activation + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F4 + DataTypeB: F8 + DestDataType: s + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: True # TN configuration + TransposeB: False # TN configuration + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all + UseBias: 1 + BiasDataTypeList: [s] + - # BenchmarkProblemSizeGroup + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 1, 1, 1, 1] + - [16, 16, 128, 1, 1, 4, 2, 2, 2] + - [16, 16, 128, 1, 1, 2, 4, 2, 2] + - [16, 16, 128, 1, 1, 8, 8, 2, 2] + - [32, 32, 64, 1, 1, 4, 4, 2, 2] + - [32, 32, 64, 1, 1, 8, 8, 2, 2] + - DepthU: [32, 64, 128] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DirectToLds: [0] + - GlobalReadVectorWidthA: [32] + - GlobalReadVectorWidthB: [16] + - LocalReadVectorWidth: [32] + - VectorWidthA: [1] + - VectorWidthB: [1] + - ClusterLocalRead: [1] + - 1LDSBuffer: [0] + - GlobalSplitU: [1] + - GlobalSplitUAlgorithm: ["MultipleBuffer"] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + - StoreVectorWidth: [4] + - InnerUnroll: [1] + - ScheduleIterAlg: [3] + - LdsPadA: [4] + - LdsPadB: [4] + - WorkGroupMapping: [64] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + ######################################## + # 1. Small, Power-of-2 - to test optimized code path + ######################################## + - Exact: [256, 256, 1, 256] + + ######################################## + # 2. Odd M - to test edge M + ######################################## + - Exact: [63, 64, 1, 64] + - Exact: [255, 256, 1, 256] + + ######################################## + # 3. Odd N - to test edge N + ######################################## + - Exact: [64, 63, 1, 64] + - Exact: [256, 255, 1, 256] + + ######################################## + # 4. Odd M and N - to test both edge dimensions + ######################################## + - Exact: [63, 63, 1, 64] + - Exact: [255, 255, 1, 256] + + ######################################## + # 5. Odd K - to test tail loop + ######################################## + - Exact: [64, 64, 1, 63] + - Exact: [256, 256, 1, 255] + + ######################################## + # 6. Small size with batch > 1 + ######################################## + - Exact: [32, 32, 8, 32] + + ######################################## + # 7. Medium size that doesn't divide evenly on CUs (stream-k) + ######################################## + - Exact: [1000, 1000, 1, 256] + - Exact: [1500, 1500, 1, 512] + - Exact: [1024, 1024, 1, 333] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] + - [Enum: gelu] + - [Enum: relu] + - [Enum: sigmoid] + - [Enum: Silu] + - [Enum: Clamp] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_fp8_fp32_tn_act_groupgemm.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_fp8_fp32_tn_act_groupgemm.yaml new file mode 100644 index 00000000000..d8dd486672e --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_fp8_fp32_tn_act_groupgemm.yaml @@ -0,0 +1,129 @@ +TestParameters: + marks: [xfail-gfx950, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201, skip-gfx1250, skip-gfx940, skip-gfx941] # Only for gfx950 + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + PrintSolutionRejectionReason: True + Device: 0 + CMakeBuildType: Release + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 21 + DataInitTypeB: 21 + DataInitTypeC: 21 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 1 + BoundsCheck: 2 + +BenchmarkProblems: + ######################################## + # MXFP4 A + FP8 B -> FP32 C/D (TN (TransposeA=True, TransposeB=False)) + # Mixed MXFP4+FP8, FP32 output, All tile sizes (64x64, 128x128, 256x256), with Activation, GroupedGemm + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F4 + DataTypeB: F8 + DestDataType: s + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: True # TN configuration + TransposeB: False # TN configuration + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all + UseBias: 1 + BiasDataTypeList: [s] + GroupedGemm: True + - # BenchmarkProblemSizeGroup + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 1, 1, 1, 1] + - [16, 16, 128, 1, 1, 4, 2, 2, 2] + - [16, 16, 128, 1, 1, 2, 4, 2, 2] + - [16, 16, 128, 1, 1, 8, 8, 2, 2] + - [32, 32, 64, 1, 1, 4, 4, 2, 2] + - [32, 32, 64, 1, 1, 8, 8, 2, 2] + - DepthU: [32, 64, 128] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DirectToLds: [0] + - GlobalReadVectorWidthA: [32] + - GlobalReadVectorWidthB: [16] + - LocalReadVectorWidth: [32] + - VectorWidthA: [1] + - VectorWidthB: [1] + - ClusterLocalRead: [1] + - 1LDSBuffer: [0] + - GlobalSplitU: [1] + - GlobalSplitUAlgorithm: ["MultipleBuffer"] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + - StoreVectorWidth: [4] + - InnerUnroll: [1] + - ScheduleIterAlg: [3] + - LdsPadA: [4] + - LdsPadB: [4] + - WorkGroupMapping: [64] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + ######################################## + # 1. Small, Power-of-2 - to test optimized code path + ######################################## + - Exact: [256, 256, 1, 256] + + ######################################## + # 2. Odd M - to test edge M + ######################################## + - Exact: [63, 64, 1, 64] + - Exact: [255, 256, 1, 256] + + ######################################## + # 3. Odd N - to test edge N + ######################################## + - Exact: [64, 63, 1, 64] + - Exact: [256, 255, 1, 256] + + ######################################## + # 4. Odd M and N - to test both edge dimensions + ######################################## + - Exact: [63, 63, 1, 64] + - Exact: [255, 255, 1, 256] + + ######################################## + # 5. Odd K - to test tail loop + ######################################## + - Exact: [64, 64, 1, 63] + - Exact: [256, 256, 1, 255] + + ######################################## + # 6. Small size with batch > 1 + ######################################## + - Exact: [32, 32, 8, 32] + + ######################################## + # 7. Medium size that doesn't divide evenly on CUs (stream-k) + ######################################## + - Exact: [1000, 1000, 1, 256] + - Exact: [1500, 1500, 1, 512] + - Exact: [1024, 1024, 1, 333] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] + - [Enum: gelu] + - [Enum: relu] + - [Enum: sigmoid] + - [Enum: Silu] + - [Enum: Clamp] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_mxfp4_bf16_tn_act.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_mxfp4_bf16_tn_act.yaml new file mode 100644 index 00000000000..db0024a6a77 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_mxfp4_bf16_tn_act.yaml @@ -0,0 +1,127 @@ +TestParameters: + marks: [xfail-gfx950, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201, skip-gfx1250, skip-gfx940, skip-gfx941] # Only for gfx950 + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + PrintSolutionRejectionReason: True + Device: 0 + CMakeBuildType: Release + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 21 + DataInitTypeB: 21 + DataInitTypeC: 21 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 1 + BoundsCheck: 2 + +BenchmarkProblems: + ######################################## + # MXFP4 A + MXFP4 B -> BF16 C/D (TN (TransposeA=True, TransposeB=False)) + # Both inputs MXFP4, BF16 output, All tile sizes (64x64, 128x128, 256x256), with Activation + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F4 + DestDataType: B + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: True # TN configuration + TransposeB: False # TN configuration + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all + UseBias: 1 + BiasDataTypeList: [s] + - # BenchmarkProblemSizeGroup + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 1, 1, 1, 1] + - [16, 16, 128, 1, 1, 4, 2, 2, 2] + - [16, 16, 128, 1, 1, 2, 4, 2, 2] + - [16, 16, 128, 1, 1, 8, 8, 2, 2] + - [32, 32, 64, 1, 1, 4, 4, 2, 2] + - [32, 32, 64, 1, 1, 8, 8, 2, 2] + - DepthU: [32, 64, 128] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DirectToLds: [0] + - GlobalReadVectorWidthA: [32] + - GlobalReadVectorWidthB: [32] + - LocalReadVectorWidth: [32] + - VectorWidthA: [1] + - VectorWidthB: [1] + - ClusterLocalRead: [1] + - 1LDSBuffer: [0] + - GlobalSplitU: [1] + - GlobalSplitUAlgorithm: ["MultipleBuffer"] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + - StoreVectorWidth: [4] + - InnerUnroll: [1] + - ScheduleIterAlg: [3] + - LdsPadA: [4] + - LdsPadB: [4] + - WorkGroupMapping: [64] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + ######################################## + # 1. Small, Power-of-2 - to test optimized code path + ######################################## + - Exact: [256, 256, 1, 256] + + ######################################## + # 2. Odd M - to test edge M + ######################################## + - Exact: [63, 64, 1, 64] + - Exact: [255, 256, 1, 256] + + ######################################## + # 3. Odd N - to test edge N + ######################################## + - Exact: [64, 63, 1, 64] + - Exact: [256, 255, 1, 256] + + ######################################## + # 4. Odd M and N - to test both edge dimensions + ######################################## + - Exact: [63, 63, 1, 64] + - Exact: [255, 255, 1, 256] + + ######################################## + # 5. Odd K - to test tail loop + ######################################## + - Exact: [64, 64, 1, 63] + - Exact: [256, 256, 1, 255] + + ######################################## + # 6. Small size with batch > 1 + ######################################## + - Exact: [32, 32, 8, 32] + + ######################################## + # 7. Medium size that doesn't divide evenly on CUs (stream-k) + ######################################## + - Exact: [1000, 1000, 1, 256] + - Exact: [1500, 1500, 1, 512] + - Exact: [1024, 1024, 1, 333] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] + - [Enum: gelu] + - [Enum: relu] + - [Enum: sigmoid] + - [Enum: Silu] + - [Enum: Clamp] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_mxfp4_bf16_tn_act_groupgemm.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_mxfp4_bf16_tn_act_groupgemm.yaml new file mode 100644 index 00000000000..56e2900ca82 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_mxfp4_bf16_tn_act_groupgemm.yaml @@ -0,0 +1,128 @@ +TestParameters: + marks: [xfail-gfx950, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201, skip-gfx1250, skip-gfx940, skip-gfx941] # Only for gfx950 + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + PrintSolutionRejectionReason: True + Device: 0 + CMakeBuildType: Release + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 21 + DataInitTypeB: 21 + DataInitTypeC: 21 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 1 + BoundsCheck: 2 + +BenchmarkProblems: + ######################################## + # MXFP4 A + MXFP4 B -> BF16 C/D (TN (TransposeA=True, TransposeB=False)) + # Both inputs MXFP4, BF16 output, All tile sizes (64x64, 128x128, 256x256), with Activation, GroupedGemm + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F4 + DestDataType: B + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: True # TN configuration + TransposeB: False # TN configuration + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all + UseBias: 1 + BiasDataTypeList: [s] + GroupedGemm: True + - # BenchmarkProblemSizeGroup + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 1, 1, 1, 1] + - [16, 16, 128, 1, 1, 4, 2, 2, 2] + - [16, 16, 128, 1, 1, 2, 4, 2, 2] + - [16, 16, 128, 1, 1, 8, 8, 2, 2] + - [32, 32, 64, 1, 1, 4, 4, 2, 2] + - [32, 32, 64, 1, 1, 8, 8, 2, 2] + - DepthU: [32, 64, 128] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DirectToLds: [0] + - GlobalReadVectorWidthA: [32] + - GlobalReadVectorWidthB: [32] + - LocalReadVectorWidth: [32] + - VectorWidthA: [1] + - VectorWidthB: [1] + - ClusterLocalRead: [1] + - 1LDSBuffer: [0] + - GlobalSplitU: [1] + - GlobalSplitUAlgorithm: ["MultipleBuffer"] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + - StoreVectorWidth: [4] + - InnerUnroll: [1] + - ScheduleIterAlg: [3] + - LdsPadA: [4] + - LdsPadB: [4] + - WorkGroupMapping: [64] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + ######################################## + # 1. Small, Power-of-2 - to test optimized code path + ######################################## + - Exact: [256, 256, 1, 256] + + ######################################## + # 2. Odd M - to test edge M + ######################################## + - Exact: [63, 64, 1, 64] + - Exact: [255, 256, 1, 256] + + ######################################## + # 3. Odd N - to test edge N + ######################################## + - Exact: [64, 63, 1, 64] + - Exact: [256, 255, 1, 256] + + ######################################## + # 4. Odd M and N - to test both edge dimensions + ######################################## + - Exact: [63, 63, 1, 64] + - Exact: [255, 255, 1, 256] + + ######################################## + # 5. Odd K - to test tail loop + ######################################## + - Exact: [64, 64, 1, 63] + - Exact: [256, 256, 1, 255] + + ######################################## + # 6. Small size with batch > 1 + ######################################## + - Exact: [32, 32, 8, 32] + + ######################################## + # 7. Medium size that doesn't divide evenly on CUs (stream-k) + ######################################## + - Exact: [1000, 1000, 1, 256] + - Exact: [1500, 1500, 1, 512] + - Exact: [1024, 1024, 1, 333] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] + - [Enum: gelu] + - [Enum: relu] + - [Enum: sigmoid] + - [Enum: Silu] + - [Enum: Clamp] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_mxfp4_fp32_tn_act.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_mxfp4_fp32_tn_act.yaml new file mode 100644 index 00000000000..cde77b36a14 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_mxfp4_fp32_tn_act.yaml @@ -0,0 +1,128 @@ +TestParameters: + marks: [skip-gfx1250, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201, skip-gfx940, skip-gfx941] # Only for gfx950 + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + PrintSolutionRejectionReason: True + Device: 0 + RotatingBufferSize: 512 + CMakeBuildType: Release + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 21 + DataInitTypeB: 21 + DataInitTypeC: 21 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 1 + BoundsCheck: 2 + +BenchmarkProblems: + ######################################## + # MXFP4 A + MXFP4 B -> FP32 C/D (TN (TransposeA=True, TransposeB=False)) + # Both inputs MXFP4, FP32 output, All tile sizes (64x64, 128x128, 256x256), with Activation + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F4 + DestDataType: s + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: True # TN configuration + TransposeB: False # TN configuration + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all + UseBias: 1 + BiasDataTypeList: [s] + - # BenchmarkProblemSizeGroup + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 1, 1, 1, 1] + - [16, 16, 128, 1, 1, 4, 2, 2, 2] + - [16, 16, 128, 1, 1, 2, 4, 2, 2] + - [16, 16, 128, 1, 1, 8, 8, 2, 2] + - [32, 32, 64, 1, 1, 4, 4, 2, 2] + - [32, 32, 64, 1, 1, 8, 8, 2, 2] + - DepthU: [64, 128, 256] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DirectToLds: [0] + - GlobalReadVectorWidthA: [32] + - GlobalReadVectorWidthB: [32] + - LocalReadVectorWidth: [32] + #- VectorWidthA: [1] + - VectorWidthB: [1] + - ClusterLocalRead: [1] + - 1LDSBuffer: [0] + - GlobalSplitU: [1] + - GlobalSplitUAlgorithm: ["MultipleBuffer"] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + #- StoreVectorWidth: [4] + - InnerUnroll: [1] + - ScheduleIterAlg: [3] + - LdsPadA: [4] + - LdsPadB: [4] + - WorkGroupMapping: [64] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + ######################################## + # 1. Small, Power-of-2 - to test optimized code path + ######################################## + - Exact: [256, 256, 1, 256] + + ######################################## + # 2. Odd M - to test edge M + ######################################## + - Exact: [63, 64, 1, 64] + - Exact: [255, 256, 1, 256] + + ######################################## + # 3. Odd N - to test edge N + ######################################## + - Exact: [64, 63, 1, 64] + - Exact: [256, 255, 1, 256] + + ######################################## + # 4. Odd M and N - to test both edge dimensions + ######################################## + #- Exact: [63, 63, 1, 64] + #- Exact: [255, 255, 1, 256] + + ######################################## + # 5. Odd K - to test tail loop + ######################################## + #- Exact: [64, 64, 1, 63] + #- Exact: [256, 256, 1, 255] + + ######################################## + # 6. Small size with batch > 1 + ######################################## + - Exact: [32, 32, 8, 32] + + ######################################## + # 7. Medium size that doesn't divide evenly on CUs (stream-k) + ######################################## + - Exact: [1000, 1000, 1, 256] + - Exact: [1500, 1500, 1, 512] + #- Exact: [1024, 1024, 1, 333] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] + - [Enum: gelu] + - [Enum: relu] + - [Enum: sigmoid] + - [Enum: Silu] + - [Enum: Clamp] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_mxfp4_fp32_tn_act_groupgemm.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_mxfp4_fp32_tn_act_groupgemm.yaml new file mode 100644 index 00000000000..d638c6880c0 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/mxfp4_mxfp4_fp32_tn_act_groupgemm.yaml @@ -0,0 +1,128 @@ +TestParameters: + marks: [skip-gfx1250, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201, skip-gfx940, skip-gfx941] # Only for gfx950 + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + PrintSolutionRejectionReason: True + Device: 0 + CMakeBuildType: Release + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 21 + DataInitTypeB: 21 + DataInitTypeC: 21 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 1 + BoundsCheck: 2 + +BenchmarkProblems: + ######################################## + # MXFP4 A + MXFP4 B -> FP32 C/D (TN (TransposeA=True, TransposeB=False)) + # Both inputs MXFP4, FP32 output, All tile sizes (64x64, 128x128, 256x256), with Activation, GroupedGemm + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F4 + DestDataType: s + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: True # TN configuration + TransposeB: False # TN configuration + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all + UseBias: 1 + BiasDataTypeList: [s] + GroupedGemm: True + - # BenchmarkProblemSizeGroup + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + - [16, 16, 128, 1, 1, 1, 1, 1, 1] + - [16, 16, 128, 1, 1, 4, 2, 2, 2] + - [16, 16, 128, 1, 1, 2, 4, 2, 2] + - [16, 16, 128, 1, 1, 8, 8, 2, 2] + - [32, 32, 64, 1, 1, 4, 4, 2, 2] + - [32, 32, 64, 1, 1, 8, 8, 2, 2] + - DepthU: [32, 64, 128] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - PrefetchGlobalRead: [2] + - PrefetchLocalRead: [1] + - DirectToLds: [0] + - GlobalReadVectorWidthA: [32] + - GlobalReadVectorWidthB: [32] + - LocalReadVectorWidth: [32] + - VectorWidthA: [1] + - VectorWidthB: [1] + - ClusterLocalRead: [1] + - 1LDSBuffer: [0] + - GlobalSplitU: [1] + - GlobalSplitUAlgorithm: ["MultipleBuffer"] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + - StoreVectorWidth: [4] + - InnerUnroll: [1] + - ScheduleIterAlg: [3] + - LdsPadA: [4] + - LdsPadB: [4] + - WorkGroupMapping: [64] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + ######################################## + # 1. Small, Power-of-2 - to test optimized code path + ######################################## + - Exact: [256, 256, 1, 256] + + ######################################## + # 2. Odd M - to test edge M + ######################################## + - Exact: [63, 64, 1, 64] + - Exact: [255, 256, 1, 256] + + ######################################## + # 3. Odd N - to test edge N + ######################################## + - Exact: [64, 63, 1, 64] + - Exact: [256, 255, 1, 256] + + ######################################## + # 4. Odd M and N - to test both edge dimensions + ######################################## + - Exact: [63, 63, 1, 64] + - Exact: [255, 255, 1, 256] + + ######################################## + # 5. Odd K - to test tail loop + ######################################## + - Exact: [64, 64, 1, 63] + - Exact: [256, 256, 1, 255] + + ######################################## + # 6. Small size with batch > 1 + ######################################## + - Exact: [32, 32, 8, 32] + + ######################################## + # 7. Medium size that doesn't divide evenly on CUs (stream-k) + ######################################## + - Exact: [1000, 1000, 1, 256] + - Exact: [1500, 1500, 1, 512] + - Exact: [1024, 1024, 1, 333] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] + - [Enum: gelu] + - [Enum: relu] + - [Enum: sigmoid] + - [Enum: Silu] + - [Enum: Clamp] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/streamk/sk_mx32f4_quick.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/streamk/sk_mx32f4_quick.yaml new file mode 100644 index 00000000000..95efebb8370 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/streamk/sk_mx32f4_quick.yaml @@ -0,0 +1,109 @@ +TestParameters: + marks: [skip-gfx1250, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx940, skip-gfx941, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201] + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + # PrintSolutionRejectionReason: True + Device: 0 + CMakeBuildType: Debug +# MergeFiles: False + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 3 + DataInitTypeB: 3 + DataInitTypeC: 0 + DataInitTypeMXSA: 3 + DataInitTypeMXSB: 3 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 0 + NumElementsToValidate: -1 + BoundsCheck: 0 + KeepBuildTmp: True + MaxFileName: 128 + DeviceLDS: 163840 + MaxLDS: 163840 + +BenchmarkProblems: + ######################################## + # FP4SS + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F4 + DestDataType: S + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: 1 + TransposeB: 0 + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all +# UseScaleAB: "Scalar" +# UseScaleCD: True + UseScaleAlphaVec: 1 + UseBias: 1 + BiasDataTypeList: [s] + + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + # - [16, 16, 128, 1, 1, 1,1, 1,1] + # - [16, 16, 128, 1, 1, 4,2, 2,2] + - [16, 16, 128, 1, 1, 2,4, 2,2] # 64x128 + # - [16, 16, 128, 1, 1, 8,8, 2,2] + # - [32, 32, 64, 1, 1, 1,1, 1,1] + - [32, 32, 64, 1, 1, 2,2, 2,2] # 128x128 + # - [32, 32, 64, 1, 1, 4,2, 2,2] + # - [32, 32, 64, 1, 1, 2,4, 2,2] + - ForceDisableShadowInit: [True, False] + # - UseSgprForGRO: [0,1] + - UseSgprForGRO: [0] + # - DepthU: [64, 128] + - DepthU: [128] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - AssertSummationElementMultiple: [64] + - LocalReadVectorWidth: [16] + # - PrefetchGlobalRead: [0,1,2] + - PrefetchGlobalRead: [2] + # - PrefetchLocalRead: [0,1] + - PrefetchLocalRead: [1] + # - PreloadKernArgs: [0,1] + - PreloadKernArgs: [1] + # - ClusterLocalRead: [0,1] + - ClusterLocalRead: [1] + - VectorWidthA: [1] + - VectorWidthB: [1] + - GlobalReadVectorWidthA: [16] + - GlobalReadVectorWidthB: [16] + - ScheduleIterAlg: [3] + - InnerUnroll: [1] + - TransposeLDS: [1] + - WaveSeparateGlobalReadA: [0] + - WaveSeparateGlobalReadB: [0] + - 1LDSBuffer: [0] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + - StoreVectorWidth: [-1] + - SourceSwap: [1] + - StreamK: [3] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Exact: [1025, 513, 1, 2048] + - Exact: [1044, 532, 1, 2048] + - Exact: [127, 127, 1, 640] #special cleanup case + - Exact: [128, 128, 1, 128] + - Exact: [129, 129, 1, 640] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/streamk/sk_mx32f8_quick.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/streamk/sk_mx32f8_quick.yaml new file mode 100644 index 00000000000..b2965f8bc46 --- /dev/null +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/streamk/sk_mx32f8_quick.yaml @@ -0,0 +1,109 @@ +TestParameters: + marks: [skip-gfx1250, skip-gfx900, skip-gfx906, skip-gfx908, skip-gfx90a, skip-gfx940, skip-gfx941, skip-gfx942, skip-gfx1010, skip-gfx1011, skip-gfx1012, skip-gfx1030, skip-gfx1100, skip-gfx1101, skip-gfx1102, skip-gfx1200, skip-gfx1201] + +GlobalParameters: + NumElementsToValidate: -1 + MinimumRequiredVersion: 5.0.0 + PrintLevel: 1 + # PrintSolutionRejectionReason: True + Device: 0 + CMakeBuildType: Debug +# MergeFiles: False + KernelTime: True + MaxWorkspaceSize: 13421772800 + DataInitTypeA: 3 + DataInitTypeB: 3 + DataInitTypeC: 0 + DataInitTypeMXSA: 3 + DataInitTypeMXSB: 3 + DataInitTypeAlpha: 1 + DataInitTypeBeta: 0 + NumElementsToValidate: -1 + BoundsCheck: 0 + KeepBuildTmp: True + MaxFileName: 128 + DeviceLDS: 163840 + MaxLDS: 163840 + +BenchmarkProblems: + ######################################## + # FP8SS + ######################################## + - + - # ProblemType + OperationType: GEMM + DataType: F8 + DestDataType: S + ComputeDataType: S + HighPrecisionAccumulate: True + MXBlockA: 32 + MXBlockB: 32 + TransposeA: 1 + TransposeB: 0 + UseBeta: True + Batched: True + Activation: True + ActivationType: hipblaslt_all +# UseScaleAB: "Scalar" +# UseScaleCD: True + UseScaleAlphaVec: 1 + UseBias: 1 + BiasDataTypeList: [s] + + - # BenchmarkProblemSizeGroup - Standard + InitialSolutionParameters: + BenchmarkCommonParameters: + - KernelLanguage: ["Assembly"] + ForkParameters: + - MatrixInstruction: + # - [16, 16, 128, 1, 1, 1,1, 1,1] + # - [16, 16, 128, 1, 1, 4,2, 2,2] + - [16, 16, 128, 1, 1, 2,4, 2,2] # 64x128 + # - [16, 16, 128, 1, 1, 8,8, 2,2] + # - [32, 32, 64, 1, 1, 1,1, 1,1] + - [32, 32, 64, 1, 1, 2,2, 2,2] # 128x128 + # - [32, 32, 64, 1, 1, 4,2, 2,2] + # - [32, 32, 64, 1, 1, 2,4, 2,2] + - ForceDisableShadowInit: [True, False] + # - UseSgprForGRO: [0,1] + - UseSgprForGRO: [0] + # - DepthU: [64, 128] + - DepthU: [128] + - AssertFree0ElementMultiple: [1] + - AssertFree1ElementMultiple: [1] + - AssertSummationElementMultiple: [64] + - LocalReadVectorWidth: [16] + # - PrefetchGlobalRead: [0,1,2] + - PrefetchGlobalRead: [2] + # - PrefetchLocalRead: [0,1] + - PrefetchLocalRead: [1] + # - PreloadKernArgs: [0,1] + - PreloadKernArgs: [1] + # - ClusterLocalRead: [0,1] + - ClusterLocalRead: [1] + - VectorWidthA: [1] + - VectorWidthB: [1] + - GlobalReadVectorWidthA: [16] + - GlobalReadVectorWidthB: [16] + - ScheduleIterAlg: [3] + - InnerUnroll: [1] + - TransposeLDS: [1] + - WaveSeparateGlobalReadA: [0] + - WaveSeparateGlobalReadB: [0] + - 1LDSBuffer: [0] + - GlobalReadPerMfma: [1] + - LocalWritePerMfma: [-1] + - StoreVectorWidth: [-1] + - SourceSwap: [1] + - StreamK: [3] + BenchmarkJoinParameters: + BenchmarkFinalParameters: + - ProblemSizes: + - Exact: [1025, 513, 1, 2048] + - Exact: [1044, 532, 1, 2048] + - Exact: [127, 127, 1, 640] #special cleanup case + - Exact: [128, 128, 1, 128] + - Exact: [129, 129, 1, 640] + - BiasTypeArgs: ['s'] + - ActivationArgs: + - [Enum: none] diff --git a/projects/hipblaslt/tensilelite/client/CMakeLists.txt b/projects/hipblaslt/tensilelite/client/CMakeLists.txt index c4c9901e898..b3e809f3b47 100644 --- a/projects/hipblaslt/tensilelite/client/CMakeLists.txt +++ b/projects/hipblaslt/tensilelite/client/CMakeLists.txt @@ -1,17 +1,19 @@ # Copyright Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -add_executable(tensilelite-client "${CMAKE_CURRENT_SOURCE_DIR}/main.cpp") - -add_executable(cpu-gemm-driver "${CMAKE_CURRENT_SOURCE_DIR}/cpu_gemm_driver.cpp") - -set(COMMON_CLIENT_LIBS - hip::device - hip::host - tensilelite::tensilelite-host - rocisa::rocisa-cpp - OpenMP::OpenMP_CXX -) +# Shared core for tensilelite-client and cpu-gemm-driver. +# +# Both executables consume nearly the same sources, link line, include +# directories, compile definitions, C++ standard, and sanitizer flags. Owning +# them on a single STATIC helper target removes a `foreach(target ...)` loop +# that applied identical properties to two targets, a `COMMON_CLIENT_SOURCES` +# variable that hid the source list from `target_sources` (and from IDE +# navigation), and three duplicated `target_include_directories` calls. +# +# Each executable is now a thin wrapper that adds only its `main` translation +# unit and links to this helper. +add_library(tensilelite-client-common STATIC) +add_library(tensilelite::client-common ALIAS tensilelite-client-common) if(NOT WIN32) find_package(amd_smi REQUIRED) @@ -19,39 +21,70 @@ else() find_package(amd_smi) endif() +target_link_libraries(tensilelite-client-common + PUBLIC + hip::device + hip::host + tensilelite::tensilelite-host + rocisa::rocisa-cpp + OpenMP::OpenMP_CXX +) if(amd_smi_FOUND) - list(APPEND COMMON_CLIENT_LIBS amd_smi) + target_link_libraries(tensilelite-client-common PUBLIC amd_smi) endif() -foreach(target tensilelite-client cpu-gemm-driver) - target_link_libraries(${target} PRIVATE ${COMMON_CLIENT_LIBS}) +# When mxDataGenerator is enabled, hipblaslt::mxdatagen propagates the +# HIPBLASLT_ENABLE_MXDATAGENERATOR macro, the C++20 floor, the include path +# for ``, hipblaslt::headers (so `` API headers +# resolve), and the link to roc::mxDataGenerator. PUBLIC so the executables +# below pick the macro up too. +if(HIPBLASLT_ENABLE_MXDATAGENERATOR) + target_link_libraries(tensilelite-client-common PUBLIC hipblaslt::mxdatagen) +endif() - target_include_directories(${target} PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/include") - target_compile_definitions(${target} PRIVATE TENSILE_DEFAULT_SERIALIZATION) +target_include_directories(tensilelite-client-common + PUBLIC + $ +) - set_target_properties(${target} PROPERTIES - CXX_STANDARD 17 +target_compile_definitions(tensilelite-client-common + PUBLIC + TENSILE_DEFAULT_SERIALIZATION +) + +set_target_properties(tensilelite-client-common + PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD 20 CXX_STANDARD_REQUIRED ON CXX_EXTENSIONS OFF - ) +) - if(HIPBLASLT_ENABLE_ASAN) - hipblaslt_target_configure_sanitizers(${target} PRIVATE) - endif() -endforeach() +if(HIPBLASLT_ENABLE_ASAN) + hipblaslt_target_configure_sanitizers(tensilelite-client-common PUBLIC) +endif() option(TENSILELITE_CLIENT_ENABLE_ROCPROFSDK "Enable in-process ROCprofiler-SDK counters" OFF) -if (TENSILELITE_CLIENT_ENABLE_ROCPROFSDK) - add_compile_definitions(TENSILELITE_CLIENT_ENABLE_ROCPROFSDK=1) - find_package(rocprofiler-sdk) +target_compile_definitions(tensilelite-client-common + PUBLIC + TENSILELITE_CLIENT_ENABLE_ROCPROFSDK=$ +) +if(TENSILELITE_CLIENT_ENABLE_ROCPROFSDK) + find_package(rocprofiler-sdk) + if(NOT rocprofiler-sdk_FOUND) + message(FATAL_ERROR "ROCprofiler-SDK not found.") + endif() +endif() - if (rocprofiler-sdk_FOUND) +add_executable(tensilelite-client "${CMAKE_CURRENT_SOURCE_DIR}/main.cpp") +target_link_libraries(tensilelite-client PRIVATE tensilelite-client-common) + +add_executable(cpu-gemm-driver "${CMAKE_CURRENT_SOURCE_DIR}/cpu_gemm_driver.cpp") +target_link_libraries(cpu-gemm-driver PRIVATE tensilelite-client-common) + +# rocprofiler-sdk is only used by tensilelite-client (in-process counters). +if(TENSILELITE_CLIENT_ENABLE_ROCPROFSDK) target_link_libraries(tensilelite-client PRIVATE rocprofiler-sdk) - else() - message(FATAL_ERROR "ROCprofiler-SDK not found.") - endif() -else() - add_compile_definitions(TENSILELITE_CLIENT_ENABLE_ROCPROFSDK=0) endif() add_subdirectory(include) diff --git a/projects/hipblaslt/tensilelite/client/include/CMakeLists.txt b/projects/hipblaslt/tensilelite/client/include/CMakeLists.txt index 0a03559eb16..7147587b357 100644 --- a/projects/hipblaslt/tensilelite/client/include/CMakeLists.txt +++ b/projects/hipblaslt/tensilelite/client/include/CMakeLists.txt @@ -1,7 +1,8 @@ # Copyright Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -set(COMMON_CLIENT_SOURCES +target_sources(tensilelite-client-common + PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/ProgramOptions.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/TypedId.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/MetaResultReporter.hpp" @@ -30,6 +31,3 @@ set(COMMON_CLIENT_SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/HardwareMonitor.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/CSVStackFile.hpp" ) - -target_sources(tensilelite-client PRIVATE ${COMMON_CLIENT_SOURCES}) -target_sources(cpu-gemm-driver PRIVATE ${COMMON_CLIENT_SOURCES}) diff --git a/projects/hipblaslt/tensilelite/client/include/DataInitialization.hpp b/projects/hipblaslt/tensilelite/client/include/DataInitialization.hpp index b07bf7ca944..e6451a71fe4 100644 --- a/projects/hipblaslt/tensilelite/client/include/DataInitialization.hpp +++ b/projects/hipblaslt/tensilelite/client/include/DataInitialization.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -43,6 +43,17 @@ namespace TensileLite { namespace Client { + inline bool isMXFP4Tensor(const TensorDescriptor& tensor, size_t mxBlock) + { + return tensor.dataType() == rocisa::DataType::Float4 && mxBlock > 0; + } + + inline bool isMXFP4Problem(const ContractionProblemGemm& problem) + { + return isMXFP4Tensor(problem.a(), problem.mxBlockA()) + || isMXFP4Tensor(problem.b(), problem.mxBlockB()); + } + // Problem-indept. from 0~7, and 16, and 23~26 (fixed values for every problem) // And problem-dept. from 8~15 (values depend on problem) // RandomNegPosLimited: integer -128~128. fp -1.0~1.0 @@ -428,14 +439,15 @@ namespace TensileLite initArray( initMode, static_cast(array), descriptor); break; +#ifndef _WIN32 #ifdef TENSILE_USE_FP6 case rocisa::DataType::Float6: - initArray(initMode, static_cast(array), descriptor); + initArray(initMode, static_cast(array), descriptor); break; #endif // #ifdef TENSILE_USE_FP6 #ifdef TENSILE_USE_BF6 case rocisa::DataType::BFloat6: - initArray(initMode, static_cast(array), descriptor); + initArray(initMode, static_cast(array), descriptor); break; #endif // #ifdef TENSILE_USE_BF6 #ifdef TENSILE_USE_FP4 @@ -443,6 +455,7 @@ namespace TensileLite initArray(initMode, static_cast(array), descriptor); break; #endif // #ifdef TENSILE_USE_FP4 +#endif // !_WIN32 case rocisa::DataType::E8: initArray(initMode, static_cast(array), descriptor); break; @@ -463,15 +476,11 @@ namespace TensileLite case rocisa::DataType::BFloat8Float8: case rocisa::DataType::Float8BFloat8_fnuz: case rocisa::DataType::BFloat8Float8_fnuz: -#ifndef TENSILE_USE_FP6 +#ifdef _WIN32 case rocisa::DataType::Float6: -#endif -#ifndef TENSILE_USE_BF6 case rocisa::DataType::BFloat6: -#endif -#ifndef TENSILE_USE_FP4 case rocisa::DataType::Float4: -#endif +#endif // _WIN32 ; } } @@ -847,9 +856,34 @@ namespace TensileLite } virtual void preBenchmarkRun() override {} virtual void postBenchmarkRun() override {} - virtual void preProblem(ContractionProblem* const problem) override {} + virtual void preProblem(ContractionProblem* const problem) override + { + m_currentGemmProblem + = dynamic_cast(problem); + } virtual void postProblem() override {} - virtual void preSolution(ContractionSolution* const solution) override {} + virtual void preSolution(ContractionSolution* const solution) override + { + m_currentSolution = solution; + // Re-init MX FP4 inputs once the solution is known (MI-based preSwizzle when enabled). + // Do not gate on useScaleAB: MX kernels may use MXSA/MXSB with empty useScaleAB. + if(m_currentSolution != nullptr && m_currentGemmProblem != nullptr + && !m_gpuPtrs.empty()) + { + bool isMXFP4 = isMXFP4Problem(*m_currentGemmProblem); + if(isMXFP4) + { + initializeMXData(*m_currentGemmProblem); + copyValidToGPUBuffer(*m_currentGemmProblem); + copyInputs(m_gpuPtrs, + m_gpuBatchPtrs, + m_maxElements, + m_groupedOffsets, + *m_currentGemmProblem, + hipMemcpyDeviceToDevice); + } + } + } virtual void postSolution() override {} virtual bool needMoreRunsInSolution() const override { @@ -966,6 +1000,8 @@ namespace TensileLite void initializeConstantInputs(ContractionProblemGemm const& problem); + void initializeMXData(ContractionProblemGemm const& problem); + void copyInputs(std::vector& ptrs, std::vector& batchPtrs, std::vector& maxElements, @@ -1059,6 +1095,11 @@ namespace TensileLite int64_t m_rotatingBuffer = 0; std::shared_ptr m_rm; int32_t m_rotatingMode = 0; + + ContractionSolution const* m_currentSolution = nullptr; + ContractionProblemGemm const* m_currentGemmProblem = nullptr; + + int m_mxScaleFormat = 0; }; template <> @@ -2153,54 +2194,55 @@ namespace TensileLite { return std::numeric_limits::min(); } +#ifndef _WIN32 #ifdef TENSILE_USE_FP6 template <> - inline Float6x16 DataInitialization::getValue() + inline Float6x32 DataInitialization::getValue() { - return Float6x16(0.0f); + return Float6x32(0.0f); } template <> - inline Float6x16 DataInitialization::getValue() + inline Float6x32 DataInitialization::getValue() { - return Float6x16(1.0f); + return Float6x32(1.0f); } template <> - inline Float6x16 DataInitialization::getValue() + inline Float6x32 DataInitialization::getValue() { - return Float6x16(2.0f); + return Float6x32(2.0f); } template <> - inline Float6x16 DataInitialization::getValue() + inline Float6x32 DataInitialization::getValue() { - return Float6x16(-1.0f); + return Float6x32(-1.0f); } template <> - inline Float6x16 DataInitialization::getValue() + inline Float6x32 DataInitialization::getValue() { - return Float6x16(7.5f); + return Float6x32(7.5f); } template <> - inline Float6x16 DataInitialization::getValue() + inline Float6x32 DataInitialization::getValue() { - return Float6x16(0.125f); + return Float6x32(0.125f); } template <> - inline Float6x16 DataInitialization::getValue() + inline Float6x32 DataInitialization::getValue() { - return Float6x16(0.875f); + return Float6x32(0.875f); } template <> - inline Float6x16 DataInitialization::getValue() + inline Float6x32 DataInitialization::getValue() { throw std::runtime_error("NaN not available for float6."); } template <> - inline Float6x16 DataInitialization::getValue() + inline Float6x32 DataInitialization::getValue() { throw std::runtime_error("Inf not available for float6."); } template <> - inline Float6x16 DataInitialization::getValue() + inline Float6x32 DataInitialization::getValue() { float v0 = static_cast((rand() % 7) - 3); float v1 = static_cast((rand() % 7) - 3); @@ -2218,71 +2260,90 @@ namespace TensileLite float v13 = static_cast((rand() % 7) - 3); float v14 = static_cast((rand() % 7) - 3); float v15 = static_cast((rand() % 7) - 3); - - return Float6x16(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15); - } - template <> - inline Float6x16 DataInitialization::getValue() + float v16 = static_cast((rand() % 7) - 3); + float v17 = static_cast((rand() % 7) - 3); + float v18 = static_cast((rand() % 7) - 3); + float v19 = static_cast((rand() % 7) - 3); + float v20 = static_cast((rand() % 7) - 3); + float v21 = static_cast((rand() % 7) - 3); + float v22 = static_cast((rand() % 7) - 3); + float v23 = static_cast((rand() % 7) - 3); + float v24 = static_cast((rand() % 7) - 3); + float v25 = static_cast((rand() % 7) - 3); + float v26 = static_cast((rand() % 7) - 3); + float v27 = static_cast((rand() % 7) - 3); + float v28 = static_cast((rand() % 7) - 3); + float v29 = static_cast((rand() % 7) - 3); + float v30 = static_cast((rand() % 7) - 3); + float v31 = static_cast((rand() % 7) - 3); + + return Float6x32(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, + v30, v31); + } + template <> + inline Float6x32 DataInitialization::getValue() { - throw std::runtime_error("BadInput not available for float6."); + throw std::runtime_error("BadInput not available for float4."); } template <> - inline Float6x16 DataInitialization::getValue() + inline Float6x32 DataInitialization::getValue() { - throw std::runtime_error("BadOutput not available for float6."); + throw std::runtime_error("BadOutput not available for float4."); } #endif // #ifdef TENSILE_USE_FP6 #ifdef TENSILE_USE_BF6 template <> - inline BFloat6x16 DataInitialization::getValue() + inline BFloat6x32 DataInitialization::getValue() { - return BFloat6x16(0.0f); + return BFloat6x32(0.0f); } template <> - inline BFloat6x16 DataInitialization::getValue() + inline BFloat6x32 DataInitialization::getValue() { - return BFloat6x16(1.0f); + return BFloat6x32(1.0f); } template <> - inline BFloat6x16 DataInitialization::getValue() + inline BFloat6x32 DataInitialization::getValue() { - return BFloat6x16(2.0f); + return BFloat6x32(2.0f); } template <> - inline BFloat6x16 DataInitialization::getValue() + inline BFloat6x32 DataInitialization::getValue() { - return BFloat6x16(-1.0f); + return BFloat6x32(-1.0f); } template <> - inline BFloat6x16 DataInitialization::getValue() + inline BFloat6x32 DataInitialization::getValue() { - return BFloat6x16(7.5f); + return BFloat6x32(7.5f); } template <> - inline BFloat6x16 DataInitialization::getValue() + inline BFloat6x32 DataInitialization::getValue() { - return BFloat6x16(0.125f); + return BFloat6x32(0.125f); } template <> - inline BFloat6x16 DataInitialization::getValue() + inline BFloat6x32 DataInitialization::getValue() { - return BFloat6x16(0.875f); + return BFloat6x32(0.875f); } template <> - inline BFloat6x16 DataInitialization::getValue() + inline BFloat6x32 DataInitialization::getValue() { throw std::runtime_error("NaN not available for float6."); } template <> - inline BFloat6x16 DataInitialization::getValue() + inline BFloat6x32 DataInitialization::getValue() { throw std::runtime_error("Inf not available for float6."); } template <> - inline BFloat6x16 DataInitialization::getValue() + inline BFloat6x32 DataInitialization::getValue() { - BFloat6x16 ret; + BFloat6x32 ret; float v0 = static_cast((rand() % 7) - 3); float v1 = static_cast((rand() % 7) - 3); @@ -2300,28 +2361,49 @@ namespace TensileLite float v13 = static_cast((rand() % 7) - 3); float v14 = static_cast((rand() % 7) - 3); float v15 = static_cast((rand() % 7) - 3); - - return BFloat6x16(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15); - } - template <> - inline BFloat6x16 DataInitialization::getValue() + float v16 = static_cast((rand() % 7) - 3); + float v17 = static_cast((rand() % 7) - 3); + float v18 = static_cast((rand() % 7) - 3); + float v19 = static_cast((rand() % 7) - 3); + float v20 = static_cast((rand() % 7) - 3); + float v21 = static_cast((rand() % 7) - 3); + float v22 = static_cast((rand() % 7) - 3); + float v23 = static_cast((rand() % 7) - 3); + float v24 = static_cast((rand() % 7) - 3); + float v25 = static_cast((rand() % 7) - 3); + float v26 = static_cast((rand() % 7) - 3); + float v27 = static_cast((rand() % 7) - 3); + float v28 = static_cast((rand() % 7) - 3); + float v29 = static_cast((rand() % 7) - 3); + float v30 = static_cast((rand() % 7) - 3); + float v31 = static_cast((rand() % 7) - 3); + + return BFloat6x32(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, + v30, v31); + + } + template <> + inline BFloat6x32 DataInitialization::getValue() { - throw std::runtime_error("BadInput not available for bfloat6."); + throw std::runtime_error("BadInput not available for float4."); } template <> - inline BFloat6x16 DataInitialization::getValue() + inline BFloat6x32 DataInitialization::getValue() { - throw std::runtime_error("BadOutput not available for bfloat6."); + throw std::runtime_error("BadOutput not available for float4."); } #endif // #ifdef TENSILE_USE_BF6 + #ifdef TENSILE_USE_FP4 template <> inline Float4x2 DataInitialization::getValue() { union { - uint8_t bits; + uint8_t bits; Float4x2 value; } x; @@ -2333,7 +2415,7 @@ namespace TensileLite { union { - uint8_t bits; + uint8_t bits; Float4x2 value; } x; @@ -2345,7 +2427,7 @@ namespace TensileLite { union { - uint8_t bits; + uint8_t bits; Float4x2 value; } x; @@ -2357,7 +2439,7 @@ namespace TensileLite { union { - uint8_t bits; + uint8_t bits; Float4x2 value; } x; @@ -2369,7 +2451,7 @@ namespace TensileLite { union { - uint8_t bits; + uint8_t bits; Float4x2 value; } x; @@ -2381,7 +2463,7 @@ namespace TensileLite { union { - uint8_t bits; + uint8_t bits; Float4x2 value; } x; @@ -2393,7 +2475,7 @@ namespace TensileLite { union { - uint8_t bits; + uint8_t bits; Float4x2 value; } x; @@ -2415,13 +2497,13 @@ namespace TensileLite { union { - uint8_t bits; + uint8_t bits; Float4x2 value; } x; uint8_t val0 = static_cast(rand() % 15); uint8_t val1 = static_cast(rand() % 15); - x.bits = (val1 << 4) | val0; + x.bits = (val1 << 4) | val0; return x.value; } template <> @@ -2435,6 +2517,7 @@ namespace TensileLite throw std::runtime_error("BadOutput not available for float4."); } #endif // #ifdef TENSILE_USE_FP4 +#endif // !_WIN32 template <> inline E8 DataInitialization::getValue() @@ -2642,9 +2725,10 @@ namespace TensileLite return value == DataInitialization::getValue(); } +#ifndef _WIN32 #ifdef TENSILE_USE_FP6 template <> - inline bool DataInitialization::isBadInput(Float6x16 value) + inline bool DataInitialization::isBadInput(Float6x32 value) { return false; } @@ -2652,7 +2736,7 @@ namespace TensileLite #ifdef TENSILE_USE_BF6 template <> - inline bool DataInitialization::isBadInput(BFloat6x16 value) + inline bool DataInitialization::isBadInput(BFloat6x32 value) { return false; } @@ -2665,6 +2749,7 @@ namespace TensileLite return false; } #endif // #ifdef TENSILE_USE_FP4 +#endif // !_WIN32 template <> inline bool DataInitialization::isBadInput(E8 value) @@ -2757,9 +2842,10 @@ namespace TensileLite return value == DataInitialization::getValue(); } +#ifndef _WIN32 #ifdef TENSILE_USE_FP6 template <> - inline bool DataInitialization::isBadOutput(Float6x16 value) + inline bool DataInitialization::isBadOutput(Float6x32 value) { return false; } @@ -2767,7 +2853,7 @@ namespace TensileLite #ifdef TENSILE_USE_BF6 template <> - inline bool DataInitialization::isBadOutput(BFloat6x16 value) + inline bool DataInitialization::isBadOutput(BFloat6x32 value) { return false; } @@ -2780,6 +2866,7 @@ namespace TensileLite return false; } #endif // #ifdef TENSILE_USE_FP4 +#endif // !_WIN32 template <> inline bool DataInitialization::isBadOutput(E8 value) @@ -2798,7 +2885,7 @@ namespace TensileLite { float val = useCos ? cos(idx) : sin(idx); if(useAbs) - val = abs(val); + val = std::fabs(val); return val; } @@ -2807,7 +2894,7 @@ namespace TensileLite { double val = useCos ? cos(idx) : sin(idx); if(useAbs) - val = abs(val); + val = std::fabs(val); return val; } @@ -2868,10 +2955,10 @@ namespace TensileLite throw std::runtime_error("Trig not available for Int8."); } +#ifndef _WIN32 #ifdef TENSILE_USE_FP6 template <> - inline Float6x16 - DataInitialization::getTrigValue(int idx, bool useCos, bool useAbs) + inline Float6x32 DataInitialization::getTrigValue(int idx, bool useCos, bool useAbs) { float v0 = getTrigValue(idx, useCos, useAbs); float v1 = getTrigValue(idx, useCos, useAbs); @@ -2889,15 +2976,33 @@ namespace TensileLite float v13 = getTrigValue(idx, useCos, useAbs); float v14 = getTrigValue(idx, useCos, useAbs); float v15 = getTrigValue(idx, useCos, useAbs); - - return Float6x16(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15); + float v16 = getTrigValue(idx, useCos, useAbs); + float v17 = getTrigValue(idx, useCos, useAbs); + float v18 = getTrigValue(idx, useCos, useAbs); + float v19 = getTrigValue(idx, useCos, useAbs); + float v20 = getTrigValue(idx, useCos, useAbs); + float v21 = getTrigValue(idx, useCos, useAbs); + float v22 = getTrigValue(idx, useCos, useAbs); + float v23 = getTrigValue(idx, useCos, useAbs); + float v24 = getTrigValue(idx, useCos, useAbs); + float v25 = getTrigValue(idx, useCos, useAbs); + float v26 = getTrigValue(idx, useCos, useAbs); + float v27 = getTrigValue(idx, useCos, useAbs); + float v28 = getTrigValue(idx, useCos, useAbs); + float v29 = getTrigValue(idx, useCos, useAbs); + float v30 = getTrigValue(idx, useCos, useAbs); + float v31 = getTrigValue(idx, useCos, useAbs); + + return Float6x32(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, + v30, v31); } #endif // #ifdef TENSILE_USE_FP6 #ifdef TENSILE_USE_BF6 template <> - inline BFloat6x16 - DataInitialization::getTrigValue(int idx, bool useCos, bool useAbs) + inline BFloat6x32 DataInitialization::getTrigValue(int idx, bool useCos, bool useAbs) { float v0 = getTrigValue(idx, useCos, useAbs); float v1 = getTrigValue(idx, useCos, useAbs); @@ -2915,21 +3020,40 @@ namespace TensileLite float v13 = getTrigValue(idx, useCos, useAbs); float v14 = getTrigValue(idx, useCos, useAbs); float v15 = getTrigValue(idx, useCos, useAbs); - - return BFloat6x16(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15); + float v16 = getTrigValue(idx, useCos, useAbs); + float v17 = getTrigValue(idx, useCos, useAbs); + float v18 = getTrigValue(idx, useCos, useAbs); + float v19 = getTrigValue(idx, useCos, useAbs); + float v20 = getTrigValue(idx, useCos, useAbs); + float v21 = getTrigValue(idx, useCos, useAbs); + float v22 = getTrigValue(idx, useCos, useAbs); + float v23 = getTrigValue(idx, useCos, useAbs); + float v24 = getTrigValue(idx, useCos, useAbs); + float v25 = getTrigValue(idx, useCos, useAbs); + float v26 = getTrigValue(idx, useCos, useAbs); + float v27 = getTrigValue(idx, useCos, useAbs); + float v28 = getTrigValue(idx, useCos, useAbs); + float v29 = getTrigValue(idx, useCos, useAbs); + float v30 = getTrigValue(idx, useCos, useAbs); + float v31 = getTrigValue(idx, useCos, useAbs); + + return BFloat6x32(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, + v30, v31); } #endif // #ifdef TENSILE_USE_BF6 #ifdef TENSILE_USE_FP4 template <> - inline Float4x2 - DataInitialization::getTrigValue(int idx, bool useCos, bool useAbs) + inline Float4x2 DataInitialization::getTrigValue(int idx, bool useCos, bool useAbs) { float val0 = getTrigValue(idx, useCos, useAbs); float val1 = getTrigValue(idx, useCos, useAbs); return Float4x2(val0, val1); } #endif // #ifdef TENSILE_USE_FP4 +#endif // !_WIN32 template <> inline E8 DataInitialization::getTrigValue(int idx, bool useCos, bool useAbs) @@ -3236,9 +3360,10 @@ namespace TensileLite return getValue(); } +#ifndef _WIN32 #ifdef TENSILE_USE_FP6 template <> - inline Float6x16 DataInitialization::getValue() + inline Float6x32 DataInitialization::getValue() { float v0 = rocm_random_narrow_range{}(); float v1 = rocm_random_narrow_range{}(); @@ -3256,14 +3381,33 @@ namespace TensileLite float v13 = rocm_random_narrow_range{}(); float v14 = rocm_random_narrow_range{}(); float v15 = rocm_random_narrow_range{}(); - - return Float6x16(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15); + float v16 = rocm_random_narrow_range{}(); + float v17 = rocm_random_narrow_range{}(); + float v18 = rocm_random_narrow_range{}(); + float v19 = rocm_random_narrow_range{}(); + float v20 = rocm_random_narrow_range{}(); + float v21 = rocm_random_narrow_range{}(); + float v22 = rocm_random_narrow_range{}(); + float v23 = rocm_random_narrow_range{}(); + float v24 = rocm_random_narrow_range{}(); + float v25 = rocm_random_narrow_range{}(); + float v26 = rocm_random_narrow_range{}(); + float v27 = rocm_random_narrow_range{}(); + float v28 = rocm_random_narrow_range{}(); + float v29 = rocm_random_narrow_range{}(); + float v30 = rocm_random_narrow_range{}(); + float v31 = rocm_random_narrow_range{}(); + + return Float6x32(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, + v30, v31); } #endif // #ifdef TENSILE_USE_FP6 #ifdef TENSILE_USE_BF6 template <> - inline BFloat6x16 DataInitialization::getValue() + inline BFloat6x32 DataInitialization::getValue() { float v0 = rocm_random_narrow_range{}(); float v1 = rocm_random_narrow_range{}(); @@ -3281,8 +3425,27 @@ namespace TensileLite float v13 = rocm_random_narrow_range{}(); float v14 = rocm_random_narrow_range{}(); float v15 = rocm_random_narrow_range{}(); - - return BFloat6x16(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15); + float v16 = rocm_random_narrow_range{}(); + float v17 = rocm_random_narrow_range{}(); + float v18 = rocm_random_narrow_range{}(); + float v19 = rocm_random_narrow_range{}(); + float v20 = rocm_random_narrow_range{}(); + float v21 = rocm_random_narrow_range{}(); + float v22 = rocm_random_narrow_range{}(); + float v23 = rocm_random_narrow_range{}(); + float v24 = rocm_random_narrow_range{}(); + float v25 = rocm_random_narrow_range{}(); + float v26 = rocm_random_narrow_range{}(); + float v27 = rocm_random_narrow_range{}(); + float v28 = rocm_random_narrow_range{}(); + float v29 = rocm_random_narrow_range{}(); + float v30 = rocm_random_narrow_range{}(); + float v31 = rocm_random_narrow_range{}(); + + return BFloat6x32(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, + v30, v31); } #endif // #ifdef TENSILE_USE_BF6 @@ -3290,10 +3453,10 @@ namespace TensileLite template <> inline Float4x2 DataInitialization::getValue() { - return Float4x2(rocm_random_narrow_range{}(), - rocm_random_narrow_range{}()); + return getValue(); } #endif // #ifdef TENSILE_USE_FP4 +#endif // !_WIN32 template <> inline E8 DataInitialization::getValue() @@ -3412,9 +3575,10 @@ namespace TensileLite return getValueWithUpperLowerBoundInteger(); } +#ifndef _WIN32 #ifdef TENSILE_USE_FP6 template <> - inline Float6x16 DataInitialization::getValue() + inline Float6x32 DataInitialization::getValue() { float v0 = getValueWithUpperLowerBoundFP(); float v1 = getValueWithUpperLowerBoundFP(); @@ -3432,14 +3596,33 @@ namespace TensileLite float v13 = getValueWithUpperLowerBoundFP(); float v14 = getValueWithUpperLowerBoundFP(); float v15 = getValueWithUpperLowerBoundFP(); - - return Float6x16(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15); + float v16 = getValueWithUpperLowerBoundFP(); + float v17 = getValueWithUpperLowerBoundFP(); + float v18 = getValueWithUpperLowerBoundFP(); + float v19 = getValueWithUpperLowerBoundFP(); + float v20 = getValueWithUpperLowerBoundFP(); + float v21 = getValueWithUpperLowerBoundFP(); + float v22 = getValueWithUpperLowerBoundFP(); + float v23 = getValueWithUpperLowerBoundFP(); + float v24 = getValueWithUpperLowerBoundFP(); + float v25 = getValueWithUpperLowerBoundFP(); + float v26 = getValueWithUpperLowerBoundFP(); + float v27 = getValueWithUpperLowerBoundFP(); + float v28 = getValueWithUpperLowerBoundFP(); + float v29 = getValueWithUpperLowerBoundFP(); + float v30 = getValueWithUpperLowerBoundFP(); + float v31 = getValueWithUpperLowerBoundFP(); + + return Float6x32(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, + v30, v31); } #endif // #ifdef TENSILE_USE_FP6 #ifdef TENSILE_USE_BF6 template <> - inline BFloat6x16 DataInitialization::getValue() + inline BFloat6x32 DataInitialization::getValue() { float v0 = getValueWithUpperLowerBoundFP(); float v1 = getValueWithUpperLowerBoundFP(); @@ -3457,11 +3640,31 @@ namespace TensileLite float v13 = getValueWithUpperLowerBoundFP(); float v14 = getValueWithUpperLowerBoundFP(); float v15 = getValueWithUpperLowerBoundFP(); - - return BFloat6x16(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15); + float v16 = getValueWithUpperLowerBoundFP(); + float v17 = getValueWithUpperLowerBoundFP(); + float v18 = getValueWithUpperLowerBoundFP(); + float v19 = getValueWithUpperLowerBoundFP(); + float v20 = getValueWithUpperLowerBoundFP(); + float v21 = getValueWithUpperLowerBoundFP(); + float v22 = getValueWithUpperLowerBoundFP(); + float v23 = getValueWithUpperLowerBoundFP(); + float v24 = getValueWithUpperLowerBoundFP(); + float v25 = getValueWithUpperLowerBoundFP(); + float v26 = getValueWithUpperLowerBoundFP(); + float v27 = getValueWithUpperLowerBoundFP(); + float v28 = getValueWithUpperLowerBoundFP(); + float v29 = getValueWithUpperLowerBoundFP(); + float v30 = getValueWithUpperLowerBoundFP(); + float v31 = getValueWithUpperLowerBoundFP(); + + return BFloat6x32(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, + v30, v31); } #endif // #ifdef TENSILE_USE_BF6 + #ifdef TENSILE_USE_FP4 template <> inline Float4x2 DataInitialization::getValue() @@ -3470,6 +3673,7 @@ namespace TensileLite getValueWithUpperLowerBoundFP()); } #endif // #ifdef TENSILE_USE_FP4 +#endif // !_WIN32 template <> inline E8 DataInitialization::getValue() @@ -3564,9 +3768,10 @@ namespace TensileLite return static_cast(i); } +#ifndef _WIN32 #ifdef TENSILE_USE_FP6 template <> - inline Float6x16 DataInitialization::ConvertTo(size_t i) + inline Float6x32 DataInitialization::ConvertTo(size_t i) { float v0 = float(i); float v1 = float(i); @@ -3584,14 +3789,33 @@ namespace TensileLite float v13 = float(i); float v14 = float(i); float v15 = float(i); - - return Float6x16(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15); + float v16 = float(i); + float v17 = float(i); + float v18 = float(i); + float v19 = float(i); + float v20 = float(i); + float v21 = float(i); + float v22 = float(i); + float v23 = float(i); + float v24 = float(i); + float v25 = float(i); + float v26 = float(i); + float v27 = float(i); + float v28 = float(i); + float v29 = float(i); + float v30 = float(i); + float v31 = float(i); + + return Float6x32(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, + v30, v31); } #endif // #ifdef TENSILE_USE_FP6 #ifdef TENSILE_USE_BF6 template <> - inline BFloat6x16 DataInitialization::ConvertTo(size_t i) + inline BFloat6x32 DataInitialization::ConvertTo(size_t i) { float v0 = float(i); float v1 = float(i); @@ -3609,8 +3833,27 @@ namespace TensileLite float v13 = float(i); float v14 = float(i); float v15 = float(i); - - return BFloat6x16(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15); + float v16 = float(i); + float v17 = float(i); + float v18 = float(i); + float v19 = float(i); + float v20 = float(i); + float v21 = float(i); + float v22 = float(i); + float v23 = float(i); + float v24 = float(i); + float v25 = float(i); + float v26 = float(i); + float v27 = float(i); + float v28 = float(i); + float v29 = float(i); + float v30 = float(i); + float v31 = float(i); + + return BFloat6x32(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, + v30, v31); } #endif // #ifdef TENSILE_USE_BF6 @@ -3621,6 +3864,7 @@ namespace TensileLite return Float4x2(float(i), float(i)); } #endif // #ifdef TENSILE_USE_FP4 +#endif // !_WIN32 template <> inline E8 DataInitialization::ConvertTo(size_t i) @@ -3713,9 +3957,10 @@ namespace TensileLite { return static_cast(value); } +#ifndef _WIN32 #ifdef TENSILE_USE_FP6 template <> - inline Float6x16 DataInitialization::convertDoubleTo(double value) + inline Float6x32 DataInitialization::convertDoubleTo(double value) { float v0 = float(value); float v1 = float(value); @@ -3733,14 +3978,33 @@ namespace TensileLite float v13 = float(value); float v14 = float(value); float v15 = float(value); - - return Float6x16(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15); + float v16 = float(value); + float v17 = float(value); + float v18 = float(value); + float v19 = float(value); + float v20 = float(value); + float v21 = float(value); + float v22 = float(value); + float v23 = float(value); + float v24 = float(value); + float v25 = float(value); + float v26 = float(value); + float v27 = float(value); + float v28 = float(value); + float v29 = float(value); + float v30 = float(value); + float v31 = float(value); + + return Float6x32(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, + v30, v31); } #endif // #ifdef TENSILE_USE_FP6 #ifdef TENSILE_USE_BF6 template <> - inline BFloat6x16 DataInitialization::convertDoubleTo(double value) + inline BFloat6x32 DataInitialization::convertDoubleTo(double value) { float v0 = float(value); float v1 = float(value); @@ -3758,8 +4022,27 @@ namespace TensileLite float v13 = float(value); float v14 = float(value); float v15 = float(value); - - return BFloat6x16(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15); + float v16 = float(value); + float v17 = float(value); + float v18 = float(value); + float v19 = float(value); + float v20 = float(value); + float v21 = float(value); + float v22 = float(value); + float v23 = float(value); + float v24 = float(value); + float v25 = float(value); + float v26 = float(value); + float v27 = float(value); + float v28 = float(value); + float v29 = float(value); + float v30 = float(value); + float v31 = float(value); + + return BFloat6x32(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, + v30, v31); } #endif // #ifdef TENSILE_USE_BF6 @@ -3770,6 +4053,7 @@ namespace TensileLite return Float4x2(float(value), float(value)); } #endif // #ifdef TENSILE_USE_FP4 +#endif // !_WIN32 template <> inline E8 DataInitialization::convertDoubleTo(double value) diff --git a/projects/hipblaslt/tensilelite/client/include/LogReporter.hpp b/projects/hipblaslt/tensilelite/client/include/LogReporter.hpp index e4c5f604db6..f4766249778 100644 --- a/projects/hipblaslt/tensilelite/client/include/LogReporter.hpp +++ b/projects/hipblaslt/tensilelite/client/include/LogReporter.hpp @@ -336,6 +336,32 @@ namespace TensileLite reinterpret_cast(data), tensor, reinterpret_cast(ptrVal)); +#ifndef _WIN32 +#ifdef TENSILE_USE_FP6 + else if(tensor.dataType() == rocisa::DataType::Float6) + logTensorTyped(level, + name, + reinterpret_cast(data), + tensor, + reinterpret_cast(ptrVal)); +#endif // #ifdef TENSILE_USE_FP6 +#ifdef TENSILE_USE_BF6 + else if(tensor.dataType() == rocisa::DataType::BFloat6) + logTensorTyped(level, + name, + reinterpret_cast(data), + tensor, + reinterpret_cast(ptrVal)); +#endif // #ifdef TENSILE_USE_BF6 +#ifdef TENSILE_USE_FP4 + else if(tensor.dataType() == rocisa::DataType::Float4) + logTensorTyped(level, + name, + reinterpret_cast(data), + tensor, + reinterpret_cast(ptrVal)); +#endif // #ifdef TENSILE_USE_FP4 +#endif // !_WIN32 else if(tensor.dataType() == rocisa::DataType::ComplexFloat) logTensorTyped(level, name, @@ -352,17 +378,17 @@ namespace TensileLite else if(tensor.dataType() == rocisa::DataType::Float6) logTensorTyped(level, name, - reinterpret_cast(data), + reinterpret_cast(data), tensor, - reinterpret_cast(ptrVal)); + reinterpret_cast(ptrVal)); #endif // #ifdef TENSILE_USE_FP6 #ifdef TENSILE_USE_BF6 else if(tensor.dataType() == rocisa::DataType::BFloat6) logTensorTyped(level, name, - reinterpret_cast(data), + reinterpret_cast(data), tensor, - reinterpret_cast(ptrVal)); + reinterpret_cast(ptrVal)); #endif // #ifdef TENSILE_USE_BF6 #ifdef TENSILE_USE_FP4 else if(tensor.dataType() == rocisa::DataType::Float4) diff --git a/projects/hipblaslt/tensilelite/client/include/ProgramOptions.hpp b/projects/hipblaslt/tensilelite/client/include/ProgramOptions.hpp index ef481d4e7a6..fb4ac7c5f89 100644 --- a/projects/hipblaslt/tensilelite/client/include/ProgramOptions.hpp +++ b/projects/hipblaslt/tensilelite/client/include/ProgramOptions.hpp @@ -89,10 +89,6 @@ namespace TensileLite public: variable_value() = default; - explicit variable_value(std::any v) - : m_value(std::move(v)) - { - } bool empty() const { diff --git a/projects/hipblaslt/tensilelite/client/include/TypedId.hpp b/projects/hipblaslt/tensilelite/client/include/TypedId.hpp index 2fd1d8c2509..2f4f32498ad 100644 --- a/projects/hipblaslt/tensilelite/client/include/TypedId.hpp +++ b/projects/hipblaslt/tensilelite/client/include/TypedId.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2023-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2023-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -131,14 +131,16 @@ namespace TensileLite // hybrid using TypedGemm_F8B8_F8_S = TypedGemm; - using TypedGemm_F8B8_H_S = TypedGemm; + using TypedGemm_F8B8_H_S + = TypedGemm; using TypedGemm_F8B8_B_S = TypedGemm; using TypedGemm_F8B8_S_S = TypedGemm; using TypedGemm_B8F8_F8_S = TypedGemm; - using TypedGemm_B8F8_H_S = TypedGemm; + using TypedGemm_B8F8_H_S + = TypedGemm; using TypedGemm_B8F8_B_S = TypedGemm; using TypedGemm_B8F8_S_S @@ -167,26 +169,78 @@ namespace TensileLite using TypedGemm_B8N_B_S = TypedGemm; // hybrid - using TypedGemm_F8B8N_F8N_S - = TypedGemm; - using TypedGemm_F8B8N_H_S = TypedGemm; - using TypedGemm_F8B8N_B_S - = TypedGemm; - using TypedGemm_F8B8N_S_S - = TypedGemm; - using TypedGemm_B8F8N_F8N_S - = TypedGemm; - using TypedGemm_B8F8N_H_S = TypedGemm; - using TypedGemm_B8F8N_B_S = TypedGemm; - using TypedGemm_B8F8N_S_S - = TypedGemm; - using TypedGemm_F8B8N_B8N_S - = TypedGemm; - using TypedGemm_B8F8N_B8N_S - = TypedGemm; - using TypedGemm_H_F8B8N_H_S = TypedGemm; - using TypedGemm_H_B8F8N_H_S = TypedGemm; - + using TypedGemm_F8B8N_F8N_S = TypedGemm; + using TypedGemm_F8B8N_H_S + = TypedGemm; + using TypedGemm_F8B8N_B_S = TypedGemm; + using TypedGemm_F8B8N_S_S = TypedGemm; + using TypedGemm_B8F8N_F8N_S = TypedGemm; + using TypedGemm_B8F8N_H_S + = TypedGemm; + using TypedGemm_B8F8N_B_S = TypedGemm; + using TypedGemm_B8F8N_S_S = TypedGemm; + using TypedGemm_F8B8N_B8N_S = TypedGemm; + using TypedGemm_B8F8N_B8N_S = TypedGemm; + using TypedGemm_H_F8B8N_H_S + = TypedGemm; + using TypedGemm_H_B8F8N_H_S + = TypedGemm; #ifdef TENSILE_USE_HALF // Mix precision: OCPFP8 @@ -231,26 +285,27 @@ namespace TensileLite #endif // TENSILE_USE_HALF #endif // TENSILE_USE_FP8_BF8 +#ifndef _WIN32 #ifdef TENSILE_USE_FP6 - using TypedGemm_F6_S_S = TypedGemm; + using TypedGemm_F6_S_S = TypedGemm; #endif // TENSILE_USE_FP6 #ifdef TENSILE_USE_BF6 - using TypedGemm_BF6_S_S = TypedGemm; + using TypedGemm_BF6_S_S = TypedGemm; #endif // TENSILE_USE_BF6 #if defined(TENSILE_USE_FP6) && defined(TENSILE_USE_BF6) - using TypedGemm_F6B6_S_S = TypedGemm; - using TypedGemm_B6F6_S_S = TypedGemm; + using TypedGemm_F6B6_S_S = TypedGemm; + using TypedGemm_B6F6_S_S = TypedGemm; #endif // defined(TENSILE_USE_FP6) && defined(TENSILE_USE_BF6) #ifdef TENSILE_USE_FP4 using TypedGemm_F4_S_S = TypedGemm; #endif // TENSILE_USE_FP4 #if defined(TENSILE_USE_FP4) && defined(TENSILE_USE_FP6) - using TypedGemm_F4F6_S_S = TypedGemm; - using TypedGemm_F6F4_S_S = TypedGemm; + using TypedGemm_F4F6_S_S = TypedGemm; + using TypedGemm_F6F4_S_S = TypedGemm; #endif // defined(TENSILE_USE_FP4) && defined(TENSILE_USE_FP6) #if defined(TENSILE_USE_FP4) && defined(TENSILE_USE_BF6) - using TypedGemm_F4B6_S_S = TypedGemm; - using TypedGemm_B6F4_S_S = TypedGemm; + using TypedGemm_F4B6_S_S = TypedGemm; + using TypedGemm_B6F4_S_S = TypedGemm; #endif // defined(TENSILE_USE_FP4) && defined(TENSILE_USE_BF6) #if defined(TENSILE_USE_FP8_BF8) && defined(TENSILE_USE_FP4) // DestDataType: S @@ -284,15 +339,16 @@ namespace TensileLite using TypedGemm_F4B8_B_S = TypedGemm; #endif // defined(TENSILE_USE_FP8_BF8) && defined(TENSILE_USE_FP4) && defined(TENSILE_USE_BF16) #if defined(TENSILE_USE_FP6) && defined(TENSILE_USE_FP8_BF8) - using TypedGemm_F8F6_S_S = TypedGemm; - using TypedGemm_F6F8_S_S = TypedGemm; - using TypedGemm_B8F6_S_S = TypedGemm; - using TypedGemm_F6B8_S_S = TypedGemm; + using TypedGemm_F8F6_S_S = TypedGemm; + using TypedGemm_F6F8_S_S = TypedGemm; + using TypedGemm_B8F6_S_S = TypedGemm; + using TypedGemm_F6B8_S_S = TypedGemm; #endif // defined(TENSILE_USE_FP6) && defined(TENSILE_USE_FP8_BF8) #if defined(TENSILE_USE_BF6) && defined(TENSILE_USE_FP8_BF8) - using TypedGemm_F8B6_S_S = TypedGemm; - using TypedGemm_B6F8_S_S = TypedGemm; - using TypedGemm_B8B6_S_S = TypedGemm; - using TypedGemm_B6B8_S_S = TypedGemm; + using TypedGemm_F8B6_S_S = TypedGemm; + using TypedGemm_B6F8_S_S = TypedGemm; + using TypedGemm_B8B6_S_S = TypedGemm; + using TypedGemm_B6B8_S_S = TypedGemm; #endif // defined(TENSILE_USE_BF6) && defined(TENSILE_USE_FP8_BF8) +#endif // !_WIN32 } // namespace TensileLite diff --git a/projects/hipblaslt/tensilelite/client/main.cpp b/projects/hipblaslt/tensilelite/client/main.cpp index a986f7e0506..e5da637905e 100644 --- a/projects/hipblaslt/tensilelite/client/main.cpp +++ b/projects/hipblaslt/tensilelite/client/main.cpp @@ -61,6 +61,7 @@ #include "Profiler.hpp" #endif +#include #include #include #include @@ -219,6 +220,7 @@ namespace TensileLite ("mx-b-type", po::value()->default_value(rocisa::DataType::E8), "type of mx datatype input matrix B") ("swizzle-tensor-a", po::value()->default_value(false), "Swizzle input tensor A.") ("swizzle-tensor-b", po::value()->default_value(false), "Swizzle input tensor B.") + ("mx-scale-format", po::value()->default_value(0), "MX scale data format (0=none, 1=pre-swizzle for GPU kernel layout)") ("activation-compute-type", po::value()->default_value(rocisa::DataType::None), "Activation compute type.") ("high-precision-accumulate", po::value()->default_value(false), "Use high-precision accumulate.") ("sparse", po::value()->default_value(0), "A or B matrix is sparse matrix.") @@ -1079,7 +1081,7 @@ int main(int argc, const char* argv[]) size_t warmupInvocations = listeners.numWarmupRuns(); size_t syncs = listeners.numSyncs(); size_t enq = listeners.numEnqueuesPerSync(); - size_t maxRotatingBufferNum = max(warmupInvocations, syncs * enq); + size_t maxRotatingBufferNum = std::max(warmupInvocations, syncs * enq); std::vector> inputArr; { diff --git a/projects/hipblaslt/tensilelite/client/src/CMakeLists.txt b/projects/hipblaslt/tensilelite/client/src/CMakeLists.txt index 5ca2209319b..da3c3fcc54f 100644 --- a/projects/hipblaslt/tensilelite/client/src/CMakeLists.txt +++ b/projects/hipblaslt/tensilelite/client/src/CMakeLists.txt @@ -1,37 +1,35 @@ # Copyright Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -set(COMMON_CLIENT_SOURCES - "${CMAKE_CURRENT_SOURCE_DIR}/BenchmarkTimer.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/ClientProblemFactory.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/ProgramOptions.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/DataInitialization.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/PerformanceReporter.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/HardwareMonitorListener.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/TimingEvents.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/HardwareMonitor.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/ReferenceValidator.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/ResultFileReporter.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/ResultReporter.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/MetaRunListener.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/Reference.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/ProgressListener.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/TypedId.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/Rotating.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/CSVStackFile.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/SolutionIterator.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/LibraryUpdateReporter.cpp" -) - -target_sources(tensilelite-client PRIVATE ${COMMON_CLIENT_SOURCES}) -target_sources(cpu-gemm-driver PRIVATE ${COMMON_CLIENT_SOURCES}) - -if (TENSILELITE_CLIENT_ENABLE_ROCPROFSDK) - -target_sources(tensilelite-client +target_sources(tensilelite-client-common PRIVATE - "${CMAKE_CURRENT_SOURCE_DIR}/Profiler.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/RocProfiler.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/BenchmarkTimer.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/ClientProblemFactory.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/ProgramOptions.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/DataInitialization.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/PerformanceReporter.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/HardwareMonitorListener.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/TimingEvents.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/HardwareMonitor.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/ReferenceValidator.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/ResultFileReporter.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/ResultReporter.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/MetaRunListener.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/Reference.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/ProgressListener.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/TypedId.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/Rotating.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/CSVStackFile.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/SolutionIterator.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/LibraryUpdateReporter.cpp" ) +# Profiler sources only target tensilelite-client (cpu-gemm-driver does not +# expose ROCprofiler counters). +if(TENSILELITE_CLIENT_ENABLE_ROCPROFSDK) + target_sources(tensilelite-client + PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/Profiler.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/RocProfiler.cpp" + ) endif() diff --git a/projects/hipblaslt/tensilelite/client/src/ClientProblemFactory.cpp b/projects/hipblaslt/tensilelite/client/src/ClientProblemFactory.cpp index 34d85d41987..14e0bd9b47c 100644 --- a/projects/hipblaslt/tensilelite/client/src/ClientProblemFactory.cpp +++ b/projects/hipblaslt/tensilelite/client/src/ClientProblemFactory.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -130,6 +130,15 @@ namespace TensileLite m_tensorStrides[i] = std::vector>(); } } + + // MX scale element types: use dedicated options (see main.cpp mx-a-type / mx-b-type). + // Do not rely on the generic tensor loop alone — args.count("mx-a-type") is often false + // when the value only comes from program_options default_value or from the INI merge. + m_tensorTypes[ContractionProblemGemm::TENSOR::MXSA] + = args["mx-a-type"].as(); + m_tensorTypes[ContractionProblemGemm::TENSOR::MXSB] + = args["mx-b-type"].as(); + // Get constant types for(size_t i = 0; i < constants.size(); i++) { diff --git a/projects/hipblaslt/tensilelite/client/src/DataInitialization.cpp b/projects/hipblaslt/tensilelite/client/src/DataInitialization.cpp index 533955e631a..4723a8a5d42 100644 --- a/projects/hipblaslt/tensilelite/client/src/DataInitialization.cpp +++ b/projects/hipblaslt/tensilelite/client/src/DataInitialization.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,6 +25,10 @@ *******************************************************************************/ #include "DataInitialization.hpp" + +#if HIPBLASLT_ENABLE_MXDATAGENERATOR +#include +#endif #include "TensorDataManipulation.hpp" #include "Utility.hpp" // #include "DataInitializationTyped.hpp" @@ -908,6 +912,7 @@ namespace TensileLite , m_keepPristineCopyOnGPU(args["pristine-on-gpu"].as()) , m_workspaceSize(problemFactory.workspaceSize()) , m_pruneMode(args["prune-mode"].as()) + , m_mxScaleFormat(args["mx-scale-format"].as()) { m_rotatingBuffer @@ -958,12 +963,10 @@ namespace TensileLite } } - bool isRMInit = false; for(auto const& p : problemFactory.problems()) { if(auto ptr = dynamic_cast(p.get())) { - std::vector vec_rm; const ContractionProblemGemm& problem = (*ptr); for(size_t i = 0; i < problem.tensors().size(); i++) { @@ -994,20 +997,6 @@ namespace TensileLite pristine.maxElements = std::max(pristine.maxElements, numAllocatedElements); - if(m_rotatingBuffer) - { - if(i <= ContractionProblemGemm::TENSOR::METADATA) - { - if(i == ContractionProblemGemm::TENSOR::C && problem.beta() == 0.0) - { - vec_rm.push_back(0); - } - else - { - vec_rm.push_back(numAllocatedBytes); - } - } - } if(m_vdata[i].name.empty()) { m_vdata[i].name = problem.tensors()[i].getName(); @@ -1020,15 +1009,6 @@ namespace TensileLite throw std::runtime_error(s.c_str()); } } - if(m_rotatingBuffer) - { - if(!isRMInit) - { - m_rm = std::make_shared(vec_rm.size()); - isRMInit = true; - } - m_rm->addRotatingSize(vec_rm); - } auto constants = problem.constants(); for(size_t i = 0; i < constants.size(); i++) { @@ -1059,12 +1039,10 @@ namespace TensileLite size_t maxElements; std::vector offsets; }; - std::vector vec_rm; - auto gElements + auto gElements = std::vector>(m_vdata.size()); for(auto const& problem : problems.gemms) { - std::vector tmp_rm; for(size_t i = 0; i < problem.tensors().size(); i++) { auto dataType = problem.tensors()[i].dataType(); @@ -1083,22 +1061,6 @@ namespace TensileLite += problem.tensors()[i].totalAllocatedElements(); gElements[i][dataType].offsets.push_back( problem.tensors()[i].totalAllocatedElements()); - if(m_rotatingBuffer) - { - if(i <= ContractionProblemGemm::TENSOR::METADATA) - { - if(i == ContractionProblemGemm::TENSOR::C - && problem.beta() == 0.0) - { - tmp_rm.push_back(0); - } - else - { - tmp_rm.push_back( - problem.tensors()[i].totalAllocatedBytes()); - } - } - } if(m_vdata[i].name.empty()) { m_vdata[i].name = problem.tensors()[i].getName(); @@ -1112,21 +1074,6 @@ namespace TensileLite throw std::runtime_error(s.c_str()); } } - if(vec_rm.empty()) - { - vec_rm = tmp_rm; - } - else - { - if(vec_rm.size() != tmp_rm.size()) - { - throw std::runtime_error("Unable to update vec_rm."); - } - for(size_t i = 0; i < tmp_rm.size(); i++) - { - vec_rm[i] += tmp_rm[i]; - } - } auto constants = problem.constants(); for(size_t i = 0; i < constants.size(); i++) { @@ -1148,15 +1095,6 @@ namespace TensileLite numOfBatch *= problem.batchSize(i); m_maxBatch = std::max(m_maxBatch, numOfBatch); } - if(m_rotatingBuffer) - { - if(!isRMInit) - { - m_rm = std::make_shared(vec_rm.size()); - isRMInit = true; - } - m_rm->addRotatingSize(vec_rm); - } // Update maxElements for(size_t i = 0; i < gElements.size(); i++) @@ -1222,8 +1160,8 @@ namespace TensileLite else if(m_curBoundsCheck == BoundsCheckMode::GuardPageFront || m_curBoundsCheck == BoundsCheckMode::GuardPageBack) { - float dataTypeSize = DataTypeInfo::Get(p->first).elementSize; - unsigned int roundUpSize = divideElementSize(pageSize, dataTypeSize); + float dataTypeSize = DataTypeInfo::Get(p->first).elementSize; + size_t roundUpSize = divideElementSize(pageSize, dataTypeSize); p->second.maxElements = RoundUpToMultiple(p->second.maxElements, roundUpSize); // No bias page guard @@ -1234,6 +1172,98 @@ namespace TensileLite << ToString(m_vdata[i].init) << std::endl; } + // Rotating buffer sizes must match post-bounds-check pristine.maxElements (e.g. guard + // page round-up). vec_rm was previously built before that adjustment, undersizing pools. + if(m_rotatingBuffer) + { + m_rm.reset(); + bool isRMInitPost = false; + for(auto const& p : problemFactory.problems()) + { + if(auto ptr = dynamic_cast(p.get())) + { + std::vector vec_rm; + const ContractionProblemGemm& problem = *ptr; + for(size_t i = 0; i < problem.tensors().size(); i++) + { + if(i > ContractionProblemGemm::TENSOR::METADATA) + continue; + auto dataType = problem.tensors()[i].dataType(); + auto it = m_vdata[i].pristine.find(dataType); + if(i == ContractionProblemGemm::TENSOR::C && problem.beta() == 0.0) + { + vec_rm.push_back(0); + continue; + } + if(it == m_vdata[i].pristine.end() || it->second.maxElements == 0) + { + vec_rm.push_back(0); + continue; + } + size_t const bytes = multiplyElementSize( + it->second.maxElements, DataTypeInfo::Get(dataType).elementSize); + vec_rm.push_back(bytes); + } + if(!isRMInitPost) + { + m_rm = std::make_shared(vec_rm.size()); + isRMInitPost = true; + } + m_rm->addRotatingSize(vec_rm); + } + else if(auto ptr = dynamic_cast(p.get())) + { + const ContractionProblemGroupedGemm& grouped = *ptr; + std::vector vec_rm; + for(auto const& problem : grouped.gemms) + { + std::vector tmp_rm; + for(size_t i = 0; i < problem.tensors().size(); i++) + { + if(i > ContractionProblemGemm::TENSOR::METADATA) + continue; + auto dataType = problem.tensors()[i].dataType(); + auto it = m_vdata[i].pristine.find(dataType); + if(i == ContractionProblemGemm::TENSOR::C && problem.beta() == 0.0) + { + tmp_rm.push_back(0); + continue; + } + if(it == m_vdata[i].pristine.end() || it->second.maxElements == 0) + { + tmp_rm.push_back(0); + continue; + } + size_t const bytes = multiplyElementSize( + it->second.maxElements, DataTypeInfo::Get(dataType).elementSize); + tmp_rm.push_back(bytes); + } + if(vec_rm.empty()) + { + vec_rm = std::move(tmp_rm); + } + else + { + if(vec_rm.size() != tmp_rm.size()) + { + throw std::runtime_error("Unable to update vec_rm."); + } + for(size_t j = 0; j < tmp_rm.size(); j++) + { + vec_rm[j] += tmp_rm[j]; + } + } + } + if(!isRMInitPost) + { + m_rm = std::make_shared(vec_rm.size()); + isRMInitPost = true; + } + m_rm->addRotatingSize(vec_rm); + } + } + } + // Init contants for(size_t i = 0; i < m_cdata.size(); i++) { @@ -1289,6 +1319,13 @@ namespace TensileLite m_problemDependentData |= (m_sparse | (args["bias-type-args"].as>().size() > 1)); + + // Force problem-dependent initialization for MX FP4 to enable mxDataGenerator + if(args.count("mx-a-block") && args["mx-a-block"].as() > 0) + m_problemDependentData = true; + if(args.count("mx-b-block") && args["mx-b-block"].as() > 0) + m_problemDependentData = true; + allocNewCPUInputs(); allocNewGPUInputs(); @@ -1600,7 +1637,8 @@ namespace TensileLite { padding = p.maxElements - t.totalAllocatedElements(); } - padding *= DataTypeInfo::Get(t.dataType()).elementSize; + padding = multiplyElementSize(padding, + DataTypeInfo::Get(t.dataType()).elementSize); return padding; }; @@ -1712,6 +1750,10 @@ namespace TensileLite void DataInitialization::initializeCPUInputs(ContractionProblemGemm const& problem) { + bool useMXGenerator = isMXFP4Problem(problem); + if(useMXGenerator) + initializeMXData(problem); + auto& tensors = problem.tensors(); for(size_t i = 0; i < m_vdata.size(); i++) { @@ -1719,6 +1761,12 @@ namespace TensileLite or i == ContractionProblemGemm::TENSOR::METADATA) continue; + if(useMXGenerator && (i == ContractionProblemGemm::TENSOR::A + || i == ContractionProblemGemm::TENSOR::B + || i == ContractionProblemGemm::TENSOR::MXSA + || i == ContractionProblemGemm::TENSOR::MXSB)) + continue; + if(m_problemDependentData) { // Should this m_cEqualsD set in ContractionProblem or boost args? @@ -1777,6 +1825,180 @@ namespace TensileLite } } +#if HIPBLASLT_ENABLE_MXDATAGENERATOR + namespace + { + /** Maps Tensile MX scale element type to hipDataType for generateMXInput (mxDataGen). */ + hipDataType hipMxScaleTypeForDataGenerator(rocisa::DataType mxType) + { + switch(mxType) + { + case rocisa::DataType::Float8: + return HIP_R_8F_E4M3; + case rocisa::DataType::E5M3: + return static_cast(HIP_R_8F_E5M3_EXT); + case rocisa::DataType::E8: + case rocisa::DataType::None: + return HIP_R_8F_UE8M0; + default: + throw std::runtime_error( + "initializeMXData: unsupported MX scale element type for generateMXInput"); + } + } + } // namespace + + void DataInitialization::initializeMXData(ContractionProblemGemm const& problem) + { + // Initializes A, B, MXSA, MXSB so the default-init loop in initializeCPUInputs + // can safely skip them. For MX-FP4 sides we drive mxDataGenerator (so the values + // are coordinated with their E8 scales); for any non-FP4 side (e.g. MX-B6 or non-MX + // mixed-mode) we fall back to the same initArray path the default loop would have + // taken, to avoid leaving the malloc'd buffers uninitialized. + auto const& tensors = problem.tensors(); + + auto initTensorFromDefault = [&](int i) { + for(auto& p : m_vdata[i].pristine) + { + if(p.second.initDescriptor[0] != tensors[i]) + { + p.second.initDescriptor[0] = tensors[i]; + initArray(p.first, + m_vdata[i].init, + p.second.cpuInput.valid.get(), + tensors[i]); + } + } + }; + + // Compute preSwizzle parameters from the solution's matrix instruction to rearrange + // the scale tensor into the GPU kernel's expected memory layout + std::vector preSwizzleA, preTileA, preSwizzleB, preTileB; + + if(m_mxScaleFormat > 0 && m_currentSolution != nullptr + && !m_currentSolution->problemType.useScaleAB.empty()) + { + auto const& mi = m_currentSolution->sizeMapping.matrixInstruction; + size_t MiK = static_cast(mi[2]); + constexpr size_t swizzleTileMN = 32; // 2 SIMDs * 16 lanes per wave for MN access + constexpr size_t tileK = 256 / swizzleTileMN; // scale blocks per wave in K + + if(MiK > 0) + { + if(problem.mxBlockA() > 0 && MiK % problem.mxBlockA() == 0) + { + // scale tensor: scaleRows = sizes[0]/mxBlock, scaleCols = sizes[1] + // preSwizzle requires both to be multiples of their tile dimensions + size_t scaleRowsA = problem.a().sizes()[0] / problem.mxBlockA(); + size_t scaleColsA = problem.a().sizes()[1]; + if(scaleRowsA % tileK == 0 && scaleColsA % swizzleTileMN == 0) + { + size_t subTileK = MiK / problem.mxBlockA(); + preSwizzleA = {swizzleTileMN, tileK, subTileK}; + preTileA = {tileK, swizzleTileMN}; + } + } + + if(problem.mxBlockB() > 0 && MiK % problem.mxBlockB() == 0) + { + size_t scaleRowsB = problem.b().sizes()[0] / problem.mxBlockB(); + size_t scaleColsB = problem.b().sizes()[1]; + if(scaleRowsB % tileK == 0 && scaleColsB % swizzleTileMN == 0) + { + size_t subTileK = MiK / problem.mxBlockB(); + preSwizzleB = {swizzleTileMN, tileK, subTileK}; + preTileB = {tileK, swizzleTileMN}; + } + } + } + } + + if(isMXFP4Tensor(problem.a(), problem.mxBlockA())) + { + auto const& tensorA = problem.a(); + auto rows = tensorA.sizes()[0]; + auto cols = tensorA.sizes()[1]; + auto stride = tensorA.strides()[1]; + + auto& pristineA + = m_vdata[ContractionProblemGemm::TENSOR::A].pristine[rocisa::DataType::Float4]; + auto& pristineE8A + = m_vdata[ContractionProblemGemm::TENSOR::MXSA].pristine[problem.mxsa().dataType()]; + + generateMXInput((hipDataType)HIP_R_4F_E2M1, + hipMxScaleTypeForDataGenerator(problem.mxTypeA()), + pristineA.cpuInput.valid.get(), + pristineE8A.cpuInput.valid.get(), + rows, + cols, + stride, + problem.transA(), + preSwizzleA, + preTileA, + problem.mxBlockA(), + 1, + true, + "Bounded", + -1.0f, + 1.0f); + } + else + { + // A is not FP4 (or mxBlockA == 0). The default-init loop will skip A and + // MXSA because useMXGenerator is true, so seed them here with the same + // initArray path the default loop would have used. + initTensorFromDefault(ContractionProblemGemm::TENSOR::A); + if(problem.mxBlockA() > 0) + initTensorFromDefault(ContractionProblemGemm::TENSOR::MXSA); + } + + if(isMXFP4Tensor(problem.b(), problem.mxBlockB())) + { + auto const& tensorB = problem.b(); + auto rows = tensorB.sizes()[0]; + auto cols = tensorB.sizes()[1]; + auto stride = tensorB.strides()[1]; + + auto& pristineB + = m_vdata[ContractionProblemGemm::TENSOR::B].pristine[rocisa::DataType::Float4]; + auto& pristineE8B + = m_vdata[ContractionProblemGemm::TENSOR::MXSB].pristine[problem.mxsb().dataType()]; + + generateMXInput((hipDataType)HIP_R_4F_E2M1, + hipMxScaleTypeForDataGenerator(problem.mxTypeB()), + pristineB.cpuInput.valid.get(), + pristineE8B.cpuInput.valid.get(), + rows, + cols, + stride, + problem.transB(), + preSwizzleB, + preTileB, + problem.mxBlockB(), + 1, + false, + "Bounded", + -1.0f, + 1.0f); + } + else + { + // B is not FP4 (or mxBlockB == 0). Same fallback rationale as the A side. + initTensorFromDefault(ContractionProblemGemm::TENSOR::B); + if(problem.mxBlockB() > 0) + initTensorFromDefault(ContractionProblemGemm::TENSOR::MXSB); + } + } +#else // HIPBLASLT_ENABLE_MXDATAGENERATOR + void DataInitialization::initializeMXData(ContractionProblemGemm const& /*problem*/) + { + // The MX data generator is disabled at build time. Reaching this + // path means a problem requiring MX FP4 initialization was issued + // against a build that doesn't include mxDataGenerator support. + throw std::runtime_error( + "MX data initialization requires HIPBLASLT_ENABLE_MXDATAGENERATOR=ON at build time"); + } +#endif // HIPBLASLT_ENABLE_MXDATAGENERATOR + void DataInitialization::initializeConstantInputs(ContractionProblemGemm const& problem) { // Update constants if needed @@ -1827,10 +2049,26 @@ namespace TensileLite case rocisa::DataType::BFloat8_fnuz: prop.value = getValue(prop.init, prop.freeValue); break; +#ifndef _WIN32 +#ifdef TENSILE_USE_FP6 case rocisa::DataType::Float6: + prop.value = getValue(prop.init, prop.freeValue); + break; +#endif // #ifdef TENSILE_USE_FP6 +#ifdef TENSILE_USE_BF6 case rocisa::DataType::BFloat6: + prop.value = getValue(prop.init, prop.freeValue); + break; +#endif // #ifdef TENSILE_USE_BF6 +#ifdef TENSILE_USE_FP4 case rocisa::DataType::Float4: + prop.value = getValue(prop.init, prop.freeValue); + break; +#endif // #ifdef TENSILE_USE_FP4 +#endif // !_WIN32 case rocisa::DataType::E8: + prop.value = getValue(prop.init, prop.freeValue); + break; case rocisa::DataType::E5M3: case rocisa::DataType::Int64: case rocisa::DataType::XFloat32: @@ -1838,7 +2076,13 @@ namespace TensileLite case rocisa::DataType::Float8BFloat8: case rocisa::DataType::BFloat8Float8: case rocisa::DataType::Float8BFloat8_fnuz: - case rocisa::DataType::BFloat8Float8_fnuz:; + case rocisa::DataType::BFloat8Float8_fnuz: +#ifdef _WIN32 + case rocisa::DataType::Float6: + case rocisa::DataType::BFloat6: + case rocisa::DataType::Float4: +#endif // _WIN32 + ; } } if(Debug::Instance().printTensorInfo() && prop.dataType != rocisa::DataType::None) @@ -2542,11 +2786,11 @@ namespace TensileLite auto castInputs = static_pointer_cast(inputs); size_t rotatingSize = getRotatingSize(*gemmProblem, *castInputs); int32_t rotatingNum - = min(maxRotatingBufferNum, ceil((float)m_rotatingBuffer / rotatingSize)) + = std::min(maxRotatingBufferNum, static_cast(ceil((float)m_rotatingBuffer / rotatingSize))) - 1; // Minus the original buffer. // <= 0 means don't rotating - rotatingNum = max(0, rotatingNum); + rotatingNum = std::max(0, rotatingNum); int32_t totalRotatingSizeNeeded = rotatingNum * rotatingSize; std::cout << "Rotating buffer set to: " << m_rotatingBuffer @@ -2604,11 +2848,11 @@ namespace TensileLite += getRotatingSize(groupedProblem->gemms[i], castInputs->grouped[i]); } int32_t rotatingNum - = min(maxRotatingBufferNum, ceil((float)m_rotatingBuffer / rotatingSize)) + = std::min(maxRotatingBufferNum, static_cast(ceil((float)m_rotatingBuffer / rotatingSize))) - 1; // Minus the original buffer. // <= 0 means don't rotating - rotatingNum = max(0, rotatingNum); + rotatingNum = std::max(0, rotatingNum); int32_t totalRotatingSizeNeeded = rotatingNum * rotatingSize; std::cout << "Rotating buffer set to: " << m_rotatingBuffer diff --git a/projects/hipblaslt/tensilelite/client/src/Reference.cpp b/projects/hipblaslt/tensilelite/client/src/Reference.cpp index 50242538971..025a1e641d5 100644 --- a/projects/hipblaslt/tensilelite/client/src/Reference.cpp +++ b/projects/hipblaslt/tensilelite/client/src/Reference.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -31,6 +31,7 @@ #include "TimingInstrumentation.hpp" #include "TypedId.hpp" +#include #include #include #include @@ -67,6 +68,35 @@ namespace TensileLite } } + /** One MX scale tensor element as float (E8 / E5M3 / Float8 E4M3). + * + * Returns the magnitude of the scale (std::fabs). MX scales are interpreted + * as positive multipliers per the OCP MX spec; for the canonical UE8M0 + * (E8) type this is a no-op, but for E5M3 / Float8 (E4M3) used as scale + * elements it preserves the prior reference behaviour that explicitly + * applied abs() before folding the scale into the accumulator. + */ + inline float mxScaleElementAsFloat(rocisa::DataType mxType, void const* base, size_t index) + { + float v; + switch(mxType) + { + case rocisa::DataType::E8: + v = static_cast(static_cast(base)[index]); + break; + case rocisa::DataType::E5M3: + v = static_cast(static_cast(base)[index]); + break; + case rocisa::DataType::Float8: + v = static_cast(static_cast(base)[index]); + break; + default: + throw std::runtime_error(concatenate( + "Reference MX scale: unsupported element type ", static_cast(mxType))); + } + return std::fabs(v); + } + // Helper class that wraps a shadow copy of input buffers in float format. // It quietly manages the indirection between directly using the input pointer // (for float) and a shadow copy (for half / bfloat16). @@ -390,7 +420,7 @@ namespace TensileLite case rocisa::DataType::Float6: case rocisa::DataType::BFloat6: case rocisa::DataType::Float4: - ; + ; } return DataInitialization::getValue(); } @@ -819,6 +849,159 @@ namespace TensileLite throw std::runtime_error("Unsupported input type."); } +#ifdef _WIN32 + template + inline Accumulator getElement(ContractionProblemGemm const& problem, + Type const* ptr, + const size_t idx, + void const* scalePtr, + const bool conjugate) +#else // _WIN32 + template ::value +#endif // #ifdef TENSILE_USE_FP6 +#ifdef TENSILE_USE_BF6 + && !std::is_same::value +#endif // #ifdef TENSILE_USE_BF6 +#ifdef TENSILE_USE_FP4 + && !std::is_same::value +#endif // #ifdef TENSILE_USE_FP4 + , + bool> + = true> + inline Accumulator getElement(ContractionProblemGemm const& problem, + Type const* ptr, + const size_t idx, + void const* scalePtr, + const bool conjugate) +#endif // _WIN32 + { + // case I8/I32/I32, I8 be implicitly cast to int. + constexpr bool needAccumCast + = !std::is_same() && !std::is_same(); + using MultT = std::conditional_t; + + constexpr bool needMathOpAccumCast = !std::is_same(); + using MathOpMultT = std::conditional_t; + + Type val = Transform::Input(ptr[idx], conjugate); + + if constexpr(sizeof(Type) > sizeof(ComputeInputType)) + { + ComputeInputType valCast; + if(problem.useScaleAB() == "Scalar") + { + Accumulator scale + = GetValue(problem.alphaType(), scalePtr, 0, conjugate); + auto tmp = multiply(val, scale); + valCast = static_cast(tmp); + } + else + { + valCast = static_cast(val); + } + return static_cast( + static_cast(static_cast(valCast))); + } + + return static_cast(static_cast(static_cast(val))); + } + +#if !defined(_WIN32) \ + && (defined(TENSILE_USE_FP6) || defined(TENSILE_USE_BF6) || defined(TENSILE_USE_FP4)) + template ::value +#endif // #ifdef TENSILE_USE_FP6 +#ifdef TENSILE_USE_BF6 + || std::is_same::value +#endif // #ifdef TENSILE_USE_BF6 +#ifdef TENSILE_USE_FP4 + || std::is_same::value +#endif // #ifdef TENSILE_USE_FP4 + , + bool> + = true> + inline Accumulator getElement(ContractionProblemGemm const& problem, + Type const* ptr, + const size_t idx, + void const* scalePtr, + const bool conjugate) + { + size_t packIdx = idx / TypeInfo::Packing; + size_t elemIdx = idx % TypeInfo::Packing; + + return static_cast(ptr[packIdx].getElement(elemIdx)); + } +#endif // !_WIN32 && (TENSILE_USE_FP6 || TENSILE_USE_BF6 || TENSILE_USE_FP4) + + template ::value + && !std::is_same::value), + bool> + = true> + Accumulator multiply(ContractionProblemGemm const& problem, + ContractionInputs const& inputs, + AType const* aPtr, + BType const* bPtr, + const size_t aIdx, + const size_t bIdx, + const bool aConjugate, + const bool bConjugate) + { + + auto aVal = getElement( + problem, aPtr, aIdx, inputs.scaleA, aConjugate); + auto bVal = getElement( + problem, bPtr, bIdx, inputs.scaleB, bConjugate); + + return multiply(aVal, bVal); + } + + template ::value + && std::is_same::value, + bool> + = true> + Accumulator multiply(ContractionProblemGemm const& problem, + ContractionInputs const& inputs, + AType const* aPtr, + BType const* bPtr, + const size_t aIdx, + const size_t bIdx, + const bool aConjugate, + const bool bConjugate) + { + AType aVal = Transform::Input(aPtr[aIdx], aConjugate); + BType bVal = Transform::Input(bPtr[bIdx], bConjugate); + return multiply(aVal, bVal); + } + + bool isFastPathEligible(ContractionProblemGemm const& problem) { @@ -1242,128 +1425,6 @@ namespace TensileLite } - template ::value -#endif // #ifdef TENSILE_USE_FP6 -#ifdef TENSILE_USE_BF6 - && !std::is_same::value -#endif // #ifdef TENSILE_USE_BF6 -#ifdef TENSILE_USE_FP4 - && !std::is_same::value -#endif // #ifdef TENSILE_USE_FP4 - , bool> - = true> - inline Accumulator getElement( - ContractionProblemGemm const& problem, - Type const* ptr, - const size_t idx, - void const* scalePtr, - const bool conjugate) - { - // case I8/I32/I32, I8 be implicitly cast to int. - constexpr bool needAccumCast = !std::is_same() && !std::is_same(); - using MultT = std::conditional_t; - - constexpr bool needMathOpAccumCast = !std::is_same(); - using MathOpMultT = std::conditional_t; - - Type val = Transform::Input(ptr[idx], conjugate); - - if constexpr(sizeof(Type) > sizeof(ComputeInputType)) - { - ComputeInputType valCast; - if(problem.useScaleAB() == "Scalar") - { - Accumulator scale = GetValue(problem.alphaType(), scalePtr, 0, conjugate); - auto tmp = multiply(val, scale); - valCast = static_cast(tmp); - } - else - { - valCast = static_cast(val); - } - return static_cast(static_cast(static_cast(valCast))); - } - - return static_cast(static_cast(static_cast(val))); - } - -#if defined(TENSILE_USE_FP6) || defined(TENSILE_USE_BF6) || defined(TENSILE_USE_FP4) - template ::value -#endif // #ifdef TENSILE_USE_FP6 -#ifdef TENSILE_USE_BF6 - || std::is_same::value -#endif // #ifdef TENSILE_USE_BF6 -#ifdef TENSILE_USE_FP4 - || std::is_same::value -#endif // #ifdef TENSILE_USE_FP4 - , bool> - = true> - inline Accumulator getElement( - ContractionProblemGemm const& problem, - Type const* ptr, - const size_t idx, - void const* scalePtr, - const bool conjugate) - { - size_t packIdx = idx / TypeInfo::Packing; - size_t elemIdx = idx % TypeInfo::Packing; - - return static_cast(ptr[packIdx].getElement(elemIdx)); - } -#endif - - template::value && !std::is_same::value), bool> = true - > - Accumulator multiply( - ContractionProblemGemm const& problem, - ContractionInputs const& inputs, - AType const* aPtr, - BType const* bPtr, - const size_t aIdx, - const size_t bIdx, - const bool aConjugate, - const bool bConjugate) - { - - auto aVal = getElement( - problem, aPtr, aIdx, inputs.scaleA, aConjugate); - auto bVal = getElement( - problem, bPtr, bIdx, inputs.scaleB, bConjugate); - - return multiply(aVal, bVal); - } - - template::value && std::is_same::value, bool> = true - > - Accumulator multiply( - ContractionProblemGemm const& problem, - ContractionInputs const& inputs, - AType const* aPtr, - BType const* bPtr, - const size_t aIdx, - const size_t bIdx, - const bool aConjugate, - const bool bConjugate) - { - AType aVal = Transform::Input(aPtr[aIdx], aConjugate); - BType bVal = Transform::Input(bPtr[bIdx], bConjugate); - return multiply(aVal, bVal); - } - template void ReferenceSolution::SolveCPU( ContractionProblemGemm const& problem, @@ -1388,10 +1449,12 @@ namespace TensileLite } // Convert void* to pointers - typename Inputs::AType const* aPtr = (typename Inputs::AType const*)inputs.a; - typename Inputs::BType const* bPtr = (typename Inputs::BType const*)inputs.b; - typename Inputs::CType const* cPtr = (typename Inputs::CType const*)inputs.c; - typename Inputs::DType* dPtr = (typename Inputs::DType*)inputs.d; + typename Inputs::AType const* aPtr = (typename Inputs::AType const*)inputs.a; + typename Inputs::BType const* bPtr = (typename Inputs::BType const*)inputs.b; + typename Inputs::CType const* cPtr = (typename Inputs::CType const*)inputs.c; + typename Inputs::DType* dPtr = (typename Inputs::DType*)inputs.d; + void const* mxsaBase = inputs.mxsa; + void const* mxsbBase = inputs.mxsb; auto const& freeIndicesA = problem.freeIndicesA(); auto const& freeIndicesB = problem.freeIndicesB(); @@ -1527,65 +1590,87 @@ namespace TensileLite aCoord[boundIndices[i].a] = bound[i]; bCoord[boundIndices[i].b] = bound[i]; - if(problem.boundIndices()[i].aMirror) + if(boundIndices[i].aMirror) aCoord[boundIndices[i].a] = boundSize[i] - aCoord[boundIndices[i].a] - 1; - if(problem.boundIndices()[i].bMirror) + if(boundIndices[i].bMirror) bCoord[boundIndices[i].b] = boundSize[i] - bCoord[boundIndices[i].b] - 1; if(problem.mxBlockA()) - mxsaCoord[boundIndices[i].a] = aCoord[boundIndices[i].a] / problem.mxBlockA(); + mxsaCoord[boundIndices[i].a] + = aCoord[boundIndices[i].a] / problem.mxBlockA(); + if(problem.mxBlockB()) - mxsbCoord[boundIndices[i].b] = bCoord[boundIndices[i].b] / problem.mxBlockB(); + mxsbCoord[boundIndices[i].b] + = bCoord[boundIndices[i].b] / problem.mxBlockB(); } - size_t aIndex = a.index(aCoord); - size_t bIndex = b.index(bCoord); + size_t aIndex = a.index(aCoord); + size_t bIndex = b.index(bCoord); size_t mxsaIndex = problem.mxBlockA() ? mxsa.index(mxsaCoord) : 0; size_t mxsbIndex = problem.mxBlockB() ? mxsb.index(mxsbCoord) : 0; - auto aStride = problem.a().strides()[boundIndices[0].a]; - auto bStride = problem.b().strides()[boundIndices[0].b]; - auto mxsaStride = problem.mxBlockA() ? mxsa.strides()[boundIndices[0].a] : 0; - auto mxsbStride = problem.mxBlockB() ? mxsb.strides()[boundIndices[0].b] : 0; + auto aStride = a.strides()[boundIndices[0].a]; + auto bStride = b.strides()[boundIndices[0].b]; + auto mxsaStride + = problem.mxBlockA() ? mxsa.strides()[boundIndices[0].a] : 0; + auto mxsbStride + = problem.mxBlockB() ? mxsb.strides()[boundIndices[0].b] : 0; // innermost bound calculation: - size_t innerMXLoop = std::max(std::max(problem.mxBlockA(), problem.mxBlockB()), 1); - for(size_t i = 0; i < boundSize[0]; i+=innerMXLoop) + size_t innerMXLoop = std::max( + std::max(problem.mxBlockA(), problem.mxBlockB()), 1); + for(size_t i = 0; i < boundSize[0]; i += innerMXLoop) { Accumulator val(0); - for(size_t j = 0; j( - problem, inputs, aPtr, bPtr, aIdx, bIdx, aConjugate, bConjugate); + val += multiply(problem, + inputs, + aPtr, + bPtr, + aIdx, + bIdx, + aConjugate, + bConjugate); } - Accumulator mxScale(1); - if (problem.mxBlockA()) + float mxScale = 1.0f; + if(problem.mxBlockA()) { - size_t mxsaI = (boundIndices[0].aMirror ? (boundSize[0] - i - 1) : i) / problem.mxBlockA(); + size_t mxsaI + = (boundIndices[0].aMirror ? (boundSize[0] - i - 1) : i) + / problem.mxBlockA(); size_t mxsaIdx = mxsaIndex + (mxsaI * mxsaStride); - Accumulator sa = abs(GetValue(mxsa.dataType(), inputs.mxsa, mxsaIdx, 0)); - mxScale = multiply(mxScale, sa); + mxScale = multiply( + mxScale, + mxScaleElementAsFloat(problem.mxTypeA(), mxsaBase, mxsaIdx)); } - if (problem.mxBlockB()) + if(problem.mxBlockB()) { - size_t mxsbI = (boundIndices[0].bMirror ? (boundSize[0] - i - 1) : i) / problem.mxBlockB(); + size_t mxsbI + = (boundIndices[0].bMirror ? (boundSize[0] - i - 1) : i) + / problem.mxBlockB(); size_t mxsbIdx = mxsbIndex + (mxsbI * mxsbStride); - Accumulator sb = abs(GetValue(mxsb.dataType(), inputs.mxsb, mxsbIdx, 0)); - mxScale = multiply(mxScale, sb); + mxScale = multiply( + mxScale, + mxScaleElementAsFloat(problem.mxTypeB(), mxsbBase, mxsbIdx)); } value += multiply(val, mxScale); } @@ -2376,6 +2461,7 @@ namespace TensileLite #endif // TENSILE_USE_HALF #endif // TENSILE_USE_FP8_BF8 +#ifndef _WIN32 #ifdef TENSILE_USE_FP6 case TypedGemm_F6_S_S::TypeId(): { @@ -2390,6 +2476,14 @@ namespace TensileLite problem, inputs, elementsToValidate); } #endif //TENSILE_USE_BF6 +#ifdef TENSILE_USE_FP4 + + case TypedGemm_F4_S_S::TypeId(): + { + return ReferenceSolution::SolveCPU( + problem, inputs, elementsToValidate); + } +#endif //TENSILE_USE_FP4 #if defined(TENSILE_USE_FP6) && defined(TENSILE_USE_BF6) case TypedGemm_F6B6_S_S::TypeId(): { @@ -2402,37 +2496,31 @@ namespace TensileLite problem, inputs, elementsToValidate); } #endif // defined(TENSILE_USE_FP6) && defined(TENSILE_USE_BF6) -#ifdef TENSILE_USE_FP4 - case TypedGemm_F4_S_S::TypeId(): +#if defined(TENSILE_USE_FP6) && defined(TENSILE_USE_FP4) + case TypedGemm_F6F4_S_S::TypeId(): { - return ReferenceSolution::SolveCPU( + return ReferenceSolution::SolveCPU( problem, inputs, elementsToValidate); } -#endif //TENSILE_USE_FP4 -#if defined(TENSILE_USE_FP4) && defined(TENSILE_USE_FP6) case TypedGemm_F4F6_S_S::TypeId(): { return ReferenceSolution::SolveCPU( problem, inputs, elementsToValidate); } - case TypedGemm_F6F4_S_S::TypeId(): +#endif // defined(TENSILE_USE_FP6) && defined(TENSILE_USE_FP4) +#if defined(TENSILE_USE_BF6) && defined(TENSILE_USE_FP4) + case TypedGemm_B6F4_S_S::TypeId(): { - return ReferenceSolution::SolveCPU( + return ReferenceSolution::SolveCPU( problem, inputs, elementsToValidate); } -#endif // defined(TENSILE_USE_FP4) && defined(TENSILE_USE_FP6) -#if defined(TENSILE_USE_FP4) && defined(TENSILE_USE_BF6) case TypedGemm_F4B6_S_S::TypeId(): { return ReferenceSolution::SolveCPU( problem, inputs, elementsToValidate); } - case TypedGemm_B6F4_S_S::TypeId(): - { - return ReferenceSolution::SolveCPU( - problem, inputs, elementsToValidate); - } -#endif // defined(TENSILE_USE_FP4) && defined(TENSILE_USE_BF6) +#endif // defined(TENSILE_USE_BF6) && defined(TENSILE_USE_FP4) +#endif // !_WIN32 #if defined(TENSILE_USE_FP8_BF8) && defined(TENSILE_USE_FP4) // DestDataType: S case TypedGemm_F8F4_S_S::TypeId(): diff --git a/projects/hipblaslt/tensilelite/client/src/ReferenceValidator.cpp b/projects/hipblaslt/tensilelite/client/src/ReferenceValidator.cpp index 3a80f2de406..a18fed44642 100644 --- a/projects/hipblaslt/tensilelite/client/src/ReferenceValidator.cpp +++ b/projects/hipblaslt/tensilelite/client/src/ReferenceValidator.cpp @@ -128,6 +128,23 @@ namespace TensileLite m_validatedSolution = false; m_errorInSolution = false; m_executedSolution = false; + + // MXFP4 data can depend on the selected solution (e.g. MI-based preSwizzle when + // enabled). prepareCPUInputs runs in preProblem before preSolution, so the initial + // SolveCPU may not match GPU inputs refreshed in DataInitialization::preSolution. + // Re-run the CPU reference after that refresh (CPU "current" buffers are synced there + // when the MX path runs). + if(!m_enabled || m_problem == nullptr || m_referenceInputs == nullptr + || solution == nullptr) + return; + + if(auto* gemm = dynamic_cast(m_problem)) + { + if(!isMXFP4Problem(*gemm)) + return; + ScopedTimer timer("cpu_reference_gemm_per_solution"); + SolveCPU(m_problem, m_referenceInputs.get(), m_elementsToValidate); + } } bool ReferenceValidator::needMoreRunsInSolution() const diff --git a/projects/hipblaslt/tensilelite/include/Tensile/Activation.hpp b/projects/hipblaslt/tensilelite/include/Tensile/Activation.hpp index 34816679429..9746a3c5e8c 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/Activation.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/Activation.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2024 Advanced Micro Devices, Inc. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,6 +27,7 @@ #pragma once #include +#include #include #include #include diff --git a/projects/hipblaslt/tensilelite/include/Tensile/ContractionProblem.hpp b/projects/hipblaslt/tensilelite/include/Tensile/ContractionProblem.hpp index 1bd8f94ec80..1b1a68c58f2 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/ContractionProblem.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/ContractionProblem.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -1067,11 +1067,6 @@ namespace TensileLite void setMXScaleA(rocisa::DataType mxType, int mxBlock, std::vector saStride = {}); - size_t mxBlockA() const - { - return m_mxBlockA; - } - rocisa::DataType mxTypeA() const { return m_mxTypeA; @@ -1079,11 +1074,6 @@ namespace TensileLite void setMXScaleB(rocisa::DataType mxType, int mxBlock, std::vector sbStride = {}); - size_t mxBlockB() const - { - return m_mxBlockB; - } - rocisa::DataType mxTypeB() const { return m_mxTypeB; @@ -1109,6 +1099,16 @@ namespace TensileLite m_swizzleTensorB = swizzle; } + size_t mxBlockA() const + { + return m_mxBlockA; + } + + size_t mxBlockB() const + { + return m_mxBlockB; + } + /// Allocated elements excluding batch dimensions /// Used in assembly kernels to determine buffer limits, if batch dimes not /// packed diff --git a/projects/hipblaslt/tensilelite/include/Tensile/ContractionProblemPredicates.hpp b/projects/hipblaslt/tensilelite/include/Tensile/ContractionProblemPredicates.hpp index c555756bd78..c751d26c581 100755 --- a/projects/hipblaslt/tensilelite/include/Tensile/ContractionProblemPredicates.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/ContractionProblemPredicates.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -35,6 +35,7 @@ #include #include +#include #include #include #include @@ -1428,6 +1429,56 @@ namespace TensileLite return "TypesEqual"; } + // This function checks if the computeInputType of a problem matches the computeInputType of a + // solution. Because for Float8{_fnuz}/BFloat8{_fnuz}, the computeInputTypeA and computeInputTypeB + // might be concatenated like this: computeInputTypeAcomputeInputTypeB{_fnuz}, this function + // includes special logic to inspect types at the concatenation position. + // + bool validateComputeType(rocisa::DataType const& problemComputeInputType, bool const isComputeInputTypeA) const + { + using ArrType = std::array, 4>; + + // computeInputTypeA corresponds to the first type in the concatenated type + ArrType const computeInputTypeAMapping = + {{ + {rocisa::DataType::BFloat8_fnuz, rocisa::DataType::BFloat8Float8_fnuz}, + {rocisa::DataType::Float8_fnuz, rocisa::DataType::Float8BFloat8_fnuz}, + {rocisa::DataType::BFloat8, rocisa::DataType::BFloat8Float8}, + {rocisa::DataType::Float8, rocisa::DataType::Float8BFloat8} + }}; + + // computeInputTypeB corresponds to the second type in the concatenated type + ArrType const computeInputTypeBMapping = + {{ + {rocisa::DataType::Float8_fnuz, rocisa::DataType::BFloat8Float8_fnuz}, + {rocisa::DataType::BFloat8_fnuz, rocisa::DataType::Float8BFloat8_fnuz}, + {rocisa::DataType::Float8, rocisa::DataType::BFloat8Float8}, + {rocisa::DataType::BFloat8, rocisa::DataType::Float8BFloat8} + }}; + + auto validate = [](ArrType const& arr, + rocisa::DataType const& t1, rocisa::DataType const& t2){ + if(t1 == t2) + return true; + return std::any_of(arr.begin(), arr.end(), [&](const auto& p){ + return (p == std::make_pair(t1, t2)) or (p == std::make_pair(t2, t1)); + }); + }; + + if(isComputeInputTypeA) + { + // value[4] is computeInputTypeA + return + validate(computeInputTypeAMapping, problemComputeInputType, value[4]); + } + else + { + // value[5] is computeInputTypeB + return + validate(computeInputTypeBMapping, problemComputeInputType, value[5]); + } + } + virtual bool operator()(ContractionProblemGemm const& problem) const override { return problem.a().dataType() == value[0] && problem.b().dataType() == value[1] @@ -1547,10 +1598,10 @@ namespace TensileLite virtual bool operator()(ContractionProblemGemm const& problem) const override { const uint64_t TWO_POW_32 = 4294967296; - return multiplyElementSize((problem.a().strides()[1] * min(value.depthUorMT0, problem.a().sizes()[1]) + value.shiftPtrElemA), + return multiplyElementSize((problem.a().strides()[1] * std::min(value.depthUorMT0, problem.a().sizes()[1]) + value.shiftPtrElemA), problem.a().elementBytes()) < TWO_POW_32 - && multiplyElementSize((problem.b().strides()[1] * min(value.depthUorMT1, problem.b().sizes()[1]) + value.shiftPtrElemB) + && multiplyElementSize((problem.b().strides()[1] * std::min(value.depthUorMT1, problem.b().sizes()[1]) + value.shiftPtrElemB) ,problem.b().elementBytes()) < TWO_POW_32; } @@ -1615,7 +1666,7 @@ namespace TensileLite else { const uint64_t TWO_POW_32 = 4294967296; - return multiplyElementSize(problem.c().strides()[1] * min(value, problem.c().sizes()[1]), problem.c().elementBytes()) + return multiplyElementSize(problem.c().strides()[1] * std::min(value, problem.c().sizes()[1]), problem.c().elementBytes()) < TWO_POW_32; } } @@ -1662,7 +1713,7 @@ namespace TensileLite virtual bool operator()(ContractionProblemGemm const& problem) const override { const uint64_t TWO_POW_32 = 4294967296; - return multiplyElementSize(problem.d().strides()[1] * min(value, problem.d().sizes()[1]), problem.d().elementBytes()) + return multiplyElementSize(problem.d().strides()[1] * std::min(value, problem.d().sizes()[1]), problem.d().elementBytes()) < TWO_POW_32; } diff --git a/projects/hipblaslt/tensilelite/include/Tensile/DataTypes.hpp b/projects/hipblaslt/tensilelite/include/Tensile/DataTypes.hpp index 9721ef56c5b..e98c36cf968 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/DataTypes.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/DataTypes.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -237,7 +237,7 @@ namespace TensileLite }; template <> - struct TypeInfo : public BaseTypeInfo + struct TypeInfo : public BaseTypeInfo { }; @@ -300,28 +300,45 @@ namespace TensileLite { }; +#ifdef _WIN32 #ifdef TENSILE_USE_FP6 template <> - struct TypeInfo - : public BaseTypeInfo + struct TypeInfo : public BaseTypeInfo { }; -#endif // #ifdef TENSILE_USE_FP6 +#endif // TENSILE_USE_FP6 #ifdef TENSILE_USE_BF6 template <> - struct TypeInfo - : public BaseTypeInfo + struct TypeInfo : public BaseTypeInfo { }; -#endif // #ifdef TENSILE_USE_BF6 +#endif // TENSILE_USE_BF6 #ifdef TENSILE_USE_FP4 template <> - struct TypeInfo - : public BaseTypeInfo + struct TypeInfo : public BaseTypeInfo { }; -#endif // #ifdef TENSILE_USE_FP4 - +#endif // TENSILE_USE_FP4 +#else // _WIN32 +#ifdef TENSILE_USE_FP6 + template <> + struct TypeInfo : public BaseTypeInfo + { + }; +#endif // TENSILE_USE_FP6 +#ifdef TENSILE_USE_BF6 + template <> + struct TypeInfo : public BaseTypeInfo + { + }; +#endif // TENSILE_USE_BF6 +#ifdef TENSILE_USE_FP4 + template <> + struct TypeInfo : public BaseTypeInfo + { + }; +#endif // TENSILE_USE_FP4 +#endif // _WIN32 template <> struct TypeInfo : public BaseTypeInfo @@ -348,7 +365,18 @@ namespace TensileLite BFloat8, Float8_fnuz, BFloat8_fnuz, - int8_t>; + int8_t, +#if !defined(_WIN32) && defined(TENSILE_USE_FP6) + Float6x32, +#endif // !_WIN32 && TENSILE_USE_FP6 +#if !defined(_WIN32) && defined(TENSILE_USE_BF6) + BFloat6x32, +#endif // !_WIN32 && TENSILE_USE_BF6 +#if !defined(_WIN32) && defined(TENSILE_USE_FP4) + Float4x2, +#endif // !_WIN32 && TENSILE_USE_FP4 + E8 + >; // Convert variants to type T template @@ -416,6 +444,54 @@ namespace TensileLite return static_cast(*std::get_if(&val)); } +#if !defined(_WIN32) && defined(TENSILE_USE_FP6) + // Convert variants to type T + template + typename std::enable_if::value, T>::type + constVariantCast(const ConstantVariant& val) + { + switch(val.index()) + { + case static_cast(rocisa::DataType::Float6): + return static_cast(*std::get_if(&val)); + default: + throw std::runtime_error("Unsupported variant cast type."); + } + } +#endif // !_WIN32 && TENSILE_USE_FP6 + +#if !defined(_WIN32) && defined(TENSILE_USE_BF6) + // Convert variants to type T + template + typename std::enable_if::value, T>::type + constVariantCast(const ConstantVariant& val) + { + switch(val.index()) + { + case static_cast(rocisa::DataType::BFloat6): + return static_cast(*std::get_if(&val)); + default: + throw std::runtime_error("Unsupported variant cast type."); + } + } +#endif // !_WIN32 && TENSILE_USE_BF6 + +#if !defined(_WIN32) && defined(TENSILE_USE_FP4) + // Convert variants to type T + template + typename std::enable_if::value, T>::type + constVariantCast(const ConstantVariant& val) + { + switch(val.index()) + { + case static_cast(rocisa::DataType::Float4): + return static_cast(*std::get_if(&val)); + default: + throw std::runtime_error("Unsupported variant cast type."); + } + } +#endif // !_WIN32 && TENSILE_USE_FP4 + std::string ToString(ConstantVariant d); bool CompareValue(const ConstantVariant& d, double value); diff --git a/projects/hipblaslt/tensilelite/include/Tensile/DataTypes_BFloat6.hpp b/projects/hipblaslt/tensilelite/include/Tensile/DataTypes_BFloat6.hpp index 595ed071ab2..8a65e26b6dd 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/DataTypes_BFloat6.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/DataTypes_BFloat6.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2019-2024 Advanced Micro Devices, Inc. + * Copyright (C) 2019-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,17 +26,23 @@ #pragma once -// Workaround: ROCm's amd_hip_ocp_host.hpp has a static_assert size mismatch -// (fp6x32_packed vs __amd_fp6x32_storage_t) in its host-fallback path, which -// is taken for all non-gfx950/gfx1250 device targets. #include -#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx950__) || defined(__gfx1250__)) && !defined(WIN32) && !defined(_WIN32) #define TENSILE_USE_BF6 -#endif #ifdef TENSILE_USE_BF6 +#ifdef _WIN32 + +#include + +namespace TensileLite +{ + typedef struct BFloat6{ uint8_t data;} BFloat6; +} // end of namespace TensileLite + +#else // _WIN32 + #ifdef TENSILE_USE_HIP #include #endif @@ -59,8 +65,7 @@ namespace TensileLite stochastic }; - struct BFloat6x16_Storage - { + struct BFloat6x32_Storage { uint8_t v0 : 6; uint8_t v1 : 6; uint8_t v2 : 6; @@ -77,77 +82,58 @@ namespace TensileLite uint8_t v13 : 6; uint8_t v14 : 6; uint8_t v15 : 6; + uint8_t v16 : 6; + uint8_t v17 : 6; + uint8_t v18 : 6; + uint8_t v19 : 6; + uint8_t v20 : 6; + uint8_t v21 : 6; + uint8_t v22 : 6; + uint8_t v23 : 6; + uint8_t v24 : 6; + uint8_t v25 : 6; + uint8_t v26 : 6; + uint8_t v27 : 6; + uint8_t v28 : 6; + uint8_t v29 : 6; + uint8_t v30 : 6; + uint8_t v31 : 6; } __attribute__((packed)); // data type - struct BFloat6x16 + struct BFloat6x32 { - BFloat6x16_Storage data; - constexpr static size_t packed_size = 16; + BFloat6x32_Storage data; // default constructor - HIP_HOST_DEVICE BFloat6x16() = default; + HIP_HOST_DEVICE BFloat6x32() = default; // constructor from float - explicit HIP_HOST_DEVICE BFloat6x16(float val, - hip_bf6_rounding_mode rm - = hip_bf6_rounding_mode::standard, - uint32_t rng = 0) + explicit HIP_HOST_DEVICE BFloat6x32(float val, + hip_bf6_rounding_mode rm = hip_bf6_rounding_mode::standard, + uint32_t rng = 0) { - __amd_floatx16_storage_t f32x16; - - for(int i = 0; i < packed_size; i++) - f32x16[i] = val; - - if(rm == hip_bf6_rounding_mode::standard) - { -#ifdef HIPBLASLT_USE_HIP_FP6X16 - union - { - BFloat6x16_Storage real; - __amd_fp6x16_storage_t tmp; - } cvt; - cvt.tmp = __amd_cvt_floatx16_to_fp6x16_scale(f32x16, __AMD_OCP_E3M2, 0); - data = cvt.real; -#else - union - { - __amd_floatx32_storage_t fp32x32; - __amd_floatx16_storage_t fp32x16[2]; - } in; - union - { - BFloat6x16_Storage real[2]; - __amd_fp6x32_storage_t fp6x32; - } out; - in.fp32x16[0] = f32x16; - out.fp6x32 - = __amd_cvt_floatx32_to_fp6x32_scale(in.fp32x32, __AMD_OCP_E3M2, 0); - data = out.real[0]; -#endif + union { + BFloat6x32_Storage real; + __amd_fp6x32_storage_t tmp; + } cvt; + __amd_floatx32_storage_t f32x32; + + for (int i=0; i<32; i++) + f32x32[i] = val; + + if (rm == hip_bf6_rounding_mode::standard) { + cvt.tmp = __amd_cvt_floatx32_to_fp6x32_scale(f32x32, __AMD_OCP_E3M2, 0); + data = cvt.real; } - else - { - // TODO: update below code if hip_ext_ocp.h supports __amd_cvt_floatx16_to_fp6x16_sr_scale - union - { - __amd_floatx32_storage_t fp32x32; - __amd_floatx16_storage_t fp32x16[2]; - } in; - union - { - BFloat6x16_Storage real[2]; - __amd_fp6x32_storage_t fp6x32; - } out; - in.fp32x16[0] = f32x16; - out.fp6x32 - = __amd_cvt_floatx32_to_fp6x32_sr_scale(in.fp32x32, __AMD_OCP_E3M2, rng, 0); - data = out.real[0]; + else { + cvt.tmp = __amd_cvt_floatx32_to_fp6x32_sr_scale(f32x32, __AMD_OCP_E3M2, rng, 0); + data = cvt.real; } } // constructor from float - explicit HIP_HOST_DEVICE BFloat6x16(float v0, + explicit HIP_HOST_DEVICE BFloat6x32(float v0, float v1, float v2, float v3, @@ -163,88 +149,85 @@ namespace TensileLite float v13, float v14, float v15, - hip_bf6_rounding_mode rm - = hip_bf6_rounding_mode::standard, - uint32_t rng = 0) + float v16, + float v17, + float v18, + float v19, + float v20, + float v21, + float v22, + float v23, + float v24, + float v25, + float v26, + float v27, + float v28, + float v29, + float v30, + float v31, + hip_bf6_rounding_mode rm = hip_bf6_rounding_mode::standard, + uint32_t rng = 0) { - __amd_floatx16_storage_t f32x16; - - f32x16[0] = v0; - f32x16[1] = v1; - f32x16[2] = v2; - f32x16[3] = v3; - f32x16[4] = v4; - f32x16[5] = v5; - f32x16[6] = v6; - f32x16[7] = v7; - f32x16[8] = v8; - f32x16[9] = v9; - f32x16[10] = v10; - f32x16[11] = v11; - f32x16[12] = v12; - f32x16[13] = v13; - f32x16[14] = v14; - f32x16[15] = v15; - - if(rm == hip_bf6_rounding_mode::standard) - { -#ifdef HIPBLASLT_USE_HIP_FP6X16 - union - { - BFloat6x16_Storage real; - __amd_fp6x16_storage_t tmp; - } cvt; - cvt.tmp = __amd_cvt_floatx16_to_fp6x16_scale(f32x16, __AMD_OCP_E3M2, 0); - data = cvt.real; -#else - union - { - __amd_floatx32_storage_t fp32x32; - __amd_floatx16_storage_t fp32x16[2]; - } in; - union - { - BFloat6x16_Storage real[2]; - __amd_fp6x32_storage_t fp6x32; - } out; - in.fp32x16[0] = f32x16; - out.fp6x32 - = __amd_cvt_floatx32_to_fp6x32_scale(in.fp32x32, __AMD_OCP_E3M2, 0); - data = out.real[0]; -#endif + union { + BFloat6x32_Storage real; + __amd_fp6x32_storage_t tmp; + } cvt; + __amd_floatx32_storage_t f32x32; + + f32x32[0] = v0; + f32x32[1] = v1; + f32x32[2] = v2; + f32x32[3] = v3; + f32x32[4] = v4; + f32x32[5] = v5; + f32x32[6] = v6; + f32x32[7] = v7; + f32x32[8] = v8; + f32x32[9] = v9; + f32x32[10] = v10; + f32x32[11] = v11; + f32x32[12] = v12; + f32x32[13] = v13; + f32x32[14] = v14; + f32x32[15] = v15; + f32x32[16] = v16; + f32x32[17] = v17; + f32x32[18] = v18; + f32x32[19] = v19; + f32x32[20] = v20; + f32x32[21] = v21; + f32x32[22] = v22; + f32x32[23] = v23; + f32x32[24] = v24; + f32x32[25] = v25; + f32x32[26] = v26; + f32x32[27] = v27; + f32x32[28] = v28; + f32x32[29] = v29; + f32x32[30] = v30; + f32x32[31] = v31; + + if (rm == hip_bf6_rounding_mode::standard) { + cvt.tmp = __amd_cvt_floatx32_to_fp6x32_scale(f32x32, __AMD_OCP_E3M2, 0); + data = cvt.real; } - else - { - // TODO: update below code if hip_ext_ocp.h supports __amd_cvt_floatx16_to_fp6x16_sr_scale - union - { - __amd_floatx32_storage_t fp32x32; - __amd_floatx16_storage_t fp32x16[2]; - } in; - union - { - BFloat6x16_Storage real[2]; - __amd_fp6x32_storage_t fp6x32; - } out; - in.fp32x16[0] = f32x16; - out.fp6x32 - = __amd_cvt_floatx32_to_fp6x32_sr_scale(in.fp32x32, __AMD_OCP_E3M2, rng, 0); - data = out.real[0]; + else { + cvt.tmp = __amd_cvt_floatx32_to_fp6x32_sr_scale(f32x32, __AMD_OCP_E3M2, rng, 0); + data = cvt.real; } } inline HIP_HOST_DEVICE float getElement(size_t idx) const { - // TODO: update below code if hip_ext_ocp.h supports __amd_cvt_fp6x16_to_floatx16_scale - __amd_floatx32_storage_t fp32x32; - union - { - BFloat6x16_Storage real[2]; - __amd_fp6x32_storage_t fp6x32; - } in; - in.real[0] = data; - fp32x32 = __amd_cvt_fp6x32_to_floatx32_scale(in.fp6x32, __AMD_OCP_E3M2, 0); - if(idx < packed_size) + union { + BFloat6x32_Storage real; + __amd_fp6x32_storage_t tmp; + } cvt; + + cvt.real = data; + + __amd_floatx32_storage_t fp32x32 = __amd_cvt_fp6x32_to_floatx32_scale(cvt.tmp, __AMD_OCP_E3M2, 0); + if (idx < 32) return fp32x32[idx]; else return 0.0; @@ -253,10 +236,38 @@ namespace TensileLite // check for zero inline HIP_HOST_DEVICE bool is_zero() const { - return data.v0 == 0 && data.v1 == 0 && data.v2 == 0 && data.v3 == 0 && data.v4 == 0 - && data.v5 == 0 && data.v6 == 0 && data.v7 == 0 && data.v8 == 0 && data.v9 == 0 - && data.v10 == 0 && data.v11 == 0 && data.v12 == 0 && data.v13 == 0 - && data.v14 == 0 && data.v15 == 0; + return data.v0 == 0 && + data.v1 == 0 && + data.v2 == 0 && + data.v3 == 0 && + data.v4 == 0 && + data.v5 == 0 && + data.v6 == 0 && + data.v7 == 0 && + data.v8 == 0 && + data.v9 == 0 && + data.v10 == 0 && + data.v11 == 0 && + data.v12 == 0 && + data.v13 == 0 && + data.v14 == 0 && + data.v15 == 0 && + data.v16 == 0 && + data.v17 == 0 && + data.v18 == 0 && + data.v19 == 0 && + data.v20 == 0 && + data.v21 == 0 && + data.v22 == 0 && + data.v23 == 0 && + data.v24 == 0 && + data.v25 == 0 && + data.v26 == 0 && + data.v27 == 0 && + data.v28 == 0 && + data.v29 == 0 && + data.v30 == 0 && + data.v31 == 0; } // check for nan @@ -275,26 +286,26 @@ namespace TensileLite namespace std { - inline std::string to_string(const TensileLite::BFloat6x16& a) + inline std::string to_string(const TensileLite::BFloat6x32& a) { - // TODO: update below code if hip_ext_ocp.h supports __amd_cvt_fp6x16_to_floatx16_scale - union - { - TensileLite::BFloat6x16_Storage real[2]; - __amd_fp6x32_storage_t fp6x32; - } in; - in.real[0] = a.data; - auto result = __amd_cvt_fp6x32_to_floatx32_scale(in.fp6x32, __AMD_OCP_E3M2, 0); + union { + TensileLite::BFloat6x32_Storage real; + __amd_fp6x32_storage_t tmp; + } cvt; + + cvt.real = a.data; + + auto result = __amd_cvt_fp6x32_to_floatx32_scale(cvt.tmp, __AMD_OCP_E3M2, 0); std::string str = std::to_string(static_cast(result[0])); - for(int i = 1; i < a.packed_size; i++) - str = str + " " + std::to_string(static_cast(result[i])); + for (int i=1; i<32; i++) + str = str + " " + std::to_string(static_cast(result[i])); return str; } - inline ostream& operator<<(ostream& stream, const TensileLite::BFloat6x16 a) + inline ostream& operator<<(ostream& stream, const TensileLite::BFloat6x32 a) { return stream << to_string(a); } @@ -302,4 +313,6 @@ namespace std TENSILE_HIDDEN_END +#endif // _WIN32 + #endif // TENSILE_USE_BF6 diff --git a/projects/hipblaslt/tensilelite/include/Tensile/DataTypes_Float4.hpp b/projects/hipblaslt/tensilelite/include/Tensile/DataTypes_Float4.hpp index a13c5e6b9b6..8072aae676b 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/DataTypes_Float4.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/DataTypes_Float4.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2019-2024 Advanced Micro Devices, Inc. + * Copyright (C) 2019-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,17 +26,21 @@ #pragma once -// Workaround: ROCm's amd_hip_ocp_host.hpp has a static_assert size mismatch -// (fp6x32_packed vs __amd_fp6x32_storage_t) in its host-fallback path, which -// is taken for all non-gfx950/gfx1250 device targets. #include -#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx950__) || defined(__gfx1250__)) && !defined(WIN32) && !defined(_WIN32) #define TENSILE_USE_FP4 -#endif #ifdef TENSILE_USE_FP4 +#ifdef _WIN32 + +namespace TensileLite +{ + typedef struct Float4{ uint8_t data;} Float4; +} // end of namespace TensileLite + +#else // _WIN32 + #ifdef TENSILE_USE_HIP #include #endif @@ -156,4 +160,6 @@ namespace std TENSILE_HIDDEN_END +#endif // _WIN32 + #endif // TENSILE_USE_FP4 diff --git a/projects/hipblaslt/tensilelite/include/Tensile/DataTypes_Float6.hpp b/projects/hipblaslt/tensilelite/include/Tensile/DataTypes_Float6.hpp index 47e70cf5e03..fdeb6317a20 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/DataTypes_Float6.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/DataTypes_Float6.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2019-2024 Advanced Micro Devices, Inc. + * Copyright (C) 2019-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,17 +26,23 @@ #pragma once -// Workaround: ROCm's amd_hip_ocp_host.hpp has a static_assert size mismatch -// (fp6x32_packed vs __amd_fp6x32_storage_t) in its host-fallback path, which -// is taken for all non-gfx950/gfx1250 device targets. #include -#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx950__) || defined(__gfx1250__)) && !defined(WIN32) && !defined(_WIN32) #define TENSILE_USE_FP6 -#endif #ifdef TENSILE_USE_FP6 +#ifdef _WIN32 + +#include + +namespace TensileLite +{ + typedef struct Float6{ uint8_t data;} Float6; +} // end of namespace TensileLite + +#else // _WIN32 + #ifdef TENSILE_USE_HIP #include #endif @@ -59,8 +65,7 @@ namespace TensileLite stochastic }; - struct Float6x16_Storage - { + struct Float6x32_Storage { uint8_t v0 : 6; uint8_t v1 : 6; uint8_t v2 : 6; @@ -77,75 +82,58 @@ namespace TensileLite uint8_t v13 : 6; uint8_t v14 : 6; uint8_t v15 : 6; + uint8_t v16 : 6; + uint8_t v17 : 6; + uint8_t v18 : 6; + uint8_t v19 : 6; + uint8_t v20 : 6; + uint8_t v21 : 6; + uint8_t v22 : 6; + uint8_t v23 : 6; + uint8_t v24 : 6; + uint8_t v25 : 6; + uint8_t v26 : 6; + uint8_t v27 : 6; + uint8_t v28 : 6; + uint8_t v29 : 6; + uint8_t v30 : 6; + uint8_t v31 : 6; } __attribute__((packed)); // data type - struct Float6x16 + struct Float6x32 { - Float6x16_Storage data; - constexpr static size_t packed_size = 16; + Float6x32_Storage data; // default constructor - HIP_HOST_DEVICE Float6x16() = default; + HIP_HOST_DEVICE Float6x32() = default; // constructor from float - explicit HIP_HOST_DEVICE Float6x16(float val, - hip_f6_rounding_mode rm = hip_f6_rounding_mode::standard, + explicit HIP_HOST_DEVICE Float6x32(float val, + hip_f6_rounding_mode rm = hip_f6_rounding_mode::standard, uint32_t rng = 0) { - __amd_floatx16_storage_t f32x16; - - for(int i = 0; i < packed_size; i++) - f32x16[i] = val; - if(rm == hip_f6_rounding_mode::standard) - { -#ifdef HIPBLASLT_USE_HIP_FP6X16 - union - { - Float6x16_Storage real; - __amd_fp6x16_storage_t tmp; - } cvt; - cvt.tmp = __amd_cvt_floatx16_to_fp6x16_scale(f32x16, __AMD_OCP_E2M3, 0); - data = cvt.real; -#else - union - { - __amd_floatx32_storage_t fp32x32; - __amd_floatx16_storage_t fp32x16[2]; - } in; - union - { - Float6x16_Storage real[2]; - __amd_fp6x32_storage_t fp6x32; - } out; - in.fp32x16[0] = f32x16; - out.fp6x32 - = __amd_cvt_floatx32_to_fp6x32_scale(in.fp32x32, __AMD_OCP_E2M3, 0); - data = out.real[0]; -#endif + union { + Float6x32_Storage real; + __amd_fp6x32_storage_t tmp; + } cvt; + __amd_floatx32_storage_t f32x32; + + for (int i=0; i<32; i++) + f32x32[i] = val; + + if (rm == hip_f6_rounding_mode::standard) { + cvt.tmp = __amd_cvt_floatx32_to_fp6x32_scale(f32x32, __AMD_OCP_E2M3, 0); + data = cvt.real; } - else - { - // TODO: update below code if hip_ext_ocp.h supports __amd_cvt_floatx16_to_fp6x16_sr_scale - union - { - __amd_floatx32_storage_t fp32x32; - __amd_floatx16_storage_t fp32x16[2]; - } in; - union - { - Float6x16_Storage real[2]; - __amd_fp6x32_storage_t fp6x32; - } out; - in.fp32x16[0] = f32x16; - out.fp6x32 - = __amd_cvt_floatx32_to_fp6x32_sr_scale(in.fp32x32, __AMD_OCP_E2M3, rng, 0); - data = out.real[0]; + else { + cvt.tmp = __amd_cvt_floatx32_to_fp6x32_sr_scale(f32x32, __AMD_OCP_E2M3, rng, 0); + data = cvt.real; } } // constructor from float - explicit HIP_HOST_DEVICE Float6x16(float v0, + explicit HIP_HOST_DEVICE Float6x32(float v0, float v1, float v2, float v3, @@ -161,87 +149,85 @@ namespace TensileLite float v13, float v14, float v15, - hip_f6_rounding_mode rm = hip_f6_rounding_mode::standard, + float v16, + float v17, + float v18, + float v19, + float v20, + float v21, + float v22, + float v23, + float v24, + float v25, + float v26, + float v27, + float v28, + float v29, + float v30, + float v31, + hip_f6_rounding_mode rm = hip_f6_rounding_mode::standard, uint32_t rng = 0) { - __amd_floatx16_storage_t f32x16; - - f32x16[0] = v0; - f32x16[1] = v1; - f32x16[2] = v2; - f32x16[3] = v3; - f32x16[4] = v4; - f32x16[5] = v5; - f32x16[6] = v6; - f32x16[7] = v7; - f32x16[8] = v8; - f32x16[9] = v9; - f32x16[10] = v10; - f32x16[11] = v11; - f32x16[12] = v12; - f32x16[13] = v13; - f32x16[14] = v14; - f32x16[15] = v15; - - if(rm == hip_f6_rounding_mode::standard) - { -#ifdef HIPBLASLT_USE_HIP_FP6X16 - union - { - Float6x16_Storage real; - __amd_fp6x16_storage_t tmp; - } cvt; - cvt.tmp = __amd_cvt_floatx16_to_fp6x16_scale(f32x16, __AMD_OCP_E2M3, 0); - data = cvt.real; -#else - union - { - __amd_floatx32_storage_t fp32x32; - __amd_floatx16_storage_t fp32x16[2]; - } in; - union - { - Float6x16_Storage real[2]; - __amd_fp6x32_storage_t fp6x32; - } out; - in.fp32x16[0] = f32x16; - out.fp6x32 - = __amd_cvt_floatx32_to_fp6x32_scale(in.fp32x32, __AMD_OCP_E2M3, 0); - data = out.real[0]; -#endif + union { + Float6x32_Storage real; + __amd_fp6x32_storage_t tmp; + } cvt; + __amd_floatx32_storage_t f32x32; + + f32x32[0] = v0; + f32x32[1] = v1; + f32x32[2] = v2; + f32x32[3] = v3; + f32x32[4] = v4; + f32x32[5] = v5; + f32x32[6] = v6; + f32x32[7] = v7; + f32x32[8] = v8; + f32x32[9] = v9; + f32x32[10] = v10; + f32x32[11] = v11; + f32x32[12] = v12; + f32x32[13] = v13; + f32x32[14] = v14; + f32x32[15] = v15; + f32x32[16] = v16; + f32x32[17] = v17; + f32x32[18] = v18; + f32x32[19] = v19; + f32x32[20] = v20; + f32x32[21] = v21; + f32x32[22] = v22; + f32x32[23] = v23; + f32x32[24] = v24; + f32x32[25] = v25; + f32x32[26] = v26; + f32x32[27] = v27; + f32x32[28] = v28; + f32x32[29] = v29; + f32x32[30] = v30; + f32x32[31] = v31; + + if (rm == hip_f6_rounding_mode::standard) { + cvt.tmp = __amd_cvt_floatx32_to_fp6x32_scale(f32x32, __AMD_OCP_E2M3, 0); + data = cvt.real; } - else - { - // TODO: update below code if hip_ext_ocp.h supports __amd_cvt_floatx16_to_fp6x16_sr_scale - union - { - __amd_floatx32_storage_t fp32x32; - __amd_floatx16_storage_t fp32x16[2]; - } in; - union - { - Float6x16_Storage real[2]; - __amd_fp6x32_storage_t fp6x32; - } out; - in.fp32x16[0] = f32x16; - out.fp6x32 - = __amd_cvt_floatx32_to_fp6x32_sr_scale(in.fp32x32, __AMD_OCP_E2M3, rng, 0); - data = out.real[0]; + else { + cvt.tmp = __amd_cvt_floatx32_to_fp6x32_sr_scale(f32x32, __AMD_OCP_E2M3, rng, 0); + data = cvt.real; } } inline HIP_HOST_DEVICE float getElement(size_t idx) const { - // TODO: update below code if hip_ext_ocp.h supports __amd_cvt_fp6x16_to_floatx16_scale - __amd_floatx32_storage_t fp32x32; - union - { - Float6x16_Storage real[2]; - __amd_fp6x32_storage_t fp6x32; - } in; - in.real[0] = data; - fp32x32 = __amd_cvt_fp6x32_to_floatx32_scale(in.fp6x32, __AMD_OCP_E2M3, 0); - if(idx < packed_size) + union { + Float6x32_Storage real; + __amd_fp6x32_storage_t tmp; + } cvt; + + cvt.real = data; + + __amd_floatx32_storage_t fp32x32 = __amd_cvt_fp6x32_to_floatx32_scale(cvt.tmp, __AMD_OCP_E2M3, 0); + if (idx < 32) return fp32x32[idx]; else return 0.0; @@ -250,10 +236,38 @@ namespace TensileLite // check for zero inline HIP_HOST_DEVICE bool is_zero() const { - return data.v0 == 0 && data.v1 == 0 && data.v2 == 0 && data.v3 == 0 && data.v4 == 0 - && data.v5 == 0 && data.v6 == 0 && data.v7 == 0 && data.v8 == 0 && data.v9 == 0 - && data.v10 == 0 && data.v11 == 0 && data.v12 == 0 && data.v13 == 0 - && data.v14 == 0 && data.v15 == 0; + return data.v0 == 0 && + data.v1 == 0 && + data.v2 == 0 && + data.v3 == 0 && + data.v4 == 0 && + data.v5 == 0 && + data.v6 == 0 && + data.v7 == 0 && + data.v8 == 0 && + data.v9 == 0 && + data.v10 == 0 && + data.v11 == 0 && + data.v12 == 0 && + data.v13 == 0 && + data.v14 == 0 && + data.v15 == 0 && + data.v16 == 0 && + data.v17 == 0 && + data.v18 == 0 && + data.v19 == 0 && + data.v20 == 0 && + data.v21 == 0 && + data.v22 == 0 && + data.v23 == 0 && + data.v24 == 0 && + data.v25 == 0 && + data.v26 == 0 && + data.v27 == 0 && + data.v28 == 0 && + data.v29 == 0 && + data.v30 == 0 && + data.v31 == 0; } // check for nan @@ -272,26 +286,26 @@ namespace TensileLite namespace std { - inline std::string to_string(const TensileLite::Float6x16& a) + inline std::string to_string(const TensileLite::Float6x32& a) { - // TODO: update below code if hip_ext_ocp.h supports __amd_cvt_fp6x16_to_floatx16_scale - union - { - TensileLite::Float6x16_Storage real[2]; - __amd_fp6x32_storage_t fp6x32; - } in; - in.real[0] = a.data; - auto result = __amd_cvt_fp6x32_to_floatx32_scale(in.fp6x32, __AMD_OCP_E2M3, 0); + union { + TensileLite::Float6x32_Storage real; + __amd_fp6x32_storage_t tmp; + } cvt; + + cvt.real = a.data; + + auto result = __amd_cvt_fp6x32_to_floatx32_scale(cvt.tmp, __AMD_OCP_E2M3, 0); std::string str = std::to_string(static_cast(result[0])); - for(int i = 1; i < a.packed_size; i++) - str = str + " " + std::to_string(static_cast(result[i])); + for (int i=1; i<32; i++) + str = str + " " + std::to_string(static_cast(result[i])); return str; } - inline ostream& operator<<(ostream& stream, const TensileLite::Float6x16 a) + inline ostream& operator<<(ostream& stream, const TensileLite::Float6x32 a) { return stream << to_string(a); } @@ -299,4 +313,6 @@ namespace std TENSILE_HIDDEN_END +#endif // _WIN32 + #endif // TENSILE_USE_FP6 diff --git a/projects/hipblaslt/tensilelite/include/Tensile/Distance.hpp b/projects/hipblaslt/tensilelite/include/Tensile/Distance.hpp index 859712ccac7..aacb534a4b8 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/Distance.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/Distance.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -395,7 +395,7 @@ namespace TensileLite // and nearest K double K = p1.size() > 3 ? p1[3] : p1[2]; - distance = abs(K - gridK); + distance = std::abs(K - gridK); return distance; } diff --git a/projects/hipblaslt/tensilelite/include/Tensile/SolutionLibrary.hpp b/projects/hipblaslt/tensilelite/include/Tensile/SolutionLibrary.hpp index 6454337c84a..04e199be7ef 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/SolutionLibrary.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/SolutionLibrary.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -65,7 +65,9 @@ namespace TensileLite && solutions.problemType.cType == problem.c().dataType() && solutions.problemType.dType == problem.d().dataType() && solutions.problemType.computeType == problem.computeType() - && solutions.problemType.groupedGemm == problem.groupedGemm()) + && solutions.problemType.groupedGemm == problem.groupedGemm() + && solutions.problemType.mxBlockA == problem.mxBlockA() + && solutions.problemType.mxBlockB == problem.mxBlockB()) return true; return false; } diff --git a/projects/hipblaslt/tensilelite/include/Tensile/TensorDescriptor.hpp b/projects/hipblaslt/tensilelite/include/Tensile/TensorDescriptor.hpp index 47f44924439..b1878a5c47d 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/TensorDescriptor.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/TensorDescriptor.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -346,11 +346,33 @@ namespace TensileLite } size_t totalAllocatedBytes() const { - return multiplyElementSize(totalAllocatedElements(), elementBytes()); + // Cannot use elementSize directly as elementSize is + // packed size for MX data types. + auto const info = DataTypeInfo::Get(m_dataType); + assert(totalAllocatedElements() * info.elementSize % info.packing == 0); + return totalAllocatedElements() * info.elementSize / info.packing; } float elementBytes() const { + // TODO:Need to enhance this to support segment type in tensileLite. + // tensileLite currently maps MX data types to unsegmented + // types in DataTypeInfo, i.e., Float4 -> Float4x2, + // Float6 -> Float6x32. As a result, return elementSize is + // incorrect for MX data types because elementSize represents + // unsegmented size in bytes not segment size. + // + // To get element size (in bytes) for f4/f6/bf6, use + // + // auto const info = DataTypeInfo::Get(m_dataType); + // auto elementSize = info.elementSize / info.packing + // + // tensileLite returns sizeof(Float4x2), sizeof(Float6x32), + // sizeof(BFloat6x32) for rocisa::f4,f6,bf6. + // + assert(m_dataType != rocisa::DataType::Float6 && + m_dataType != rocisa::DataType::BFloat6 && + m_dataType != rocisa::DataType::Float4); return DataTypeInfo::Get(m_dataType).elementSize; } diff --git a/projects/hipblaslt/tensilelite/include/Tensile/UtilsOrigami.hpp b/projects/hipblaslt/tensilelite/include/Tensile/UtilsOrigami.hpp index e35264940cd..6dc0ed3ddca 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/UtilsOrigami.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/UtilsOrigami.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,7 +25,7 @@ *******************************************************************************/ #pragma once - +#include #include #include #include @@ -79,9 +79,15 @@ namespace TensileLite return origami::data_type_t::Float8BFloat8; case rocisa::DataType::BFloat8Float8: return origami::data_type_t::BFloat8Float8; + case rocisa::DataType::Float6: + return origami::data_type_t::Float6; + case rocisa::DataType::BFloat6: + return origami::data_type_t::BFloat6; + case rocisa::DataType::Float4: + return origami::data_type_t::Float4; default: - return origami::data_type_t::None; + throw std::runtime_error("Unsupported data type: " + std::to_string(static_cast(type))); } } } // namespace TensileLite diff --git a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/enum.hpp b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/enum.hpp index 1a34ac373d8..aafd7793357 100644 --- a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/enum.hpp +++ b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/enum.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2025 Advanced Micro Devices, Inc. + * Copyright (C) 2025-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -63,7 +63,7 @@ namespace rocisa None = Count }; - inline int dataTypeToBytes(DataType type) + inline float dataTypeToBytes(DataType type) { switch(type) { @@ -105,6 +105,12 @@ namespace rocisa return 1; case DataType::BFloat8Float8: return 1; + case DataType::Float6: + return 0.75; + case DataType::BFloat6: + return 0.75; + case DataType::Float4: + return 0.5; default: return -1; // Invalid type } diff --git a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/instruction.hpp b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/instruction.hpp index 7faf9eccd4a..96b13757e79 100644 --- a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/instruction.hpp +++ b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/instruction.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2025 Advanced Micro Devices, Inc. + * Copyright (C) 2025-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -255,7 +255,7 @@ namespace rocisa if(newVal != oriVal && !outputInlineAsm){ // Base layer WA: need to store previous msb value in [15:8] bits. //int setVal = oriVal < 0? newVal : newVal + (oriVal << 8); - // only set newVal until complier support it + // only set newVal until compiler support it int setVal = newVal; std::string msbStr = "s_set_vgpr_msb " + std::to_string(setVal); std::string msbComment = std::string("src0: " + std::to_string(msbSrc[0]) + ", src1: " + std::to_string(msbSrc[1]) + \ diff --git a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/mem.hpp b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/mem.hpp index 6c1e7a16391..a177d161f22 100644 --- a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/mem.hpp +++ b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/mem.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2025 Advanced Micro Devices, Inc. + * Copyright (C) 2025-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -1324,6 +1324,27 @@ namespace rocisa } }; + struct FlatLoadB192 : public FLATReadInstruction + { + FlatLoadB192(const std::shared_ptr& dst, + const std::shared_ptr& vaddr, + std::optional flat = std::nullopt, + const std::string& comment = "") + : FLATReadInstruction(InstType::INST_B192, dst, vaddr, flat, comment) + { + } + + FlatLoadB192(const FlatLoadB192& other) + : FLATReadInstruction(other) + { + } + + std::shared_ptr clone() const override + { + return std::make_shared(*this); + } + }; + struct GlobalLoadTR8B64 : public GLOBALLoadInstruction { GlobalLoadTR8B64(const std::shared_ptr& dst, @@ -1368,27 +1389,6 @@ namespace rocisa } }; - struct FlatLoadB192 : public FLATReadInstruction - { - FlatLoadB192(const std::shared_ptr& dst, - const std::shared_ptr& vaddr, - std::optional flat = std::nullopt, - const std::string& comment = "") - : FLATReadInstruction(InstType::INST_B192, dst, vaddr, flat, comment) - { - } - - FlatLoadB192(const FlatLoadB192& other) - : FLATReadInstruction(other) - { - } - - std::shared_ptr clone() const override - { - return std::make_shared(*this); - } - }; - struct BufferStoreB8 : public MUBUFStoreInstruction { BufferStoreB8(const std::shared_ptr& src, @@ -2234,40 +2234,47 @@ namespace rocisa return dstCopy.toString() + ", " + srcs->toString(); } + int getIssueLatency() const override + { + return issueLatency(); + } + std::string toString() const override { - std::string instStr = "ds_load_b128"; - if(kernel().isaVersion[0] < 11) - { - instStr = "ds_read_b128"; - } + // Prepare first instruction + std::string instStr = kernel().isaVersion[0] < 11 ? + "ds_read_b128": + "ds_load_b128"; std::string kStr = instStr + " " + getArgStr2(); if(ds) kStr += ds->toString(); - instStr = "ds_load_b64"; - if(kernel().isaVersion[0] < 11) - { - instStr = "ds_read_b64"; - } + // Prepare second instruction + instStr = kernel().isaVersion[0] < 11 ? + "ds_read_b64": + "ds_load_b64"; std::string kStr2 = instStr + " " + getArgStr2(true); auto dsCopy = ds ? std::make_shared(*ds) : std::make_shared(); dsCopy->offset += 16; kStr2 += dsCopy->toString(); - kStr = formatWithComment(kStr); + + // Format both instructions + kStr = formatWithComment(kStr); kStr2 = formatWithComment(kStr2); - // compute 2 different dst vgpr msb - auto dstCopyPtr - = std::make_shared(*dynamic_cast(dst.get())); - int idx = dstCopyPtr->regName->offsets.size() - 1; - int regNum = 4; + + auto dstCopyPtr = std::make_shared(*dynamic_cast(dst.get())); + + int const idx = dstCopyPtr->regName->offsets.size() - 1; + int const regNum = 4; dstCopyPtr->regNum = regNum; dstCopyPtr->setMsb(); setMsb(kStr, {srcs}, dstCopyPtr); + dstCopyPtr->regName->offsets[idx] = dstCopyPtr->regName->offsets[idx] + regNum; dstCopyPtr->regNum = 2; dstCopyPtr->setMsb(); setMsb(kStr2, {srcs}, dstCopyPtr); + return kStr + kStr2; } }; @@ -2582,6 +2589,7 @@ namespace rocisa } }; + struct DSStoreB192 : public DSStoreInstruction { DSStoreB192(const std::shared_ptr& dstAddr, @@ -2627,22 +2635,25 @@ namespace rocisa return dstAddr->toString() + ", " + srcCopy.toString(); } + int getIssueLatency() const override + { + return issueLatency(); + } + std::string toString() const override { - std::string instStr = "ds_store_b128"; - if(kernel().isaVersion[0] < 11) - { - instStr = "ds_write_b128"; - } + // Prepare first instruction + std::string instStr = kernel().isaVersion[0] < 11 ? + "ds_write_b128" : + "ds_store_b128"; std::string kStr = instStr + " " + getArgStr2(); if(ds) kStr += ds->toString(); - instStr = "ds_store_b64"; - if(kernel().isaVersion[0] < 11) - { - instStr = "ds_write_b64"; - } + // Prepare second instruction + instStr = kernel().isaVersion[0] < 11 ? + "ds_write_b64" : + "ds_store_b64"; std::string kStr2 = instStr + " " + getArgStr2(true); auto dsCopy = ds ? std::make_shared(*ds) : std::make_shared(); dsCopy->offset += 16; @@ -2661,6 +2672,7 @@ namespace rocisa srcCopyPtr->regNum = 2; srcCopyPtr->setMsb(); setMsb(kStr2, {dstAddr, srcCopyPtr}, nullptr); + return kStr + kStr2; } }; diff --git a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/mfma.hpp b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/mfma.hpp index 71dfe90eb73..553f781e5e5 100644 --- a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/mfma.hpp +++ b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/instruction/mfma.hpp @@ -43,7 +43,7 @@ namespace rocisa && matrixInstB == 1) { if(dataType == DataType::Half || dataType == DataType::BFloat16 - || dataType == DataType::Int8 || is8bitFloat(dataType)) + || dataType == DataType::Int8 || is8bitFloat(dataType) || numBytes < 1) { mi_divisor = 4; miIssueLatency = 1; @@ -251,16 +251,55 @@ namespace rocisa switch(instType) { case InstType::INST_F8: - inputPermuteStr = variant[2] > 32 ? " cbsz:0 blgp:0" : ""; + inputPermuteStr = " cbsz:0 blgp:0"; break; case InstType::INST_BF8: - inputPermuteStr = variant[2] > 32 ? " cbsz:1 blgp:1" : ""; + inputPermuteStr = " cbsz:1 blgp:1"; break; case InstType::INST_F8_BF8: - inputPermuteStr = variant[2] > 32 ? " cbsz:0 blgp:1" : ""; + inputPermuteStr = " cbsz:0 blgp:1"; break; case InstType::INST_BF8_F8: - inputPermuteStr = variant[2] > 32 ? " cbsz:1 blgp:0" : ""; + inputPermuteStr = " cbsz:1 blgp:0"; + break; + case InstType::INST_F6: + inputPermuteStr = " cbsz:2 blgp:2"; + break; + case InstType::INST_BF6: + inputPermuteStr = " cbsz:3 blgp:3"; + break; + case InstType::INST_F4: + inputPermuteStr = " cbsz:4 blgp:4"; + break; + case InstType::INST_F8_F6: + inputPermuteStr = " cbsz:0 blgp:2"; + break; + case InstType::INST_F6_F8: + inputPermuteStr = " cbsz:2 blgp:0"; + break; + case InstType::INST_F8_F4: + inputPermuteStr = " cbsz:0 blgp:4"; + break; + case InstType::INST_F4_F8: + inputPermuteStr = " cbsz:4 blgp:0"; + break; + case InstType::INST_F6_B6: + inputPermuteStr = " cbsz:2 blgp:3"; + break; + case InstType::INST_B6_F6: + inputPermuteStr = " cbsz:3 blgp:2"; + break; + case InstType::INST_F6_F4: + inputPermuteStr = " cbsz:2 blgp:4"; + break; + case InstType::INST_F4_F6: + inputPermuteStr = " cbsz:4 blgp:2"; + break; + case InstType::INST_B6_F4: + inputPermuteStr = " cbsz:3 blgp:4"; + break; + case InstType::INST_F4_B6: + inputPermuteStr = " cbsz:4 blgp:3"; break; default: break; @@ -468,17 +507,17 @@ namespace rocisa MXMFMAInstruction(InstType instType, InstType accType, - InstType mxScaleAType, - InstType mxScaleBType, const std::vector& variant, const std::shared_ptr& acc, const std::shared_ptr& a, const std::shared_ptr& b, - const std::shared_ptr& acc2, - const std::shared_ptr& mxsa, - const std::shared_ptr& mxsb, - int block, - const std::string& comment = "") + const std::shared_ptr& acc2 = nullptr, + const std::shared_ptr& mxsa = nullptr, + const std::shared_ptr& mxsb = nullptr, + InstType mxScaleAType = InstType::INST_F32, + InstType mxScaleBType = InstType::INST_F32, + int block = 0, + const std::string& comment = "") : Instruction(instType, comment) , accType(accType) , mxScaleAType(mxScaleAType) @@ -487,7 +526,7 @@ namespace rocisa , acc(acc) , a(a) , b(b) - , acc2(acc2) + , acc2(acc2 ? acc2 : acc) , mxsa(mxsa) , mxsb(mxsb) , block(block) @@ -523,6 +562,8 @@ namespace rocisa std::vector getParams() const override { + if(getAsmCaps()["HasMFMA"]) + return {acc, a, b, acc2, mxsa, mxsb}; return {acc, a, b, acc2, mxsa, mxsb, block}; } @@ -541,11 +582,65 @@ namespace rocisa { std::string variantStr = std::to_string(variant[0]) + "x" + std::to_string(variant[1]) + "x" + std::to_string(variant[2]); - std::string blkStr = (block == 16) ? "16" : ""; - return "v_wmma_scale" + blkStr + "_f32_" + variantStr + "_" + typeConvert(); + if(getAsmCaps()["HasMFMA"]) + { + return "v_mfma_scale_f32_" + variantStr + "_f8f6f4"; + } + else + { + std::string blkStr = (block == 16) ? "16" : ""; + return "v_wmma_scale" + blkStr + "_f32_" + variantStr + "_" + typeConvert(); + } } - std::string getArgStr() const + std::string mfmaInputPermuteStr() const + { + if(getAsmCaps()["HasMFMA_f8f6f4"] && variant[2] > 32) + { + switch(instType) + { + case InstType::INST_F8: + return " cbsz:0 blgp:0"; + case InstType::INST_BF8: + return " cbsz:1 blgp:1"; + case InstType::INST_F8_BF8: + return " cbsz:0 blgp:1"; + case InstType::INST_BF8_F8: + return " cbsz:1 blgp:0"; + case InstType::INST_F6: + return " cbsz:2 blgp:2"; + case InstType::INST_BF6: + return " cbsz:3 blgp:3"; + case InstType::INST_F4: + return " cbsz:4 blgp:4"; + case InstType::INST_F8_F6: + return " cbsz:0 blgp:2"; + case InstType::INST_F6_F8: + return " cbsz:2 blgp:0"; + case InstType::INST_F8_F4: + return " cbsz:0 blgp:4"; + case InstType::INST_F4_F8: + return " cbsz:4 blgp:0"; + case InstType::INST_F6_B6: + return " cbsz:2 blgp:3"; + case InstType::INST_B6_F6: + return " cbsz:3 blgp:2"; + case InstType::INST_F6_F4: + return " cbsz:2 blgp:4"; + case InstType::INST_F4_F6: + return " cbsz:4 blgp:2"; + case InstType::INST_B6_F4: + return " cbsz:3 blgp:4"; + case InstType::INST_F4_B6: + return " cbsz:4 blgp:3"; + default: + break; + } + } + return ""; + } + + std::string wmmaInputPermuteStr() const { constexpr size_t f4_t = 32; std::string inputPermuteStr = ""; @@ -706,8 +801,25 @@ namespace rocisa break; } - return acc->toString() + ", " + a->toString() + ", " + b->toString() + ", " - + acc2->toString() + ", " + mxsa->toString() + ", " + mxsb->toString() + inputPermuteStr; + return inputPermuteStr; + } + + std::string getArgStr() const + { + if(getAsmCaps()["HasMFMA"]) + { + std::string mxsaStr = mxsa ? mxsa->toString() : ""; + std::string mxsbStr = mxsb ? mxsb->toString() : ""; + return acc->toString() + ", " + a->toString() + ", " + b->toString() + ", " + + acc2->toString() + ", " + mxsaStr + ", " + mxsbStr + + mfmaInputPermuteStr(); + } + else + { + return acc->toString() + ", " + a->toString() + ", " + b->toString() + ", " + + acc2->toString() + ", " + mxsa->toString() + ", " + mxsb->toString() + + wmmaInputPermuteStr(); + } } std::string toString() const override @@ -717,6 +829,14 @@ namespace rocisa setMsb(kStr, {a, b, acc2}, acc); return formatWithComment(kStr); } + + int getIssueLatency() const override + { + auto dataType = instTypeToDataType(instType); + auto [issueLatency, miLatency] = getMFMAIssueLatency( + dataType, variant[0], variant.size() > 3 ? variant[3] : 1); + return issueLatency; + } }; struct SMFMAInstruction : public Instruction diff --git a/projects/hipblaslt/tensilelite/rocisa/rocisa/src/count.cpp b/projects/hipblaslt/tensilelite/rocisa/rocisa/src/count.cpp index 62229f7aef5..f1d9f148146 100644 --- a/projects/hipblaslt/tensilelite/rocisa/rocisa/src/count.cpp +++ b/projects/hipblaslt/tensilelite/rocisa/rocisa/src/count.cpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2025 Advanced Micro Devices, Inc. + * Copyright (C) 2025-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -52,7 +52,7 @@ namespace rocisa int count = 0; for(const auto& i : ptr->itemList) { - count += countX(i); + count += countX(i, weights); } return count; } diff --git a/projects/hipblaslt/tensilelite/rocisa/rocisa/src/enum.cpp b/projects/hipblaslt/tensilelite/rocisa/rocisa/src/enum.cpp index a8eea5aff50..ddc5814ed88 100644 --- a/projects/hipblaslt/tensilelite/rocisa/rocisa/src/enum.cpp +++ b/projects/hipblaslt/tensilelite/rocisa/rocisa/src/enum.cpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2025 Advanced Micro Devices, Inc. + * Copyright (C) 2025-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -137,6 +137,7 @@ void init_enum(nb::module_ m) .value("INST_B6_B8", rocisa::InstType::INST_B6_B8) .value("INST_CVT", rocisa::InstType::INST_CVT) .value("INST_MACRO", rocisa::InstType::INST_MACRO) + .value("INST_B192", rocisa::InstType::INST_B192) .value("INST_NOTYPE", rocisa::InstType::INST_NOTYPE) .export_values(); diff --git a/projects/hipblaslt/tensilelite/rocisa/rocisa/src/functions/f_math.cpp b/projects/hipblaslt/tensilelite/rocisa/rocisa/src/functions/f_math.cpp index 0bb0e4bc686..26daa9466c2 100644 --- a/projects/hipblaslt/tensilelite/rocisa/rocisa/src/functions/f_math.cpp +++ b/projects/hipblaslt/tensilelite/rocisa/rocisa/src/functions/f_math.cpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2025 Advanced Micro Devices, Inc. + * Copyright (C) 2025-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -307,29 +307,25 @@ namespace rocisa } return module; } - template std::shared_ptr vectorMultiplyBpe(std::string, std::string, float, const std::string&); template std::shared_ptr vectorMultiplyBpe(int, int, float, const std::string&); - template std::shared_ptr vectorMultiply64Bpe(int, int, float, int, const std::string&); - template std::shared_ptr scalarMultiplyBpe(int, int, float, const std::string&); template std::shared_ptr scalarMultiplyBpe(std::string, std::string, float, const std::string&); template std::shared_ptr scalarMultiplyBpe(int, std::string, float, const std::string&); - + template std::shared_ptr + scalarMultiplyBpe(std::string, int, float, const std::string&); template std::shared_ptr scalarMultiply64Bpe(int, int, float, int, const std::string&); template std::shared_ptr scalarMultiply64Bpe(std::string, std::string, float, int, const std::string&); - } // namespace rocisa - void math_func(nb::module_ m) { m.def("vectorStaticDivideAndRemainder", @@ -458,7 +454,6 @@ void math_func(nb::module_ m) nb::arg("tmpVgprRes") = std::nullopt, nb::arg("tmpSgprRes") = std::nullopt, nb::arg("comment") = ""); - m.def("scalarStaticDivideAndRemainder", nb::overload_cast, int>( &rocisa::scalarStaticDivideAndRemainder), @@ -725,6 +720,13 @@ void math_func(nb::module_ m) nb::arg("src"), nb::arg("bpe"), nb::arg("comment") = ""); + m.def("scalarMultiplyBpe", + nb::overload_cast( + &rocisa::scalarMultiplyBpe), + nb::arg("dst"), + nb::arg("src"), + nb::arg("bpe"), + nb::arg("comment") = ""); m.def("scalarMultiply64Bpe", nb::overload_cast( &rocisa::scalarMultiply64Bpe), @@ -741,4 +743,5 @@ void math_func(nb::module_ m) nb::arg("bpe"), nb::arg("tmp"), nb::arg("comment") = ""); -} \ No newline at end of file +} + diff --git a/projects/hipblaslt/tensilelite/rocisa/rocisa/src/instruction/mfma.cpp b/projects/hipblaslt/tensilelite/rocisa/rocisa/src/instruction/mfma.cpp index 68fd813c1b1..bb746e79245 100644 --- a/projects/hipblaslt/tensilelite/rocisa/rocisa/src/instruction/mfma.cpp +++ b/projects/hipblaslt/tensilelite/rocisa/rocisa/src/instruction/mfma.cpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2025 Advanced Micro Devices, Inc. + * Copyright (C) 2025-2026 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -161,9 +161,13 @@ void mfma_inst(nb::module_ m_mfma) }); nb::class_(m_mfma, "MXMFMAInstruction") + // The C++ constructor parameter order was reshuffled (mxScaleA/BType moved + // from positions 3-4 to after mxsb). nb::kw_only() forces every Python + // caller to spell out argument names, which prevents a silent positional + // mis-binding if the C++ signature changes again. The lone in-tree caller + // (KernelWriterAssembly.MXMFMAInstruction(...)) already uses keyword args, + // so this is non-breaking. .def(nb::init&, const std::shared_ptr&, @@ -172,21 +176,24 @@ void mfma_inst(nb::module_ m_mfma) const std::shared_ptr&, const std::shared_ptr&, const std::shared_ptr&, + rocisa::InstType, + rocisa::InstType, int, const std::string&>(), + nb::kw_only(), nb::arg("instType"), nb::arg("accType"), - nb::arg("mxScaleAType"), - nb::arg("mxScaleBType"), nb::arg("variant"), nb::arg("acc"), nb::arg("a"), nb::arg("b"), - nb::arg("acc2"), - nb::arg("mxsa"), - nb::arg("mxsb"), - nb::arg("block"), - nb::arg("comment") = "") + nb::arg("acc2") = nullptr, + nb::arg("mxsa") = nullptr, + nb::arg("mxsb") = nullptr, + nb::arg("mxScaleAType") = rocisa::InstType::INST_F32, + nb::arg("mxScaleBType") = rocisa::InstType::INST_F32, + nb::arg("block") = 0, + nb::arg("comment") = "") .def_rw("a", &rocisa::MXMFMAInstruction::a) .def_rw("b", &rocisa::MXMFMAInstruction::b) .def_rw("mxsa", &rocisa::MXMFMAInstruction::mxsa) @@ -194,6 +201,7 @@ void mfma_inst(nb::module_ m_mfma) .def_rw("acc", &rocisa::MXMFMAInstruction::acc) .def_rw("acc2", &rocisa::MXMFMAInstruction::acc2) .def("getParams", &rocisa::MXMFMAInstruction::getParams) + .def("getIssueLatency", &rocisa::MXMFMAInstruction::getIssueLatency) .def("__str__", &rocisa::MXMFMAInstruction::toString) .def("__deepcopy__", [](const rocisa::MXMFMAInstruction& self, const nb::dict&) { return new rocisa::MXMFMAInstruction(self); diff --git a/projects/hipblaslt/tensilelite/src/ContractionProblem.cpp b/projects/hipblaslt/tensilelite/src/ContractionProblem.cpp index f2b2aceddc6..0ecaca85e37 100644 --- a/projects/hipblaslt/tensilelite/src/ContractionProblem.cpp +++ b/projects/hipblaslt/tensilelite/src/ContractionProblem.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -584,6 +584,8 @@ namespace TensileLite = TensorDescriptor("scaleAlphaVec"); gemm.m_tensors[ContractionProblemGemm::TENSOR::METADATA] = TensorDescriptor("metadata"); gemm.m_tensors[ContractionProblemGemm::TENSOR::AMAXD] = TensorDescriptor("amaxD"); + gemm.m_tensors[ContractionProblemGemm::TENSOR::MXSA] = TensorDescriptor("mxScaleA"); + gemm.m_tensors[ContractionProblemGemm::TENSOR::MXSB] = TensorDescriptor("mxScaleB"); gemm.m_tensors[ContractionProblemGemm::TENSOR::COMPRESSED] = TensorDescriptor("compressed"); gemm.m_tensors[ContractionProblemGemm::TENSOR::MXSA] = TensorDescriptor("mx-a"); gemm.m_tensors[ContractionProblemGemm::TENSOR::MXSB] = TensorDescriptor("mx-b"); @@ -686,7 +688,7 @@ namespace TensileLite normalize(); calcArithmeticIntensity(); } - + void ContractionProblemGemm::setMXScaleA(rocisa::DataType mxTypeA, int mxBlockA, std::vector saStride) { m_mxBlockA = mxBlockA; @@ -1281,10 +1283,16 @@ namespace TensileLite gflop += 2 * cSize * 1e-9; // Include (+ beta * C) in gflops cSize *= 2; // Include read C and write D in gbytes } + // TODO: for MX data types, the size is smaller than a byte + // so we need to use (elementSize/packing) to derive the actual + // byte size of a segment. + auto infoA = DataTypeInfo::Get(a().dataType()); + auto infoB = DataTypeInfo::Get(b().dataType()); + auto infoC = DataTypeInfo::Get(c().dataType()); double gbyte - = (multiplyElementSize(aSize, a().elementBytes()) + - multiplyElementSize(bSize, b().elementBytes()) + - multiplyElementSize(cSize, c().elementBytes())) + = ((aSize * infoA.elementSize / infoA.packing) + + (bSize * infoB.elementSize / infoB.packing) + + (cSize * infoC.elementSize / infoC.packing)) * 1e-9; m_arithmeticIntensity = gflop / gbyte; @@ -1496,7 +1504,7 @@ namespace TensileLite rocisa::DataType typeAlpha, rocisa::DataType typeBeta, rocisa::DataType typeComputeInputA, - rocisa::DataType typeComputeInputB, + rocisa::DataType typeComputeInputB, rocisa::DataType typeCompute, double alpha, double beta, diff --git a/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp b/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp index 1b74af9dc68..e5c7eb3d74f 100644 --- a/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp +++ b/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -744,7 +744,7 @@ namespace TensileLite // Clamp minimum iters per tile to 1 to allow stream-k index calculation to work in case K==0 // In this case no actual iterations will be run, but workgroups will be mapped correctly for beta*C - auto itersPerTile = max(1, problem.getItersPerTile(sizeMapping)); + auto itersPerTile = std::max(size_t{1}, problem.getItersPerTile(sizeMapping)); auto totalIters = tiles * itersPerTile; uint32_t magicNumberItersPerTile; @@ -763,7 +763,7 @@ namespace TensileLite if(sizeMapping.streamK == 1) // Basic SK { - uint32_t itersPerWave = CeilDivide(totalIters, numWorkGroups.x); + uint32_t itersPerWave = CeilDivide(static_cast(totalIters), static_cast(numWorkGroups.x)); args.template append("SKItersPerWG", itersPerWave); } else if(sizeMapping.streamK >= 2) // Two-tile SK @@ -798,7 +798,7 @@ namespace TensileLite // dpTilesPerWG = bigEnough ? (tiles - skTiles) / skGrid : 0; skTiles = bigEnough ? sk.grid * fullTiles + tiles % sk.grid : tiles; // Cap Stream-K tiles at total number of tiles in case of large multiplier - skTiles = min(skTiles, tiles); + skTiles = std::min(skTiles, static_cast(tiles)); } uint32_t skItersPerWG = skTiles * itersPerTile / sk.grid; @@ -826,7 +826,9 @@ namespace TensileLite autoStaggerUStrideShift, autoGsuVal); - if(!problemType.useScaleAB.empty()) //kernel input data + // NOTE: an assumption here is A & B must be both MX data types or non-MX data types. + // Mixing is not supported. + if(!problemType.useScaleAB.empty()) { args.template append("scaleA", inputs.scaleA); args.template append("scaleB", inputs.scaleB); @@ -1215,20 +1217,20 @@ namespace TensileLite uint32_t N = problem.freeSizeB(0); uint32_t B = problem.batchSize(0); uint32_t K = problem.boundSize(0); - uint32_t GSULimit1 = max(1, (uint32_t)std::floor(numCUs / numWGs)); - uint32_t GSULimit2 = max(1, (uint32_t)std::floor((float)K / (float)MT2 / 3.0)); - uint32_t gsuVal = min(GSULimit2, max(1, GSULimit1)); + uint32_t GSULimit1 = std::max(1u, (uint32_t)std::floor(numCUs / numWGs)); + uint32_t GSULimit2 = std::max(1u, (uint32_t)std::floor((float)K / (float)MT2 / 3.0)); + uint32_t gsuVal = std::min(GSULimit2, std::max(1u, GSULimit1)); // WorkgroupNumberCheck #define MAX_WORKGROUP_NUMBER 16777216 if(gsuVal > 1) - gsuVal = min(gsuVal, - MAX_WORKGROUP_NUMBER / std::ceil(static_cast(M) / MT0) - / std::ceil(static_cast(N) / MT1) / B); + gsuVal = std::min(gsuVal, + static_cast(MAX_WORKGROUP_NUMBER / std::ceil(static_cast(M) / MT0) + / std::ceil(static_cast(N) / MT1) / B)); // GlobalSplitUCheckMinK if(gsuVal > 1) - gsuVal = min(gsuVal, std::ceil(static_cast(K) / MT2)); + gsuVal = std::min(gsuVal, static_cast(std::ceil(static_cast(K) / MT2))); // SynchronizerSizeCheck if(gsuVal > 1 && sizeMapping.globalAccumulation == 3) // MBSK @@ -1251,10 +1253,10 @@ namespace TensileLite uint32_t workGroupSize = sizeMapping.workGroupSize.x * sizeMapping.workGroupSize.y * sizeMapping.workGroupSize.z; uint32_t maxGsuValue = (std::numeric_limits::max() / workGroupSize) / tiles; - gsuVal = min(gsuVal, maxGsuValue); + gsuVal = std::min(gsuVal, maxGsuValue); // avoid gsu < 1 - gsuVal = max(gsuVal, 1); + gsuVal = std::max(gsuVal, 1u); static const char* envStr = std::getenv("TENSILE_AUTO_GSU_ALGO"); if(envStr != NULL) @@ -2482,7 +2484,7 @@ namespace TensileLite } int factorDim - = max(problemType.useGradient ? 0 : problemType.useBias, problemType.useScaleAlphaVec); + = std::max(problemType.useGradient ? 0 : problemType.useBias, problemType.useScaleAlphaVec); if(factorDim) { if(factorDim == 2) @@ -3504,7 +3506,7 @@ namespace TensileLite // whichever is minimum. else if(pAMDGPU->skMaxCUs > 0) { - skGrid = min(cuCount, pAMDGPU->skMaxCUs); + skGrid = std::min(cuCount, static_cast(pAMDGPU->skMaxCUs)); } // Multiply the cuCount with a constant factor (c), and launch diff --git a/projects/hipblaslt/tensilelite/src/DataTypes.cpp b/projects/hipblaslt/tensilelite/src/DataTypes.cpp index d987ae07e58..083dc2e4d30 100644 --- a/projects/hipblaslt/tensilelite/src/DataTypes.cpp +++ b/projects/hipblaslt/tensilelite/src/DataTypes.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -149,13 +149,21 @@ namespace rocisa return TensileLite::TypeInfo::ElementSize; case rocisa::DataType::BFloat8Float8_fnuz: return TensileLite::TypeInfo::ElementSize; +#ifdef _WIN32 + case rocisa::DataType::Float6: + return TensileLite::TypeInfo::ElementSize; + case rocisa::DataType::BFloat6: + return TensileLite::TypeInfo::ElementSize; + case rocisa::DataType::Float4: + return TensileLite::TypeInfo::ElementSize; +#else // _WIN32 #ifdef TENSILE_USE_FP6 case rocisa::DataType::Float6: - return TensileLite::TypeInfo::ElementSize; + return TensileLite::TypeInfo::ElementSize; #endif // #ifdef TENSILE_USE_FP6 #ifdef TENSILE_USE_BF6 case rocisa::DataType::BFloat6: - return TensileLite::TypeInfo::ElementSize; + return TensileLite::TypeInfo::ElementSize; #endif // #ifdef TENSILE_USE_BF6 #ifdef TENSILE_USE_FP4 case rocisa::DataType::Float4: @@ -163,16 +171,17 @@ namespace rocisa #endif // #ifdef TENSILE_USE_FP4 #ifndef TENSILE_USE_FP6 case rocisa::DataType::Float6: - return 12.f / 16.f; // same as TypeInfo, 16 x 6-bit in 12 bytes + return 24.f / 32.f; // same as TypeInfo, 32 x 6-bit in 24 bytes #endif #ifndef TENSILE_USE_BF6 case rocisa::DataType::BFloat6: - return 12.f / 16.f; + return 24.f / 32.f; #endif #ifndef TENSILE_USE_FP4 case rocisa::DataType::Float4: return 0.5f; // TypeInfo: 2 x fp4 in 1 byte #endif +#endif // _WIN32 case rocisa::DataType::E8: return TensileLite::TypeInfo::ElementSize; case rocisa::DataType::E5M3: @@ -296,15 +305,21 @@ namespace TensileLite registerTypeInfo(); registerTypeInfo(); registerTypeInfo(); +#ifdef _WIN32 + registerTypeInfo(); + registerTypeInfo(); + registerTypeInfo(); +#else // _WIN32 #ifdef TENSILE_USE_FP6 - registerTypeInfo(); + registerTypeInfo(); #endif // #ifdef TENSILE_USE_FP6 #ifdef TENSILE_USE_BF6 - registerTypeInfo(); + registerTypeInfo(); #endif // #ifdef TENSILE_USE_BF6 #ifdef TENSILE_USE_FP4 registerTypeInfo(); #endif // #ifdef TENSILE_USE_FP4 +#endif // _WIN32 registerTypeInfo(); registerTypeInfo(); diff --git a/projects/hipblaslt/tensilelite/tests/CMakeLists.txt b/projects/hipblaslt/tensilelite/tests/CMakeLists.txt index fd94b1d4456..0fc098762b9 100644 --- a/projects/hipblaslt/tensilelite/tests/CMakeLists.txt +++ b/projects/hipblaslt/tensilelite/tests/CMakeLists.txt @@ -49,6 +49,14 @@ target_link_libraries(tensilelite-tests ${hipblaslt_msgpack_target} ) +if(HIPBLASLT_ENABLE_MXDATAGENERATOR) + target_sources(tensilelite-tests + PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/MXDataGen_test.cpp" + ) + target_link_libraries(tensilelite-tests PRIVATE hipblaslt::mxdatagen) +endif() + gtest_discover_tests(tensilelite-tests WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} TIMEOUT 60) target_link_libraries(tensilelite-tests PUBLIC GTest::gtest) diff --git a/projects/hipblaslt/tensilelite/tests/MXDataGen_test.cpp b/projects/hipblaslt/tensilelite/tests/MXDataGen_test.cpp new file mode 100644 index 00000000000..0fed9f0493c --- /dev/null +++ b/projects/hipblaslt/tensilelite/tests/MXDataGen_test.cpp @@ -0,0 +1,255 @@ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include + +#include +#include + +/** + * @brief Returns true if a 4-bit FP4 E2M1 nibble represents zero. + * + * FP4 E2M1 values are packed two-per-byte (low nibble first). + * Both 0x0 (+0) and 0x8 (-0) decode to zero. + */ +static bool isZeroNibble(uint8_t nibble) +{ + // FP4 E2M1: 0x0 = +0.0, 0x8 = -0.0 + return (nibble == 0x0) || (nibble == 0x8); +} + +/** + * @brief Count elements that decode to zero in a packed FP4 buffer. + */ +static size_t countZerosFP4(const uint8_t* packedData, size_t numPackedBytes) +{ + size_t zeros = 0; + for(size_t i = 0; i < numPackedBytes; ++i) + { + uint8_t lo = packedData[i] & 0x0F; + uint8_t hi = (packedData[i] >> 4) & 0x0F; + if(isZeroNibble(lo)) + ++zeros; + if(isZeroNibble(hi)) + ++zeros; + } + return zeros; +} + +class MXDataGenFP4Test : public ::testing::TestWithParam> +{ +}; + +/** + * @brief Verify that generateMXInput produces FP4 data with an acceptable zero frequency. + * + * FP4 E2M1 has 16 nibble values, 2 of which are zero (0x0 = +0, 0x8 = -0), giving a + * naive baseline of 2/16 = 12.5%. MX block scaling slightly elevates this: the block + * maximum is guaranteed non-zero, pushing small elements toward zero. Empirically the + * zero frequency converges to ~12.89% for large matrices with bounded [-1, 1] input. + */ +TEST_P(MXDataGenFP4Test, ZeroFrequencyWithinBounds) +{ + auto [rows, cols, mxBlock, isTranspose] = GetParam(); + + const uint64_t numElements = rows * cols; + const uint64_t numPacked = (numElements + 1) / 2; + const size_t numScales = ((rows + mxBlock - 1) / mxBlock) * cols; + + std::vector dataBuffer(numPacked, 0); + std::vector scaleBuffer(numScales, 0); + + std::vector emptySwizzle; + std::vector emptyTile; + + generateMXInput((hipDataType)HIP_R_4F_E2M1, + dataBuffer.data(), + scaleBuffer.data(), + rows, + cols, + rows, // stride = rows (column-major) + isTranspose, + emptySwizzle, + emptyTile, + mxBlock, + 1, + true, + "Bounded", + -1.0f, + 1.0f); + + size_t zeros = countZerosFP4(dataBuffer.data(), numPacked); + double zeroPercent = 100.0 * static_cast(zeros) / static_cast(numElements); + + EXPECT_LT(zeroPercent, 13.0) + << "Zero frequency " << zeroPercent << "% exceeds 13% upper bound for " + << rows << "x" << cols << " FP4 matrix (transpose=" << isTranspose << ")"; + + // Ensure non-trivial data was actually generated (not all zeros) + EXPECT_GT(numElements - zeros, 0u) + << "All elements are zero for " << rows << "x" << cols << " FP4 matrix"; +} + +INSTANTIATE_TEST_SUITE_P( + FP4ZeroFrequency, + MXDataGenFP4Test, + ::testing::Values( + // rows, cols, mxBlock, isTranspose + std::make_tuple(128u, 128u, 32, true), + std::make_tuple(256u, 256u, 32, true), + std::make_tuple(2048u, 1026u, 32, true), + std::make_tuple(2048u, 514u, 32, false) + ) +); + +/** + * @brief Regression guard: generateMXInput must be deterministic (fixed seed). + * + * Any post-generation overwrite of the MXSA/MXSB buffers (e.g., the general + * tensor-init loop in initializeCPUInputs) desynchronises the CPU reference + * from GPU data, causing intermittent single-element validation failures. + * rows=K (must be mxBlock-aligned), cols=M/N (need not be). + */ +class MXGeneratorDeterminismTest + : public ::testing::TestWithParam> +{ +}; + +TEST_P(MXGeneratorDeterminismTest, GeneratorOutputIsDeterministic) +{ + auto [rows, cols, mxBlock, isTranspose, isMatrixA] = GetParam(); + + const size_t numPacked = (rows * cols + 1) / 2; + const size_t numScales = (rows / mxBlock) * cols; + + std::vector data1(numPacked); + std::vector data2(numPacked); + std::vector scale1(numScales, 0x00); + std::vector scale2(numScales, 0xFF); // sentinel: catches no-write if scale1==scale2 passes + + std::vector emptySwizzle, emptyTile; + + generateMXInput((hipDataType)HIP_R_4F_E2M1, + data1.data(), scale1.data(), + rows, cols, rows, isTranspose, + emptySwizzle, emptyTile, + mxBlock, 1, isMatrixA, "Bounded", -1.f, 1.f); + + generateMXInput((hipDataType)HIP_R_4F_E2M1, + data2.data(), scale2.data(), + rows, cols, rows, isTranspose, + emptySwizzle, emptyTile, + mxBlock, 1, isMatrixA, "Bounded", -1.f, 1.f); + + EXPECT_EQ(data1, data2) + << "FP4 data is non-deterministic"; + EXPECT_EQ(scale1, scale2) + << "Scale data is non-deterministic; any post-generation overwrite will corrupt validation"; + + bool allZero = std::all_of(scale1.begin(), scale1.end(), [](uint8_t b){ return b == 0; }); + bool allOnes = std::all_of(scale1.begin(), scale1.end(), [](uint8_t b){ return b == 0xFF; }); + EXPECT_FALSE(allZero) << "Scale buffer is all-zero — generator did not write"; + EXPECT_FALSE(allOnes) << "Scale buffer is all-0xFF (max UE8M0 value) — generator likely failed; bounded [-1,1] input should produce varied scales"; +} + +INSTANTIATE_TEST_SUITE_P( + GeneratorDeterminism, + MXGeneratorDeterminismTest, + ::testing::Values( + // rows=K, cols=M or N (tensorA.sizes()={K,M}, tensorB.sizes()={K,N}) + std::make_tuple(1024u, 128u, 32, true, true), // transposed A + std::make_tuple(1024u, 128u, 32, false, false), // non-transposed B + std::make_tuple(1024u, 204u, 32, true, true), // M=204, non-32-aligned (was failing) + std::make_tuple(1024u, 213u, 32, true, true) // M=213, non-32-aligned (was failing) + ) +); + +// ============================================================================ +// PreSwizzle scale tests +// +// Verify generateMXInput with preSwizzle produces scale data that is a +// permutation of the unswizzled layout. gfx950 FP4 MX kernels expect: +// preSwizzle = {swizzleTileMN=32, tileK=8, subTileK=MiK/mxBlock} +// preTile = {tileK=8, swizzleTileMN=32} +// swizzleTileMN=32 is fixed (2 SIMDs * 16 lanes); subTileK=4 for MiK=128, mxBlock=32. +// ============================================================================ + +// Params: {rows, cols, mxBlock, isTranspose, isMatrixA} +class MXPreSwizzleTest + : public ::testing::TestWithParam> +{ +}; + +/** @brief Verify preSwizzle produces a non-trivial permutation of scale data. */ +TEST_P(MXPreSwizzleTest, ScaleIsPermutationOfUnswizzled) +{ + auto [rows, cols, mxBlock, isTranspose, isMatrixA] = GetParam(); + + const std::vector preSwizzle = {32, 8, 4}; + const std::vector preTile = {8, 32}; + + const uint64_t numElements = rows * cols; + const uint64_t numPacked = (numElements + 1) / 2; + const size_t numScales = ((rows + mxBlock - 1) / mxBlock) * cols; + + std::vector dataNoShuf(numPacked, 0); + std::vector scaleNoShuf(numScales, 0); + std::vector dataShuf(numPacked, 0); + std::vector scaleShuf(numScales, 0); + + // Generate without preSwizzle + generateMXInput((hipDataType)HIP_R_4F_E2M1, + dataNoShuf.data(), + scaleNoShuf.data(), + rows, cols, rows, + isTranspose, + {}, {}, + mxBlock, 1, isMatrixA, + "Bounded", -1.0f, 1.0f); + + // Generate with preSwizzle + generateMXInput((hipDataType)HIP_R_4F_E2M1, + dataShuf.data(), + scaleShuf.data(), + rows, cols, rows, + isTranspose, + preSwizzle, preTile, + mxBlock, 1, isMatrixA, + "Bounded", -1.0f, 1.0f); + + // The scale buffers must be different + EXPECT_NE(scaleNoShuf, scaleShuf) + << "Scale data was not shuffled for " << rows << "x" << cols + << " (transpose=" << isTranspose << ", isMatrixA=" << isMatrixA << ")"; + + // The shuffled scale must be a permutation: same multiset of bytes + std::vector sortedNoShuf = scaleNoShuf; + std::vector sortedShuf = scaleShuf; + std::sort(sortedNoShuf.begin(), sortedNoShuf.end()); + std::sort(sortedShuf.begin(), sortedShuf.end()); + EXPECT_EQ(sortedNoShuf, sortedShuf) + << "Pre-shuffled scale is not a permutation of the unshuffled scale for " + << rows << "x" << cols; + + // Data buffer must be identical (preSwizzle only affects scale, not data) + EXPECT_EQ(dataNoShuf, dataShuf) + << "Data buffer changed unexpectedly with preSwizzle for " + << rows << "x" << cols; +} + +INSTANTIATE_TEST_SUITE_P( + FP4PreSwizzle, + MXPreSwizzleTest, + ::testing::Values( + // rows, cols, mxBlock, isTranspose, isMatrixA + // Test size constraints for preSwizzle {32,8,4} + preTile {8,32}: + // rows % 256 == 0 (scaleRows = rows/mxBlock must be divisible by tileK=8) + // cols % 32 == 0 (scaleCols must be divisible by swizzleTileMN=32) std::make_tuple(256u, 256u, 32, true, true), // scale A transposed + std::make_tuple(256u, 256u, 32, false, false), // scale B non-transposed + std::make_tuple(512u, 256u, 32, true, true), // larger scale A + std::make_tuple(256u, 512u, 32, false, false), // larger scale B + std::make_tuple(4096u, 16384u, 32, true, true) // benchmark-scale problem + ) +); diff --git a/projects/hipblaslt/tensilelite/tox.ini b/projects/hipblaslt/tensilelite/tox.ini index 8a23b401b0a..c309c9b3353 100644 --- a/projects/hipblaslt/tensilelite/tox.ini +++ b/projects/hipblaslt/tensilelite/tox.ini @@ -63,6 +63,14 @@ commands = pytest -v --basetemp={envtmpdir} {posargs} +[testenv:rocisa] +description = "Runs only rocisa tests" +basepython = python3 +commands = + pip install --upgrade pip + pip install --no-build-isolation {toxinidir}/rocisa/ + pytest -v --basetemp={envtmpdir} --junit-xml={toxinidir}/rocisa_tests.xml --junit-prefix=rocisa --color=yes -n 4 {toxinidir}/rocisa/test {posargs} + [testenv:lint] basepython = python3 deps =