Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/repos-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@
"url": "ROCm/rocPRIM",
"branch": "develop",
"category": "projects",
"auto_subtree_pull": true,
"auto_subtree_push": false,
"auto_subtree_pull": false,
"auto_subtree_push": true,
"enable_pr_fanout": false
},
{
Expand Down
13 changes: 9 additions & 4 deletions .github/workflows/pr-import.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ jobs:
run: |
PR_JSON=$(gh pr view ${{ github.event.inputs.subrepo-pr-number }} \
--repo ${{ github.event.inputs.subrepo-upstream }} \
--json title,body,headRefName,headRepository \
--jq '{title: .title, body: .body, head_ref: .headRefName, head_repo: .headRepository.cloneUrl}')
--json title,body,headRefName,headRepository,isDraft \
--jq '{title: .title, body: .body, head_ref: .headRefName, head_repo: .headRepository.cloneUrl, is_draft: .isDraft}')

echo "$PR_JSON" > pr.json

Expand All @@ -102,6 +102,7 @@ jobs:

echo "head_ref=$(jq -r .head_ref pr.json)" >> $GITHUB_OUTPUT
echo "head_repo=$(jq -r .head_repo pr.json)" >> $GITHUB_OUTPUT
echo "is_draft=$(jq -r .is_draft pr.json)" >> $GITHUB_OUTPUT

- name: Create new branch for import
id: import-branch
Expand All @@ -126,21 +127,25 @@ jobs:
SUBREPO_REPO="${{ github.event.inputs.subrepo-repo }}"
SUBREPO_PR_NUMBER="${{ github.event.inputs.subrepo-pr-number }}"
SUBREPO_URL="https://github.com/$SUBREPO_REPO/pull/$SUBREPO_PR_NUMBER"

AUTHOR=$(gh pr view "$SUBREPO_PR_NUMBER" --repo "$UPSTREAM_REPO" --json author --jq .author.login)

echo "${{ steps.prdata.outputs.body }}" > pr_body.txt

{
echo ""
echo "---"
echo "🔁 Imported from [$SUBREPO_REPO#$SUBREPO_PR_NUMBER]($SUBREPO_URL)"
echo "🧑‍💻 Originally authored by @$AUTHOR"
} >> pr_body.txt

DRAFT_FLAG=""
if [[ "${{ steps.prdata.outputs.is_draft }}" == "true" ]]; then
DRAFT_FLAG="--draft"
fi

gh pr create \
--base develop \
--head "${{ steps.import-branch.outputs.import_branch }}" \
--title "$PR_TITLE" \
--label "imported pr" \
$DRAFT_FLAG \
--body-file pr_body.txt
14 changes: 13 additions & 1 deletion .github/workflows/update-subtrees.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ env:
MONOREPO_URL: github.com/ROCm/rocm-libraries.git
MONOREPO_BRANCH: develop

concurrency:
group: pr-update-subtrees-develop
cancel-in-progress: false

jobs:
synchronize-subtrees:
runs-on: ubuntu-24.04
Expand All @@ -32,6 +36,10 @@ jobs:
git config user.name "assistant-librarian[bot]"
git config user.email "assistant-librarian[bot]@users.noreply.github.com"

- name: Switch to the Monorepo branch
run: |
git checkout -B "${{ env.MONOREPO_BRANCH }}" "origin/${{ env.MONOREPO_BRANCH }}"

- name: Update Repositories in the Monorepo
run: |
has_errors=false
Expand All @@ -47,9 +55,13 @@ jobs:
}
fi
if [ "$enable_push" = true ]; then
git subtree push --prefix "${category}/${repo}" https://github.com/${url}.git $branch --quiet || {
git fetch origin subtrees/${repo}/${branch}
git branch -f subtrees/${repo}/${branch} origin/subtrees/${repo}/${branch}
git subtree split --prefix="${category}/${repo}" -b subtrees/${repo}/${branch} --quiet --rejoin || {
has_errors=true
}
git push origin subtrees/${repo}/${branch}
git push https://github.com/${url}.git subtrees/${repo}/${branch}:${branch}
fi
done

Expand Down
11 changes: 6 additions & 5 deletions projects/hipblaslt/clients/benchmarks/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ try
std::vector<int64_t> stride_a, stride_b, stride_c, stride_d, stride_e;
std::vector<uint32_t> gsu_vector, wgm_vector;
arg.init(); // set all defaults
const char* tuningEnv = getenv("HIPBLASLT_TUNING_FILE");
const char* tuningEnv = getenv("HIPBLASLT_TUNING_FILE");
const char* tuningMaxWorkSpace = getenv("HIPBLASLT_TUNING_USER_MAX_WORKSPACE");
if(tuningEnv)
{
bool tuning_success = tuning_path_compare_git_version(tuningEnv);
Expand Down Expand Up @@ -466,7 +467,7 @@ try
"Cold Iterations to run before entering the timing loop")

