Skip to content

Commit

Permalink
internal api: add new api to poll client_hello callback (#3230)
Browse files Browse the repository at this point in the history
  • Loading branch information
toidiu authored Mar 15, 2022
1 parent d731648 commit 70884dc
Show file tree
Hide file tree
Showing 8 changed files with 275 additions and 21 deletions.
3 changes: 3 additions & 0 deletions bindings/rust/s2n-tls-sys/src/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ extern "C" {
config: *mut *mut s2n_config,
) -> ::libc::c_int;
}
extern "C" {
pub fn s2n_config_client_hello_cb_enable_poll(config: *mut s2n_config) -> ::libc::c_int;
}
104 changes: 93 additions & 11 deletions tests/unit/s2n_self_talk_client_hello_cb_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "api/s2n.h"
#include "tls/s2n_connection.h"
#include "tls/s2n_internal.h"

struct client_hello_context {
int invoked;
Expand All @@ -38,6 +39,7 @@ struct client_hello_context {
* this flag tests the previous behavior from blocking callbacks
*/
int legacy_rc_for_server_name_used;
bool mark_done;
};

int mock_client(struct s2n_test_io_pair *io_pair, int expect_failure, int expect_server_name_used)
Expand Down Expand Up @@ -181,17 +183,35 @@ int client_hello_fail_handshake(struct s2n_connection *conn, void *ctx)
/* Return negative value to terminate the handshake */
return -1;

}
}

int s2n_client_hello_poll_cb(struct s2n_connection *conn, void *ctx)
{
struct client_hello_context *client_hello_ctx;
if (ctx == NULL) {
return -1;
}
client_hello_ctx = ctx;
/* Increment counter to ensure that callback was invoked */
client_hello_ctx->invoked++;

if (client_hello_ctx->mark_done) {
EXPECT_SUCCESS(s2n_client_hello_cb_done(conn));
return S2N_SUCCESS;
}

return S2N_SUCCESS;
}

int s2n_negotiate_nonblocking_ch_cb(struct s2n_connection *conn,
int s2n_negotiate_nonblocking_ch_cb(struct s2n_connection *conn,
struct client_hello_context *ch_ctx, bool server_name_used)
{
s2n_blocked_status blocked;
EXPECT_NOT_NULL(conn);
/* negotiate handshake, we should pause after the nonblocking callback is invoked */
EXPECT_FAILURE_WITH_ERRNO(s2n_negotiate(conn, &blocked), S2N_ERR_ASYNC_BLOCKED);
EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_APPLICATION_INPUT);

/* verify client hello cb has been invoked */
EXPECT_EQUAL(ch_ctx->invoked, 1);

Expand All @@ -201,7 +221,7 @@ int s2n_negotiate_nonblocking_ch_cb(struct s2n_connection *conn,
}
/* unless explicitly unblocked we should stay paused */
EXPECT_FAILURE_WITH_ERRNO(s2n_negotiate(conn, &blocked), S2N_ERR_ASYNC_BLOCKED);
EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_APPLICATION_INPUT);
EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_APPLICATION_INPUT);

/* mark the client hello cb complete */
EXPECT_SUCCESS(s2n_client_hello_cb_done(conn));
Expand All @@ -211,11 +231,39 @@ int s2n_negotiate_nonblocking_ch_cb(struct s2n_connection *conn,
return s2n_negotiate(conn, &blocked);
}

int s2n_negotiate_nonblocking_poll(struct s2n_connection *conn,
struct client_hello_context *ch_ctx)
{
EXPECT_NOT_NULL(conn);
EXPECT_NOT_NULL(ch_ctx);
int invoked = 0;
s2n_blocked_status blocked = S2N_NOT_BLOCKED;

EXPECT_EQUAL(ch_ctx->invoked, 0);

do {
/* negotiate handshake, we should pause after the nonblocking callback is invoked */
EXPECT_FAILURE_WITH_ERRNO(s2n_negotiate(conn, &blocked), S2N_ERR_ASYNC_BLOCKED);
EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_APPLICATION_INPUT);
invoked++;
EXPECT_EQUAL(ch_ctx->invoked, invoked);
} while(invoked < 10);
EXPECT_EQUAL(ch_ctx->invoked, invoked);

ch_ctx->mark_done = true;

/* Expect the callback to complete after 2nd iteration */
EXPECT_SUCCESS(s2n_negotiate(conn, &blocked));
EXPECT_EQUAL(ch_ctx->invoked, invoked + 1);

