Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 85 additions & 2 deletions integration/connection_string_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,16 @@ class ConnectionString {
m_connect_timeout(-1), m_network_timeout(-1), m_host_pattern(""),
m_enable_failure_detection(true), m_failure_detection_time(-1), m_failure_detection_timeout(-1),
m_failure_detection_interval(-1), m_failure_detection_count(-1), m_monitor_disposal_time(-1),
m_read_timeout(-1), m_write_timeout(-1),
m_read_timeout(-1), m_write_timeout(-1), m_auth_mode(""), m_auth_region(""), m_auth_host(""),
m_auth_port(-1), m_auth_expiration(-1), m_secret_id(""),

is_set_uid(false), is_set_pwd(false), is_set_db(false), is_set_log_query(false),
is_set_allow_reader_connections(false), is_set_multi_statements(false), is_set_enable_cluster_failover(false),
is_set_failover_timeout(false), is_set_connect_timeout(false), is_set_network_timeout(false), is_set_host_pattern(false),
is_set_enable_failure_detection(false), is_set_failure_detection_time(false), is_set_failure_detection_timeout(false),
is_set_failure_detection_interval(false), is_set_failure_detection_count(false), is_set_monitor_disposal_time(false),
is_set_read_timeout(false), is_set_write_timeout(false) {};
is_set_read_timeout(false), is_set_write_timeout(false), is_set_auth_mode(false), is_set_auth_region(false),
is_set_auth_host(false), is_set_auth_port(false), is_set_auth_expiration(false), is_set_secret_id(false) {};

std::string get_connection_string() const {
char conn_in[4096] = "\0";
Expand Down Expand Up @@ -115,6 +117,24 @@ class ConnectionString {
if (is_set_write_timeout) {
length += sprintf(conn_in + length, "WRITETIMEOUT=%d;", m_write_timeout);
}
if (is_set_auth_mode) {
length += sprintf(conn_in + length, "AUTHENTICATION_MODE=%s;", m_auth_mode.c_str());
}
if (is_set_auth_region) {
length += sprintf(conn_in + length, "AWS_REGION=%s;", m_auth_region.c_str());
}
if (is_set_auth_host) {
length += sprintf(conn_in + length, "IAM_HOST=%s;", m_auth_host.c_str());
}
if (is_set_auth_port) {
length += sprintf(conn_in + length, "IAM_PORT=%d;", m_auth_port);
}
if (is_set_auth_expiration) {
length += sprintf(conn_in + length, "IAM_EXPIRATION_TIME=%d;", m_auth_expiration);
}
if (is_set_secret_id) {
length += sprintf(conn_in + length, "SECRET_ID=%s;", m_secret_id.c_str());
}
snprintf(conn_in + length, sizeof(conn_in) - length, "\0");

std::string connection_string(conn_in);
Expand All @@ -133,6 +153,8 @@ class ConnectionString {
std::string m_host_pattern;
bool m_enable_failure_detection;
int m_failure_detection_time, m_failure_detection_timeout, m_failure_detection_interval, m_failure_detection_count, m_monitor_disposal_time, m_read_timeout, m_write_timeout;
std::string m_auth_mode, m_auth_region, m_auth_host, m_secret_id;
int m_auth_port, m_auth_expiration;

bool is_set_uid, is_set_pwd, is_set_db;
bool is_set_log_query, is_set_allow_reader_connections, is_set_multi_statements;
Expand All @@ -143,6 +165,7 @@ class ConnectionString {
bool is_set_failure_detection_time, is_set_failure_detection_timeout, is_set_failure_detection_interval, is_set_failure_detection_count;
bool is_set_monitor_disposal_time;
bool is_set_read_timeout, is_set_write_timeout;
bool is_set_auth_mode, is_set_auth_region, is_set_auth_host, is_set_auth_port, is_set_auth_expiration, is_set_secret_id;

void set_dsn(const std::string& dsn) {
m_dsn = dsn;
Expand Down Expand Up @@ -250,6 +273,36 @@ class ConnectionString {
m_write_timeout = write_timeout;
is_set_write_timeout = true;
}

void set_auth_mode(const std::string& auth_mode) {
m_auth_mode = auth_mode;
is_set_auth_mode = true;
}

void set_auth_region(const std::string& auth_region) {
m_auth_region = auth_region;
is_set_auth_region = true;
}

void set_auth_host(const std::string& auth_host) {
m_auth_host = auth_host;
is_set_auth_host = true;
}

void set_auth_port(const int& auth_port) {
m_auth_port = auth_port;
is_set_auth_port = true;
}

void set_auth_expiration(const int& auth_expiration) {
m_auth_expiration = auth_expiration;
is_set_auth_expiration = true;
}

void set_secret_id(const std::string& secret_id) {
m_secret_id = secret_id;
is_set_secret_id = true;
}
};

class ConnectionStringBuilder {
Expand Down Expand Up @@ -368,6 +421,36 @@ class ConnectionStringBuilder {
return *this;
}

ConnectionStringBuilder& withAuthMode(const std::string& auth_mode) {
connection_string->set_auth_mode(auth_mode);
return *this;
}

ConnectionStringBuilder& withAuthRegion(const std::string& auth_region) {
connection_string->set_auth_region(auth_region);
return *this;
}

ConnectionStringBuilder& withAuthHost(const std::string& auth_host) {
connection_string->set_auth_host(auth_host);
return *this;
}

ConnectionStringBuilder& withAuthPort(const int& auth_port) {
connection_string->set_auth_port(auth_port);
return *this;
}

ConnectionStringBuilder& withAuthExpiration(const int& auth_expiration) {
connection_string->set_auth_expiration(auth_expiration);
return *this;
}

ConnectionStringBuilder& withSecretId(const std::string& secret_id) {
connection_string->set_secret_id(secret_id);
return *this;
}

std::string build() const {
if (connection_string->m_dsn.empty()) {
throw std::runtime_error("DSN is a required field in a connection string.");
Expand Down
4 changes: 2 additions & 2 deletions util/installer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,10 @@ static SQLWCHAR W_SSL_CRLPATH[] =

/* AWS Authentication */
static SQLWCHAR W_AUTH_MODE[] = { 'A', 'U', 'T', 'H', 'E', 'N', 'T', 'I', 'C', 'A', 'T', 'I', 'O', 'N', '_', 'M', 'O', 'D', 'E', 0};
static SQLWCHAR W_AUTH_REGION[] = { 'I', 'A', 'M', '_', 'R', 'E', 'G', 'I', 'O', 'N', 0 };
static SQLWCHAR W_AUTH_REGION[] = { 'A', 'W', 'S', '_', 'R', 'E', 'G', 'I', 'O', 'N', 0 };
static SQLWCHAR W_AUTH_HOST[] = { 'I', 'A', 'M', '_', 'H', 'O', 'S', 'T', 0 };
static SQLWCHAR W_AUTH_PORT[] = { 'I', 'A', 'M', '_', 'P', 'O', 'R', 'T', 0 };
static SQLWCHAR W_AUTH_EXPIRATION[] = { 'E', 'X', 'P', 'I', 'R', 'A', 'T', 'I', 'O', 'N', '_', 'T', 'I', 'M', 'E', 0 };
static SQLWCHAR W_AUTH_EXPIRATION[] = { 'I', 'A', 'M', '_', 'E', 'X', 'P', 'I', 'R', 'A', 'T', 'I', 'O', 'N', '_', 'T', 'I', 'M', 'E', 0 };
static SQLWCHAR W_AUTH_SECRET_ID[] = { 'S', 'E', 'C', 'R', 'E', 'T', '_', 'I', 'D', 0 };

/* Failover */
Expand Down