Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: query flight rpc connection leaking #13956

Merged
merged 13 commits into from
Dec 14, 2023
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions src/common/base/src/base/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod stoppable;
mod string;
mod take_mut;
mod uniq_id;
mod watch_notify;

pub use net::get_free_tcp_port;
pub use net::get_free_udp_port;
Expand Down Expand Up @@ -52,3 +53,4 @@ pub use tokio;
pub use uniq_id::GlobalSequence;
pub use uniq_id::GlobalUniqName;
pub use uuid;
pub use watch_notify::WatchNotify;
71 changes: 71 additions & 0 deletions src/common/base/src/base/watch_notify.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright 2021 Datafuse Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use tokio::sync::watch;

/// A Notify based on tokio::sync::watch,
/// which allows `notify_waiters` to be called before `notified` was called,
/// without losing notification.
pub struct WatchNotify {
rx: watch::Receiver<bool>,
tx: watch::Sender<bool>,
}

impl Default for WatchNotify {
fn default() -> Self {
Self::new()
}
}

impl WatchNotify {
pub fn new() -> Self {
let (tx, rx) = watch::channel(false);
Self { rx, tx }
}

pub async fn notified(&self) {
let mut rx = self.rx.clone();
// we do care about the result,
// any change or error should wake up the waiting task
let _ = rx.changed().await;
}

pub fn notify_waiters(&self) {
let _ = self.tx.send_replace(true);
}
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn test_notify() {
let notify = WatchNotify::new();
let notified = notify.notified();
notify.notify_waiters();
notified.await;
}

#[tokio::test]
async fn test_notify_waiters_ahead() {
let notify = WatchNotify::new();
// notify_waiters ahead of notified being instantiated and awaited
notify.notify_waiters();

// this should not await indefinitely
let notified = notify.notified();
notified.await;
}
}
2 changes: 1 addition & 1 deletion src/query/service/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ socket2 = "0.5.3"
strength_reduce = "0.2.4"
tempfile = "3.4.0"
time = "0.3.14"
tokio = { workspace = true }
tokio-stream = { workspace = true, features = ["net"] }
toml = { version = "0.7.3", default-features = false }
tonic = { workspace = true }
Expand Down Expand Up @@ -185,7 +186,6 @@ temp-env = "0.3.0"
tempfile = "3.4.0"
tower = "0.4.13"
url = "2.3.1"
walkdir = { workspace = true }
wiremock = "0.5.14"

[build-dependencies]
Expand Down
32 changes: 21 additions & 11 deletions src/query/service/src/api/rpc/flight_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ use common_arrow::arrow_format::flight::data::FlightData;
use common_arrow::arrow_format::flight::data::Ticket;
use common_arrow::arrow_format::flight::service::flight_service_client::FlightServiceClient;
use common_base::base::tokio;
use common_base::base::tokio::sync::Notify;
use common_base::base::tokio::time::Duration;
use common_base::runtime::GlobalIORuntime;
use common_base::runtime::TrySpawn;
use common_exception::ErrorCode;
use common_exception::Result;
use futures::StreamExt;
use futures_util::future::Either;
use minitrace::full_name;
use minitrace::future::FutureExt;
use minitrace::Span;
use tonic::transport::channel::Channel;
use tonic::Request;
use tonic::Status;
Expand All @@ -39,6 +39,7 @@ use tonic::Streaming;
use crate::api::rpc::flight_actions::FlightAction;
use crate::api::rpc::packets::DataPacket;
use crate::api::rpc::request_builder::RequestBuilder;
use crate::pipelines::executor::WatchNotify;

