diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 091d55bcd1..b67ae5560e 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -32,6 +32,8 @@ #include #include +struct curl_slist; + namespace kvikio { class CurlHandle; // Prototype @@ -92,6 +94,7 @@ class S3Endpoint : public RemoteEndpoint { std::string _url; std::string _aws_sigv4; std::string _aws_userpwd; + curl_slist* _curl_header_list{}; /** * @brief Unwrap an optional parameter, obtaining a default from the environment. @@ -153,11 +156,14 @@ class S3Endpoint : public RemoteEndpoint { * `AWS_ACCESS_KEY_ID` environment variable is used. * @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the * `AWS_SECRET_ACCESS_KEY` environment variable is used. + * @param aws_session_token The AWS session token to use. If nullopt, the value of the + * `AWS_SESSION_TOKEN` environment variable is used. */ S3Endpoint(std::string url, std::optional aws_region = std::nullopt, std::optional aws_access_key = std::nullopt, - std::optional aws_secret_access_key = std::nullopt); + std::optional aws_secret_access_key = std::nullopt, + std::optional aws_session_token = std::nullopt); /** * @brief Create a S3 endpoint from a bucket and object name. @@ -173,16 +179,19 @@ class S3Endpoint : public RemoteEndpoint { * the scheme: "//". If nullopt, the value of the * `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS * url scheme is used: "https://.s3..amazonaws.com/". + * @param aws_session_token The AWS session token to use. If nullopt, the value of the + * `AWS_SESSION_TOKEN` environment variable is used. */ S3Endpoint(std::pair bucket_and_object_names, std::optional aws_region = std::nullopt, std::optional aws_access_key = std::nullopt, std::optional aws_secret_access_key = std::nullopt, - std::optional aws_endpoint_url = std::nullopt); + std::optional aws_endpoint_url = std::nullopt, + std::optional aws_session_token = std::nullopt); void setopt(CurlHandle& curl) override; std::string str() const override; - ~S3Endpoint() override = default; + ~S3Endpoint() override; }; /** diff --git a/cpp/src/remote_handle.cpp b/cpp/src/remote_handle.cpp index 4df2f48898..485e0739ac 100644 --- a/cpp/src/remote_handle.cpp +++ b/cpp/src/remote_handle.cpp @@ -151,6 +151,7 @@ void S3Endpoint::setopt(CurlHandle& curl) curl.setopt(CURLOPT_URL, _url.c_str()); curl.setopt(CURLOPT_AWS_SIGV4, _aws_sigv4.c_str()); curl.setopt(CURLOPT_USERPWD, _aws_userpwd.c_str()); + if (_curl_header_list) { curl.setopt(CURLOPT_HTTPHEADER, _curl_header_list); } } std::string S3Endpoint::unwrap_or_default(std::optional aws_arg, @@ -203,7 +204,8 @@ std::pair S3Endpoint::parse_s3_url(std::string const& S3Endpoint::S3Endpoint(std::string url, std::optional aws_region, std::optional aws_access_key, - std::optional aws_secret_access_key) + std::optional aws_secret_access_key, + std::optional aws_session_token) : _url{std::move(url)} { KVIKIO_NVTX_FUNC_RANGE(); @@ -243,24 +245,45 @@ S3Endpoint::S3Endpoint(std::string url, ss << access_key << ":" << secret_access_key; _aws_userpwd = ss.str(); } + // Access key IDs beginning with ASIA are temporary credentials that are created using AWS STS + // operations. They need a session token to work. + if (access_key.compare(0, 4, std::string("ASIA")) == 0) { + // Create a Custom Curl header for the session token. + // The _curl_header_list created by curl_slist_append must be manually freed + // (see https://curl.se/libcurl/c/CURLOPT_HTTPHEADER.html) + auto session_token = + unwrap_or_default(std::move(aws_session_token), + "AWS_SESSION_TOKEN", + "When using temporary credentials, AWS_SESSION_TOKEN must be set."); + std::stringstream ss; + ss << "x-amz-security-token: " << session_token; + _curl_header_list = curl_slist_append(NULL, ss.str().c_str()); + KVIKIO_EXPECT(_curl_header_list != nullptr, + "Failed to create curl header for AWS token", + std::runtime_error); + } } S3Endpoint::S3Endpoint(std::pair bucket_and_object_names, std::optional aws_region, std::optional aws_access_key, std::optional aws_secret_access_key, - std::optional aws_endpoint_url) + std::optional aws_endpoint_url, + std::optional aws_session_token) : S3Endpoint(url_from_bucket_and_object(std::move(bucket_and_object_names.first), std::move(bucket_and_object_names.second), aws_region, std::move(aws_endpoint_url)), aws_region, std::move(aws_access_key), - std::move(aws_secret_access_key)) + std::move(aws_secret_access_key), + std::move(aws_session_token)) { KVIKIO_NVTX_FUNC_RANGE(); } +S3Endpoint::~S3Endpoint() { curl_slist_free_all(_curl_header_list); } + std::string S3Endpoint::str() const { return _url; } RemoteHandle::RemoteHandle(std::unique_ptr endpoint, std::size_t nbytes) diff --git a/python/kvikio/kvikio/remote_file.py b/python/kvikio/kvikio/remote_file.py index f10f4b49f9..55cce53115 100644 --- a/python/kvikio/kvikio/remote_file.py +++ b/python/kvikio/kvikio/remote_file.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. All rights reserved. # See file LICENSE for terms. from __future__ import annotations @@ -81,6 +81,7 @@ def open_s3( - `AWS_DEFAULT_REGION` - `AWS_ACCESS_KEY_ID` - `AWS_SECRET_ACCESS_KEY` + - `AWS_SESSION_TOKEN` (when using temporary credentials) Additionally, to overwrite the AWS endpoint, set `AWS_ENDPOINT_URL`. See @@ -115,6 +116,7 @@ def open_s3_url( - `AWS_DEFAULT_REGION` - `AWS_ACCESS_KEY_ID` - `AWS_SECRET_ACCESS_KEY` + - `AWS_SESSION_TOKEN` (when using temporary credentials) Additionally, if `url` is a S3 url, it is possible to overwrite the AWS endpoint by setting `AWS_ENDPOINT_URL`.