Skip to content

Commit

Permalink
ExeUnit: fix run -> term -> PID response race condition (#1629)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfranciszkiewicz authored Oct 4, 2021
1 parent df9d13a commit 7538c12
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 45 deletions.
150 changes: 109 additions & 41 deletions exe-unit/src/runtime/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ use std::future::Future;
use std::ops::Not;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::task::{Context, Poll, Waker};

use futures::channel::mpsc::SendError;
use futures::channel::oneshot;
use futures::future::{BoxFuture, Fuse, Shared};
use futures::future::{BoxFuture, Shared};
use futures::{FutureExt, SinkExt, TryFutureExt};

use crate::message::{CommandContext, RuntimeEvent};
Expand All @@ -17,46 +17,68 @@ use ya_runtime_api::server::{ProcessStatus, RuntimeStatus};

#[derive(Default, Clone)]
pub(crate) struct EventMonitor {
processes: Arc<Mutex<HashMap<u64, Channel>>>,
fallback: Arc<Mutex<Option<Channel>>>,
inner: Arc<Mutex<Inner>>,
}

#[derive(Default)]
struct Inner {
next_process: Option<Channel>,
processes: HashMap<u64, Channel>,
fallback: Option<Channel>,
}

impl EventMonitor {
pub fn any_process<'a>(&mut self, ctx: CommandContext) -> Handle<'a> {
let mut inner = self.fallback.lock().unwrap();
inner.replace(Channel::simple(ctx));
let mut inner = self.inner.lock().unwrap();
inner.fallback.replace(Channel::fallback(ctx.clone()));

Handle::Fallback {
monitor: self.clone(),
}
Handle::Fallback {}
}

pub fn process<'a>(&mut self, ctx: CommandContext, pid: u64) -> Handle<'a> {
let entry = Channel::new(ctx);
let done_rx = entry.done_rx().unwrap();
pub fn next_process<'a>(&mut self, ctx: CommandContext) -> Handle<'a> {
let mut inner = self.inner.lock().unwrap();
let channel = Channel::new(ctx, Default::default());
let handle = Handle::process(&self, &channel);
inner.next_process.replace(channel);

let mut inner = self.processes.lock().unwrap();
inner.insert(pid, entry);
handle
}

Handle::Process {
monitor: self.clone(),
pid,
done_rx,
}
#[allow(unused)]
pub fn process<'a>(&mut self, ctx: CommandContext, pid: u64) -> Handle<'a> {
let mut inner = self.inner.lock().unwrap();
let channel = Channel::new(ctx, pid);
let handle = Handle::process(&self, &channel);
inner.processes.insert(pid, channel);

handle
}
}

impl ya_runtime_api::server::RuntimeHandler for EventMonitor {
fn on_process_status<'a>(&self, status: ProcessStatus) -> BoxFuture<'a, ()> {
let running = status.running;
let (mut ctx, done_tx) = {
let mut proc_map = self.processes.lock().unwrap();
let mut fallback = self.fallback.lock().unwrap();
let mut inner = self.inner.lock().unwrap();

let entry = match proc_map.get_mut(&status.pid).or(fallback.as_mut()) {
if !inner.processes.contains_key(&status.pid) {
if let Some(channel) = inner.next_process.take() {
channel.waker.lock().unwrap().pid.replace(status.pid);
inner.processes.insert(status.pid, channel);
}
}

let entry = match inner.processes.get_mut(&status.pid) {
Some(entry) => entry,
None => return futures::future::ready(()).boxed(),
None => match inner.fallback.as_mut() {
Some(entry) => entry,
None => return futures::future::ready(()).boxed(),
},
};
let done_tx = status.running.not().then(|| entry.done_tx()).flatten();

let done_tx = running.not().then(|| entry.done_tx()).flatten();
entry.wake();

(entry.ctx.clone(), done_tx)
};

Expand Down Expand Up @@ -85,8 +107,8 @@ impl ya_runtime_api::server::RuntimeHandler for EventMonitor {
use ya_runtime_api::server::proto::response::runtime_status::Kind;

let mut ctx = {
let channel = self.fallback.lock().unwrap();
match channel.as_ref() {
let inner = self.inner.lock().unwrap();
match inner.fallback.as_ref() {
Some(c) => c.ctx.clone(),
None => return futures::future::ready(()).boxed(),
}
Expand Down Expand Up @@ -122,20 +144,37 @@ impl ya_runtime_api::server::RuntimeHandler for EventMonitor {
pub(crate) enum Handle<'a> {
Process {
monitor: EventMonitor,
pid: u64,
done_rx: BoxFuture<'a, Result<i32, ()>>,
waker: Arc<Mutex<ProcessWaker>>,
},
Fallback {
#[allow(unused)]
monitor: EventMonitor,
},
Fallback {},
}

#[derive(Default)]
pub(crate) struct ProcessWaker {
pid: Option<u64>,
waker: Option<Waker>,
}

impl<'a> Handle<'a> {
fn process(monitor: &EventMonitor, channel: &Channel) -> Self {
Handle::Process {
monitor: monitor.clone(),
done_rx: channel.done_rx().unwrap(),
waker: channel.waker.clone(),
}
}
}

impl<'a> Drop for Handle<'a> {
fn drop(&mut self) {
match self {
Handle::Process { monitor, pid, .. } => {
monitor.processes.lock().unwrap().remove(pid);
Handle::Process { monitor, waker, .. } => {
if let Some(pid) = { waker.lock().unwrap().pid } {
let mut inner = monitor.inner.lock().unwrap();
inner.processes.remove(&pid);
inner.next_process.take();
}
}
_ => {
// ignore
Expand All @@ -149,31 +188,60 @@ impl<'a> Future for Handle<'a> {

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
Handle::Process { done_rx, .. } => match Pin::new(done_rx).poll(cx) {
Handle::Process { done_rx, waker, .. } => match Pin::new(done_rx).poll(cx) {
Poll::Ready(Ok(c)) => Poll::Ready(c),
Poll::Ready(Err(_)) => Poll::Ready(1),
Poll::Pending => Poll::Pending,
Poll::Pending => {
let mut guard = waker.lock().unwrap();

if let Some(waker) = &guard.waker {
let cx_waker = cx.waker();
if !waker.will_wake(cx_waker) {
guard.waker.replace(cx_waker.clone());
}
} else {
guard.waker.replace(cx.waker().clone());
}

Poll::Pending
}
},
Handle::Fallback { .. } => Poll::Ready(0),
}
}
}

struct Channel {
pub(crate) struct Channel {
ctx: CommandContext,
done: Option<DoneChannel>,
waker: Arc<Mutex<ProcessWaker>>,
}

impl Channel {
fn new(ctx: CommandContext) -> Self {
fn new(ctx: CommandContext, pid: u64) -> Self {
Channel {
ctx,
done: Some(Default::default()),
waker: Arc::new(Mutex::new(ProcessWaker {
pid: Some(pid),
waker: None,
})),
}
}

fn fallback(ctx: CommandContext) -> Self {
Channel {
ctx,
done: None,
waker: Default::default(),
}
}

fn simple(ctx: CommandContext) -> Self {
Channel { ctx, done: None }
fn wake(&self) {
let guard = self.waker.lock().unwrap();
if let Some(waker) = &guard.waker {
waker.wake_by_ref();
}
}

fn done_tx(&mut self) -> Option<oneshot::Sender<i32>> {
Expand All @@ -189,15 +257,15 @@ impl Channel {

struct DoneChannel {
tx: Option<oneshot::Sender<i32>>,
rx: Shared<Fuse<oneshot::Receiver<i32>>>,
rx: Shared<oneshot::Receiver<i32>>,
}

impl Default for DoneChannel {
fn default() -> Self {
let (tx, rx) = oneshot::channel();
Self {
tx: Some(tx),
rx: rx.fuse().shared(),
rx: rx.shared(),
}
}
}
8 changes: 4 additions & 4 deletions exe-unit/src/runtime/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,12 @@ impl RuntimeProcess {
run_process.bin = entry_point;
run_process.args = args;

let process = match service.run_process(run_process).await {
Ok(result) => result,
Err(error) => return Err(Error::RuntimeError(format!("{:?}", error))),
let handle = monitor.next_process(ctx);
if let Err(error) = service.run_process(run_process).await {
return Err(Error::RuntimeError(format!("{:?}", error)));
};

Ok(monitor.process(ctx, process.pid).await)
Ok(handle.await)
};

async move {
Expand Down

0 comments on commit 7538c12

Please sign in to comment.