Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ prepare_environment_for_run(parser_data_t& _data)
rocprofsys::argparse::add_ld_preload(_data);
rocprofsys::argparse::add_ld_library_path(_data);
}

rocprofsys::argparse::add_torch_library_path(_data, _data.verbose > 0);

rocprofsys::common::consolidate_env_entries(_data.current);
}

void
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -933,3 +933,9 @@ parse_args(int argc, char** argv, std::vector<char*>& _env)

return _outv;
}

void
add_torch_library_path(std::vector<char*>& envp, const std::vector<char*>& argv)
{
rocprofsys::common::add_torch_library_path(envp, argv, verbose > 0, updated_envs);
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ main(int argc, char** argv)
_argv.emplace_back(argv[i]);
}

add_torch_library_path(_env, _argv);

print_updated_environment(_env);

if(!_argv.empty())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ get_initial_environment();

std::vector<char*>
parse_args(int argc, char** argv, std::vector<char*>& envp);

void
add_torch_library_path(std::vector<char*>& envp, const std::vector<char*>& argv);
254 changes: 252 additions & 2 deletions projects/rocprofiler-systems/source/lib/common/environment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@

#include "common/join.hpp"
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <string>
#include <string_view>
Expand Down Expand Up @@ -197,7 +200,7 @@ remove_env(std::vector<char*>& _environ, std::string_view _env_var,
{
if(match(itr))
{
free(itr);
std::free(itr);
itr = nullptr;
}
}
Expand Down Expand Up @@ -266,6 +269,113 @@ discover_llvm_libdir_for_ompt(bool verbose = false)
return {};
}

inline bool
is_python_interpreter(std::string_view executable)
{
if(executable.empty()) return false;

const auto slash_pos = executable.rfind('/');
const auto basename = (slash_pos != std::string_view::npos)
? executable.substr(slash_pos + 1)
: executable;

if(basename == "python" || basename == "python3") return true;

constexpr std::string_view python3_prefix = "python3.";

const bool has_valid_prefix =
basename.size() > python3_prefix.size() &&
basename.substr(0, python3_prefix.size()) == python3_prefix;
if(!has_valid_prefix) return false;

const auto version_digits = basename.substr(python3_prefix.size());

return std::all_of(version_digits.begin(), version_digits.end(),
[](unsigned char c) { return std::isdigit(c); });
}

inline std::string
discover_torch_libpath(const std::string& python_binary, bool verbose = false)
{
if(python_binary.empty()) return {};

Comment thread
mradosav-amd marked this conversation as resolved.
const auto is_safe_executable_path = [](const std::string& path) {
// Allow only a conservative set of characters in the executable path to
// avoid injection when used in a shell command.
for(unsigned char c : path)
{
if(std::isalnum(c) != 0) continue;
switch(c)
{
case '/':
case '.':
case '_':
case '-':
case '+': break;
default: return false;
}
}
return true;
};

if(!is_safe_executable_path(python_binary))
{
ROCPROFSYS_ENVIRON_LOG(
verbose, "Unsafe characters detected in Python interpreter path: %s\n",
python_binary.c_str());
return {};
}

const auto cmd = "\"" + python_binary +
"\" -c \"import torch; print(torch.__path__[0])\" 2>/dev/null";

FILE* pipe = popen(cmd.c_str(), "r");
if(!pipe)
{
ROCPROFSYS_ENVIRON_LOG(verbose, "Failed to execute command: %s\n", cmd.c_str());
return {};
}

char buffer[1024];
std::string result;
while(fgets(buffer, sizeof(buffer), pipe))
{
result.append(buffer);
// stop if we've read the full line (torch path is printed on a single line)
if(!result.empty() && result.back() == '\n') break;
}

int status = pclose(pipe);

if(status != 0 || result.empty())
{
ROCPROFSYS_ENVIRON_LOG(verbose, "torch not found for Python interpreter: %s\n",
python_binary.c_str());
return {};
}

while(!result.empty() &&
(result.back() == '\n' || result.back() == '\r' || result.back() == ' '))
{
result.pop_back();
}

if(result.empty()) return {};

std::string torch_libdir = result + "/lib";

if(!::tim::filepath::direxists(torch_libdir))
{
ROCPROFSYS_ENVIRON_LOG(verbose, "torch lib directory does not exist: %s\n",
torch_libdir.c_str());
return {};
}

ROCPROFSYS_ENVIRON_LOG(verbose, "Discovered torch library path: %s\n",
torch_libdir.c_str());
return torch_libdir;
}

