Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions plugins/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,15 @@ impl<'a> ServiceGenerator<'a> {
)| {
quote! {
#( #attrs )*
async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output;
fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> impl ::std::future::Future<Output = #output> + ::core::marker::Send;
}
},
);

let stub_doc = format!("The stub trait for service [`{service_ident}`].");
quote! {
#( #attrs )*
#vis trait #service_ident: ::core::marker::Sized {
#vis trait #service_ident: ::core::marker::Sized + ::core::marker::Send {
#( #rpc_fns )*

/// Returns a serving function to use with
Expand Down
18 changes: 13 additions & 5 deletions tarpc/src/client/stub.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Provides a Stub trait, implemented by types that can call remote services.

use std::future::Future;

use crate::{
client::{Channel, RpcError},
context,
Expand All @@ -16,21 +18,25 @@ mod mock;
/// A connection to a remote service.
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
#[allow(async_fn_in_trait)]
pub trait Stub {
pub trait Stub: Send {
/// The service request type.
type Req: RequestName;

/// The service response type.
type Resp;

/// Calls a remote service.
async fn call(&self, ctx: context::Context, request: Self::Req)
-> Result<Self::Resp, RpcError>;
fn call(
&self,
ctx: context::Context,
request: Self::Req,
) -> impl Future<Output = Result<Self::Resp, RpcError>> + Send;
}

impl<Req, Resp> Stub for Channel<Req, Resp>
where
Req: RequestName,
Req: RequestName + Send,
Resp: Send,
{
type Req = Req;
type Resp = Resp;
Expand All @@ -42,7 +48,9 @@ where

impl<S> Stub for S
where
S: Serve + Clone,
S: Serve + Clone + Send + Sync,
S::Req: Send + Sync,
S::Resp: Send + Sync,
{
type Req = S::Req;
type Resp = S::Resp;
Expand Down
17 changes: 9 additions & 8 deletions tarpc/src/client/stub/load_balance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ mod round_robin {

impl<Stub> stub::Stub for RoundRobin<Stub>
where
Stub: stub::Stub,
Stub: stub::Stub + Sync,
Stub::Req: Send,
{
type Req = Stub::Req;
type Resp = Stub::Resp;
Expand Down Expand Up @@ -110,9 +111,9 @@ mod consistent_hash {

impl<Stub, S> stub::Stub for ConsistentHash<Stub, S>
where
Stub: stub::Stub,
Stub::Req: Hash,
S: BuildHasher,
Stub: stub::Stub + Sync,
Stub::Req: Hash + Send,
S: BuildHasher + Send + Sync,
{
type Req = Stub::Req;
type Resp = Stub::Resp;
Expand Down Expand Up @@ -188,7 +189,7 @@ mod consistent_hash {
use std::{
collections::HashMap,
hash::{BuildHasher, Hash, Hasher},
rc::Rc,
sync::Arc,
};

#[tokio::test]
Expand Down Expand Up @@ -230,11 +231,11 @@ mod consistent_hash {
}

struct FakeHasherBuilder {
recorded_hashes: Rc<HashMap<Vec<u8>, u64>>,
recorded_hashes: Arc<HashMap<Vec<u8>, u64>>,
}

struct FakeHasher {
recorded_hashes: Rc<HashMap<Vec<u8>, u64>>,
recorded_hashes: Arc<HashMap<Vec<u8>, u64>>,
output: u64,
}

Expand All @@ -258,7 +259,7 @@ mod consistent_hash {
recorded_hashes.insert(recorder.0, fake_hash);
}
Self {
recorded_hashes: Rc::new(recorded_hashes),
recorded_hashes: Arc::new(recorded_hashes),
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions tarpc/src/client/stub/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ where

impl<Req, Resp> Stub for Mock<Req, Resp>
where
Req: Eq + Hash + RequestName,
Resp: Clone,
Req: Eq + Hash + RequestName + Send + Sync,
Resp: Clone + Send + Sync,
{
type Req = Req;
type Resp = Resp;
Expand Down
6 changes: 3 additions & 3 deletions tarpc/src/client/stub/retry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use std::sync::Arc;

impl<Stub, Req, F> stub::Stub for Retry<F, Stub>
where
Req: RequestName,
Stub: stub::Stub<Req = Arc<Req>>,
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool,
Req: RequestName + Send + Sync,
Stub: Sync + stub::Stub<Req = Arc<Req>>,
F: Send + Sync + Fn(&Result<Stub::Resp, RpcError>, u32) -> bool,
{
type Req = Req;
type Resp = Stub::Resp;
Expand Down
9 changes: 3 additions & 6 deletions tarpc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ use std::{any::Any, error::Error, io, sync::Arc, time::Instant};
/// A message from a client to a server.
#[derive(Debug)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum ClientMessage<T> {
/// A request initiated by a user. The server responds to a request by invoking a
/// service-provided request handler. The handler completes with a [`response`](Response), which
Expand All @@ -280,7 +279,6 @@ pub enum ClientMessage<T> {

/// A request from a client to a server.
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Request<T> {
/// Trace context, deadline, and other cross-cutting concerns.
Expand All @@ -294,14 +292,14 @@ pub struct Request<T> {
/// Implemented by the request types generated by tarpc::service.
pub trait RequestName {
/// The name of a request.
fn name(&self) -> &'static str;
fn name(&self) -> &str;
}

impl<Req> RequestName for Arc<Req>
where
Req: RequestName,
{
fn name(&self) -> &'static str {
fn name(&self) -> &str {
self.as_ref().name()
}
}
Expand All @@ -310,7 +308,7 @@ impl<Req> RequestName for Box<Req>
where
Req: RequestName,
{
fn name(&self) -> &'static str {
fn name(&self) -> &str {
self.as_ref().name()
}
}
Expand Down Expand Up @@ -360,7 +358,6 @@ impl RequestName for u64 {

/// A response from a server to a client.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Response<T> {
/// The ID of the request being responded to.
Expand Down
20 changes: 12 additions & 8 deletions tarpc/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,19 @@ impl Config {

/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
#[allow(async_fn_in_trait)]
pub trait Serve {
pub trait Serve: Send {
/// Type of request.
type Req: RequestName;

/// Type of response.
type Resp;

/// Responds to a single request.
async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError>;
fn serve(
self,
ctx: context::Context,
req: Self::Req,
) -> impl Future<Output = Result<Self::Resp, ServerError>> + Send;
}

/// A Serve wrapper around a Fn.
Expand Down Expand Up @@ -115,9 +119,9 @@ where

impl<Req, Resp, Fut, F> Serve for ServeFn<Req, Resp, F>
where
Req: RequestName,
F: FnOnce(context::Context, Req) -> Fut,
Fut: Future<Output = Result<Resp, ServerError>>,
Req: RequestName + Send,
F: FnOnce(context::Context, Req) -> Fut + Send,
Fut: Future<Output = Result<Resp, ServerError>> + Send,
{
type Req = Req;
type Resp = Resp;
Expand Down Expand Up @@ -1046,7 +1050,7 @@ mod tests {
#[tokio::test]
async fn serve_before_mutates_context() -> anyhow::Result<()> {
struct SetDeadline(Instant);
impl<Req> BeforeRequest<Req> for SetDeadline {
impl<Req: Send + Sync> BeforeRequest<Req> for SetDeadline {
async fn before(
&mut self,
ctx: &mut context::Context,
Expand Down Expand Up @@ -1085,7 +1089,7 @@ mod tests {
}
}
}
impl<Req> BeforeRequest<Req> for PrintLatency {
impl<Req: Send + Sync> BeforeRequest<Req> for PrintLatency {
async fn before(
&mut self,
_: &mut context::Context,
Expand All @@ -1095,7 +1099,7 @@ mod tests {
Ok(())
}
}
impl<Resp> AfterRequest<Resp> for PrintLatency {
impl<Resp: Send> AfterRequest<Resp> for PrintLatency {
async fn after(&mut self, _: &mut context::Context, _: &mut Result<Resp, ServerError>) {
tracing::info!("Elapsed: {:?}", self.start.elapsed());
}
Expand Down
11 changes: 7 additions & 4 deletions tarpc/src/server/request_hook/after.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@ use futures::prelude::*;

/// A hook that runs after request execution.
#[allow(async_fn_in_trait)]
pub trait AfterRequest<Resp> {
pub trait AfterRequest<Resp>: Send {
/// The function that is called after request execution.
///
/// The hook can modify the request context and the response.
async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result<Resp, ServerError>);
fn after(&mut self, ctx: &mut context::Context, resp: &mut Result<Resp, ServerError>) -> impl Future<Output = ()> + Send;
}

impl<F, Fut, Resp> AfterRequest<Resp> for F
where
F: FnMut(&mut context::Context, &mut Result<Resp, ServerError>) -> Fut,
Fut: Future<Output = ()>,
F: Send + FnMut(&mut context::Context, &mut Result<Resp, ServerError>) -> Fut,
Fut: Send + Future<Output = ()>,
Resp: Send,
{
async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result<Resp, ServerError>) {
self(ctx, resp).await
Expand Down Expand Up @@ -53,6 +54,8 @@ impl<Serv, Hook> Serve for ServeThenHook<Serv, Hook>
where
Serv: Serve,
Hook: AfterRequest<Serv::Resp>,
Serv::Req: Send,
Serv::Resp: Send,
{
type Req = Serv::Req;
type Resp = Serv::Resp;
Expand Down
Loading