Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions cpp/include/kvikio/remote_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ class S3Endpoint : public RemoteEndpoint {
* `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS
* url scheme is used: "https://<bucket_name>.s3.<region>.amazonaws.com/<object_name>".
*/
static std::string url_from_bucket_and_object(std::string const& bucket_name,
std::string const& object_name,
static std::string url_from_bucket_and_object(std::string bucket_name,
std::string object_name,
std::optional<std::string> aws_region,
std::optional<std::string> aws_endpoint_url);

Expand Down Expand Up @@ -162,8 +162,7 @@ class S3Endpoint : public RemoteEndpoint {
/**
* @brief Create a S3 endpoint from a bucket and object name.
*
* @param bucket_name The name of the S3 bucket.
* @param object_name The name of the S3 object.
* @param bucket_and_object_names The bucket and object names of the S3 bucket.
* @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
* `AWS_DEFAULT_REGION` environment variable is used.
* @param aws_access_key The AWS access key to use. If nullopt, the value of the
Expand All @@ -175,8 +174,7 @@ class S3Endpoint : public RemoteEndpoint {
* `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS
* url scheme is used: "https://<bucket_name>.s3.<region>.amazonaws.com/<object_name>".
*/
S3Endpoint(std::string const& bucket_name,
std::string const& object_name,
S3Endpoint(std::pair<std::string, std::string> bucket_and_object_names,
std::optional<std::string> aws_region = std::nullopt,
std::optional<std::string> aws_access_key = std::nullopt,
std::optional<std::string> aws_secret_access_key = std::nullopt,
Expand Down
19 changes: 10 additions & 9 deletions cpp/src/remote_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ std::string S3Endpoint::unwrap_or_default(std::optional<std::string> aws_arg,
return std::string(env);
}

std::string S3Endpoint::url_from_bucket_and_object(std::string const& bucket_name,
std::string const& object_name,
std::string S3Endpoint::url_from_bucket_and_object(std::string bucket_name,
std::string object_name,
std::optional<std::string> aws_region,
std::optional<std::string> aws_endpoint_url)
{
Expand Down Expand Up @@ -245,17 +245,18 @@ S3Endpoint::S3Endpoint(std::string url,
}
}

S3Endpoint::S3Endpoint(std::string const& bucket_name,
std::string const& object_name,
S3Endpoint::S3Endpoint(std::pair<std::string, std::string> bucket_and_object_names,
std::optional<std::string> aws_region,
std::optional<std::string> aws_access_key,
std::optional<std::string> aws_secret_access_key,
std::optional<std::string> aws_endpoint_url)
: S3Endpoint(
url_from_bucket_and_object(bucket_name, object_name, aws_region, std::move(aws_endpoint_url)),
aws_region,
std::move(aws_access_key),
std::move(aws_secret_access_key))
: S3Endpoint(url_from_bucket_and_object(std::move(bucket_and_object_names.first),
std::move(bucket_and_object_names.second),
aws_region,
std::move(aws_endpoint_url)),
aws_region,
std::move(aws_access_key),
std::move(aws_secret_access_key))
{
KVIKIO_NVTX_FUNC_RANGE();
}
Expand Down
2 changes: 2 additions & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,6 @@ kvikio_add_test(NAME DEFAULTS_TEST SOURCES test_defaults.cpp)

kvikio_add_test(NAME ERROR_TEST SOURCES test_error.cpp)

kvikio_add_test(NAME REMOTE_HANDLE_TEST SOURCES test_remote_handle.cpp utils/env.cpp)

rapids_test_install_relocatable(INSTALL_COMPONENT_SET testing DESTINATION bin/tests/libkvikio)
40 changes: 40 additions & 0 deletions cpp/tests/test_remote_handle.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <kvikio/remote_handle.hpp>

#include <unordered_map>
#include "utils/env.hpp"

TEST(RemoteHandleTest, s3_endpoint_constructor)
{
kvikio::test::EnvVarContext env_var_ctx{{{"AWS_DEFAULT_REGION", "my_aws_default_region"},
{"AWS_ACCESS_KEY_ID", "my_aws_access_key_id"},
{"AWS_SECRET_ACCESS_KEY", "my_aws_secrete_access_key"},
{"AWS_ENDPOINT_URL", "https://my_aws_endpoint_url"}}};
std::string url = "https://my_aws_endpoint_url/bucket_name/object_name";
std::string aws_region = "my_aws_region";
// Use the overload where the full url and the optional aws_region are specified.
kvikio::S3Endpoint s1(url, aws_region);

std::string bucket_name = "bucket_name";
std::string object_name = "object_name";
// Use the other overload where the bucket and object names are specified.
kvikio::S3Endpoint s2(std::make_pair(bucket_name, object_name));

EXPECT_EQ(s1.str(), s2.str());
}
54 changes: 54 additions & 0 deletions cpp/tests/utils/env.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "env.hpp"
#include <cstdlib>
#include <sstream>
#include <utility>

#include <kvikio/error.hpp>

namespace kvikio::test {
EnvVarContext::EnvVarContext(
std::initializer_list<std::pair<std::string, std::string>> env_var_entries)
{
for (auto const& [key, current_value] : env_var_entries) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you know we'd ever need to unset an env var instead of setting to the specified value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of CUDF_BENCHMARK_DROP_CACHE. No matter what this env var's value is, as long as it exists, the page cache is to be dropped (i.e. CUDF_BENCHMARK_DROP_CACHE=<nothing>/true/false/123 all have the same effect). In this case to fully revert to a previous state, we need to unset the env var.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not referring to reverting the state, but the actual test state requiring an env var to be unset.

Copy link
Contributor Author

@kingcrimsontianyu kingcrimsontianyu Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I think a thorough test of S3 endpoint would need this feature. The constructor of S3 endpoint has "optional" parameters, but if they are neither passed as argument or set via the env var, an exception will be thrown. We may want to check if the constructor can throw properly when some arguments are missing in one test, and then check when another combination of arguments are missing in another test. This is where env var context could help.

EnvVarState env_var_state;
if (auto const res = std::getenv(key.c_str()); res != nullptr) {
env_var_state.existed_before = true;
env_var_state.previous_value = res;
}
SYSCALL_CHECK(setenv(key.c_str(), current_value.c_str(), 1 /* allow overwrite */));
if (_env_var_map.find(key) != _env_var_map.end()) {
std::stringstream ss;
ss << "Environment variable " << key << " has already been set in this context.";
KVIKIO_FAIL(ss.str());
}
_env_var_map.insert({key, std::move(env_var_state)});
}
}

EnvVarContext::~EnvVarContext()
{
for (auto const& [key, state] : _env_var_map) {
if (state.existed_before) {
SYSCALL_CHECK(setenv(key.c_str(), state.previous_value.c_str(), 1 /* allow overwrite */));
} else {
SYSCALL_CHECK(unsetenv(key.c_str()));
}
}
}
} // namespace kvikio::test
67 changes: 67 additions & 0 deletions cpp/tests/utils/env.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <initializer_list>
#include <string>
#include <unordered_map>
#include <utility>

namespace kvikio::test {
/**
* @brief RAII class used to temporarily set environment variables to new values upon construction,
* and restore them to previous values upon destruction.
*/
class EnvVarContext {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat! I might borrow this for cudf compression tests, they are getting to a point where more than one env var is being set

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps in a future PR this class could be moved out of the test and become part of KvikIO source. Same for the TempDir class from utils.hpp.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it's fine in tests, it's not a part of the functionality of the kvikIO library

private:
/**
* @brief The state of an environment variable
*/
struct EnvVarState {
// Whether the environment variable existed before entering the context
bool existed_before{};
// The previous value of the environment variable, if existed
std::string previous_value{};
};

public:
/**
* @brief Set the environment variables to new values
*
* @param env_var_entries User-specified environment variables. Each entry includes the variable
* name and value.
*/
EnvVarContext(std::initializer_list<std::pair<std::string, std::string>> env_var_entries);

/**
* @brief Restore the environment variables to previous values
*
* Reset the environment variables to their previous states:
* - If one existed before, restore to its previous value.
* - Otherwise, remove it from the environment.
*/
~EnvVarContext();

EnvVarContext(EnvVarContext const&) = delete;
EnvVarContext(EnvVarContext&&) = delete;
EnvVarContext& operator=(EnvVarContext const&) = delete;
EnvVarContext& operator=(EnvVarContext&&) = delete;

private:
std::unordered_map<std::string, EnvVarState> _env_var_map;
};
} // namespace kvikio::test
15 changes: 9 additions & 6 deletions python/kvikio/kvikio/_lib/remote_handle.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved.
# See file LICENSE for terms.

# distutils: language = c++
Expand All @@ -9,6 +9,7 @@ from typing import Optional
from cython.operator cimport dereference as deref
from libc.stdint cimport uintptr_t
from libcpp.memory cimport make_unique, unique_ptr
from libcpp.pair cimport pair
from libcpp.string cimport string
from libcpp.utility cimport move, pair

Expand All @@ -25,7 +26,7 @@ cdef extern from "<kvikio/remote_handle.hpp>" nogil:

cdef cppclass cpp_S3Endpoint "kvikio::S3Endpoint"(cpp_RemoteEndpoint):
cpp_S3Endpoint(string url) except +
cpp_S3Endpoint(string bucket_name, string object_name) except +
cpp_S3Endpoint(pair[string, string] bucket_and_object_names) except +

pair[string, string] cpp_parse_s3_url \
"kvikio::S3Endpoint::parse_s3_url"(string url) except +
Expand Down Expand Up @@ -56,6 +57,10 @@ cdef string _to_string(str s):
else:
return string()

cdef pair[string, string] _to_string_pair(str s1, str s2):
"""Wrap two Python string objects in a C++ pair"""
return pair[string, string](_to_string(s1), _to_string(s2))

# Helper function to cast an endpoint to its base class `RemoteEndpoint`
cdef extern from *:
"""
Expand Down Expand Up @@ -105,7 +110,7 @@ cdef class RemoteFile:
return RemoteFile._from_endpoint(
cast_to_remote_endpoint(
make_unique[cpp_S3Endpoint](
_to_string(bucket_name), _to_string(object_name)
_to_string_pair(bucket_name, object_name)
)
),
nbytes
Expand All @@ -131,9 +136,7 @@ cdef class RemoteFile:
cdef pair[string, string] bucket_and_object = cpp_parse_s3_url(_to_string(url))
return RemoteFile._from_endpoint(
cast_to_remote_endpoint(
make_unique[cpp_S3Endpoint](
bucket_and_object.first, bucket_and_object.second
)
make_unique[cpp_S3Endpoint](bucket_and_object)
),
nbytes
)
Expand Down
Loading