Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
21313ef
api for pkce(Proof Key for Code Exchange)
zhaohuabing Jan 2, 2025
7dbbde5
implementation
zhaohuabing Jan 2, 2025
c008e80
update tests
zhaohuabing Jan 9, 2025
aac2652
minor wording
zhaohuabing Jan 9, 2025
0dd6e3a
add change log
zhaohuabing Jan 9, 2025
60e5bab
fix format
zhaohuabing Jan 10, 2025
6df85ea
fix test
zhaohuabing Jan 10, 2025
62f5537
remove enable_pkce
zhaohuabing Jan 11, 2025
73860ef
Merge branch 'main' into api-ouath-pkce
zhaohuabing Jan 15, 2025
3cf2b43
encrypt the code verifier
zhaohuabing Jan 16, 2025
009cbee
Merge remote-tracking branch 'origin/main' into api-ouath-pkce
zhaohuabing Jan 17, 2025
bfdd457
fix integration test
zhaohuabing Jan 17, 2025
06ee3f1
address comment
zhaohuabing Jan 21, 2025
e45cf89
minor change
zhaohuabing Jan 21, 2025
d0402b5
Merge remote-tracking branch 'origin/main' into api-ouath-pkce
zhaohuabing Jan 22, 2025
2eb97c5
fix test
zhaohuabing Jan 22, 2025
fa02bb3
fix test
zhaohuabing Jan 22, 2025
db8f193
Merge remote-tracking branch 'origin/main' into api-ouath-pkce
zhaohuabing Mar 7, 2025
daa9271
change log
zhaohuabing Mar 7, 2025
765b479
change log
zhaohuabing Mar 7, 2025
248af96
add test
zhaohuabing Mar 10, 2025
31da834
fix format
zhaohuabing Mar 10, 2025
67e1237
fix test
zhaohuabing Mar 10, 2025
47e9be9
fix test
zhaohuabing Mar 10, 2025
dae8f2b
add more tests
zhaohuabing Mar 10, 2025
efa59da
minor change
zhaohuabing Mar 11, 2025
7eeffc1
fix verify
zhaohuabing Mar 11, 2025
3e447b4
update patch
zhaohuabing Mar 11, 2025
2d4baf3
call EVP_CIPHER_CTX_free(ctx) when decryption fails
zhaohuabing Mar 11, 2025
ac3b502
Merge branch 'main' into api-ouath-pkce
zhaohuabing Mar 19, 2025
05a54cd
Merge remote-tracking branch 'origin/main' into api-ouath-pkce
zhaohuabing Mar 25, 2025
69211f5
Revert "fix verify"
zhaohuabing Mar 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions api/envoy/extensions/filters/http/oauth2/v3/oauth.proto
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ message CookieConfig {
SameSite same_site = 1 [(validate.rules).enum = {defined_only: true}];
}