("algo_method",
value<std::string>(&algo_method_str)->default_value(tuningEnv? "all" : "heuristic"),
value<std::string>(&algo_method_str)->default_value("heuristic"),
"Use different algorithm search API. Options: heuristic, all, index.")

("solution_index",
Expand Down Expand Up @@ -578,7 +579,7 @@ try
"C and D are stored in same memory")

("workspace",
value<size_t>(&arg.user_allocated_workspace)->default_value(128 * 1024 * 1024),
value<size_t>(&arg.user_allocated_workspace)->default_value(tuningEnv && tuningMaxWorkSpace ? atoi(tuningMaxWorkSpace) : 128 * 1024 * 1024),
"Set fixed workspace memory size (bytes) instead of using hipblaslt managed memory")

("log_function_name",
Expand Down Expand Up @@ -671,11 +672,11 @@ try
}
else if(algo_method_str.compare("all") == 0)
{
arg.algo_method = 1;
arg.algo_method = tuningEnv ? 0 : 1;
}
else if(algo_method_str.compare("index") == 0)
{
arg.algo_method = tuningEnv ? 1 : 2;
arg.algo_method = tuningEnv ? 0 : 2;
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,18 @@ To find and use the best GEMM kernel for a problem, follow these steps:

export HIPBLASLT_TUNING_FILE=tuning.txt

Additionally, you can set the environment variable to specify that the solution found in the tuning stage is under the constraint of the max workspace size setting:

.. code-block:: bash

export HIPBLASLT_TUNING_USER_MAX_WORKSPACE=<value> (Default value is: 128 * 1024 * 1024)

The default settings for the following parameters in ``hipblaslt-bench`` are changed in the tuning environment.

.. code-block:: bash

--iters |-i <value> (Default value is: 1000)
--cold_iters |-j <value> (Default value is: 1000)
--algo_method <value> (Default value is: all)
--requested_solution <value> (Default value is: -1)
--rotating <value> (Default value is: 512)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
*
* MIT License
*
* Copyright (C) 2022-2024 Advanced Micro Devices, Inc.
* Copyright (C) 2022-2025 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -387,13 +387,13 @@ rocblaslt_status rocblaslt_is_algo_supported_cpp(rocblaslt_handle h
std::shared_ptr<void> gemmData,
rocblaslt_matmul_algo& algo,
const rocblaslt::RocTuningV2* tuning,
size_t& workspaceSizeInBytes);
size_t& workspaceSizeInBytes);

rocblaslt_status
rocblaslt_algo_get_heuristic_cpp(rocblaslt_handle handle,
rocblaslt::RocGemmType gemmType,
std::shared_ptr<void> gemmData,
const int workspaceBytes,
const size_t maxWorkspaceBytes,
const int requestedAlgoCount,
std::vector<rocblaslt_matmul_heuristic_result>& results);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,11 @@ inline bool

