Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add socket injection into client #147

Merged
merged 11 commits into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions clickhouse/base/socket.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#include "socket.h"
#include "singleton.h"
#include "../client.h"

#include <assert.h>
#include <stdexcept>
#include <system_error>
#include <unordered_set>
#include <memory.h>
#include <thread>

#if !defined(_win_)
# include <errno.h>
Expand Down Expand Up @@ -213,6 +215,15 @@ const std::string & NetworkAddress::Host() const {
}


SocketBase::~SocketBase() = default;

SocketFactory::~SocketFactory() = default;

void SocketFactory::sleepFor(const std::chrono::milliseconds& duration) {
std::this_thread::sleep_for(duration);
}


Socket::Socket(const NetworkAddress& addr)
: handle_(SocketConnect(addr))
{}
Expand Down Expand Up @@ -286,6 +297,33 @@ std::unique_ptr<OutputStream> Socket::makeOutputStream() const {
return std::make_unique<SocketOutput>(handle_);
}

NonSecureSocketFactory::~NonSecureSocketFactory() {}

std::unique_ptr<SocketBase> NonSecureSocketFactory::connect(const ClientOptions &opts) {
const auto address = NetworkAddress(opts.host, std::to_string(opts.port));

auto socket = doConnect(address);
setSocketOptions(*socket, opts);

return socket;
}

std::unique_ptr<Socket> NonSecureSocketFactory::doConnect(const NetworkAddress& address) {
return std::make_unique<Socket>(address);
}

void NonSecureSocketFactory::setSocketOptions(Socket &socket, const ClientOptions &opts) {
if (opts.tcp_keepalive) {
socket.SetTcpKeepAlive(
static_cast<int>(opts.tcp_keepalive_idle.count()),
static_cast<int>(opts.tcp_keepalive_intvl.count()),
static_cast<int>(opts.tcp_keepalive_cnt));
}
if (opts.tcp_nodelay) {
socket.SetTcpNoDelay(opts.tcp_nodelay);
}
}