// [#next-free-field: 7]
// [#next-free-field: 8]
message CookieConfigs {
// Configuration for the bearer token cookie.
CookieConfig bearer_token_cookie_config = 1;
Expand All @@ -58,11 +58,14 @@ message CookieConfigs {

// Configuration for the OAuth nonce cookie.
CookieConfig oauth_nonce_cookie_config = 6;

// Configuration for the code verifier cookie.
CookieConfig code_verifier_cookie_config = 7;
}

// [#next-free-field: 6]
message OAuth2Credentials {
// [#next-free-field: 7]
// [#next-free-field: 8]
message CookieNames {
// Cookie name to hold OAuth bearer token value. When the authentication server validates the
// client and returns an authorization token back to the OAuth filter, no matter what format
Expand Down Expand Up @@ -91,6 +94,10 @@ message OAuth2Credentials {
// Cookie name to hold the nonce value. Defaults to ``OauthNonce``.
string oauth_nonce = 6
[(validate.rules).string = {well_known_regex: HTTP_HEADER_NAME ignore_empty: true}];

// Cookie name to hold the PKCE code verifier. Defaults to ``OauthCodeVerifier``.
string code_verifier = 7
[(validate.rules).string = {well_known_regex: HTTP_HEADER_NAME ignore_empty: true}];
}

// The client_id to be used in the authorize calls. This value will be URL encoded when sent to the OAuth server.
Expand Down
3 changes: 3 additions & 0 deletions changelogs/current.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ minor_behavior_changes:
change: |
The formatter ``%CEL%`` and ``%METADATA%`` will be treated as built-in formatters and could be used directly in the
substitution format string if the related extensions are linked.
- area: oauth2
change: |
Introduced PKCE(Proof Key for Code Exchange) support for OAuth2 authorization code flow.

bug_fixes:
# *Changes expected to improve the state of the world and are unlikely to have negative effects*
Expand Down
217 changes: 202 additions & 15 deletions source/extensions/filters/http/oauth2/filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "absl/strings/str_split.h"
#include "jwt_verify_lib/jwt.h"
#include "jwt_verify_lib/status.h"
#include "openssl/rand.h"

using namespace std::chrono_literals;

Expand All @@ -50,6 +51,8 @@ constexpr absl::string_view queryParamsError = "error";
constexpr absl::string_view queryParamsCode = "code";
constexpr absl::string_view queryParamsState = "state";
constexpr absl::string_view queryParamsRedirectUri = "redirect_uri";
constexpr absl::string_view queryParamsCodeChallenge = "code_challenge";
constexpr absl::string_view queryParamsCodeChallengeMethod = "code_challenge_method";

constexpr absl::string_view stateParamsUrl = "url";
constexpr absl::string_view stateParamsCsrfToken = "csrf_token";
Expand Down Expand Up @@ -227,6 +230,30 @@ bool validateCsrfTokenHmac(const std::string& hmac_secret, const std::string& cs
return generateHmacBase64(hmac_secret_vec, token) == hmac;
}

// Generates a PKCE code verifier with 32 octets of randomness.
// This follows recommendations in RFC 7636:
// https://datatracker.ietf.org/doc/html/rfc7636#section-7.1
std::string generateCodeVerifier(Random::RandomGenerator& random) {
MemBlockBuilder<uint64_t> mem_block(4);
// create 4 random uint64_t values to fill the buffer because RFC 7636 recommends 32 octets of
// randomness.
for (size_t i = 0; i < 4; i++) {
mem_block.appendOne(random.random());
}

std::unique_ptr<uint64_t[]> data = mem_block.release();
return Base64Url::encode(reinterpret_cast<char*>(data.get()), 4 * sizeof(uint64_t));
}

// Generates a PKCE code challenge from a code verifier.
std::string generateCodeChallenge(const std::string& code_verifier) {
auto& crypto_util = Envoy::Common::Crypto::UtilitySingleton::get();
std::vector<uint8_t> sha256_digest =
crypto_util.getSha256Digest(Buffer::OwnedImpl(code_verifier));
std::string sha256_string(sha256_digest.begin(), sha256_digest.end());
return Base64Url::encode(sha256_string.data(), sha256_string.size());
}

/**
* Encodes the state parameter for the OAuth2 flow.
* The state parameter is a base64Url encoded JSON object containing the original request URL and a
Expand All @@ -241,6 +268,129 @@ std::string encodeState(const std::string& original_request_url, const std::stri
return Base64Url::encode(json.data(), json.size());
}

/**
* Encrypt a plaintext string using AES-256-CBC.
*/
std::string encrypt(const std::string& plaintext, const std::string& secret,
Copy link
Copy Markdown
Member Author

@zhaohuabing zhaohuabing Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The encrpt and decrpt method can be also applied to the ID and Access tokens to provide an addtional layer of protection. We can address this in a follow-up PR.

Copy link
Copy Markdown
Contributor

@arminabf arminabf Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there already any active development efforts related to this?

Random::RandomGenerator& random) {
// Generate the key from the secret using SHA-256
std::vector<unsigned char> key(SHA256_DIGEST_LENGTH); // AES-256 requires 256-bit (32 bytes) key
SHA256(reinterpret_cast<const unsigned char*>(secret.c_str()), secret.size(), key.data());

// Generate a random IV
MemBlockBuilder<uint64_t> mem_block(4);
// create 2 random uint64_t values to fill the buffer because AES-256-CBC requires 16 bytes IV
for (size_t i = 0; i < 2; i++) {
mem_block.appendOne(random.random());
}

std::unique_ptr<uint64_t[]> data = mem_block.release();
const unsigned char* raw_data = reinterpret_cast<const unsigned char*>(data.get());

// AES uses 16-byte IV
std::vector<unsigned char> iv(16);
iv.assign(raw_data, raw_data + 16);

EVP_CIPHER_CTX* ctx = EVP_CIPHER_CTX_new();
RELEASE_ASSERT(ctx, "Failed to create context");

std::vector<unsigned char> ciphertext(plaintext.size() + EVP_MAX_BLOCK_LENGTH);
int len = 0, ciphertext_len = 0;

// Initialize encryption operation
int result = EVP_EncryptInit_ex(ctx, EVP_aes_256_cbc(), nullptr, key.data(), iv.data());
RELEASE_ASSERT(result == 1, "Encryption initialization failed");

// Encrypt the plaintext
result = EVP_EncryptUpdate(ctx, ciphertext.data(), &len,
reinterpret_cast<const unsigned char*>(plaintext.c_str()),
plaintext.size());
RELEASE_ASSERT(result == 1, "Encryption update failed");

ciphertext_len += len;

// Finalize encryption
result = EVP_EncryptFinal_ex(ctx, ciphertext.data() + len, &len);
RELEASE_ASSERT(result == 1, "Encryption finalization failed");

ciphertext_len += len;

EVP_CIPHER_CTX_free(ctx);

// AES uses 16-byte IV
ciphertext.resize(ciphertext_len);

// Prepend the IV to the ciphertext
std::vector<unsigned char> combined(iv.size() + ciphertext.size());
std::copy(iv.begin(), iv.end(), combined.begin());
std::copy(ciphertext.begin(), ciphertext.end(), combined.begin() + iv.size());

// Base64Url encode the IV + ciphertext
return Base64Url::encode(reinterpret_cast<const char*>(combined.data()), combined.size());
}

struct DecryptResult {
std::string plaintext;
absl::optional<std::string> error;
};

/**
* Decrypt an AES-256-CBC encrypted string.
*/
DecryptResult decrypt(const std::string& encrypted, const std::string& secret) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpenSSL document says that the high level API EVP_XXX should be used instead of the low level API AES_XXX.

xref: https://wiki.openssl.org/index.php/OpenSSL_3.0#Low_Level_APIs

// Decode the Base64Url-encoded input
std::string decoded = Base64Url::decode(encrypted);
std::vector<unsigned char> combined(decoded.begin(), decoded.end());

if (combined.size() <= 16) {
return {"", "Invalid encrypted data"};
}

// Extract the IV (first 16 bytes)
std::vector<unsigned char> iv(combined.begin(), combined.begin() + 16);

// Extract the ciphertext (remaining bytes)
std::vector<unsigned char> ciphertext(combined.begin() + 16, combined.end());

// Generate the key from the secret using SHA-256
std::vector<unsigned char> key(SHA256_DIGEST_LENGTH);
SHA256(reinterpret_cast<const unsigned char*>(secret.c_str()), secret.size(), key.data());

EVP_CIPHER_CTX* ctx = EVP_CIPHER_CTX_new();
RELEASE_ASSERT(ctx, "Failed to create context");

std::vector<unsigned char> plaintext(ciphertext.size() + EVP_MAX_BLOCK_LENGTH);
int len = 0, plaintext_len = 0;

// Initialize decryption operation
if (EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), nullptr, key.data(), iv.data()) != 1) {
EVP_CIPHER_CTX_free(ctx);
return {"", "failed to initialize decryption"};
}

// Decrypt the ciphertext
if (EVP_DecryptUpdate(ctx, plaintext.data(), &len, ciphertext.data(), ciphertext.size()) != 1) {
EVP_CIPHER_CTX_free(ctx);
return {"", "failed to decrypt data"};
}
plaintext_len += len;

// Finalize decryption
if (EVP_DecryptFinal_ex(ctx, plaintext.data() + len, &len) != 1) {
EVP_CIPHER_CTX_free(ctx);
return {"", "failed to finalize decryption"};
}

plaintext_len += len;

EVP_CIPHER_CTX_free(ctx);

// Resize to actual plaintext length
plaintext.resize(plaintext_len);

return {std::string(plaintext.begin(), plaintext.end()), std::nullopt};
}

} // namespace

