diff --git a/include/envoy/network/connection.h b/include/envoy/network/connection.h index 28420a3d5f4a3..fd4eb89a5ad23 100644 --- a/include/envoy/network/connection.h +++ b/include/envoy/network/connection.h @@ -166,6 +166,30 @@ class Connection : public Event::DeferredDeletable, public FilterManager { */ virtual const Network::Address::InstanceConstSharedPtr& remoteAddress() const PURE; + /** + * Credentials of the peer of a socket as decided by SO_PEERCRED. + */ + struct UnixDomainSocketPeerCredentials { + /** + * The process id of the peer. + */ + int32_t pid; + /** + * The user id of the peer. + */ + uint32_t uid; + /** + * The group id of the peer. + */ + uint32_t gid; + }; + + /** + * @return The unix socket peer credentials of the the remote client. Note that this is only + * supported for unix socket connections. + */ + virtual absl::optional unixSocketPeerCredentials() const PURE; + /** * @return the local address of the connection. For client connections, this is the origin * address. For server connections, this is the local destination address. For server connections diff --git a/source/common/network/connection_impl.cc b/source/common/network/connection_impl.cc index e4e1753290a43..fd2c80ef4cb7e 100644 --- a/source/common/network/connection_impl.cc +++ b/source/common/network/connection_impl.cc @@ -504,6 +504,23 @@ void ConnectionImpl::onReadReady() { } } +absl::optional +ConnectionImpl::unixSocketPeerCredentials() const { + // TODO(snowp): Support non-linux platforms. +#ifndef SO_PEERCRED + return absl::nullopt; +#else + struct ucred ucred; + socklen_t ucred_size = sizeof(ucred); + int rc = getsockopt(ioHandle().fd(), SOL_SOCKET, SO_PEERCRED, &ucred, &ucred_size); + if (rc == -1) { + return absl::nullopt; + } + + return {{ucred.pid, ucred.uid, ucred.gid}}; +#endif +} + void ConnectionImpl::onWriteReady() { ENVOY_CONN_LOG(trace, "write ready", *this); diff --git a/source/common/network/connection_impl.h b/source/common/network/connection_impl.h index c2fb2584746d2..935704f51f855 100644 --- a/source/common/network/connection_impl.h +++ b/source/common/network/connection_impl.h @@ -80,6 +80,7 @@ class ConnectionImpl : public virtual Connection, const Address::InstanceConstSharedPtr& localAddress() const override { return socket_->localAddress(); } + absl::optional unixSocketPeerCredentials() const override; void setConnectionStats(const ConnectionStats& stats) override; const Ssl::ConnectionInfo* ssl() const override { return transport_socket_->ssl(); } State state() const override; diff --git a/test/integration/uds_integration_test.cc b/test/integration/uds_integration_test.cc index 8c93f738ea05a..caa7c93b9c953 100644 --- a/test/integration/uds_integration_test.cc +++ b/test/integration/uds_integration_test.cc @@ -78,6 +78,31 @@ HttpIntegrationTest::ConnectionCreationFunction UdsListenerIntegrationTest::crea }; } +TEST_P(UdsListenerIntegrationTest, TestPeerCredentials) { + fake_upstreams_count_ = 1; + initialize(); + auto client_connection = createConnectionFn()(); + codec_client_ = makeHttpConnection(std::move(client_connection)); + Http::TestHeaderMapImpl request_headers{ + {":method", "POST"}, {":path", "/test/long/url"}, {":scheme", "http"}, + {":authority", "host"}, {"x-lyft-user-id", "123"}, {"x-forwarded-for", "10.0.0.1"}}; + auto response = codec_client_->makeHeaderOnlyRequest(request_headers); + waitForNextUpstreamRequest(0); + + auto credentials = codec_client_->connection()->unixSocketPeerCredentials(); +#ifndef SO_PEERCRED + EXPECT_EQ(credentials, absl::nullopt); +#else + EXPECT_EQ(credentials->pid, getpid()); + EXPECT_EQ(credentials->uid, getuid()); + EXPECT_EQ(credentials->gid, getgid()); +#endif + + upstream_request_->encodeHeaders(Http::TestHeaderMapImpl{{":status", "200"}}, true); + + response->waitForEndStream(); +} + TEST_P(UdsListenerIntegrationTest, RouterRequestAndResponseWithBodyNoBuffer) { ConnectionCreationFunction creator = createConnectionFn(); testRouterRequestAndResponseWithBody(1024, 512, false, &creator); diff --git a/test/mocks/network/connection.h b/test/mocks/network/connection.h index 672986a77ebb0..3430fcf14d0a2 100644 --- a/test/mocks/network/connection.h +++ b/test/mocks/network/connection.h @@ -65,6 +65,8 @@ class MockConnection : public Connection, public MockConnectionBase { MOCK_METHOD1(detectEarlyCloseWhenReadDisabled, void(bool)); MOCK_CONST_METHOD0(readEnabled, bool()); MOCK_CONST_METHOD0(remoteAddress, const Address::InstanceConstSharedPtr&()); + MOCK_CONST_METHOD0(unixSocketPeerCredentials, + absl::optional()); MOCK_CONST_METHOD0(localAddress, const Address::InstanceConstSharedPtr&()); MOCK_METHOD1(setConnectionStats, void(const ConnectionStats& stats)); MOCK_CONST_METHOD0(ssl, const Ssl::ConnectionInfo*()); @@ -109,6 +111,8 @@ class MockClientConnection : public ClientConnection, public MockConnectionBase MOCK_METHOD1(detectEarlyCloseWhenReadDisabled, void(bool)); MOCK_CONST_METHOD0(readEnabled, bool()); MOCK_CONST_METHOD0(remoteAddress, const Address::InstanceConstSharedPtr&()); + MOCK_CONST_METHOD0(unixSocketPeerCredentials, + absl::optional()); MOCK_CONST_METHOD0(localAddress, const Address::InstanceConstSharedPtr&()); MOCK_METHOD1(setConnectionStats, void(const ConnectionStats& stats)); MOCK_CONST_METHOD0(ssl, const Ssl::ConnectionInfo*());