diff --git a/presto-native-execution/presto_cpp/main/Announcer.cpp b/presto-native-execution/presto_cpp/main/Announcer.cpp index b5a504ebc131f..d7a716e7c026f 100644 --- a/presto-native-execution/presto_cpp/main/Announcer.cpp +++ b/presto-native-execution/presto_cpp/main/Announcer.cpp @@ -130,6 +130,7 @@ void Announcer::makeAnnouncement() { eventBaseThread_.getEventBase(), address_, std::chrono::milliseconds(10'000), + pool_, clientCertAndKeyPath_, ciphers_); } @@ -139,7 +140,7 @@ void Announcer::makeAnnouncement() { return; } - client_->sendRequest(announcementRequest_, pool_.get(), announcementBody_) + client_->sendRequest(announcementRequest_, announcementBody_) .via(eventBaseThread_.getEventBase()) .thenValue([](auto response) { auto message = response->headers(); diff --git a/presto-native-execution/presto_cpp/main/PrestoExchangeSource.cpp b/presto-native-execution/presto_cpp/main/PrestoExchangeSource.cpp index 7f46f832099f6..29122b47eef8f 100644 --- a/presto-native-execution/presto_cpp/main/PrestoExchangeSource.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoExchangeSource.cpp @@ -72,6 +72,7 @@ PrestoExchangeSource::PrestoExchangeSource( eventBase, address, std::chrono::milliseconds(10'000), + pool_, clientCertAndKeyPath_, ciphers_, [](size_t bufferBytes) { @@ -107,7 +108,7 @@ void PrestoExchangeSource::doRequest() { .method(proxygen::HTTPMethod::GET) .url(path) .header(protocol::PRESTO_MAX_SIZE_HTTP_HEADER, "32MB") - .send(httpClient_.get(), pool_.get()) + .send(httpClient_.get()) .via(driverCPUExecutor()) .thenValue([path, self](std::unique_ptr response) { velox::common::testutil::TestValue::adjust( @@ -273,7 +274,7 @@ void PrestoExchangeSource::acknowledgeResults(int64_t ackSequence) { http::RequestBuilder() .method(proxygen::HTTPMethod::GET) .url(ackPath) - .send(httpClient_.get(), pool_.get()) + .send(httpClient_.get()) .via(driverCPUExecutor()) .thenValue([self](std::unique_ptr response) { VLOG(1) << "Ack " << response->headers()->getStatusCode(); @@ -292,7 +293,7 @@ void PrestoExchangeSource::abortResults() { http::RequestBuilder() .method(proxygen::HTTPMethod::DELETE) .url(basePath_) - .send(httpClient_.get(), pool_.get()) + .send(httpClient_.get()) .via(driverCPUExecutor()) .thenValue([queue, self](std::unique_ptr response) { auto statusCode = response->headers()->getStatusCode(); diff --git a/presto-native-execution/presto_cpp/main/http/HttpClient.cpp b/presto-native-execution/presto_cpp/main/http/HttpClient.cpp index ef08b8bd19dfe..c7fb11fe8e357 100644 --- a/presto-native-execution/presto_cpp/main/http/HttpClient.cpp +++ b/presto-native-execution/presto_cpp/main/http/HttpClient.cpp @@ -21,6 +21,7 @@ HttpClient::HttpClient( folly::EventBase* eventBase, const folly::SocketAddress& address, std::chrono::milliseconds timeout, + std::shared_ptr pool, const std::string& clientCertAndKeyPath, const std::string& ciphers, std::function&& reportOnBodyStatsFunc) @@ -31,6 +32,7 @@ HttpClient::HttpClient( std::chrono::milliseconds(folly::HHWheelTimer::DEFAULT_TICK_INTERVAL), folly::AsyncTimeout::InternalEnum::NORMAL, timeout)), + pool_(std::move(pool)), clientCertAndKeyPath_(clientCertAndKeyPath), ciphers_(ciphers), reportOnBodyStatsFunc_(std::move(reportOnBodyStatsFunc)), @@ -52,11 +54,11 @@ HttpClient::~HttpClient() { HttpResponse::HttpResponse( std::unique_ptr headers, - velox::memory::MemoryPool* pool, + std::shared_ptr pool, uint64_t minResponseAllocBytes, uint64_t maxResponseAllocBytes) : headers_(std::move(headers)), - pool_(pool), + pool_(std::move(pool)), minResponseAllocBytes_(minResponseAllocBytes), maxResponseAllocBytes_(maxResponseAllocBytes) { VELOX_CHECK_NOT_NULL(pool_); @@ -141,7 +143,6 @@ class ResponseHandler : public proxygen::HTTPTransactionHandler { public: ResponseHandler( const proxygen::HTTPMessage& request, - velox::memory::MemoryPool* pool, uint64_t maxResponseAllocBytes, const std::string& body, std::function reportOnBodyStatsFunc, @@ -149,13 +150,12 @@ class ResponseHandler : public proxygen::HTTPTransactionHandler { : request_(request), body_(body), reportOnBodyStatsFunc_(std::move(reportOnBodyStatsFunc)), - pool_(pool), minResponseAllocBytes_(velox::memory::AllocationTraits::pageBytes( - pool_->sizeClasses().front())), + client->memoryPool()->sizeClasses().front())), maxResponseAllocBytes_( std::max(minResponseAllocBytes_, maxResponseAllocBytes)), client_(std::move(client)) { - VELOX_CHECK_NOT_NULL(pool_); + VELOX_CHECK_NOT_NULL(client_->memoryPool().get()); } folly::SemiFuture> initialize( @@ -175,7 +175,10 @@ class ResponseHandler : public proxygen::HTTPTransactionHandler { void onHeadersComplete( std::unique_ptr msg) noexcept override { response_ = std::make_unique( - std::move(msg), pool_, minResponseAllocBytes_, maxResponseAllocBytes_); + std::move(msg), + client_->memoryPool(), + minResponseAllocBytes_, + maxResponseAllocBytes_); } void onBody(std::unique_ptr chain) noexcept override { @@ -225,7 +228,6 @@ class ResponseHandler : public proxygen::HTTPTransactionHandler { const proxygen::HTTPMessage request_; const std::string body_; const std::function reportOnBodyStatsFunc_; - velox::memory::MemoryPool* const pool_; const uint64_t minResponseAllocBytes_; const uint64_t maxResponseAllocBytes_; std::unique_ptr response_; @@ -297,11 +299,9 @@ class ConnectionHandler : public proxygen::HTTPConnector::Callback { folly::SemiFuture> HttpClient::sendRequest( const proxygen::HTTPMessage& request, - velox::memory::MemoryPool* pool, const std::string& body) { auto responseHandler = std::make_shared( request, - pool, maxResponseAllocBytes_, body, reportOnBodyStatsFunc_, diff --git a/presto-native-execution/presto_cpp/main/http/HttpClient.h b/presto-native-execution/presto_cpp/main/http/HttpClient.h index 5f9459b883186..8ecb8688742a4 100644 --- a/presto-native-execution/presto_cpp/main/http/HttpClient.h +++ b/presto-native-execution/presto_cpp/main/http/HttpClient.h @@ -27,7 +27,7 @@ class HttpResponse { public: HttpResponse( std::unique_ptr headers, - velox::memory::MemoryPool* pool, + std::shared_ptr pool, uint64_t minResponseAllocBytes, uint64_t maxResponseAllocBytes); @@ -86,7 +86,7 @@ class HttpResponse { FOLLY_ALWAYS_INLINE size_t nextAllocationSize(uint64_t dataLength) const; const std::unique_ptr headers_; - velox::memory::MemoryPool* const pool_; + const std::shared_ptr pool_; const uint64_t minResponseAllocBytes_; const uint64_t maxResponseAllocBytes_; @@ -105,6 +105,7 @@ class HttpClient : public std::enable_shared_from_this { folly::EventBase* FOLLY_NONNULL eventBase, const folly::SocketAddress& address, std::chrono::milliseconds timeout, + std::shared_ptr pool, const std::string& clientCertAndKeyPath = "", const std::string& ciphers = "", std::function&& reportOnBodyStatsFunc = nullptr); @@ -114,13 +115,17 @@ class HttpClient : public std::enable_shared_from_this { // TODO Avoid copy by using IOBuf for body folly::SemiFuture> sendRequest( const proxygen::HTTPMessage& request, - velox::memory::MemoryPool* pool, const std::string& body = ""); + const std::shared_ptr& memoryPool() { + return pool_; + } + private: folly::EventBase* const eventBase_; const folly::SocketAddress address_; const folly::HHWheelTimer::UniquePtr timer_; + const std::shared_ptr pool_; // clientCertAndKeyPath_ Points to a file (usually with pem extension) which // contains certificate and key concatenated together const std::string clientCertAndKeyPath_; @@ -163,11 +168,10 @@ class RequestBuilder { folly::SemiFuture> send( HttpClient* client, - velox::memory::MemoryPool* pool, const std::string& body = "") { header(proxygen::HTTP_HEADER_CONTENT_LENGTH, std::to_string(body.size())); headers_.ensureHostHeader(); - return client->sendRequest(headers_, pool, body); + return client->sendRequest(headers_, body); } private: diff --git a/presto-native-execution/presto_cpp/main/http/tests/HttpTest.cpp b/presto-native-execution/presto_cpp/main/http/tests/HttpTest.cpp index 311d852d1d7e2..70713967edc73 100644 --- a/presto-native-execution/presto_cpp/main/http/tests/HttpTest.cpp +++ b/presto-native-execution/presto_cpp/main/http/tests/HttpTest.cpp @@ -184,6 +184,7 @@ class HttpClientFactory { const folly::SocketAddress& address, const std::chrono::milliseconds& timeout, bool useHttps, + std::shared_ptr pool, std::function&& reportOnBodyStatsFunc = nullptr) { if (useHttps) { std::string clientCaPath = getCertsPath("client_ca.pem"); @@ -192,6 +193,7 @@ class HttpClientFactory { eventBase_.get(), address, timeout, + pool, clientCaPath, ciphers, std::move(reportOnBodyStatsFunc)); @@ -200,6 +202,7 @@ class HttpClientFactory { eventBase_.get(), address, timeout, + pool, "", "", std::move(reportOnBodyStatsFunc)); @@ -211,12 +214,13 @@ class HttpClientFactory { std::unique_ptr eventBaseThread_; }; -folly::SemiFuture> -sendGet(http::HttpClient* client, const std::string& url, MemoryPool* pool) { +folly::SemiFuture> sendGet( + http::HttpClient* client, + const std::string& url) { return http::RequestBuilder() .method(proxygen::HTTPMethod::GET) .url(url) - .send(client, pool); + .send(client); } static std::unique_ptr getServer(bool useHttps) { @@ -280,41 +284,39 @@ TEST_P(HttpTestSuite, basic) { HttpClientFactory clientFactory; auto client = clientFactory.newClient( - serverAddress, std::chrono::milliseconds(1'000), useHttps); + serverAddress, std::chrono::milliseconds(1'000), useHttps, memoryPool); { - auto response = sendGet(client.get(), "/ping", memoryPool.get()).get(); + auto response = sendGet(client.get(), "/ping").get(); ASSERT_EQ(response->headers()->getStatusCode(), http::kHttpOk); - response = - sendGet(client.get(), "/echo/good-morning", memoryPool.get()).get(); + response = sendGet(client.get(), "/echo/good-morning").get(); ASSERT_EQ(response->headers()->getStatusCode(), http::kHttpOk); ASSERT_EQ(bodyAsString(*response, memoryPool.get()), "/echo/good-morning"); response = http::RequestBuilder() .method(proxygen::HTTPMethod::POST) .url("/echo") - .send(client.get(), memoryPool.get(), "Good morning!") + .send(client.get(), "Good morning!") .get(); ASSERT_EQ(response->headers()->getStatusCode(), http::kHttpOk); ASSERT_EQ(bodyAsString(*response, memoryPool.get()), "Good morning!"); - response = sendGet(client.get(), "/wrong/path", memoryPool.get()).get(); + response = sendGet(client.get(), "/wrong/path").get(); ASSERT_EQ(response->headers()->getStatusCode(), http::kHttpNotFound); - auto tryResponse = - sendGet(client.get(), "/blackhole", memoryPool.get()).getTry(); + auto tryResponse = sendGet(client.get(), "/blackhole").getTry(); ASSERT_TRUE(tryResponse.hasException()); auto httpException = dynamic_cast( tryResponse.tryGetExceptionObject()); ASSERT_EQ(httpException->getProxygenError(), proxygen::kErrorTimeout); - response = sendGet(client.get(), "/ping", memoryPool.get()).get(); + response = sendGet(client.get(), "/ping").get(); ASSERT_EQ(response->headers()->getStatusCode(), http::kHttpOk); } wrapper.stop(); - auto tryResponse = sendGet(client.get(), "/ping", memoryPool.get()).getTry(); + auto tryResponse = sendGet(client.get(), "/ping").getTry(); ASSERT_TRUE(tryResponse.hasException()); auto socketException = dynamic_cast( @@ -338,14 +340,12 @@ TEST_P(HttpTestSuite, httpResponseAllocationFailure) { HttpClientFactory clientFactory; auto client = clientFactory.newClient( - serverAddress, std::chrono::milliseconds(1'000), useHttps); + serverAddress, std::chrono::milliseconds(1'000), useHttps, leafPool); { const std::string echoMessage(memoryCapBytes * 4, 'C'); auto response = - sendGet( - client.get(), fmt::format("/echo/{}", echoMessage), leafPool.get()) - .get(); + sendGet(client.get(), fmt::format("/echo/{}", echoMessage)).get(); ASSERT_EQ(response->headers()->getStatusCode(), http::kHttpOk); ASSERT_TRUE(response->hasError()); VELOX_ASSERT_THROW(response->consumeBody(), ""); @@ -366,9 +366,9 @@ TEST_P(HttpTestSuite, serverRestart) { HttpClientFactory clientFactory; auto client = clientFactory.newClient( - serverAddress, std::chrono::milliseconds(1'000), useHttps); + serverAddress, std::chrono::milliseconds(1'000), useHttps, memoryPool); - auto response = sendGet(client.get(), "/ping", memoryPool.get()).get(); + auto response = sendGet(client.get(), "/ping").get(); ASSERT_EQ(response->headers()->getStatusCode(), http::kHttpOk); wrapper->stop(); @@ -381,8 +381,8 @@ TEST_P(HttpTestSuite, serverRestart) { serverAddress = wrapper->start().get(); client = clientFactory.newClient( - serverAddress, std::chrono::milliseconds(1'000), useHttps); - response = sendGet(client.get(), "/ping", memoryPool.get()).get(); + serverAddress, std::chrono::milliseconds(1'000), useHttps, memoryPool); + response = sendGet(client.get(), "/ping").get(); ASSERT_EQ(response->headers()->getStatusCode(), http::kHttpOk); wrapper->stop(); } @@ -478,12 +478,12 @@ TEST_P(HttpTestSuite, asyncRequests) { HttpClientFactory clientFactory; auto client = clientFactory.newClient( - serverAddress, std::chrono::milliseconds(1'000), useHttps); + serverAddress, std::chrono::milliseconds(1'000), useHttps, memoryPool); auto [reqPromise, reqFuture] = folly::makePromiseContract(); request->requestPromise = std::move(reqPromise); - auto responseFuture = sendGet(client.get(), "/async/msg", memoryPool.get()); + auto responseFuture = sendGet(client.get(), "/async/msg"); // Wait until the request reaches to the server. std::move(reqFuture).wait(); @@ -513,13 +513,13 @@ TEST_P(HttpTestSuite, timedOutRequests) { HttpClientFactory clientFactory; auto client = clientFactory.newClient( - serverAddress, std::chrono::milliseconds(1'000), useHttps); + serverAddress, std::chrono::milliseconds(1'000), useHttps, memoryPool); request->maxWaitMillis = 100; auto [reqPromise, reqFuture] = folly::makePromiseContract(); request->requestPromise = std::move(reqPromise); - auto responseFuture = sendGet(client.get(), "/async/msg", memoryPool.get()); + auto responseFuture = sendGet(client.get(), "/async/msg"); // Wait until the request reaches to the server. std::move(reqFuture).wait(); @@ -549,13 +549,13 @@ TEST_P(HttpTestSuite, DISABLED_outstandingRequests) { HttpClientFactory clientFactory; auto client = clientFactory.newClient( - serverAddress, std::chrono::milliseconds(10'000), useHttps); + serverAddress, std::chrono::milliseconds(10'000), useHttps, memoryPool); request->maxWaitMillis = 0; auto [reqPromise, reqFuture] = folly::makePromiseContract(); request->requestPromise = std::move(reqPromise); - auto responseFuture = sendGet(client.get(), "/async/msg", memoryPool.get()); + auto responseFuture = sendGet(client.get(), "/async/msg"); // Wait until the request reaches to the server. std::move(reqFuture).wait(); @@ -585,12 +585,13 @@ TEST_P(HttpTestSuite, testReportOnBodyStatsFunc) { serverAddress, std::chrono::milliseconds(1'000), useHttps, + memoryPool, [&](size_t bufferBytes) { reportedCount.fetch_add(bufferBytes); }); auto [reqPromise, reqFuture] = folly::makePromiseContract(); request->requestPromise = std::move(reqPromise); - auto responseFuture = sendGet(client.get(), "/async/msg", memoryPool.get()); + auto responseFuture = sendGet(client.get(), "/async/msg"); // Wait until the request reaches to the server. std::string responseData = "Success";