pub struct FlightClient {
inner: FlightServiceClient<Channel>,
Expand Down Expand Up @@ -107,10 +108,10 @@ impl FlightClient {
fn streaming_receiver(
query_id: &str,
mut streaming: Streaming<FlightData>,
) -> (Arc<Notify>, Receiver<Result<FlightData>>) {
) -> (Arc<WatchNotify>, Receiver<Result<FlightData>>) {
let (tx, rx) = async_channel::bounded(1);
let notify = Arc::new(tokio::sync::Notify::new());
GlobalIORuntime::instance().spawn(query_id, {
let notify = Arc::new(WatchNotify::new());
let fut = {
let notify = notify.clone();
async move {
let mut notified = Box::pin(notify.notified());
Expand Down Expand Up @@ -143,7 +144,10 @@ impl FlightClient {
drop(streaming);
tx.close();
}
});
}
.in_span(Span::enter_with_local_parent(full_name!()));

tokio::spawn(async_backtrace::location!(String::from(query_id)).frame(fut));

(notify, rx)
}
Expand Down Expand Up @@ -179,15 +183,21 @@ impl FlightClient {
}

pub struct FlightReceiver {
notify: Arc<Notify>,
notify: Arc<WatchNotify>,
rx: Receiver<Result<FlightData>>,
}

impl Drop for FlightReceiver {
fn drop(&mut self) {
self.close();
}
}

impl FlightReceiver {
pub fn create(rx: Receiver<Result<FlightData>>) -> FlightReceiver {
FlightReceiver {
rx,
notify: Arc::new(Notify::new()),
notify: Arc::new(WatchNotify::new()),
}
}

Expand Down Expand Up @@ -238,7 +248,7 @@ impl FlightSender {
pub enum FlightExchange {
Dummy,
Receiver {
notify: Arc<Notify>,
notify: Arc<WatchNotify>,
receiver: Receiver<Result<FlightData>>,
},
Sender(Sender<Result<FlightData, Status>>),
Expand All @@ -250,7 +260,7 @@ impl FlightExchange {
}

pub fn create_receiver(
notify: Arc<Notify>,
notify: Arc<WatchNotify>,
receiver: Receiver<Result<FlightData>>,
) -> FlightExchange {
FlightExchange::Receiver { notify, receiver }
Expand Down
8 changes: 4 additions & 4 deletions src/query/service/src/pipelines/executor/executor_tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,28 @@ use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;

use common_base::base::tokio::sync::Notify;
use common_exception::Result;
use parking_lot::Mutex;
use petgraph::prelude::NodeIndex;

use crate::pipelines::executor::ExecutorTask;
use crate::pipelines::executor::ExecutorWorkerContext;
use crate::pipelines::executor::WatchNotify;
use crate::pipelines::executor::WorkersCondvar;
use crate::pipelines::executor::WorkersWaitingStatus;
use crate::pipelines::processors::ProcessorPtr;

pub struct ExecutorTasksQueue {
finished: Arc<AtomicBool>,
finished_notify: Arc<Notify>,
finished_notify: Arc<WatchNotify>,
workers_tasks: Mutex<ExecutorTasks>,
}

impl ExecutorTasksQueue {
pub fn create(workers_size: usize) -> Arc<ExecutorTasksQueue> {
Arc::new(ExecutorTasksQueue {
finished: Arc::new(AtomicBool::new(false)),
finished_notify: Arc::new(Notify::new()),
finished_notify: Arc::new(WatchNotify::new()),
workers_tasks: Mutex::new(ExecutorTasks::create(workers_size)),
})
}
Expand Down Expand Up @@ -183,7 +183,7 @@ impl ExecutorTasksQueue {
}
}

pub fn get_finished_notify(&self) -> Arc<Notify> {
pub fn get_finished_notify(&self) -> Arc<WatchNotify> {
self.finished_notify.clone()
}

Expand Down
1 change: 1 addition & 0 deletions src/query/service/src/pipelines/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod pipeline_pulling_executor;
mod pipeline_pushing_executor;
mod processor_async_task;

pub use common_base::base::WatchNotify;
pub use executor_condvar::WorkersCondvar;
pub use executor_condvar::WorkersWaitingStatus;
pub use executor_graph::RunningGraph;
Expand Down
6 changes: 3 additions & 3 deletions src/query/service/src/pipelines/executor/pipeline_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use std::sync::Arc;
use std::time::Instant;

use common_base::base::tokio;
use common_base::base::tokio::sync::Notify;
use common_base::runtime::catch_unwind;
use common_base::runtime::GlobalIORuntime;
use common_base::runtime::Runtime;
Expand Down Expand Up @@ -46,6 +45,7 @@ use crate::pipelines::executor::ExecutorSettings;
use crate::pipelines::executor::ExecutorTasksQueue;
use crate::pipelines::executor::ExecutorWorkerContext;
use crate::pipelines::executor::RunningGraph;
use crate::pipelines::executor::WatchNotify;
use crate::pipelines::executor::WorkersCondvar;

pub type InitCallback = Box<dyn FnOnce() -> Result<()> + Send + Sync + 'static>;
Expand All @@ -62,7 +62,7 @@ pub struct PipelineExecutor {
on_init_callback: Mutex<Option<InitCallback>>,
on_finished_callback: Mutex<Option<FinishedCallback>>,
settings: ExecutorSettings,
finished_notify: Arc<Notify>,
finished_notify: Arc<WatchNotify>,
finished_error: Mutex<Option<ErrorCode>>,
#[allow(unused)]
lock_guards: Vec<LockGuard>,
Expand Down Expand Up @@ -195,7 +195,7 @@ impl PipelineExecutor {
async_runtime: GlobalIORuntime::instance(),
settings,
finished_error: Mutex::new(None),
finished_notify: Arc::new(Notify::new()),
finished_notify: Arc::new(WatchNotify::new()),
lock_guards,
}))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use std::sync::atomic;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;

use common_base::base::tokio::sync::Notify;
use common_catalog::table_context::TableContext;
use common_exception::Result;
use common_expression::types::DataType;
Expand All @@ -35,6 +34,7 @@ use common_sql::executor::physical_plans::RangeJoinType;
use parking_lot::Mutex;
use parking_lot::RwLock;

use crate::pipelines::executor::WatchNotify;
use crate::pipelines::processors::transforms::range_join::IEJoinState;
use crate::sessions::QueryContext;

Expand All @@ -51,7 +51,7 @@ pub struct RangeJoinState {
pub(crate) other_conditions: Vec<RemoteExpr>,
// Pipeline event related
pub(crate) partition_finished: Mutex<bool>,
pub(crate) finished_notify: Arc<Notify>,
pub(crate) finished_notify: Arc<WatchNotify>,
pub(crate) left_sinker_count: RwLock<usize>,
pub(crate) right_sinker_count: RwLock<usize>,
// Task that need to be executed, pair.0 is left table block, pair.1 is right table block
Expand Down Expand Up @@ -81,7 +81,7 @@ impl RangeJoinState {
// join_type: range_join.join_type.clone(),
other_conditions: range_join.other_conditions.clone(),
partition_finished: Mutex::new(false),
finished_notify: Arc::new(Notify::new()),
finished_notify: Arc::new(WatchNotify::new()),
left_sinker_count: RwLock::new(0),
right_sinker_count: RwLock::new(0),
tasks: RwLock::new(vec![]),
Expand Down
Loading