SocketInput::SocketInput(SOCKET s)
: s_(s)
{
Expand Down
46 changes: 42 additions & 4 deletions clickhouse/base/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <cstddef>
#include <string>
#include <chrono>

#if defined(_win_)
# include <winsock2.h>
Expand All @@ -28,6 +29,8 @@ struct addrinfo;

namespace clickhouse {

struct ClientOptions;

/** Address of a host to establish connection to.
*
*/
Expand Down Expand Up @@ -57,13 +60,35 @@ class windowsErrorCategory : public std::error_category {

#endif

class Socket {

class SocketBase {
public:
virtual ~SocketBase();

virtual std::unique_ptr<InputStream> makeInputStream() const = 0;
virtual std::unique_ptr<OutputStream> makeOutputStream() const = 0;
};


class SocketFactory {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To simplify issue with reconnecting and timeouts (and to provide a framework for upcoming multiple-hosts #139 ). Please consider updating the interface:

class SocketFactory {
        public:
    virtual ~SocketFactory();

    // TODO: move connection-related options to ConnectionOptions structure.

    virtual std::unique_ptr<SocketBase> Connect(const NetworkAddress &, const ClientOptions &) = 0;
    // reconnects to a given address, maybe with a timeout or with any special procedures.
    virtual std::unique_ptr<SocketBase> ReConnect(const NetworkAddress &, const ClientOptions &) = 0;
};

Copy link
Contributor Author

@itrofimow itrofimow Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a problem with NetworkAddress - its constructor uses blocking getaddrinfo; one way of soling this is to make some base class for NetworkAddress, but i feel like this would complicate the interface too much (and we already have some Java-style abstract factories here)

Introducing ConnectionOptions seems like a good enhancement for me, but i believe it has to be done in #139, or at least on the receiving end of upcoming merge conflict. Since SocketFactory is a new feature, there shouldn't be a concern with backwards compatibility

public:
virtual ~SocketFactory();

// TODO: move connection-related options to ConnectionOptions structure.

virtual std::unique_ptr<SocketBase> connect(const ClientOptions& opts) = 0;

virtual void sleepFor(const std::chrono::milliseconds& duration);
};


class Socket : public SocketBase {
public:
Socket(const NetworkAddress& addr);
Socket(Socket&& other) noexcept;
Socket& operator=(Socket&& other) noexcept;

virtual ~Socket();
~Socket() override;

/// @params idle the time (in seconds) the connection needs to remain
/// idle before TCP starts sending keepalive probes.
Expand All @@ -75,8 +100,8 @@ class Socket {
/// @params nodelay whether to enable TCP_NODELAY
void SetTcpNoDelay(bool nodelay) noexcept;

virtual std::unique_ptr<InputStream> makeInputStream() const;
virtual std::unique_ptr<OutputStream> makeOutputStream() const;
std::unique_ptr<InputStream> makeInputStream() const override;
std::unique_ptr<OutputStream> makeOutputStream() const override;

protected:
Socket(const Socket&) = delete;
Expand All @@ -87,6 +112,19 @@ class Socket {
};


class NonSecureSocketFactory : public SocketFactory {
public:
~NonSecureSocketFactory() override;

std::unique_ptr<SocketBase> connect(const ClientOptions& opts) override;

protected:
virtual std::unique_ptr<Socket> doConnect(const NetworkAddress& address);

void setSocketOptions(Socket& socket, const ClientOptions& opts);
};


class SocketInput : public InputStream {
public:
explicit SocketInput(SOCKET s);
Expand Down
33 changes: 32 additions & 1 deletion clickhouse/base/sslsocket.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "sslsocket.h"
#include "../client.h"

#include <stdexcept>

Expand Down Expand Up @@ -100,6 +101,19 @@ SSL_CTX * prepareSSLContext(const clickhouse::SSLParams & context_params) {
#undef HANDLE_SSL_CTX_ERROR
}

clickhouse::SSLParams GetSSLParams(const clickhouse::ClientOptions& opts) {
const auto& ssl_options = opts.ssl_options;
return clickhouse::SSLParams{
ssl_options.path_to_ca_files,
ssl_options.path_to_ca_directory,
ssl_options.use_default_ca_locations,
ssl_options.context_options,
ssl_options.min_protocol_version,
ssl_options.max_protocol_version,
ssl_options.use_sni
};
}

}

namespace clickhouse {
Expand Down Expand Up @@ -135,7 +149,8 @@ SSL_CTX * SSLContext::getContext() {
<< "\n\t handshake state: " << SSL_get_state(ssl_) \
<< std::endl
*/
SSLSocket::SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params, SSLContext& context)
SSLSocket::SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params,
SSLContext& context)
: Socket(addr)
, ssl_(SSL_new(context.getContext()), &SSL_free)
{
Expand Down Expand Up @@ -181,6 +196,22 @@ SSLSocket::SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params, S
}
}

SSLSocketFactory::SSLSocketFactory(const ClientOptions& opts)
: NonSecureSocketFactory()
, ssl_params_(GetSSLParams(opts)) {
if (opts.ssl_options.ssl_context) {
ssl_context_ = std::make_unique<SSLContext>(*opts.ssl_options.ssl_context);
} else {
ssl_context_ = std::make_unique<SSLContext>(ssl_params_);
}
}

SSLSocketFactory::~SSLSocketFactory() = default;

std::unique_ptr<Socket> SSLSocketFactory::doConnect(const NetworkAddress& address) {
return std::make_unique<SSLSocket>(address, ssl_params_, *ssl_context_);
}

std::unique_ptr<InputStream> SSLSocket::makeInputStream() const {
return std::make_unique<SSLSocketInput>(ssl_.get());
}
Expand Down
18 changes: 16 additions & 2 deletions clickhouse/base/sslsocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ class SSLContext

class SSLSocket : public Socket {
public:
explicit SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params, SSLContext& context);
explicit SSLSocket(const NetworkAddress& addr, const SSLParams & ssl_params,
SSLContext& context);
SSLSocket(SSLSocket &&) = default;
~SSLSocket() = default;
~SSLSocket() override = default;

SSLSocket(const SSLSocket & ) = delete;
SSLSocket& operator=(const SSLSocket & ) = delete;
Expand All @@ -56,6 +57,19 @@ class SSLSocket : public Socket {
std::unique_ptr<SSL, void (*)(SSL *s)> ssl_;
};

class SSLSocketFactory : public NonSecureSocketFactory {
public:
explicit SSLSocketFactory(const ClientOptions& opts);
~SSLSocketFactory() override;

protected:
std::unique_ptr<Socket> doConnect(const NetworkAddress& address) override;

private:
const SSLParams ssl_params_;
std::unique_ptr<SSLContext> ssl_context_;
};

class SSLSocketInput : public InputStream {
public:
explicit SSLSocketInput(SSL *ssl);
Expand Down
Loading