FilterConfig::FilterConfig(
Expand Down Expand Up @@ -518,8 +668,26 @@ Http::FilterHeadersStatus OAuth2Filter::decodeHeaders(Http::RequestHeaderMap& he
Formatter::FormatterImpl::create(config_->redirectUri()), Formatter::FormatterPtr);
const auto redirect_uri =
formatter->formatWithContext({&headers}, decoder_callbacks_->streamInfo());

std::string encrypted_code_verifier =
Http::Utility::parseCookieValue(headers, config_->cookieNames().code_verifier_);
if (encrypted_code_verifier.empty()) {
ENVOY_LOG(error, "code verifier cookie is missing in the request");
sendUnauthorizedResponse();
return Http::FilterHeadersStatus::StopIteration;
}

DecryptResult decrypt_result = decrypt(encrypted_code_verifier, config_->hmacSecret());
if (decrypt_result.error.has_value()) {
ENVOY_LOG(error, "decryption failed: {}", decrypt_result.error.value());
sendUnauthorizedResponse();
return Http::FilterHeadersStatus::StopIteration;
}

std::string code_verifier = decrypt_result.plaintext;

oauth_client_->asyncGetAccessToken(auth_code_, config_->clientId(), config_->clientSecret(),
redirect_uri, config_->authType());
redirect_uri, code_verifier, config_->authType());

