Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions cpp/include/kvikio/remote_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#include <kvikio/posix_io.hpp>
#include <kvikio/utils.hpp>

struct curl_slist;

namespace kvikio {

class CurlHandle; // Prototype
Expand Down Expand Up @@ -92,6 +94,7 @@ class S3Endpoint : public RemoteEndpoint {
std::string _url;
std::string _aws_sigv4;
std::string _aws_userpwd;
curl_slist* _curl_header_list{};

/**
* @brief Unwrap an optional parameter, obtaining a default from the environment.
Expand Down Expand Up @@ -153,11 +156,14 @@ class S3Endpoint : public RemoteEndpoint {
* `AWS_ACCESS_KEY_ID` environment variable is used.
* @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the
* `AWS_SECRET_ACCESS_KEY` environment variable is used.
* @param aws_session_token The AWS session token to use. If nullopt, the value of the
* `AWS_SESSION_TOKEN` environment variable is used.
*/
S3Endpoint(std::string url,
std::optional<std::string> aws_region = std::nullopt,
std::optional<std::string> aws_access_key = std::nullopt,
std::optional<std::string> aws_secret_access_key = std::nullopt);
std::optional<std::string> aws_secret_access_key = std::nullopt,
std::optional<std::string> aws_session_token = std::nullopt);

/**
* @brief Create a S3 endpoint from a bucket and object name.
Expand All @@ -173,16 +179,19 @@ class S3Endpoint : public RemoteEndpoint {
* the scheme: "<aws_endpoint_url>/<bucket_name>/<object_name>". If nullopt, the value of the
* `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS
* url scheme is used: "https://<bucket_name>.s3.<region>.amazonaws.com/<object_name>".
* @param aws_session_token The AWS session token to use. If nullopt, the value of the
* `AWS_SESSION_TOKEN` environment variable is used.
*/
S3Endpoint(std::pair<std::string, std::string> bucket_and_object_names,
std::optional<std::string> aws_region = std::nullopt,
std::optional<std::string> aws_access_key = std::nullopt,
std::optional<std::string> aws_secret_access_key = std::nullopt,
std::optional<std::string> aws_endpoint_url = std::nullopt);
std::optional<std::string> aws_endpoint_url = std::nullopt,
std::optional<std::string> aws_session_token = std::nullopt);

void setopt(CurlHandle& curl) override;
std::string str() const override;
~S3Endpoint() override = default;
~S3Endpoint() override;
};

/**
Expand Down
29 changes: 26 additions & 3 deletions cpp/src/remote_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ void S3Endpoint::setopt(CurlHandle& curl)
curl.setopt(CURLOPT_URL, _url.c_str());
curl.setopt(CURLOPT_AWS_SIGV4, _aws_sigv4.c_str());
curl.setopt(CURLOPT_USERPWD, _aws_userpwd.c_str());
if (_curl_header_list) { curl.setopt(CURLOPT_HTTPHEADER, _curl_header_list); }
}

std::string S3Endpoint::unwrap_or_default(std::optional<std::string> aws_arg,
Expand Down Expand Up @@ -203,7 +204,8 @@ std::pair<std::string, std::string> S3Endpoint::parse_s3_url(std::string const&
S3Endpoint::S3Endpoint(std::string url,
std::optional<std::string> aws_region,
std::optional<std::string> aws_access_key,
std::optional<std::string> aws_secret_access_key)
std::optional<std::string> aws_secret_access_key,
std::optional<std::string> aws_session_token)
: _url{std::move(url)}
{
KVIKIO_NVTX_FUNC_RANGE();
Expand Down Expand Up @@ -243,24 +245,45 @@ S3Endpoint::S3Endpoint(std::string url,
ss << access_key << ":" << secret_access_key;
_aws_userpwd = ss.str();
}
// Access key IDs beginning with ASIA are temporary credentials that are created using AWS STS
// operations. They need a session token to work.
if (access_key.compare(0, 4, std::string("ASIA")) == 0) {
// Create a Custom Curl header for the session token.
// The _curl_header_list created by curl_slist_append must be manually freed
// (see https://curl.se/libcurl/c/CURLOPT_HTTPHEADER.html)
auto session_token =
unwrap_or_default(std::move(aws_session_token),
"AWS_SESSION_TOKEN",
"When using temporary credentials, AWS_SESSION_TOKEN must be set.");
std::stringstream ss;
ss << "x-amz-security-token: " << session_token;
_curl_header_list = curl_slist_append(NULL, ss.str().c_str());
KVIKIO_EXPECT(_curl_header_list != nullptr,
"Failed to create curl header for AWS token",
std::runtime_error);
}
}

S3Endpoint::S3Endpoint(std::pair<std::string, std::string> bucket_and_object_names,
std::optional<std::string> aws_region,
std::optional<std::string> aws_access_key,
std::optional<std::string> aws_secret_access_key,
std::optional<std::string> aws_endpoint_url)
std::optional<std::string> aws_endpoint_url,
std::optional<std::string> aws_session_token)
: S3Endpoint(url_from_bucket_and_object(std::move(bucket_and_object_names.first),
std::move(bucket_and_object_names.second),
aws_region,
std::move(aws_endpoint_url)),
aws_region,
std::move(aws_access_key),
std::move(aws_secret_access_key))
std::move(aws_secret_access_key),
std::move(aws_session_token))
{
KVIKIO_NVTX_FUNC_RANGE();
}

S3Endpoint::~S3Endpoint() { curl_slist_free_all(_curl_header_list); }

std::string S3Endpoint::str() const { return _url; }

RemoteHandle::RemoteHandle(std::unique_ptr<RemoteEndpoint> endpoint, std::size_t nbytes)
Expand Down
4 changes: 3 additions & 1 deletion python/kvikio/kvikio/remote_file.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved.
# See file LICENSE for terms.

from __future__ import annotations
Expand Down Expand Up @@ -81,6 +81,7 @@ def open_s3(
- `AWS_DEFAULT_REGION`
- `AWS_ACCESS_KEY_ID`
- `AWS_SECRET_ACCESS_KEY`
- `AWS_SESSION_TOKEN` (when using temporary credentials)

Additionally, to overwrite the AWS endpoint, set `AWS_ENDPOINT_URL`.
See <https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-envvars.html>
Expand Down Expand Up @@ -115,6 +116,7 @@ def open_s3_url(
- `AWS_DEFAULT_REGION`
- `AWS_ACCESS_KEY_ID`
- `AWS_SECRET_ACCESS_KEY`
- `AWS_SESSION_TOKEN` (when using temporary credentials)

Additionally, if `url` is a S3 url, it is possible to overwrite the AWS endpoint
by setting `AWS_ENDPOINT_URL`.
Expand Down