diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 6204a037b1..091d55bcd1 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -126,8 +126,8 @@ class S3Endpoint : public RemoteEndpoint { * `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS * url scheme is used: "https://.s3..amazonaws.com/". */ - static std::string url_from_bucket_and_object(std::string const& bucket_name, - std::string const& object_name, + static std::string url_from_bucket_and_object(std::string bucket_name, + std::string object_name, std::optional aws_region, std::optional aws_endpoint_url); @@ -162,8 +162,7 @@ class S3Endpoint : public RemoteEndpoint { /** * @brief Create a S3 endpoint from a bucket and object name. * - * @param bucket_name The name of the S3 bucket. - * @param object_name The name of the S3 object. + * @param bucket_and_object_names The bucket and object names of the S3 bucket. * @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the * `AWS_DEFAULT_REGION` environment variable is used. * @param aws_access_key The AWS access key to use. If nullopt, the value of the @@ -175,8 +174,7 @@ class S3Endpoint : public RemoteEndpoint { * `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS * url scheme is used: "https://.s3..amazonaws.com/". */ - S3Endpoint(std::string const& bucket_name, - std::string const& object_name, + S3Endpoint(std::pair bucket_and_object_names, std::optional aws_region = std::nullopt, std::optional aws_access_key = std::nullopt, std::optional aws_secret_access_key = std::nullopt, diff --git a/cpp/src/remote_handle.cpp b/cpp/src/remote_handle.cpp index c25d918853..4df2f48898 100644 --- a/cpp/src/remote_handle.cpp +++ b/cpp/src/remote_handle.cpp @@ -168,8 +168,8 @@ std::string S3Endpoint::unwrap_or_default(std::optional aws_arg, return std::string(env); } -std::string S3Endpoint::url_from_bucket_and_object(std::string const& bucket_name, - std::string const& object_name, +std::string S3Endpoint::url_from_bucket_and_object(std::string bucket_name, + std::string object_name, std::optional aws_region, std::optional aws_endpoint_url) { @@ -245,17 +245,18 @@ S3Endpoint::S3Endpoint(std::string url, } } -S3Endpoint::S3Endpoint(std::string const& bucket_name, - std::string const& object_name, +S3Endpoint::S3Endpoint(std::pair bucket_and_object_names, std::optional aws_region, std::optional aws_access_key, std::optional aws_secret_access_key, std::optional aws_endpoint_url) - : S3Endpoint( - url_from_bucket_and_object(bucket_name, object_name, aws_region, std::move(aws_endpoint_url)), - aws_region, - std::move(aws_access_key), - std::move(aws_secret_access_key)) + : S3Endpoint(url_from_bucket_and_object(std::move(bucket_and_object_names.first), + std::move(bucket_and_object_names.second), + aws_region, + std::move(aws_endpoint_url)), + aws_region, + std::move(aws_access_key), + std::move(aws_secret_access_key)) { KVIKIO_NVTX_FUNC_RANGE(); } diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index ae1febebb9..4b4ad1049a 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -76,4 +76,6 @@ kvikio_add_test(NAME DEFAULTS_TEST SOURCES test_defaults.cpp) kvikio_add_test(NAME ERROR_TEST SOURCES test_error.cpp) +kvikio_add_test(NAME REMOTE_HANDLE_TEST SOURCES test_remote_handle.cpp utils/env.cpp) + rapids_test_install_relocatable(INSTALL_COMPONENT_SET testing DESTINATION bin/tests/libkvikio) diff --git a/cpp/tests/test_remote_handle.cpp b/cpp/tests/test_remote_handle.cpp new file mode 100644 index 0000000000..650f1500f3 --- /dev/null +++ b/cpp/tests/test_remote_handle.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include "utils/env.hpp" + +TEST(RemoteHandleTest, s3_endpoint_constructor) +{ + kvikio::test::EnvVarContext env_var_ctx{{{"AWS_DEFAULT_REGION", "my_aws_default_region"}, + {"AWS_ACCESS_KEY_ID", "my_aws_access_key_id"}, + {"AWS_SECRET_ACCESS_KEY", "my_aws_secrete_access_key"}, + {"AWS_ENDPOINT_URL", "https://my_aws_endpoint_url"}}}; + std::string url = "https://my_aws_endpoint_url/bucket_name/object_name"; + std::string aws_region = "my_aws_region"; + // Use the overload where the full url and the optional aws_region are specified. + kvikio::S3Endpoint s1(url, aws_region); + + std::string bucket_name = "bucket_name"; + std::string object_name = "object_name"; + // Use the other overload where the bucket and object names are specified. + kvikio::S3Endpoint s2(std::make_pair(bucket_name, object_name)); + + EXPECT_EQ(s1.str(), s2.str()); +} diff --git a/cpp/tests/utils/env.cpp b/cpp/tests/utils/env.cpp new file mode 100644 index 0000000000..5e713dca04 --- /dev/null +++ b/cpp/tests/utils/env.cpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "env.hpp" +#include +#include +#include + +#include + +namespace kvikio::test { +EnvVarContext::EnvVarContext( + std::initializer_list> env_var_entries) +{ + for (auto const& [key, current_value] : env_var_entries) { + EnvVarState env_var_state; + if (auto const res = std::getenv(key.c_str()); res != nullptr) { + env_var_state.existed_before = true; + env_var_state.previous_value = res; + } + SYSCALL_CHECK(setenv(key.c_str(), current_value.c_str(), 1 /* allow overwrite */)); + if (_env_var_map.find(key) != _env_var_map.end()) { + std::stringstream ss; + ss << "Environment variable " << key << " has already been set in this context."; + KVIKIO_FAIL(ss.str()); + } + _env_var_map.insert({key, std::move(env_var_state)}); + } +} + +EnvVarContext::~EnvVarContext() +{ + for (auto const& [key, state] : _env_var_map) { + if (state.existed_before) { + SYSCALL_CHECK(setenv(key.c_str(), state.previous_value.c_str(), 1 /* allow overwrite */)); + } else { + SYSCALL_CHECK(unsetenv(key.c_str())); + } + } +} +} // namespace kvikio::test diff --git a/cpp/tests/utils/env.hpp b/cpp/tests/utils/env.hpp new file mode 100644 index 0000000000..8ec51ad27b --- /dev/null +++ b/cpp/tests/utils/env.hpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace kvikio::test { +/** + * @brief RAII class used to temporarily set environment variables to new values upon construction, + * and restore them to previous values upon destruction. + */ +class EnvVarContext { + private: + /** + * @brief The state of an environment variable + */ + struct EnvVarState { + // Whether the environment variable existed before entering the context + bool existed_before{}; + // The previous value of the environment variable, if existed + std::string previous_value{}; + }; + + public: + /** + * @brief Set the environment variables to new values + * + * @param env_var_entries User-specified environment variables. Each entry includes the variable + * name and value. + */ + EnvVarContext(std::initializer_list> env_var_entries); + + /** + * @brief Restore the environment variables to previous values + * + * Reset the environment variables to their previous states: + * - If one existed before, restore to its previous value. + * - Otherwise, remove it from the environment. + */ + ~EnvVarContext(); + + EnvVarContext(EnvVarContext const&) = delete; + EnvVarContext(EnvVarContext&&) = delete; + EnvVarContext& operator=(EnvVarContext const&) = delete; + EnvVarContext& operator=(EnvVarContext&&) = delete; + + private: + std::unordered_map _env_var_map; +}; +} // namespace kvikio::test diff --git a/python/kvikio/kvikio/_lib/remote_handle.pyx b/python/kvikio/kvikio/_lib/remote_handle.pyx index 1e0b14acb9..5a7ba2c846 100644 --- a/python/kvikio/kvikio/_lib/remote_handle.pyx +++ b/python/kvikio/kvikio/_lib/remote_handle.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. # distutils: language = c++ @@ -9,6 +9,7 @@ from typing import Optional from cython.operator cimport dereference as deref from libc.stdint cimport uintptr_t from libcpp.memory cimport make_unique, unique_ptr +from libcpp.pair cimport pair from libcpp.string cimport string from libcpp.utility cimport move, pair @@ -25,7 +26,7 @@ cdef extern from "" nogil: cdef cppclass cpp_S3Endpoint "kvikio::S3Endpoint"(cpp_RemoteEndpoint): cpp_S3Endpoint(string url) except + - cpp_S3Endpoint(string bucket_name, string object_name) except + + cpp_S3Endpoint(pair[string, string] bucket_and_object_names) except + pair[string, string] cpp_parse_s3_url \ "kvikio::S3Endpoint::parse_s3_url"(string url) except + @@ -56,6 +57,10 @@ cdef string _to_string(str s): else: return string() +cdef pair[string, string] _to_string_pair(str s1, str s2): + """Wrap two Python string objects in a C++ pair""" + return pair[string, string](_to_string(s1), _to_string(s2)) + # Helper function to cast an endpoint to its base class `RemoteEndpoint` cdef extern from *: """ @@ -105,7 +110,7 @@ cdef class RemoteFile: return RemoteFile._from_endpoint( cast_to_remote_endpoint( make_unique[cpp_S3Endpoint]( - _to_string(bucket_name), _to_string(object_name) + _to_string_pair(bucket_name, object_name) ) ), nbytes @@ -131,9 +136,7 @@ cdef class RemoteFile: cdef pair[string, string] bucket_and_object = cpp_parse_s3_url(_to_string(url)) return RemoteFile._from_endpoint( cast_to_remote_endpoint( - make_unique[cpp_S3Endpoint]( - bucket_and_object.first, bucket_and_object.second - ) + make_unique[cpp_S3Endpoint](bucket_and_object) ), nbytes )