return S2N_SUCCESS;
}

int s2n_negotiate_blocking_ch_cb(struct s2n_connection *conn, struct client_hello_context *ch_ctx)
{
s2n_blocked_status blocked;
EXPECT_NOT_NULL(conn);

int rc = s2n_negotiate(conn, &blocked);
/* verify client hello cb has been invoked */
EXPECT_EQUAL(ch_ctx->invoked, 1);
Expand Down Expand Up @@ -329,7 +377,7 @@ int run_test_config_swap_ch_cb(s2n_client_hello_cb_mode cb_mode,

EXPECT_SUCCESS(start_client_conn(&io_pair, &pid, 0 , 1));
EXPECT_SUCCESS(init_server_conn(&conn, &io_pair, config));

/* do the handshake */
if ( cb_mode == S2N_CLIENT_HELLO_CB_NONBLOCKING && !ch_ctx->mark_done_during_callback) {
/* swap the config and mark server_name_used in the async context */
Expand All @@ -340,7 +388,7 @@ int run_test_config_swap_ch_cb(s2n_client_hello_cb_mode cb_mode,
*/
EXPECT_SUCCESS(s2n_negotiate_blocking_ch_cb(conn, ch_ctx));
}

/* Server name and error are as expected with null connection */
EXPECT_NULL(s2n_get_server_name(NULL));
EXPECT_EQUAL(s2n_errno, S2N_ERR_NULL);
Expand All @@ -349,7 +397,7 @@ int run_test_config_swap_ch_cb(s2n_client_hello_cb_mode cb_mode,
EXPECT_STRING_EQUAL(s2n_get_application_protocol(conn), protocols[0]);

EXPECT_SUCCESS(server_recv(conn));

EXPECT_SUCCESS(test_case_clean(conn, pid, config, &io_pair, ch_ctx));
EXPECT_SUCCESS(s2n_config_free(swap_config));
return S2N_SUCCESS;
Expand All @@ -372,7 +420,7 @@ int run_test_no_config_swap_ch_cb(s2n_client_hello_cb_mode cb_mode,
EXPECT_SUCCESS(s2n_config_set_client_hello_cb_mode(config, cb_mode));
EXPECT_SUCCESS(start_client_conn(&io_pair, &pid, 0 , 0));
EXPECT_SUCCESS(init_server_conn(&conn, &io_pair, config));

/* do the handshake */
if ( cb_mode == S2N_CLIENT_HELLO_CB_NONBLOCKING ) {
/* swap the config and mark server_name_used in the async context */
Expand All @@ -387,7 +435,7 @@ int run_test_no_config_swap_ch_cb(s2n_client_hello_cb_mode cb_mode,
EXPECT_EQUAL(s2n_errno, S2N_ERR_NULL);

EXPECT_SUCCESS(server_recv(conn));

EXPECT_SUCCESS(test_case_clean(conn, pid, config, &io_pair, ch_ctx));
return S2N_SUCCESS;
}
Expand Down Expand Up @@ -425,7 +473,7 @@ int run_test_reject_handshake_ch_cb(s2n_client_hello_cb_mode cb_mode,

/* Ensure that callback was invoked */
EXPECT_EQUAL(ch_ctx->invoked, 1);

/* shutdown to flush alert, expext failure as client doesn't send close notify */
EXPECT_FAILURE(s2n_shutdown(conn, &blocked));
EXPECT_SUCCESS(s2n_connection_free(conn));
Expand All @@ -434,6 +482,37 @@ int run_test_reject_handshake_ch_cb(s2n_client_hello_cb_mode cb_mode,
return S2N_SUCCESS;
}

int run_test_poll_ch_cb(s2n_client_hello_cb_mode cb_mode,
struct s2n_cert_chain_and_key *chain_and_key,
struct client_hello_context *ch_ctx)
{
struct s2n_test_io_pair io_pair = { 0 };
struct s2n_config *config = s2n_config_new();
EXPECT_NOT_NULL(config);
struct s2n_connection *conn;
pid_t pid = 0;

EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, chain_and_key));

/* Setup ClientHello callback */
EXPECT_SUCCESS(s2n_config_set_client_hello_cb(config, s2n_client_hello_poll_cb, ch_ctx));
EXPECT_SUCCESS(s2n_config_set_client_hello_cb_mode(config, cb_mode));

/* Enable callback polling mode */
EXPECT_SUCCESS(s2n_config_client_hello_cb_enable_poll(config));

EXPECT_SUCCESS(start_client_conn(&io_pair, &pid, 0 , 0));
EXPECT_SUCCESS(init_server_conn(&conn, &io_pair, config));

/* negotiate and make assertions */
EXPECT_SUCCESS(s2n_negotiate_nonblocking_poll(conn, ch_ctx));

EXPECT_SUCCESS(server_recv(conn));

EXPECT_SUCCESS(test_case_clean(conn, pid, config, &io_pair, ch_ctx));
return S2N_SUCCESS;
}

int main(int argc, char **argv)
{
struct client_hello_context client_hello_ctx = {0};
Expand Down Expand Up @@ -487,6 +566,9 @@ int main(int argc, char **argv)
EXPECT_SUCCESS(run_test_reject_handshake_ch_cb(S2N_CLIENT_HELLO_CB_NONBLOCKING,
chain_and_key, &client_hello_ctx));

EXPECT_SUCCESS(run_test_poll_ch_cb(S2N_CLIENT_HELLO_CB_NONBLOCKING,
chain_and_key, &client_hello_ctx));

EXPECT_SUCCESS(s2n_cert_chain_and_key_free(chain_and_key));
free(cert_chain_pem);
free(private_key_pem);
Expand Down
124 changes: 123 additions & 1 deletion tests/unit/s2n_server_hello_retry_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include "tls/extensions/s2n_server_key_share.h"

#include "error/s2n_errno.h"
#include "utils/s2n_result.h"
#include "tls/s2n_internal.h"

#define HELLO_RETRY_MSG_NO 1

Expand All @@ -37,8 +39,45 @@ const uint8_t COMPRESSION_METHOD_SIZE = 1;
struct client_hello_context {
int invocations;
s2n_client_hello_cb_mode mode;
bool mark_done;
bool enable_poll;
};

int s2n_negotiate_poll_hello_retry(struct s2n_connection *server_conn,
struct s2n_connection *client_conn,
struct client_hello_context *client_hello_ctx)
{
s2n_blocked_status blocked = S2N_NOT_BLOCKED;
EXPECT_FAILURE_WITH_ERRNO(s2n_negotiate(client_conn, &blocked), S2N_ERR_IO_BLOCKED);

int expected_invocation = 0;

/* if polling is enabled then confirm that the callback is incremented each time */
if (client_hello_ctx->enable_poll) {
do {
/* invocation should increase each time s2n_negotiate is called */
EXPECT_FAILURE_WITH_ERRNO(s2n_negotiate(server_conn, &blocked), S2N_ERR_ASYNC_BLOCKED);
EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_APPLICATION_INPUT);
expected_invocation++;
EXPECT_EQUAL(client_hello_ctx->invocations, expected_invocation);
} while (expected_invocation < 10);
}
EXPECT_EQUAL(client_hello_ctx->invocations, expected_invocation);

/* complete the callback on the next call */
client_hello_ctx->mark_done = true;
EXPECT_SUCCESS(s2n_negotiate_test_server_and_client(server_conn, client_conn));

/*
* hello retry will invoke the s2n_negotiate twice but the callback should
* be called once regardless of polling
*/
expected_invocation++;
EXPECT_EQUAL(client_hello_ctx->invocations, expected_invocation);

return S2N_SUCCESS;
}

static int client_hello_detect_duplicate_calls(struct s2n_connection *conn, void *ctx)
{
if (ctx == NULL) {
Expand All @@ -56,6 +95,78 @@ static int client_hello_detect_duplicate_calls(struct s2n_connection *conn, void
return 0;
}

int s2n_client_hello_poll_cb(struct s2n_connection *conn, void *ctx)
{
struct client_hello_context *client_hello_ctx;
if (ctx == NULL) {
return -1;
}
client_hello_ctx = ctx;
/* Increment counter to ensure that callback was invoked */
client_hello_ctx->invocations++;

if (client_hello_ctx->mark_done) {
EXPECT_SUCCESS(s2n_client_hello_cb_done(conn));
return S2N_SUCCESS;
}

return S2N_SUCCESS;
}

S2N_RESULT hello_retry_client_hello_cb_test(bool enable_poll) {
struct s2n_cert_chain_and_key *tls13_chain_and_key = NULL;
EXPECT_SUCCESS(s2n_test_cert_chain_and_key_new(&tls13_chain_and_key,
S2N_ECDSA_P384_PKCS1_CERT_CHAIN, S2N_ECDSA_P384_PKCS1_KEY));
EXPECT_NOT_NULL(tls13_chain_and_key);

DEFER_CLEANUP(struct s2n_config *config = s2n_config_new(), s2n_config_ptr_free);
EXPECT_NOT_NULL(config);

EXPECT_SUCCESS(s2n_config_set_unsafe_for_testing(config));
EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(config, tls13_chain_and_key));
EXPECT_SUCCESS(s2n_config_set_cipher_preferences(config, "default_tls13"));

DEFER_CLEANUP(struct s2n_connection *server_conn = s2n_connection_new(S2N_SERVER), s2n_connection_ptr_free);
DEFER_CLEANUP(struct s2n_connection *client_conn = s2n_connection_new(S2N_CLIENT), s2n_connection_ptr_free);
EXPECT_NOT_NULL(server_conn);
EXPECT_NOT_NULL(client_conn);

EXPECT_SUCCESS(s2n_connection_set_config(server_conn, config));
EXPECT_SUCCESS(s2n_connection_set_config(client_conn, config));

struct s2n_test_io_pair io_pair = { 0 };
EXPECT_SUCCESS(s2n_io_pair_init_non_blocking(&io_pair));
EXPECT_SUCCESS(s2n_connections_set_io_pair(client_conn, server_conn, &io_pair));

/* Force HRR path */
client_conn->security_policy_override = &security_policy_test_tls13_retry;

/* setup the client hello callback */
struct client_hello_context client_hello_ctx = {.invocations = 0,
.mode = S2N_CLIENT_HELLO_CB_NONBLOCKING, .mark_done = false,
.enable_poll = enable_poll };
EXPECT_SUCCESS(s2n_config_set_client_hello_cb(config,
s2n_client_hello_poll_cb, &client_hello_ctx));
EXPECT_SUCCESS(s2n_config_set_client_hello_cb_mode(config,
S2N_CLIENT_HELLO_CB_NONBLOCKING));

if (enable_poll) {
/* Enable callback polling mode */
EXPECT_SUCCESS(s2n_config_client_hello_cb_enable_poll(config));
}

/* negotiate and make assertions */
EXPECT_SUCCESS(s2n_negotiate_poll_hello_retry(server_conn, client_conn, &client_hello_ctx));

/* check hello retry state */
EXPECT_TRUE(IS_HELLO_RETRY_HANDSHAKE(client_conn));
EXPECT_TRUE(IS_HELLO_RETRY_HANDSHAKE(server_conn));

/* cleanup */
EXPECT_SUCCESS(s2n_cert_chain_and_key_free(tls13_chain_and_key));
EXPECT_SUCCESS(s2n_io_pair_close(&io_pair));
return S2N_RESULT_OK;
}

int main(int argc, char **argv)
{
Expand Down Expand Up @@ -450,6 +561,9 @@ int main(int argc, char **argv)
EXPECT_TRUE(server_conn->handshake.handshake_type & HELLO_RETRY_REQUEST);
EXPECT_EQUAL(client_hello_ctx.invocations, 1);

EXPECT_TRUE(IS_HELLO_RETRY_HANDSHAKE(client_conn));
EXPECT_TRUE(IS_HELLO_RETRY_HANDSHAKE(server_conn));

EXPECT_SUCCESS(s2n_connection_free(server_conn));
EXPECT_SUCCESS(s2n_connection_free(client_conn));
EXPECT_SUCCESS(s2n_config_free(client_config));
Expand All @@ -458,6 +572,14 @@ int main(int argc, char **argv)
EXPECT_SUCCESS(s2n_io_pair_close(&io_pair));
}

/* Hello Retry Request + (poll and no-poll) client hello callback */
{
/* enable polling */
EXPECT_OK(hello_retry_client_hello_cb_test(true));
/* disable polling */
EXPECT_OK(hello_retry_client_hello_cb_test(false));
}

/* Test s2n_set_hello_retry_required correctly sets the handshake type to HELLO_RETRY_REQUEST,
* when conn->actual_protocol_version is set to TLS1.3 version */
{
Expand All @@ -471,7 +593,7 @@ int main(int argc, char **argv)
EXPECT_SUCCESS(s2n_connection_free(conn));
}

/* Test s2n_set_hello_retry_required raises a S2N_ERR_INVALID_HELLO_RETRY error
/* Test s2n_set_hello_retry_required raises a S2N_ERR_INVALID_HELLO_RETRY error
* when conn->actual_protocol_version is less than TLS1.3 */
{
struct s2n_connection *conn;
Expand Down
Loading

0 comments on commit 70884dc

Please sign in to comment.