// pause while we await the next step from the OAuth server
return Http::FilterHeadersStatus::StopAllIterationAndBuffer;
Expand Down Expand Up @@ -580,22 +748,13 @@ void OAuth2Filter::redirectToOAuthServer(Http::RequestHeaderMap& headers) {
// The CSRF token cookie contains the CSRF token that is used to prevent CSRF attacks for the
// OAuth flow. It was named "oauth_nonce" because the CSRF token contains a generated nonce.
// "oauth_csrf_token" would be a more accurate name for the cookie.
std::string csrf_token;
bool csrf_token_cookie_exists = false;
const auto csrf_token_cookie =
Http::Utility::parseCookies(headers, [this](absl::string_view key) {
return key == config_->cookieNames().oauth_nonce_;
});
if (csrf_token_cookie.find(config_->cookieNames().oauth_nonce_) != csrf_token_cookie.end()) {
csrf_token = csrf_token_cookie.at(config_->cookieNames().oauth_nonce_);
csrf_token_cookie_exists = true;
} else {
// Generate a CSRF token to prevent CSRF attacks.
csrf_token = generateCsrfToken(config_->hmacSecret(), random_);
}

std::string csrf_token =
Http::Utility::parseCookieValue(headers, config_->cookieNames().oauth_nonce_);
bool csrf_token_cookie_exists = !csrf_token.empty();
// Set the CSRF token cookie if it does not exist.
if (!csrf_token_cookie_exists) {
// Generate a CSRF token to prevent CSRF attacks.
csrf_token = generateCsrfToken(config_->hmacSecret(), random_);
// Expire the CSRF token cookie in 10 minutes.
// This should be enough time for the user to complete the OAuth flow.
std::string csrf_expires = std::to_string(10 * 60);
Expand Down Expand Up @@ -629,6 +788,30 @@ void OAuth2Filter::redirectToOAuthServer(Http::RequestHeaderMap& headers) {
const std::string escaped_redirect_uri = Http::Utility::PercentEncoding::urlEncode(redirect_uri);
query_params.overwrite(queryParamsRedirectUri, escaped_redirect_uri);

// Generate a PKCE code verifier and challenge for the OAuth flow.
const std::string code_verifier = generateCodeVerifier(random_);
// Encrypt the code verifier, using HMAC secret as the symmetric key.
const std::string encrypted_code_verifier =
encrypt(code_verifier, config_->hmacSecret(), random_);

// Expire the code verifier cookie in 10 minutes.
// This should be enough time for the user to complete the OAuth flow.
std::string expire_in = std::to_string(10 * 60);
std::string same_site = getSameSiteString(config_->codeVerifierCookieSettings().same_site_);
std::string cookie_tail_http_only =
fmt::format(CookieTailHttpOnlyFormatString, expire_in, same_site);
if (!config_->cookieDomain().empty()) {
cookie_tail_http_only = absl::StrCat(
fmt::format(CookieDomainFormatString, config_->cookieDomain()), cookie_tail_http_only);
}
response_headers->addReferenceKey(Http::Headers::get().SetCookie,
absl::StrCat(config_->cookieNames().code_verifier_, "=",
encrypted_code_verifier, cookie_tail_http_only));

const std::string code_challenge = generateCodeChallenge(code_verifier);
query_params.overwrite(queryParamsCodeChallenge, code_challenge);
query_params.overwrite(queryParamsCodeChallengeMethod, "S256");

// Copy the authorization endpoint URL to replace its query params.
auto authorization_endpoint_url = config_->authorizationEndpointUrl();
const std::string path_and_query_params = query_params.replaceQueryString(
Expand Down Expand Up @@ -676,6 +859,10 @@ Http::FilterHeadersStatus OAuth2Filter::signOutUser(const Http::RequestHeaderMap
Http::Headers::get().SetCookie,
absl::StrCat(fmt::format(CookieDeleteFormatString, config_->cookieNames().oauth_nonce_),
cookie_domain));
response_headers->addReferenceKey(
Http::Headers::get().SetCookie,
absl::StrCat(fmt::format(CookieDeleteFormatString, config_->cookieNames().code_verifier_),
cookie_domain));
response_headers->setLocation(new_path);
decoder_callbacks_->encodeHeaders(std::move(response_headers), true, SIGN_OUT);

Expand Down
15 changes: 12 additions & 3 deletions source/extensions/filters/http/oauth2/filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,31 +90,36 @@ struct CookieNames {
cookie_names)
: CookieNames(cookie_names.bearer_token(), cookie_names.oauth_hmac(),
cookie_names.oauth_expires(), cookie_names.id_token(),
cookie_names.refresh_token(), cookie_names.oauth_nonce()) {}
cookie_names.refresh_token(), cookie_names.oauth_nonce(),
cookie_names.code_verifier()) {}

CookieNames(const std::string& bearer_token, const std::string& oauth_hmac,
const std::string& oauth_expires, const std::string& id_token,
const std::string& refresh_token, const std::string& oauth_nonce)
const std::string& refresh_token, const std::string& oauth_nonce,
const std::string& code_verifier)
: bearer_token_(bearer_token.empty() ? BearerToken : bearer_token),
oauth_hmac_(oauth_hmac.empty() ? OauthHMAC : oauth_hmac),
oauth_expires_(oauth_expires.empty() ? OauthExpires : oauth_expires),
id_token_(id_token.empty() ? IdToken : id_token),
refresh_token_(refresh_token.empty() ? RefreshToken : refresh_token),
oauth_nonce_(oauth_nonce.empty() ? OauthNonce : oauth_nonce) {}
oauth_nonce_(oauth_nonce.empty() ? OauthNonce : oauth_nonce),
code_verifier_(code_verifier.empty() ? CodeVerifier : code_verifier) {}

const std::string bearer_token_;
const std::string oauth_hmac_;
const std::string oauth_expires_;
const std::string id_token_;
const std::string refresh_token_;
const std::string oauth_nonce_;
const std::string code_verifier_;

static constexpr absl::string_view OauthExpires = "OauthExpires";
static constexpr absl::string_view BearerToken = "BearerToken";
static constexpr absl::string_view OauthHMAC = "OauthHMAC";
static constexpr absl::string_view OauthNonce = "OauthNonce";
static constexpr absl::string_view IdToken = "IdToken";
static constexpr absl::string_view RefreshToken = "RefreshToken";
static constexpr absl::string_view CodeVerifier = "CodeVerifier";
};

/**
Expand Down Expand Up @@ -188,6 +193,9 @@ class FilterConfig {
return refresh_token_cookie_settings_;
}
const CookieSettings& nonceCookieSettings() const { return nonce_cookie_settings_; }
const CookieSettings& codeVerifierCookieSettings() const {
return code_verifier_cookie_settings_;
}

private:
static FilterStats generateStats(const std::string& prefix, Stats::Scope& scope);
Expand Down Expand Up @@ -225,6 +233,7 @@ class FilterConfig {
const CookieSettings id_token_cookie_settings_;
const CookieSettings refresh_token_cookie_settings_;
const CookieSettings nonce_cookie_settings_;
const CookieSettings code_verifier_cookie_settings_;
};

using FilterConfigSharedPtr = std::shared_ptr<FilterConfig>;
Expand Down
Loading