enum class update_mode : uint8_t
{
REPLACE = 0,
Expand Down Expand Up @@ -335,13 +445,153 @@ update_env(std::vector<char*>& _environ, std::string_view _env_var, Tp&& _env_va
}
else
{
free(itr);
std::free(itr);
itr = strdup(join('=', _env_var, _env_val_str).c_str());
}
return;
}
_environ.emplace_back(strdup(join('=', _env_var, _env_val_str).c_str()));
}

template <typename UpdatedEnvsT>
inline void
add_torch_library_path(std::vector<char*>& envp, const std::vector<char*>& argv,
bool verbose, UpdatedEnvsT& updated_envs)
{
if(argv.empty() || argv.front() == nullptr) return;
if(!is_python_interpreter(argv.front())) return;

auto torch_libpath = discover_torch_libpath(argv.front(), verbose);
if(torch_libpath.empty()) return;

std::unordered_set<std::string> seen{ torch_libpath };
std::string result = torch_libpath;

constexpr std::string_view ld_prefix = "LD_LIBRARY_PATH=";

auto is_ld_path = [&](char* entry) {
return entry &&
std::string_view{ entry }.substr(0, ld_prefix.length()) == ld_prefix;
};

for(auto& entry : envp)
{
if(!is_ld_path(entry)) continue;

std::istringstream stream{ std::string{ entry + ld_prefix.length() } };
for(std::string path; std::getline(stream, path, ':');)
{
if(!path.empty() && seen.insert(path).second) result += ":" + path;
}

std::free(entry);
entry = nullptr;
}

envp.erase(std::remove(envp.begin(), envp.end(), nullptr), envp.end());
envp.emplace_back(strdup(join("", ld_prefix, result).c_str()));

updated_envs.emplace(ld_prefix.substr(0, ld_prefix.length() - 1));
}

inline void
consolidate_env_entries(std::vector<char*>& envp)
{
constexpr char delim = ':';

struct key_data
{
std::vector<std::string> parts;
std::unordered_set<std::string> seen;

void add_unique(std::string part)
{
if(!part.empty() && seen.insert(part).second)
parts.emplace_back(std::move(part));
}
};

auto parse_entry = [](std::string_view entry)
-> std::optional<std::pair<std::string_view, std::string_view>> {
auto eq_pos = entry.find('=');
if(eq_pos == std::string_view::npos) return std::nullopt;
return std::make_pair(entry.substr(0, eq_pos), entry.substr(eq_pos + 1));
};

auto join_parts = [delim](std::string_view key,
const std::vector<std::string>& parts) {
std::string result;

const auto total_parts_length = std::accumulate(
Comment thread
mradosav-amd marked this conversation as resolved.
parts.begin(), parts.end(), std::size_t{ 0 },
[](std::size_t acc, const std::string& part) { return acc + part.size(); });

const auto delim_count = parts.size() - 1;
const auto equal_sign_length = 1;

result.reserve(key.size() + equal_sign_length + total_parts_length + delim_count);
result.append(key);
result += '=';

result =
std::accumulate(parts.begin(), parts.end(), std::move(result),
[delim, &parts](std::string acc, const std::string& part) {
if(part != parts.front()) acc += delim;
acc.append(part);
return acc;
});

return result;
};

std::unordered_map<std::string_view, key_data> key_map;
std::vector<std::string_view> key_order;

for(auto* entry : envp)
{
if(!entry)
{
continue;
}

auto parsed = parse_entry(entry);
if(!parsed)
{
continue;
}

auto [key, value] = *parsed;

auto [it, inserted] = key_map.try_emplace(key);
if(inserted)
{
key_order.emplace_back(key);
}

auto& data = it->second;
std::istringstream stream{ std::string{ value } };
for(std::string part; std::getline(stream, part, delim);)
{
data.add_unique(part);
}
}

std::vector<char*> result;
result.reserve(key_order.size());

for(auto key : key_order)
{
result.emplace_back(strdup(join_parts(key, key_map[key].parts).c_str()));
}

for(auto* entry : envp)
{
std::free(entry);
entry = nullptr;
}

envp = std::move(result);
}

} // namespace common
} // namespace rocprofsys
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_library(
lib-common-tests
OBJECT
test_discover_llvm_libdir.cpp
test_environment.cpp
test_path.cpp
test_remove_env.cpp
test_update_env.cpp
Expand Down
Loading