// Preload problem/solution mappings
bool problem_override_from_file(rocblaslt_handle& handle,
rocblaslt_matmul_preference& pref,
RocblasltContractionProblem& problem,
rocblaslt_matmul_desc& matmul_desc,
rocblaslt_matmul_heuristic_result heuristicResultsArray[],
const std::string& file_path)
const std::string& file_path,
size_t max_workspace_bytes)
{

bool success = false;
Expand All @@ -170,15 +170,13 @@ bool problem_override_from_file(rocblaslt_handle& handle,
TensileLite::ProblemOverride prob_key(RocblasltContractionProblem2ProblemOverride(problem));
auto sol_iter = m_override.find(prob_key);

for(auto sol_idx = std::make_reverse_iterator(sol_iter.second);
!success && sol_idx != std::make_reverse_iterator(sol_iter.first);
sol_idx++)
for(auto sol_idx = sol_iter.first; !success && sol_idx != sol_iter.second; sol_idx++)
{
solutionIndex[0] = sol_idx->second;

if(rocblaslt_status_success
== getSolutionsFromIndex(
handle, solutionIndex, overrideResults, pref->max_workspace_bytes))
handle, solutionIndex, overrideResults, max_workspace_bytes))
{

size_t required_workspace_size = 0;
Expand Down Expand Up @@ -218,7 +216,7 @@ bool problem_override_from_file(rocblaslt_handle& handle,

heuristicResult_copy(&heuristicResultsArray[0],
&overrideResults[0],
pref->max_workspace_bytes,
max_workspace_bytes,
required_workspace_size);
}
}
Expand All @@ -243,9 +241,9 @@ bool problem_override_from_file_cpp(
rocblaslt_handle& handle,
rocblaslt::RocGemmType& gemmType,
std::shared_ptr<void> gemmData,
size_t workspaceSizeInBytes,
std::vector<rocblaslt_matmul_heuristic_result>& heuristicResultsArray,
const std::string& file_path)
const std::string& file_path,
size_t max_workspace_bytes)
{

bool success = false;
Expand All @@ -263,23 +261,24 @@ bool problem_override_from_file_cpp(
TensileLite::ProblemOverride prob_key(TensileDataGemm2ProblemOverride(gemmData));
auto sol_iter = m_override.find(prob_key);

for(auto sol_idx = std::make_reverse_iterator(sol_iter.second);
!success && sol_idx != std::make_reverse_iterator(sol_iter.first);
sol_idx++)
for(auto sol_idx = sol_iter.first; !success && sol_idx != sol_iter.second; sol_idx++)
{
solutionIndex[0] = sol_idx->second;
size_t maxWorkspaceSize = std::numeric_limits<size_t>::max();
solutionIndex[0] = sol_idx->second;
if(rocblaslt_status_success
== getSolutionsFromIndex(handle, solutionIndex, overrideResults, maxWorkspaceSize))
== getSolutionsFromIndex(
handle, solutionIndex, overrideResults, max_workspace_bytes))
{
rocblaslt::RocTuningV2* tuning = nullptr;

size_t required_workspace_size = 0;
rocblaslt::RocTuningV2* tuning = nullptr;

if(rocblaslt_status_success
== isSolutionSupported(handle,
static_cast<const rocblaslt::RocGemmType>(gemmType),
gemmData,
overrideResults[0].algo,
tuning,
workspaceSizeInBytes))
required_workspace_size))
{
success = true;
}
Expand All @@ -296,7 +295,7 @@ bool problem_override_from_file_cpp(
gemmData,
overrideResults[0].algo,
tuning,
workspaceSizeInBytes))
required_workspace_size))
{
success = true;
log_info(__func__, "Use the fallback fp32 solution");
Expand All @@ -306,7 +305,7 @@ bool problem_override_from_file_cpp(

if(success)
{
overrideResults[0].workspaceSize = workspaceSizeInBytes;
overrideResults[0].workspaceSize = required_workspace_size;
heuristicResultsArray.push_back(overrideResults[0]);
}
}
Expand Down Expand Up @@ -1802,8 +1801,12 @@ rocblaslt_status
bool override_success = false;
if(override.env_mode)
{
override_success = problem_override_from_file(
handle, pref, prob, matmul_desc, heuristicResultsArray, override.file_path);
override_success = problem_override_from_file(handle,
prob,
matmul_desc,
heuristicResultsArray,
override.file_path,
pref->max_workspace_bytes);
if(override_success)
requestedAlgoCount--;

Expand Down Expand Up @@ -2037,7 +2040,7 @@ rocblaslt_status
rocblaslt_algo_get_heuristic_cpp(rocblaslt_handle handle,
rocblaslt::RocGemmType gemmType,
std::shared_ptr<void> gemmData,
const int workspaceBytes,
const size_t maxWorkspaceBytes,
const int requestedAlgoCount,
std::vector<rocblaslt_matmul_heuristic_result>& results)
{
Expand All @@ -2062,7 +2065,7 @@ rocblaslt_status
if(override.env_mode)
{
override_success = problem_override_from_file_cpp(
handle, gemmType, gemmData, workspaceBytes, override_result, override.file_path);
handle, gemmType, gemmData, override_result, override.file_path, maxWorkspaceBytes);

log_api(__func__, "OverrideAlgoCount", override_success ? 1 : 0);
}
Expand All @@ -2072,7 +2075,7 @@ rocblaslt_status
= getBestSolutions(handle,
gemmType,
gemmData,
workspaceBytes,
maxWorkspaceBytes,
override_success ? requestedAlgoCount - 1 : requestedAlgoCount,
results);

Expand All @@ -2091,10 +2094,10 @@ rocblaslt_status
if(requestedAlgoCount > results.size())
{
std::vector<rocblaslt_matmul_heuristic_result> allSolutionsResults;
size_t workspaceSizeInBytes = workspaceBytes;
size_t workspaceSizeInBytes = 0;
if(rocblaslt_status_success
== getAllSolutions(
gemmData, handle, gemmType, allSolutionsResults, workspaceSizeInBytes))
gemmData, handle, gemmType, allSolutionsResults, maxWorkspaceBytes))
{
int oriReturnAlgoCount = results.size();
for(int i = 0;
Expand Down
2 changes: 1 addition & 1 deletion projects/rocprim/rmake.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/python3
""" Copyright (c) 2021-2025 Advanced Micro Devices, Inc. All rights reserved.
""" Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
Manage build and installation"""

import re
Expand Down
Loading