Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 156 additions & 28 deletions scylla/src/client/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ use crate::cluster::node::CloudEndpoint;
use crate::cluster::node::{InternalKnownNode, KnownNode, NodeRef};
use crate::cluster::{Cluster, ClusterNeatDebug, ClusterState};
use crate::errors::{
BadQuery, ExecutionError, MetadataError, NewSessionError, PagerExecutionError, PrepareError,
RequestAttemptError, RequestError, SchemaAgreementError, TracingError, UseKeyspaceError,
BadQuery, BrokenConnectionError, ExecutionError, MetadataError, NewSessionError,
PagerExecutionError, PrepareError, RequestAttemptError, RequestError, SchemaAgreementError,
TracingError, UseKeyspaceError,
};
use crate::frame::response::result;
use crate::network::tls::TlsProvider;
Expand Down Expand Up @@ -1115,7 +1116,8 @@ impl Session {
};

self.handle_set_keyspace_response(&response).await?;
self.handle_auto_await_schema_agreement(&response).await?;
self.handle_auto_await_schema_agreement(&response, coordinator.node().host_id)
.await?;
Comment on lines +1119 to +1120
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exposing coordinator turned out to be useful for driver internals as well 🎉

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, without it I would probably need to add some plumbing to pass through the id.


let (result, paging_state_response) =
response.into_query_result_and_paging_state(coordinator)?;
Expand Down Expand Up @@ -1143,10 +1145,12 @@ impl Session {
async fn handle_auto_await_schema_agreement(
&self,
response: &NonErrorQueryResponse,
coordinator_id: Uuid,
) -> Result<(), ExecutionError> {
if self.schema_agreement_automatic_waiting {
if response.as_schema_change().is_some() {
self.await_schema_agreement().await?;
self.await_schema_agreement_with_required_node(Some(coordinator_id))
.await?;
}

if self.refresh_metadata_on_auto_schema_agreement
Expand Down Expand Up @@ -1488,7 +1492,8 @@ impl Session {
};

self.handle_set_keyspace_response(&response).await?;
self.handle_auto_await_schema_agreement(&response).await?;
self.handle_auto_await_schema_agreement(&response, coordinator.node().host_id)
.await?;

let (result, paging_state_response) =
response.into_query_result_and_paging_state(coordinator)?;
Expand Down Expand Up @@ -2159,10 +2164,19 @@ impl Session {
///
/// Issues an agreement check each `Session::schema_agreement_interval`.
/// Loops indefinitely until the agreement is reached.
async fn await_schema_agreement_indefinitely(&self) -> Result<Uuid, SchemaAgreementError> {
///
/// If `required_node` is Some, only returns Ok if this node successfully
/// returned its schema version during the agreement process.
async fn await_schema_agreement_indefinitely(
&self,
required_node: Option<Uuid>,
) -> Result<Uuid, SchemaAgreementError> {
loop {
tokio::time::sleep(self.schema_agreement_interval).await;
if let Some(agreed_version) = self.check_schema_agreement().await? {
if let Some(agreed_version) = self
.check_schema_agreement_with_required_node(required_node)
.await?
{
return Ok(agreed_version);
}
}
Expand All @@ -2176,7 +2190,29 @@ impl Session {
pub async fn await_schema_agreement(&self) -> Result<Uuid, SchemaAgreementError> {
timeout(
self.schema_agreement_timeout,
self.await_schema_agreement_indefinitely(),
self.await_schema_agreement_indefinitely(None),
)
.await
.unwrap_or(Err(SchemaAgreementError::Timeout(
self.schema_agreement_timeout,
)))
}

/// Awaits schema agreement among all reachable nodes.
///
/// Issues an agreement check each `Session::schema_agreement_interval`.
/// If agreement is not reached in `Session::schema_agreement_timeout`,
/// `SchemaAgreementError::Timeout` is returned.
///
/// If `required_node` is Some, only returns Ok if this node successfully
/// returned its schema version during the agreement process.
async fn await_schema_agreement_with_required_node(
&self,
required_node: Option<Uuid>,
) -> Result<Uuid, SchemaAgreementError> {
timeout(
self.schema_agreement_timeout,
self.await_schema_agreement_indefinitely(required_node),
)
.await
.unwrap_or(Err(SchemaAgreementError::Timeout(
Expand All @@ -2188,37 +2224,123 @@ impl Session {
///
/// If so, returns that agreed upon version.
pub async fn check_schema_agreement(&self) -> Result<Option<Uuid>, SchemaAgreementError> {
self.check_schema_agreement_with_required_node(None).await
}

/// Checks if all reachable nodes have the same schema version.
/// If so, returns that agreed upon version.
///
/// If `required_node` is Some, only returns Ok if this node successfully
/// returned its schema version.
async fn check_schema_agreement_with_required_node(
&self,
required_node: Option<Uuid>,
) -> Result<Option<Uuid>, SchemaAgreementError> {
let cluster_state = self.get_cluster_state();
// The iterator is guaranteed to be nonempty.
let per_node_connections = cluster_state.iter_working_connections_per_node()?;

// Therefore, this iterator is guaranteed to be nonempty, too.
let handles = per_node_connections.map(|connections_to_node| async move {
// Iterate over connections to the node. Fail if fetching schema version failed on all connections.
// Else, return the first fetched schema version, because all shards have the same schema version.
let mut first_err = None;
for connection in connections_to_node {
match connection.fetch_schema_version().await {
Ok(schema_version) => return Ok(schema_version),
Err(err) => {
if first_err.is_none() {
first_err = Some(err);
}
}
}
}
// The iterator was guaranteed to be nonempty, so there must have been at least one error.
Err(first_err.unwrap())
let handles = per_node_connections.map(|(host_id, pool)| async move {
(host_id, Session::read_node_schema_version(pool).await)
});
// Hence, this is nonempty, too.
let versions = try_join_all(handles).await?;
let versions_results = join_all(handles).await;

// Verify that required host is present, and returned success.
if let Some(required_node) = required_node {
match versions_results
.iter()
.find(|(host_id, _)| *host_id == required_node)
{
Some((_, Ok(SchemaNodeResult::Success(_version)))) => (),
// For other connections we can ignore Broken error, but for required
// host we need an actual schema version.
Some((_, Ok(SchemaNodeResult::BrokenConnection(e)))) => {
return Err(SchemaAgreementError::RequestError(
RequestAttemptError::BrokenConnectionError(e.clone()),
))
}
Some((_, Err(e))) => return Err(e.clone()),
None => return Err(SchemaAgreementError::RequiredHostAbsent(required_node)),
}
}

// Now we no longer need all the errors. We can return if there is
// irrecoverable one, and collect the Ok values otherwise.
let versions_results: Vec<_> = versions_results
.into_iter()
.map(|(_, result)| result)
.try_collect()?;

// unwrap is safe because iterator is still not empty.
let local_version = match versions_results
.iter()
.find_or_first(|r| matches!(r, SchemaNodeResult::Success(_)))
.unwrap()
{
SchemaNodeResult::Success(v) => *v,
SchemaNodeResult::BrokenConnection(err) => {
// There are only broken connection errors. Nothing better to do
// than to return an error.
return Err(SchemaAgreementError::RequestError(
RequestAttemptError::BrokenConnectionError(err.clone()),
));
Comment on lines +2286 to +2288
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commit: Schema agreement: Ignore BrokenConnectionError

Do we need this clone?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, we iterate over &SchemaNodeResult, so we probably need it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. Maybe it would be somehow possible to avoid it by doing some magic on iterators, but it would most likely complicate the code even more, which I'd really like to avoid.
This clone only happens in error condition that should be exceedingly rare (all connections to all nodes are broken) I don't see the need to optimize this case.

}
};

// Therefore, taking the first element is safe.
let local_version: Uuid = versions[0];
let in_agreement = versions.into_iter().all(|v| v == local_version);
let in_agreement = versions_results
.into_iter()
.filter_map(|v_r| match v_r {
SchemaNodeResult::Success(v) => Some(v),
SchemaNodeResult::BrokenConnection(_) => None,
})
.all(|v| v == local_version);
Ok(in_agreement.then_some(local_version))
}

/// Iterate over connections to the node.
/// If fetching succeeds on some connection, return first fetched schema version,
/// as Ok(SchemaNodeResult::Success(...)).
/// Otherwise it means there are only errors:
/// - If, and only if, all connections returned ConnectionBrokenError, first such error will be returned,
/// as Ok(SchemaNodeResult::BrokenConnection(...)).
/// - Otherwise there is some other type of error on some connection. First such error will be returned as Err(...).
///
/// `connections_to_node` iterator must be non-empty!
async fn read_node_schema_version(
connections_to_node: impl Iterator<Item = Arc<Connection>>,
) -> Result<SchemaNodeResult, SchemaAgreementError> {
let mut first_broken_connection_err: Option<BrokenConnectionError> = None;
let mut first_unignorable_err: Option<SchemaAgreementError> = None;
for connection in connections_to_node {
match connection.fetch_schema_version().await {
Ok(schema_version) => return Ok(SchemaNodeResult::Success(schema_version)),
Err(SchemaAgreementError::RequestError(
RequestAttemptError::BrokenConnectionError(conn_err),
)) => {
if first_broken_connection_err.is_none() {
first_broken_connection_err = Some(conn_err);
}
}
Err(err) => {
if first_unignorable_err.is_none() {
first_unignorable_err = Some(err);
}
}
}
}
// The iterator was guaranteed to be nonempty, so there must have been at least one error.
// It means at least one of `first_broken_connection_err` and `first_unrecoverable_err` is Some.
if let Some(err) = first_unignorable_err {
return Err(err);
}

Ok(SchemaNodeResult::BrokenConnection(
first_broken_connection_err.unwrap(),
))
}

/// Retrieves the handle to execution profile that is used by this session
/// by default, i.e. when an executed statement does not define its own handle.
pub fn get_default_execution_profile_handle(&self) -> &ExecutionProfileHandle {
Expand Down Expand Up @@ -2300,3 +2422,9 @@ impl ExecuteRequestContext<'_> {
.log_attempt_error(*attempt_id, error, retry_decision);
}
}

#[derive(Debug)]
enum SchemaNodeResult {
Success(Uuid),
BrokenConnection(BrokenConnectionError),
}
28 changes: 16 additions & 12 deletions scylla/src/cluster/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,17 +324,21 @@ impl ClusterState {
/// Internal iterator iterates over working connections to all shards of given node.
pub(crate) fn iter_working_connections_per_node(
&self,
) -> Result<impl Iterator<Item = impl Iterator<Item = Arc<Connection>>> + '_, ConnectionPoolError>
{
) -> Result<
impl Iterator<Item = (Uuid, impl Iterator<Item = Arc<Connection>>)> + '_,
ConnectionPoolError,
> {
// The returned iterator is nonempty by nonemptiness invariant of `self.known_peers`.
assert!(!self.known_peers.is_empty());
let nodes_iter = self.known_peers.values();
let mut connection_pool_per_node_iter =
nodes_iter.map(|node| node.get_working_connections());
let mut connection_pool_per_node_iter = nodes_iter.map(|node| {
node.get_working_connections()
.map(|pool| (node.host_id, pool))
});

// First we try to find the first working pool of connections.
// If none is found, return error.
let first_working_pool_or_error: Result<Vec<Arc<Connection>>, ConnectionPoolError> =
let first_working_pool_or_error: Result<(Uuid, Vec<Arc<Connection>>), ConnectionPoolError> =
connection_pool_per_node_iter
.by_ref()
.find_or_first(Result::is_ok)
Expand All @@ -344,19 +348,19 @@ impl ClusterState {
// 1. either consumed the whole iterator without success and got the first error,
// in which case we propagate it;
// 2. or found the first working pool of connections.
let first_working_pool: Vec<Arc<Connection>> = first_working_pool_or_error?;
let first_working_pool: (Uuid, Vec<Arc<Connection>>) = first_working_pool_or_error?;

// We retrieve connection pools for remaining nodes (those that are left in the iterator
// once the first working pool has been found).
let remaining_pools_iter = connection_pool_per_node_iter;
// Errors (non-working pools) are filtered out.
let remaining_working_pools_iter = remaining_pools_iter.filter_map(Result::ok);
// Pools are made iterators, so now we have `impl Iterator<Item = impl Iterator<Item = Arc<Connection>>>`.
let remaining_working_per_node_connections_iter =
remaining_working_pools_iter.map(IntoIterator::into_iter);

Ok(std::iter::once(first_working_pool.into_iter())
.chain(remaining_working_per_node_connections_iter))
// First pool is chained with the rest.
// Then, pools are made iterators, so now we have `impl Iterator<Item = (Uuid, impl Iterator<Item = Arc<Connection>>)>`.
Ok(std::iter::once(first_working_pool)
.chain(remaining_working_pools_iter)
.map(|(host_id, pool)| (host_id, IntoIterator::into_iter(pool))))
// By an invariant `self.known_peers` is nonempty, so the returned iterator
// is nonempty, too.
}
Expand All @@ -366,7 +370,7 @@ impl ClusterState {
&self,
) -> Result<impl Iterator<Item = Arc<Connection>> + '_, ConnectionPoolError> {
self.iter_working_connections_per_node()
.map(|iter| iter.flatten())
.map(|outer_iter| outer_iter.flat_map(|(_, inner_iter)| inner_iter))
}

/// Returns nonempty iterator of working connections to all nodes.
Expand Down
7 changes: 7 additions & 0 deletions scylla/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::net::{AddrParseError, IpAddr, SocketAddr};
use std::num::ParseIntError;
use std::sync::Arc;
use thiserror::Error;
use uuid::Uuid;

use crate::frame::response;

Expand Down Expand Up @@ -209,6 +210,12 @@ pub enum SchemaAgreementError {
/// Schema agreement timed out.
#[error("Schema agreement exceeded {}ms", std::time::Duration::as_millis(.0))]
Timeout(std::time::Duration),

#[error(
"Host with id {} required for schema agreement is not present in connection pool",
0
)]
RequiredHostAbsent(Uuid),
}

/// An error that occurred during tracing info fetch.
Expand Down
1 change: 1 addition & 0 deletions scylla/tests/integration/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod history;
mod new_session;
mod pager;
mod retries;
mod schema_agreement;
mod self_identity;
mod tracing;
mod use_keyspace;
Loading