diff --git a/Cargo.lock b/Cargo.lock index 8766cb5d..44f9e4a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -741,8 +741,10 @@ version = "0.8.2" dependencies = [ "byteordered", "bytes", + "dicom-core", "dicom-dictionary-std", "dicom-encoding", + "dicom-object", "dicom-transfer-syntax-registry", "matches", "rstest", diff --git a/findscu/src/main.rs b/findscu/src/main.rs index d520cafd..009329d6 100644 --- a/findscu/src/main.rs +++ b/findscu/src/main.rs @@ -79,7 +79,7 @@ fn main() { enum Error { /// Could not initialize SCU InitScu { - source: dicom_ul::association::client::Error, + source: dicom_ul::association::Error, }, /// Could not construct DICOM command diff --git a/scpproxy/src/main.rs b/scpproxy/src/main.rs index f66c4067..1a8fb5fc 100644 --- a/scpproxy/src/main.rs +++ b/scpproxy/src/main.rs @@ -1,6 +1,6 @@ use bytes::BytesMut; use clap::{crate_version, value_parser, Arg, ArgAction, Command}; -use dicom_ul::association::client::get_client_pdu; +use dicom_ul::association::read_pdu_from_wire; use dicom_ul::pdu::writer::write_pdu; use dicom_ul::pdu::Pdu; use snafu::{Backtrace, OptionExt, Report, ResultExt, Snafu, Whatever}; @@ -65,7 +65,7 @@ pub enum ThreadMessage { }, ReadErr { from: ProviderType, - err: dicom_ul::association::client::Error, + err: dicom_ul::association::Error, }, WriteErr { from: ProviderType, @@ -99,7 +99,7 @@ fn run( let message_tx = message_tx.clone(); scu_reader_thread = thread::spawn(move || { loop { - match get_client_pdu(&mut reader, &mut buf, max_pdu_length, strict) { + match read_pdu_from_wire(&mut reader, &mut buf, max_pdu_length, strict) { Ok(pdu) => { message_tx .send(ThreadMessage::SendPdu { @@ -108,7 +108,7 @@ fn run( }) .context(SendMessageSnafu)?; } - Err(dicom_ul::association::client::Error::ConnectionClosed) => { + Err(dicom_ul::association::Error::ConnectionClosed) => { message_tx .send(ThreadMessage::Shutdown { initiator: ProviderType::Scu, @@ -137,7 +137,7 @@ fn run( let mut buf = BytesMut::with_capacity(max_pdu_length as usize); scp_reader_thread = thread::spawn(move || { loop { - match get_client_pdu(&mut reader, &mut buf, max_pdu_length, strict) { + match read_pdu_from_wire(&mut reader, &mut buf, max_pdu_length, strict) { Ok(pdu) => { message_tx .send(ThreadMessage::SendPdu { @@ -146,7 +146,7 @@ fn run( }) .context(SendMessageSnafu)?; } - Err(dicom_ul::association::client::Error::ConnectionClosed) => { + Err(dicom_ul::association::Error::ConnectionClosed) => { message_tx .send(ThreadMessage::Shutdown { initiator: ProviderType::Scp, diff --git a/storescp/src/store_async.rs b/storescp/src/store_async.rs index 6c380b89..57daad50 100644 --- a/storescp/src/store_async.rs +++ b/storescp/src/store_async.rs @@ -244,7 +244,7 @@ pub async fn run_store_async( _ => {} } } - Err(err @ dicom_ul::association::server::Error::Receive { .. }) => { + Err(err @ dicom_ul::association::Error::ReceivePdu { .. }) => { if verbose { info!("{}", Report::from_error(err)); } else { diff --git a/storescp/src/store_sync.rs b/storescp/src/store_sync.rs index a16dd2c6..45456437 100644 --- a/storescp/src/store_sync.rs +++ b/storescp/src/store_sync.rs @@ -241,7 +241,7 @@ pub fn run_store_sync(scu_stream: TcpStream, args: &App) -> Result<(), Whatever> _ => {} } } - Err(err @ dicom_ul::association::server::Error::Receive { .. }) => { + Err(err @ dicom_ul::association::Error::ReceivePdu { .. }) => { if verbose { info!("{}", Report::from_error(err)); } else { diff --git a/storescu/src/main.rs b/storescu/src/main.rs index 77a5372e..cbb0664d 100644 --- a/storescu/src/main.rs +++ b/storescu/src/main.rs @@ -119,7 +119,7 @@ struct DicomFile { enum Error { /// Could not initialize SCU Scu { - source: Box, + source: Box, }, /// Could not construct DICOM command diff --git a/ul/Cargo.toml b/ul/Cargo.toml index 544381d9..d64d0b65 100644 --- a/ul/Cargo.toml +++ b/ul/Cargo.toml @@ -35,6 +35,8 @@ dicom-dictionary-std = { path = "../dictionary-std" } matches = "0.1.8" rstest = "0.25" tokio = { version = "^1.38", features = ["io-util", "macros", "net", "rt", "rt-multi-thread"] } +dicom-core = { path = '../core' } +dicom-object = { path = '../object' } [features] async = ["dep:tokio"] diff --git a/ul/src/association/client.rs b/ul/src/association/client.rs index 2846790e..e5989fbc 100644 --- a/ul/src/association/client.rs +++ b/ul/src/association/client.rs @@ -8,176 +8,25 @@ use bytes::BytesMut; use std::{ borrow::Cow, convert::TryInto, - io::{BufRead, BufReader, Cursor, Read, Write}, + io::Write, net::{TcpStream, ToSocketAddrs}, time::Duration, }; use crate::{ - pdu::{ - read_pdu, write_pdu, AbortRQSource, AssociationAC, AssociationRJ, AssociationRQ, Pdu, - PresentationContextProposed, PresentationContextNegotiated, PresentationContextResultReason, - ReadPduSnafu, UserIdentity, UserIdentityType, UserVariableItem, DEFAULT_MAX_PDU, - MAXIMUM_PDU_SIZE, - }, - AeAddr, IMPLEMENTATION_CLASS_UID, IMPLEMENTATION_VERSION_NAME, + AeAddr, IMPLEMENTATION_CLASS_UID, IMPLEMENTATION_VERSION_NAME, association::{ + ConnectSnafu, MissingAbstractSyntaxSnafu, NegotiatedOptions, NoAcceptedPresentationContextsSnafu, ProtocolVersionMismatchSnafu, RejectedSnafu, SendPduSnafu, SendTooLongPduSnafu, SetReadTimeoutSnafu, SetWriteTimeoutSnafu, ToAddressSnafu, UnexpectedPduSnafu, UnknownPduSnafu, WireSendSnafu, read_pdu_from_wire + }, pdu::{ + AbortRQSource, AssociationAC, AssociationRQ, DEFAULT_MAX_PDU, MAXIMUM_PDU_SIZE, Pdu, PresentationContextNegotiated, PresentationContextProposed, PresentationContextResultReason, UserIdentity, UserIdentityType, UserVariableItem, write_pdu + } }; -use snafu::{ensure, Backtrace, ResultExt, Snafu}; - -use bytes::Buf; +use snafu::{ensure, ResultExt}; use super::{ pdata::{PDataReader, PDataWriter}, - uid::trim_uid, + uid::trim_uid, Result }; -#[derive(Debug, Snafu)] -#[non_exhaustive] -pub enum Error { - /// missing abstract syntax to begin negotiation - MissingAbstractSyntax { backtrace: Backtrace }, - - /// could not convert to socket address - ToAddress { - source: std::io::Error, - backtrace: Backtrace, - }, - - /// could not connect to server - Connect { - source: std::io::Error, - backtrace: Backtrace, - }, - - /// Could not set tcp read timeout - SetReadTimeout { - source: std::io::Error, - backtrace: Backtrace, - }, - - /// Could not set tcp write timeout - SetWriteTimeout { - source: std::io::Error, - backtrace: Backtrace, - }, - - /// failed to send association request - SendRequest { - #[snafu(backtrace)] - source: crate::pdu::WriteError, - }, - - /// failed to receive association response - ReceiveResponse { - #[snafu(backtrace)] - source: crate::pdu::ReadError, - }, - - #[snafu(display("unexpected response from server `{:?}`", pdu))] - #[non_exhaustive] - UnexpectedResponse { - /// the PDU obtained from the server - pdu: Box, - }, - - #[snafu(display("unknown response from server `{:?}`", pdu))] - #[non_exhaustive] - UnknownResponse { - /// the PDU obtained from the server, of variant Unknown - pdu: Box, - }, - - #[snafu(display("protocol version mismatch: expected {}, got {}", expected, got))] - ProtocolVersionMismatch { - expected: u16, - got: u16, - backtrace: Backtrace, - }, - - #[snafu(display("association rejected by the server: {}", association_rj.source))] - Rejected { - association_rj: AssociationRJ, - backtrace: Backtrace, - }, - - /// no presentation contexts accepted by the server - NoAcceptedPresentationContexts { backtrace: Backtrace }, - - /// failed to send PDU message on wire - #[non_exhaustive] - WireSend { - source: std::io::Error, - backtrace: Backtrace, - }, - - /// Operation timed out - #[non_exhaustive] - Timeout { - source: std::io::Error, - backtrace: Backtrace, - }, - - #[snafu(display( - "PDU is too large ({} bytes) to be sent to the remote application entity", - length - ))] - #[non_exhaustive] - SendTooLongPdu { length: usize, backtrace: Backtrace }, - - /// failed to receive PDU message - #[non_exhaustive] - Receive { - #[snafu(backtrace)] - source: crate::pdu::ReadError, - }, - - #[snafu(display("Connection closed by peer"))] - ConnectionClosed, -} - -pub type Result = std::result::Result; - -/// Helper function to get a PDU from a reader. -/// -/// Chunks of data are read into `read_buffer`, -/// which should be passed in subsequent calls -/// to receive more PDUs from the same stream. -pub fn get_client_pdu( - reader: &mut R, - read_buffer: &mut BytesMut, - max_pdu_length: u32, - strict: bool, -) -> Result -where - R: Read, -{ - let mut reader = BufReader::new(reader); - let msg = loop { - let mut buf = Cursor::new(&read_buffer[..]); - // try to read a PDU according to what's in the buffer - match read_pdu(&mut buf, max_pdu_length, strict).context(ReceiveResponseSnafu)? { - Some(pdu) => { - read_buffer.advance(buf.position() as usize); - break pdu; - } - None => { - // Reset position - buf.set_position(0) - } - } - // Use BufReader to get similar behavior to AsyncRead read_buf - let recv = reader - .fill_buf() - .context(ReadPduSnafu) - .context(ReceiveSnafu)?; - let bytes_read = recv.len(); - read_buffer.extend_from_slice(recv); - reader.consume(bytes_read); - ensure!(bytes_read != 0, ConnectionClosedSnafu); - }; - Ok(msg) -} - /// A DICOM association builder for a client node. /// The final outcome is a [`ClientAssociation`]. /// @@ -188,8 +37,8 @@ where /// You can create either a blocking or non-blocking client by calling either /// `establish` or `establish_async` respectively. /// -/// > **⚠️ Warning:** It is highly recommended to set `timeout` to a reasonable value for the -/// > async client since there is _no_ default timeout on +/// > **⚠️ Warning:** It is highly recommended to set `read_timeout` and `write_timeout` to a reasonable +/// > value for the async client since there is _no_ default timeout on /// > [`tokio::net::TcpStream`] /// /// ## Basic usage @@ -571,13 +420,8 @@ impl<'a> ClientAssociationOptions<'a> { } } - fn establish_impl( - self, - ae_address: AeAddr, - ) -> Result> - where - T: ToSocketAddrs, - { + /// Construct the A-ASSOCIATE-RQ PDU given the options and the AE title. + fn create_a_associate_req(&'a self, ae_title: Option<&str>) -> Result<(Vec, Pdu)> { let ClientAssociationOptions { calling_ae_title, called_ae_title, @@ -585,17 +429,13 @@ impl<'a> ClientAssociationOptions<'a> { presentation_contexts, protocol_version, max_pdu_length, - strict, username, password, kerberos_service_ticket, saml_assertion, jwt, - read_timeout, - write_timeout, - connection_timeout, + .. } = self; - // fail if no presentation contexts were provided: they represent intent, // should not be omitted by the user ensure!( @@ -604,7 +444,7 @@ impl<'a> ClientAssociationOptions<'a> { ); // choose called AE title - let called_ae_title: &str = match (&called_ae_title, ae_address.ae_title()) { + let called_ae_title: &str = match (&called_ae_title, ae_title){ (Some(aec), Some(aet)) => { if aec != aet { tracing::warn!( @@ -619,7 +459,7 @@ impl<'a> ClientAssociationOptions<'a> { }; let presentation_contexts_proposed: Vec<_> = presentation_contexts - .into_iter() + .iter() .enumerate() .map(|(i, presentation_context)| PresentationContextProposed { id: (2 * i + 1) as u8, @@ -633,72 +473,35 @@ impl<'a> ClientAssociationOptions<'a> { .collect(); let mut user_variables = vec![ - UserVariableItem::MaxLength(max_pdu_length), + UserVariableItem::MaxLength(*max_pdu_length), UserVariableItem::ImplementationClassUID(IMPLEMENTATION_CLASS_UID.to_string()), UserVariableItem::ImplementationVersionName(IMPLEMENTATION_VERSION_NAME.to_string()), ]; if let Some(user_identity) = Self::determine_user_identity( - username, - password, - kerberos_service_ticket, - saml_assertion, - jwt, + username.as_deref(), + password.as_deref(), + kerberos_service_ticket.as_deref(), + saml_assertion.as_deref(), + jwt.as_deref(), ) { user_variables.push(UserVariableItem::UserIdentityItem(user_identity)); } - let msg = Pdu::AssociationRQ(AssociationRQ { - protocol_version, + Ok((presentation_contexts_proposed.clone(), Pdu::AssociationRQ(AssociationRQ { + protocol_version: *protocol_version, calling_ae_title: calling_ae_title.to_string(), called_ae_title: called_ae_title.to_string(), application_context_name: application_context_name.to_string(), - presentation_contexts: presentation_contexts_proposed.clone(), + presentation_contexts: presentation_contexts_proposed, user_variables, - }); - - let conn_result: Result = if let Some(timeout) = connection_timeout { - let addresses = ae_address.to_socket_addrs().context(ToAddressSnafu)?; - - let mut result: Result = - Result::Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)); - - for address in addresses { - result = std::net::TcpStream::connect_timeout(&address, timeout); - if result.is_ok() { - break; - } - } - result.context(ConnectSnafu) - } else { - std::net::TcpStream::connect(ae_address).context(ConnectSnafu) - }; - - let mut socket = conn_result?; - socket - .set_read_timeout(read_timeout) - .context(SetReadTimeoutSnafu)?; - socket - .set_write_timeout(write_timeout) - .context(SetWriteTimeoutSnafu)?; - let mut buffer: Vec = Vec::with_capacity(max_pdu_length as usize); - // send request - - write_pdu(&mut buffer, &msg).context(SendRequestSnafu)?; - socket.write_all(&buffer).context(WireSendSnafu)?; - buffer.clear(); - - // !!!(#589) Soundness issue: if the SCP sends more PDUs in quick succession, - // more data may live in `buf` which may be lost, - // corrupting the PDU reader stream. - let mut buf = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); - let msg = get_client_pdu(&mut socket, &mut buf, MAXIMUM_PDU_SIZE, self.strict)?; - if !buf.is_empty() { - tracing::warn!( - "Received more data than expected in the first PDU, further issues may arise" - ); - } + }))) + } + /// Process the A-ASSOCIATE-AC PDU received from the SCP. + /// + /// Returns the negotiated options for the association + fn process_a_association_resp(&self, msg: Pdu, presentation_contexts_proposed: &[PresentationContextProposed]) -> Result { match msg { Pdu::AssociationAC(AssociationAC { protocol_version: protocol_version_scp, @@ -709,9 +512,9 @@ impl<'a> ClientAssociationOptions<'a> { user_variables, }) => { ensure!( - protocol_version == protocol_version_scp, + self.protocol_version == protocol_version_scp, ProtocolVersionMismatchSnafu { - expected: protocol_version, + expected: self.protocol_version, got: protocol_version_scp, } ); @@ -749,27 +552,11 @@ impl<'a> ClientAssociationOptions<'a> { }) .collect(); if presentation_contexts.is_empty() { - // abort connection - let _ = write_pdu( - &mut buffer, - &Pdu::AbortRQ { - source: AbortRQSource::ServiceUser, - }, - ); - let _ = socket.write_all(&buffer); - buffer.clear(); return NoAcceptedPresentationContextsSnafu.fail(); } - Ok(ClientAssociation { + Ok(NegotiatedOptions{ presentation_contexts, - requestor_max_pdu_length: max_pdu_length, - acceptor_max_pdu_length, - socket, - buffer, - strict, - read_buffer: BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize), - read_timeout, - write_timeout, + peer_max_pdu_length: acceptor_max_pdu_length, user_variables, }) } @@ -778,18 +565,65 @@ impl<'a> ClientAssociationOptions<'a> { | pdu @ Pdu::ReleaseRQ | pdu @ Pdu::AssociationRQ { .. } | pdu @ Pdu::PData { .. } - | pdu @ Pdu::ReleaseRP => { - // abort connection - let _ = write_pdu( - &mut buffer, - &Pdu::AbortRQ { - source: AbortRQSource::ServiceUser, - }, - ); - let _ = socket.write_all(&buffer); - UnexpectedResponseSnafu { pdu }.fail() + | pdu @ Pdu::ReleaseRP => UnexpectedPduSnafu { pdu }.fail(), + pdu @ Pdu::Unknown { .. } => UnknownPduSnafu { pdu }.fail() + } + } + + fn simple_tcp_connection(&self, + ae_address: AeAddr, + ) -> Result where T: ToSocketAddrs{ + let conn_result: Result = if let Some(timeout) = self.connection_timeout { + let addresses = ae_address.to_socket_addrs().context(ToAddressSnafu)?; + + let mut result: Result = + Result::Err(std::io::Error::from(std::io::ErrorKind::AddrNotAvailable)); + + for address in addresses { + result = std::net::TcpStream::connect_timeout(&address, timeout); + if result.is_ok() { + break; + } } - pdu @ Pdu::Unknown { .. } => { + result.context(ConnectSnafu) + } else { + std::net::TcpStream::connect(ae_address).context(ConnectSnafu) + }; + + let socket = conn_result?; + socket + .set_read_timeout(self.read_timeout) + .context(SetReadTimeoutSnafu)?; + socket + .set_write_timeout(self.write_timeout) + .context(SetWriteTimeoutSnafu)?; + + Ok(socket) + + } + + /// Establish the association with the given AE address. + fn establish_impl( + self, + ae_address: AeAddr, + ) -> Result> + where + T: ToSocketAddrs, + { + let (pc_proposed, a_associate) = self.create_a_associate_req(ae_address.ae_title())?; + let mut socket = self.simple_tcp_connection(ae_address)?; + let mut buffer: Vec = Vec::with_capacity(self.max_pdu_length as usize); + // send request + + write_pdu(&mut buffer, &a_associate).context(SendPduSnafu)?; + socket.write_all(&buffer).context(WireSendSnafu)?; + buffer.clear(); + + let mut buf = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + let resp = read_pdu_from_wire(&mut socket, &mut buf, MAXIMUM_PDU_SIZE, self.strict)?; + let negotiated_options = self.process_a_association_resp(resp, &pc_proposed); + match negotiated_options { + Err(e) => { // abort connection let _ = write_pdu( &mut buffer, @@ -798,7 +632,23 @@ impl<'a> ClientAssociationOptions<'a> { }, ); let _ = socket.write_all(&buffer); - UnknownResponseSnafu { pdu }.fail() + buffer.clear(); + Err(e) + }, + Ok(NegotiatedOptions{presentation_contexts, peer_max_pdu_length, user_variables}) => { + Ok(ClientAssociation { + presentation_contexts, + requestor_max_pdu_length: self.max_pdu_length, + acceptor_max_pdu_length: peer_max_pdu_length, + socket, + buffer, + strict: self.strict, + // Fixes #589, instead of creating a new buffer, we pass the existing buffer into the Association object. + read_buffer: buf, + read_timeout: self.read_timeout, + write_timeout: self.write_timeout, + user_variables, + }) } } } @@ -980,7 +830,7 @@ where /// Send a PDU message to the other intervenient. pub fn send(&mut self, msg: &Pdu) -> Result<()> { self.buffer.clear(); - write_pdu(&mut self.buffer, msg).context(SendRequestSnafu)?; + write_pdu(&mut self.buffer, msg).context(SendPduSnafu)?; if self.buffer.len() > self.acceptor_max_pdu_length as usize { return SendTooLongPduSnafu { length: self.buffer.len(), @@ -992,34 +842,7 @@ where /// Read a PDU message from the other intervenient. pub fn receive(&mut self) -> Result { - use std::io::{BufRead, BufReader, Cursor}; - - let mut reader = BufReader::new(&mut self.socket); - - loop { - let mut buf = Cursor::new(&self.read_buffer[..]); - match read_pdu(&mut buf, self.acceptor_max_pdu_length, self.strict) - .context(ReceiveResponseSnafu)? - { - Some(pdu) => { - self.read_buffer.advance(buf.position() as usize); - return Ok(pdu); - } - None => { - // Reset position - buf.set_position(0) - } - } - // Use BufReader to get similar behavior to AsyncRead read_buf - let recv = reader - .fill_buf() - .context(ReadPduSnafu) - .context(ReceiveSnafu)? - .to_vec(); - reader.consume(recv.len()); - self.read_buffer.extend_from_slice(&recv); - ensure!(!recv.is_empty(), ConnectionClosedSnafu); - } + read_pdu_from_wire(&mut self.socket, &mut self.read_buffer, self.acceptor_max_pdu_length, self.strict) } /// Gracefully terminate the association by exchanging release messages @@ -1100,8 +923,8 @@ where | pdu @ Pdu::AssociationRJ { .. } | pdu @ Pdu::AssociationRQ { .. } | pdu @ Pdu::PData { .. } - | pdu @ Pdu::ReleaseRQ => return UnexpectedResponseSnafu { pdu }.fail(), - pdu @ Pdu::Unknown { .. } => return UnknownResponseSnafu { pdu }.fail(), + | pdu @ Pdu::ReleaseRQ => return UnexpectedPduSnafu { pdu }.fail(), + pdu @ Pdu::Unknown { .. } => return UnknownPduSnafu { pdu }.fail(), } Ok(()) } @@ -1121,64 +944,30 @@ where #[cfg(feature = "async")] pub mod non_blocking { - use std::{convert::TryInto, future::Future, io::Cursor, time::Duration}; + use std::{convert::TryInto, future::Future, time::Duration}; use crate::{ association::{ client::{ - ConnectSnafu, ConnectionClosedSnafu, MissingAbstractSyntaxSnafu, - NoAcceptedPresentationContextsSnafu, ProtocolVersionMismatchSnafu, - ReceiveResponseSnafu, ReceiveSnafu, RejectedSnafu, SendRequestSnafu, - ToAddressSnafu, UnexpectedResponseSnafu, UnknownResponseSnafu, WireSendSnafu, - }, - pdata::non_blocking::{AsyncPDataWriter, PDataReader}, + ConnectSnafu, NegotiatedOptions, ToAddressSnafu, + WireSendSnafu + }, pdata::non_blocking::{AsyncPDataWriter, PDataReader}, read_pdu_from_wire_async, SendPduSnafu, UnexpectedPduSnafu, UnknownPduSnafu }, pdu::{ - AbortRQSource, AssociationAC, AssociationRQ, PresentationContextProposed, - PresentationContextNegotiated, PresentationContextResultReason, ReadPduSnafu, - UserVariableItem, DEFAULT_MAX_PDU, MAXIMUM_PDU_SIZE, + AbortRQSource, + MAXIMUM_PDU_SIZE, }, - read_pdu, write_pdu, AeAddr, Pdu, IMPLEMENTATION_CLASS_UID, IMPLEMENTATION_VERSION_NAME, + write_pdu, AeAddr, Pdu, }; - - use super::{ - ClientAssociation, ClientAssociationOptions, CloseSocket, Release, Result, - SendTooLongPduSnafu, TimeoutSnafu, + use super::{CloseSocket, Release, Result}; + use crate::association::{ + ClientAssociation, ClientAssociationOptions, + SendTooLongPduSnafu, TimeoutSnafu }; - use bytes::{Buf, BytesMut}; - use snafu::{ensure, ResultExt}; - use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; - - pub async fn get_client_pdu_async( - reader: &mut R, - max_pdu_length: u32, - strict: bool, - ) -> Result { - // receive response - use tokio::io::AsyncReadExt; - let mut read_buffer = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); - - let msg = loop { - let mut buf = Cursor::new(&read_buffer[..]); - match read_pdu(&mut buf, max_pdu_length, strict).context(ReceiveResponseSnafu)? { - Some(pdu) => { - read_buffer.advance(buf.position() as usize); - break pdu; - } - None => { - // Reset position - buf.set_position(0) - } - } - let recv = reader - .read_buf(&mut read_buffer) - .await - .context(ReadPduSnafu) - .context(ReceiveSnafu)?; - ensure!(recv > 0, ConnectionClosedSnafu); - }; - Ok(msg) - } + + use bytes::BytesMut; + use snafu::ResultExt; + use tokio::io::AsyncWriteExt; // Helper function to perform an operation with timeout async fn timeout( @@ -1195,96 +984,13 @@ pub mod non_blocking { } } - impl ClientAssociationOptions<'_> { - async fn establish_impl_async( - self, - ae_address: AeAddr, - ) -> Result> + impl<'a> ClientAssociationOptions<'a> { + pub(crate) async fn async_simple_tcp_connection(&self, ae_address: AeAddr) -> Result where - T: tokio::net::ToSocketAddrs, + T: tokio::net::ToSocketAddrs { - let ClientAssociationOptions { - calling_ae_title, - called_ae_title, - application_context_name, - presentation_contexts, - protocol_version, - max_pdu_length, - strict, - username, - password, - kerberos_service_ticket, - saml_assertion, - jwt, - read_timeout, - write_timeout, - connection_timeout, - } = self; - - // fail if no presentation contexts were provided: they represent intent, - // should not be omitted by the user - ensure!( - !presentation_contexts.is_empty(), - MissingAbstractSyntaxSnafu - ); - - // choose called AE title - let called_ae_title: &str = match (&called_ae_title, ae_address.ae_title()) { - (Some(aec), Some(aet)) => { - if aec != aet { - tracing::warn!( - "Option `called_ae_title` overrides the AE title from `{aet}` to `{aec}`" - ); - } - aec - } - (Some(aec), None) => aec, - (None, Some(aec)) => aec, - (None, None) => "ANY-SCP", - }; - - let presentation_contexts_proposed: Vec<_> = presentation_contexts - .into_iter() - .enumerate() - .map(|(i, presentation_context)| PresentationContextProposed { - id: (2 * i + 1) as u8, - abstract_syntax: presentation_context.0.to_string(), - transfer_syntaxes: presentation_context - .1 - .iter() - .map(|uid| uid.to_string()) - .collect(), - }) - .collect(); - - let mut user_variables = vec![ - UserVariableItem::MaxLength(max_pdu_length), - UserVariableItem::ImplementationClassUID(IMPLEMENTATION_CLASS_UID.to_string()), - UserVariableItem::ImplementationVersionName( - IMPLEMENTATION_VERSION_NAME.to_string(), - ), - ]; - - if let Some(user_identity) = Self::determine_user_identity( - username, - password, - kerberos_service_ticket, - saml_assertion, - jwt, - ) { - user_variables.push(UserVariableItem::UserIdentityItem(user_identity)); - } - - let msg = Pdu::AssociationRQ(AssociationRQ { - protocol_version, - calling_ae_title: calling_ae_title.to_string(), - called_ae_title: called_ae_title.to_string(), - application_context_name: application_context_name.to_string(), - presentation_contexts: presentation_contexts_proposed.clone(), - user_variables, - }); let conn_result: Result = - if let Some(timeout) = connection_timeout { + if let Some(timeout) = self.connection_timeout { let addresses = tokio::net::lookup_host(ae_address.socket_addr()) .await .context(ToAddressSnafu)?; @@ -1313,119 +1019,38 @@ pub mod non_blocking { .context(ConnectSnafu) }; - let mut socket = conn_result?; - let mut buffer: Vec = Vec::with_capacity(max_pdu_length as usize); + conn_result + + } + + async fn establish_impl_async( + self, + ae_address: AeAddr, + ) -> Result> + where + T: tokio::net::ToSocketAddrs, + { + let (pc_proposed, a_associate) = self.create_a_associate_req(ae_address.ae_title())?; + let mut socket = self.async_simple_tcp_connection(ae_address).await?; + let mut buffer: Vec = Vec::with_capacity(self.max_pdu_length as usize); // send request - write_pdu(&mut buffer, &msg).context(SendRequestSnafu)?; - timeout(write_timeout, async { + write_pdu(&mut buffer, &a_associate).context(SendPduSnafu)?; + timeout(self.write_timeout, async { socket.write_all(&buffer).await.context(WireSendSnafu)?; Ok(()) }) .await?; buffer.clear(); - let msg = timeout(read_timeout, async { - get_client_pdu_async(&mut socket, MAXIMUM_PDU_SIZE, strict).await + + let mut buf = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + let resp = timeout(self.read_timeout, async { + read_pdu_from_wire_async(&mut socket, &mut buf, MAXIMUM_PDU_SIZE, self.strict).await }) .await?; - - match msg { - Pdu::AssociationAC(AssociationAC { - protocol_version: protocol_version_scp, - application_context_name: _, - presentation_contexts: presentation_contexts_scp, - calling_ae_title: _, - called_ae_title: _, - user_variables, - }) => { - ensure!( - protocol_version == protocol_version_scp, - ProtocolVersionMismatchSnafu { - expected: protocol_version, - got: protocol_version_scp, - } - ); - - let acceptor_max_pdu_length = user_variables - .iter() - .find_map(|item| match item { - UserVariableItem::MaxLength(len) => Some(*len), - _ => None, - }) - .unwrap_or(DEFAULT_MAX_PDU); - - // treat 0 as the maximum size admitted by the standard - let acceptor_max_pdu_length = if acceptor_max_pdu_length == 0 { - MAXIMUM_PDU_SIZE - } else { - acceptor_max_pdu_length - }; - - let presentation_contexts: Vec<_> = presentation_contexts_scp - .into_iter() - .filter(|c| c.reason == PresentationContextResultReason::Acceptance - && presentation_contexts_proposed.iter().any(|p| p.id == c.id)) - .map(|c| { - let pcp = presentation_contexts_proposed - .iter() - .find(|pc| pc.id == c.id) - .unwrap(); - PresentationContextNegotiated { - id: c.id, - reason: c.reason, - transfer_syntax: c.transfer_syntax, - abstract_syntax: pcp.abstract_syntax.clone(), - } - }) - .collect(); - if presentation_contexts.is_empty() { - // abort connection - let _ = write_pdu( - &mut buffer, - &Pdu::AbortRQ { - source: AbortRQSource::ServiceUser, - }, - ); - let _ = timeout(write_timeout, async { - socket.write_all(&buffer).await.context(WireSendSnafu) - }) - .await; - buffer.clear(); - return NoAcceptedPresentationContextsSnafu.fail(); - } - Ok(ClientAssociation { - presentation_contexts, - requestor_max_pdu_length: max_pdu_length, - acceptor_max_pdu_length, - socket, - buffer, - strict, - read_timeout, - write_timeout, - read_buffer: BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize), - user_variables - }) - } - Pdu::AssociationRJ(association_rj) => RejectedSnafu { association_rj }.fail(), - pdu @ Pdu::AbortRQ { .. } - | pdu @ Pdu::ReleaseRQ - | pdu @ Pdu::AssociationRQ { .. } - | pdu @ Pdu::PData { .. } - | pdu @ Pdu::ReleaseRP => { - // abort connection - let _ = write_pdu( - &mut buffer, - &Pdu::AbortRQ { - source: AbortRQSource::ServiceUser, - }, - ); - let _ = timeout(write_timeout, async { - socket.write_all(&buffer).await.context(WireSendSnafu) - }) - .await; - UnexpectedResponseSnafu { pdu }.fail() - } - pdu @ Pdu::Unknown { .. } => { + let negotiated_options = self.process_a_association_resp(resp, &pc_proposed); + match negotiated_options { + Err(e) => { // abort connection let _ = write_pdu( &mut buffer, @@ -1433,11 +1058,25 @@ pub mod non_blocking { source: AbortRQSource::ServiceUser, }, ); - let _ = timeout(write_timeout, async { - socket.write_all(&buffer).await.context(WireSendSnafu) + socket.write_all(&buffer).await + .context(WireSendSnafu)?; + buffer.clear(); + Err(e) + }, + Ok(NegotiatedOptions{presentation_contexts, peer_max_pdu_length, user_variables}) => { + Ok(ClientAssociation { + presentation_contexts, + requestor_max_pdu_length: self.max_pdu_length, + acceptor_max_pdu_length: peer_max_pdu_length, + socket, + buffer, + strict: self.strict, + // Fixes #589, instead of creating a new buffer, we pass the existing buffer into the Association object. + read_buffer: buf, + read_timeout: self.read_timeout, + write_timeout: self.write_timeout, + user_variables, }) - .await; - UnknownResponseSnafu { pdu }.fail() } } } @@ -1500,7 +1139,7 @@ pub mod non_blocking { /// Send a PDU message to the other intervenient. pub async fn send(&mut self, msg: &Pdu) -> Result<()> { self.buffer.clear(); - write_pdu(&mut self.buffer, msg).context(SendRequestSnafu)?; + write_pdu(&mut self.buffer, msg).context(SendPduSnafu)?; if self.buffer.len() > self.acceptor_max_pdu_length as usize { return SendTooLongPduSnafu { length: self.buffer.len(), @@ -1519,28 +1158,12 @@ pub mod non_blocking { /// Read a PDU message from the other intervenient. pub async fn receive(&mut self) -> Result { timeout(self.read_timeout, async { - loop { - let mut buf = Cursor::new(&self.read_buffer[..]); - match read_pdu(&mut buf, self.requestor_max_pdu_length, self.strict) - .context(ReceiveResponseSnafu)? - { - Some(pdu) => { - self.read_buffer.advance(buf.position() as usize); - return Ok(pdu); - } - None => { - // Reset position - buf.set_position(0) - } - } - let recv = self - .socket - .read_buf(&mut self.read_buffer) - .await - .context(ReadPduSnafu) - .context(ReceiveSnafu)?; - ensure!(recv > 0, ConnectionClosedSnafu); - } + read_pdu_from_wire_async( + &mut self.socket, + &mut self.read_buffer, + self.acceptor_max_pdu_length, + self.strict + ).await }) .await } @@ -1608,21 +1231,7 @@ pub mod non_blocking { async fn release_impl(&mut self) -> Result<()> { let pdu = Pdu::ReleaseRQ; self.send(&pdu).await?; - use tokio::io::AsyncReadExt; - let mut read_buffer = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); - - let pdu = loop { - if let Ok(Some(pdu)) = read_pdu(&mut read_buffer, MAXIMUM_PDU_SIZE, self.strict) { - break pdu; - } - let recv = self - .socket - .read_buf(&mut read_buffer) - .await - .context(ReadPduSnafu) - .context(ReceiveSnafu)?; - ensure!(recv > 0, ConnectionClosedSnafu); - }; + let pdu = self.receive().await?; match pdu { Pdu::ReleaseRP => {} pdu @ Pdu::AbortRQ { .. } @@ -1630,8 +1239,8 @@ pub mod non_blocking { | pdu @ Pdu::AssociationRJ { .. } | pdu @ Pdu::AssociationRQ { .. } | pdu @ Pdu::PData { .. } - | pdu @ Pdu::ReleaseRQ => return UnexpectedResponseSnafu { pdu }.fail(), - pdu @ Pdu::Unknown { .. } => return UnknownResponseSnafu { pdu }.fail(), + | pdu @ Pdu::ReleaseRQ => return UnexpectedPduSnafu { pdu }.fail(), + pdu @ Pdu::Unknown { .. } => return UnknownPduSnafu { pdu }.fail(), } Ok(()) } @@ -1665,3 +1274,170 @@ pub mod non_blocking { } } } + + +#[cfg(test)] +mod tests { + use super::*; + #[cfg(feature = "async")] + use tokio::io::AsyncWriteExt; + #[cfg(feature = "async")] + use crate::association::read_pdu_from_wire_async; + + + impl<'a> ClientAssociationOptions<'a> { + pub(crate) fn establish_with_extra_pdus( + &self, ae_address: AeAddr, extra_pdus: Vec + ) -> Result> + where T: ToSocketAddrs + { + let (pc_proposed, a_associate) = self.create_a_associate_req(ae_address.ae_title())?; + let mut socket = self.simple_tcp_connection(ae_address)?; + let mut buffer: Vec = Vec::with_capacity(self.max_pdu_length as usize); + // send request + + write_pdu(&mut buffer, &a_associate).context(SendPduSnafu)?; + for pdu in extra_pdus { + write_pdu(&mut buffer, &pdu).context(SendPduSnafu)?; + } + socket.write_all(&buffer).context(WireSendSnafu)?; + buffer.clear(); + + let mut buf = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + let resp = read_pdu_from_wire(&mut socket, &mut buf, MAXIMUM_PDU_SIZE, self.strict)?; + let NegotiatedOptions{ + presentation_contexts, + peer_max_pdu_length, + user_variables + } = self.process_a_association_resp(resp, &pc_proposed) + .expect("Failed to process a associate response"); + Ok(ClientAssociation { + presentation_contexts, + requestor_max_pdu_length: self.max_pdu_length, + acceptor_max_pdu_length: peer_max_pdu_length, + socket, + buffer, + strict: self.strict, + // Fixes #589, instead of creating a new buffer, we pass the existing buffer into the Association object. + read_buffer: buf, + read_timeout: self.read_timeout, + write_timeout: self.write_timeout, + user_variables, + }) + } + + #[cfg(feature = "async")] + pub(crate) async fn establish_with_extra_pdus_async( + &self, ae_address: AeAddr, extra_pdus: Vec + ) -> Result> + where T: tokio::net::ToSocketAddrs + { + let (pc_proposed, a_associate) = self.create_a_associate_req(ae_address.ae_title())?; + let mut socket = self.async_simple_tcp_connection(ae_address).await?; + let mut buffer: Vec = Vec::with_capacity(self.max_pdu_length as usize); + // send request + + write_pdu(&mut buffer, &a_associate).context(SendPduSnafu)?; + for pdu in extra_pdus { + write_pdu(&mut buffer, &pdu).context(SendPduSnafu)?; + } + socket.write_all(&buffer).await.context(WireSendSnafu)?; + buffer.clear(); + + let mut buf = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + let resp = read_pdu_from_wire_async(&mut socket, &mut buf, MAXIMUM_PDU_SIZE, self.strict).await?; + let NegotiatedOptions{ + presentation_contexts, + peer_max_pdu_length, + user_variables + } = self.process_a_association_resp(resp, &pc_proposed) + .expect("Failed to process a associate response"); + Ok(ClientAssociation { + presentation_contexts, + requestor_max_pdu_length: self.max_pdu_length, + acceptor_max_pdu_length: peer_max_pdu_length, + socket, + buffer, + strict: self.strict, + // Fixes #589, instead of creating a new buffer, we pass the existing buffer into the Association object. + read_buffer: buf, + read_timeout: self.read_timeout, + write_timeout: self.write_timeout, + user_variables, + }) + } + + // Broken implementation of server establish which reproduces behavior that #589 introduced + pub fn broken_establish( + &self, ae_address: AeAddr + ) -> Result> + where T: ToSocketAddrs + { + let (pc_proposed, a_associate) = self.create_a_associate_req(ae_address.ae_title())?; + let mut socket = self.simple_tcp_connection(ae_address)?; + let mut buffer: Vec = Vec::with_capacity(self.max_pdu_length as usize); + // send request + write_pdu(&mut buffer, &a_associate).context(SendPduSnafu)?; + socket.write_all(&buffer).context(WireSendSnafu)?; + buffer.clear(); + + let mut buf = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + let resp = read_pdu_from_wire(&mut socket, &mut buf, MAXIMUM_PDU_SIZE, self.strict)?; + let NegotiatedOptions{ + presentation_contexts, + peer_max_pdu_length, + user_variables + } = self.process_a_association_resp(resp, &pc_proposed) + .expect("Failed to process a associate response"); + Ok(ClientAssociation { + presentation_contexts, + requestor_max_pdu_length: self.max_pdu_length, + acceptor_max_pdu_length: peer_max_pdu_length, + socket, + buffer, + strict: self.strict, + read_buffer: BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize), + read_timeout: self.read_timeout, + write_timeout: self.write_timeout, + user_variables, + }) + } + + #[cfg(feature = "async")] + // Broken implementation of server establish which reproduces behavior that #589 introduced + pub async fn broken_establish_async( + &self, ae_address: AeAddr + ) -> Result> + where T: tokio::net::ToSocketAddrs + { + let (pc_proposed, a_associate) = self.create_a_associate_req(ae_address.ae_title())?; + let mut socket = self.async_simple_tcp_connection(ae_address).await?; + let mut buffer: Vec = Vec::with_capacity(self.max_pdu_length as usize); + // send request + write_pdu(&mut buffer, &a_associate).context(SendPduSnafu)?; + socket.write_all(&buffer).await.context(WireSendSnafu)?; + buffer.clear(); + + let mut buf = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + let resp = read_pdu_from_wire_async(&mut socket, &mut buf, MAXIMUM_PDU_SIZE, self.strict).await?; + let NegotiatedOptions{ + presentation_contexts, + peer_max_pdu_length, + user_variables + } = self.process_a_association_resp(resp, &pc_proposed) + .expect("Failed to process a associate response"); + Ok(ClientAssociation { + presentation_contexts, + requestor_max_pdu_length: self.max_pdu_length, + acceptor_max_pdu_length: peer_max_pdu_length, + socket, + buffer, + strict: self.strict, + read_buffer: BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize), + read_timeout: self.read_timeout, + write_timeout: self.write_timeout, + user_variables, + }) + } + } +} diff --git a/ul/src/association/mod.rs b/ul/src/association/mod.rs index 58aa2bc1..a10716d2 100644 --- a/ul/src/association/mod.rs +++ b/ul/src/association/mod.rs @@ -16,6 +16,8 @@ //! //! //! [1]: std::net::TcpStream +#[cfg(test)] +mod tests; pub mod client; pub mod server; @@ -23,8 +25,212 @@ mod uid; pub(crate) mod pdata; +use std::{backtrace::Backtrace, io::{BufRead, BufReader, Cursor, Read}}; + +use bytes::{Buf, BytesMut}; pub use client::{ClientAssociation, ClientAssociationOptions}; #[cfg(feature = "async")] pub use pdata::non_blocking::AsyncPDataWriter; pub use pdata::{PDataReader, PDataWriter}; pub use server::{ServerAssociation, ServerAssociationOptions}; +use snafu::{ensure, Snafu, ResultExt}; + +use crate::{Pdu, pdu::{self, AssociationRJ, PresentationContextNegotiated, ReadPduSnafu, UserVariableItem}}; + +pub(crate) struct NegotiatedOptions{ + peer_max_pdu_length: u32, + user_variables: Vec, + presentation_contexts: Vec, +} + +type Result = std::result::Result; + +#[derive(Debug, Snafu)] +#[non_exhaustive] +pub enum Error { + /// missing abstract syntax to begin negotiation + MissingAbstractSyntax { backtrace: Backtrace }, + + /// could not convert to sockeDUt address + ToAddress { + source: std::io::Error, + backtrace: Backtrace, + }, + + /// could not connect to server + Connect { + source: std::io::Error, + backtrace: Backtrace, + }, + + /// Could not set tcp read timeout + SetReadTimeout { + source: std::io::Error, + backtrace: Backtrace, + }, + + /// Could not set tcp write timeout + SetWriteTimeout { + source: std::io::Error, + backtrace: Backtrace, + }, + + /// failed to send association request + #[snafu(display("failed to send pdu: {}", source))] + SendPdu{ + #[snafu(backtrace)] + source: crate::pdu::WriteError, + }, + + /// failed to receive association response + #[snafu(display("failed to receive pdu: {}", source))] + ReceivePdu { + #[snafu(backtrace)] + source: crate::pdu::ReadError, + }, + + #[snafu(display("unexpected response from peer `{:?}`", pdu))] + #[non_exhaustive] + UnexpectedPdu { + /// the PDU obtained from the server + pdu: Box, + }, + + #[snafu(display("unknown response from peer `{:?}`", pdu))] + #[non_exhaustive] + UnknownPdu { + /// the PDU obtained from the server, of variant Unknown + pdu: Box, + }, + + #[snafu(display("protocol version mismatch: expected {}, got {}", expected, got))] + ProtocolVersionMismatch { + expected: u16, + got: u16, + backtrace: Backtrace, + }, + + // Association rejected by the server + #[snafu(display("association rejected {}", association_rj.source))] + Rejected { + association_rj: AssociationRJ, + backtrace: Backtrace, + }, + + /// association aborted + Aborted { backtrace: Backtrace }, + + /// no presentation contexts accepted by the server + NoAcceptedPresentationContexts { backtrace: Backtrace }, + + /// failed to send PDU message on wire + #[non_exhaustive] + WireSend { + source: std::io::Error, + backtrace: Backtrace, + }, + + /// failed to read PDU message from wire + #[non_exhaustive] + WireRead { + source: std::io::Error, + backtrace: Backtrace, + }, + + /// Operation timed out + #[non_exhaustive] + Timeout { + source: std::io::Error, + backtrace: Backtrace, + }, + + #[snafu(display( + "PDU is too large ({} bytes) to be sent to the remote application entity", + length + ))] + #[non_exhaustive] + SendTooLongPdu { length: usize, backtrace: Backtrace }, + + #[snafu(display("Connection closed by peer"))] + ConnectionClosed, +} + +/// Helper function to get a PDU from a reader. +/// +/// Chunks of data are read into `read_buffer`, +/// which should be passed in subsequent calls +/// to receive more PDUs from the same stream. +pub fn read_pdu_from_wire( + reader: &mut R, + read_buffer: &mut BytesMut, + max_pdu_length: u32, + strict: bool, +) -> Result +where + R: Read, +{ + let mut reader = BufReader::new(reader); + let msg = loop { + let mut buf = Cursor::new(&read_buffer[..]); + // try to read a PDU according to what's in the buffer + match pdu::read_pdu(&mut buf, max_pdu_length, strict).context(ReceivePduSnafu)? { + Some(pdu) => { + read_buffer.advance(buf.position() as usize); + break pdu; + } + None => { + // Reset position + buf.set_position(0) + } + } + // Use BufReader to get similar behavior to AsyncRead read_buf + let recv = reader + .fill_buf() + .context(ReadPduSnafu) + .context(ReceivePduSnafu)?; + let bytes_read = recv.len(); + read_buffer.extend_from_slice(recv); + reader.consume(bytes_read); + ensure!(bytes_read != 0, ConnectionClosedSnafu); + }; + Ok(msg) +} + +#[cfg(feature = "async")] +use tokio::io::{AsyncRead, AsyncReadExt}; + +/// Helper function to get a PDU from an async reader. +/// +/// Chunks of data are read into `read_buffer`, +/// which should be passed in subsequent calls +/// to receive more PDUs from the same stream. +#[cfg(feature = "async")] +pub async fn read_pdu_from_wire_async( + reader: &mut R, + read_buffer: &mut BytesMut, + max_pdu_length: u32, + strict: bool, +) -> Result { + // receive response + + let msg = loop { + let mut buf = Cursor::new(&read_buffer[..]); + match pdu::read_pdu(&mut buf, max_pdu_length, strict).context(ReceivePduSnafu)? { + Some(pdu) => { + read_buffer.advance(buf.position() as usize); + break pdu; + } + None => { + // Reset position + buf.set_position(0) + } + } + let recv = reader + .read_buf(read_buffer) + .await + .context(ReadPduSnafu) + .context(ReceivePduSnafu)?; + ensure!(recv > 0, ConnectionClosedSnafu); + }; + Ok(msg) +} \ No newline at end of file diff --git a/ul/src/association/server.rs b/ul/src/association/server.rs index d1eeebdf..8d3298c8 100644 --- a/ul/src/association/server.rs +++ b/ul/src/association/server.rs @@ -4,23 +4,29 @@ //! in which this application entity listens to incoming association requests. //! See [`ServerAssociationOptions`] //! for details and examples on how to create an association. -use bytes::{Buf, BytesMut}; -use std::io::{BufRead, BufReader}; +use bytes::BytesMut; use std::time::Duration; -use std::{borrow::Cow, io::Cursor}; +use std::borrow::Cow; use std::{io::Write, net::TcpStream}; use dicom_encoding::transfer_syntax::TransferSyntaxIndex; use dicom_transfer_syntax_registry::TransferSyntaxRegistry; -use snafu::{ensure, Backtrace, ResultExt, Snafu}; +use snafu::{ensure, ResultExt}; +use crate::association::{ + read_pdu_from_wire, AbortedSnafu, + MissingAbstractSyntaxSnafu, RejectedSnafu, SendPduSnafu, + SendTooLongPduSnafu, SetReadTimeoutSnafu, SetWriteTimeoutSnafu, + UnexpectedPduSnafu, UnknownPduSnafu, WireSendSnafu +}; +use crate::association::NegotiatedOptions; +use crate::pdu::PresentationContextNegotiated; use crate::{ pdu::{ - read_pdu, write_pdu, AbortRQServiceProviderReason, AbortRQSource, AssociationAC, + write_pdu, AbortRQServiceProviderReason, AbortRQSource, AssociationAC, AssociationRJ, AssociationRJResult, AssociationRJServiceUserReason, AssociationRJSource, - AssociationRQ, Pdu, PresentationContextResult, PresentationContextNegotiated, - PresentationContextResultReason, ReadPduSnafu, UserIdentity, UserVariableItem, - DEFAULT_MAX_PDU, MAXIMUM_PDU_SIZE, + AssociationRQ, Pdu, PresentationContextResult, PresentationContextResultReason, + UserIdentity, UserVariableItem, DEFAULT_MAX_PDU, MAXIMUM_PDU_SIZE, }, IMPLEMENTATION_CLASS_UID, IMPLEMENTATION_VERSION_NAME, }; @@ -28,90 +34,9 @@ use crate::{ use super::{ pdata::{PDataReader, PDataWriter}, uid::trim_uid, + Error }; -#[derive(Debug, Snafu)] -#[non_exhaustive] -pub enum Error { - /// missing at least one abstract syntax to accept negotiations - MissingAbstractSyntax { backtrace: Backtrace }, - - /// failed to receive association request - ReceiveRequest { - #[snafu(backtrace)] - source: crate::pdu::ReadError, - }, - - /// failed to send association response - SendResponse { - #[snafu(backtrace)] - source: crate::pdu::WriteError, - }, - - /// failed to prepare PDU - Send { - #[snafu(backtrace)] - source: crate::pdu::WriteError, - }, - /// Failed to read from the wire - WireRead { - source: std::io::Error, - backtrace: Backtrace, - }, - /// failed to send PDU over the wire - WireSend { - source: std::io::Error, - backtrace: Backtrace, - }, - - /// failed to receive PDU - Receive { - #[snafu(backtrace)] - source: crate::pdu::ReadError, - }, - - #[snafu(display("unexpected request from SCU `{:?}`", pdu))] - #[non_exhaustive] - UnexpectedRequest { - /// the PDU obtained from the server - pdu: Box, - }, - - #[snafu(display("unknown request from SCU `{:?}`", pdu))] - #[non_exhaustive] - UnknownRequest { - /// the PDU obtained from the server, of variant Unknown - pdu: Box, - }, - - /// association rejected - Rejected { backtrace: Backtrace }, - - /// association aborted - Aborted { backtrace: Backtrace }, - - #[snafu(display( - "PDU is too large ({} bytes) to be sent to the remote application entity", - length - ))] - #[non_exhaustive] - SendTooLongPdu { length: usize, backtrace: Backtrace }, - #[snafu(display("Connection closed by peer"))] - ConnectionClosed, - - /// Could not set tcp read timeout - SetReadTimeout { - source: std::io::Error, - backtrace: Backtrace, - }, - - /// Could not set tcp write timeout - SetWriteTimeout { - source: std::io::Error, - backtrace: Backtrace, - }, -} - pub type Result = std::result::Result; /// Common interface for application entity access control policies. @@ -295,8 +220,10 @@ pub struct ServerAssociationOptions<'a, A> { strict: bool, /// whether to accept unknown abstract syntaxes promiscuous: bool, - /// Timeout for individual send/receive operations - timeout: Option, + /// TCP read timeout + read_timeout: Option, + /// TCP write timeout + write_timeout: Option, } impl Default for ServerAssociationOptions<'_, AcceptAny> { @@ -311,7 +238,8 @@ impl Default for ServerAssociationOptions<'_, AcceptAny> { max_pdu_length: DEFAULT_MAX_PDU, strict: true, promiscuous: false, - timeout: None, + read_timeout: None, + write_timeout: None, } } } @@ -362,7 +290,8 @@ where strict, promiscuous, ae_access_control: _, - timeout, + read_timeout, + write_timeout, } = self; ServerAssociationOptions { @@ -375,7 +304,8 @@ where max_pdu_length, strict, promiscuous, - timeout, + read_timeout, + write_timeout, } } @@ -432,55 +362,36 @@ where self } - /// Set the timeout for the underlying TCP socket - pub fn timeout(self, timeout: Duration) -> Self { + /// Set the read timeout for the underlying TCP socket + /// + /// This is used to set both the read and write timeout. + pub fn read_timeout(self, timeout: Duration) -> Self { Self { - timeout: Some(timeout), + read_timeout: Some(timeout), ..self } } - /// Negotiate an association with the given TCP stream. - pub fn establish(&self, mut socket: TcpStream) -> Result> { - ensure!( - !self.abstract_syntax_uids.is_empty() || self.promiscuous, - MissingAbstractSyntaxSnafu - ); - - let max_pdu_length = self.max_pdu_length; - socket - .set_read_timeout(self.timeout) - .context(SetReadTimeoutSnafu)?; - socket - .set_write_timeout(self.timeout) - .context(SetWriteTimeoutSnafu)?; - - let mut read_buffer = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); - let mut reader = BufReader::new(&mut socket); + /// Set the write timeout for the underlying TCP socket + pub fn write_timeout(self, timeout: Duration) -> Self { + Self { + write_timeout: Some(timeout), + ..self + } + } - let msg = loop { - let mut buf = Cursor::new(&read_buffer[..]); - match read_pdu(&mut buf, MAXIMUM_PDU_SIZE, self.strict).context(ReceiveRequestSnafu)? { - Some(pdu) => { - read_buffer.advance(buf.position() as usize); - break pdu; - } - None => { - // Reset position - buf.set_position(0) - } - } - // Use BufReader to get similar behavior to AsyncRead read_buf - let recv = reader - .fill_buf() - .context(ReadPduSnafu) - .context(ReceiveSnafu)? - .to_vec(); - reader.consume(recv.len()); - read_buffer.extend_from_slice(&recv); - ensure!(!recv.is_empty(), ConnectionClosedSnafu); - }; - let mut buffer: Vec = Vec::with_capacity(max_pdu_length as usize); + /// Process an association request PDU + /// + /// In the success case, returns + /// * Pdu to be written back to client + /// * Negotiated options + /// * Calling AE title + /// + /// In the error case, returns + /// * Pdu to be written back to client + /// * Error + #[allow(clippy::result_large_err)] + fn process_a_association_rq(&self, msg: Pdu) -> std::result::Result<(Pdu, NegotiatedOptions, String),(Pdu, Error)>{ match msg { Pdu::AssociationRQ(AssociationRQ { protocol_version, @@ -491,33 +402,25 @@ where user_variables, }) => { if protocol_version != self.protocol_version { - write_pdu( - &mut buffer, - &Pdu::AssociationRJ(AssociationRJ { - result: AssociationRJResult::Permanent, - source: AssociationRJSource::ServiceUser( - AssociationRJServiceUserReason::NoReasonGiven, - ), - }), - ) - .context(SendResponseSnafu)?; - socket.write_all(&buffer).context(WireSendSnafu)?; - return RejectedSnafu.fail(); + let association_rj= AssociationRJ { + result: AssociationRJResult::Permanent, + source: AssociationRJSource::ServiceUser( + AssociationRJServiceUserReason::NoReasonGiven, + ), + }; + let pdu = Pdu::AssociationRJ(association_rj.clone()); + return Err((pdu, RejectedSnafu{association_rj}.build())); } if application_context_name != self.application_context_name { - write_pdu( - &mut buffer, - &Pdu::AssociationRJ(AssociationRJ { - result: AssociationRJResult::Permanent, - source: AssociationRJSource::ServiceUser( - AssociationRJServiceUserReason::ApplicationContextNameNotSupported, - ), - }), - ) - .context(SendResponseSnafu)?; - socket.write_all(&buffer).context(WireSendSnafu)?; - return RejectedSnafu.fail(); + let association_rj = AssociationRJ { + result: AssociationRJResult::Permanent, + source: AssociationRJSource::ServiceUser( + AssociationRJServiceUserReason::ApplicationContextNameNotSupported, + ), + }; + let pdu = Pdu::AssociationRJ(association_rj.clone()); + return Err((pdu, RejectedSnafu{association_rj}.build())); } self.ae_access_control @@ -536,16 +439,12 @@ where ) .map(Ok) .unwrap_or_else(|reason| { - write_pdu( - &mut buffer, - &Pdu::AssociationRJ(AssociationRJ { - result: AssociationRJResult::Permanent, - source: AssociationRJSource::ServiceUser(reason), - }), - ) - .context(SendResponseSnafu)?; - socket.write_all(&buffer).context(WireSendSnafu)?; - RejectedSnafu.fail() + let association_rj = AssociationRJ { + result: AssociationRJResult::Permanent, + source: AssociationRJSource::ServiceUser(reason), + }; + let pdu = Pdu::AssociationRJ(association_rj.clone()); + Err((pdu, RejectedSnafu{ association_rj}.build())) })?; // fetch requested maximum PDU length @@ -600,58 +499,93 @@ where }) .collect(); - write_pdu( - &mut buffer, - &Pdu::AssociationAC(AssociationAC { - protocol_version: self.protocol_version, - application_context_name, - presentation_contexts: presentation_contexts_negotiated - .iter() - .map(|pc| PresentationContextResult { - id: pc.id, - reason: pc.reason.clone(), - transfer_syntax: pc.transfer_syntax.clone(), - }) - .collect(), - calling_ae_title: calling_ae_title.clone(), - called_ae_title, - user_variables: vec![ - UserVariableItem::MaxLength(max_pdu_length), - UserVariableItem::ImplementationClassUID( - IMPLEMENTATION_CLASS_UID.to_string(), - ), - UserVariableItem::ImplementationVersionName( - IMPLEMENTATION_VERSION_NAME.to_string(), - ), - ], - }), - ) - .context(SendResponseSnafu)?; - socket.write_all(&buffer).context(WireSendSnafu)?; - - Ok(ServerAssociation { + let pdu = Pdu::AssociationAC(AssociationAC { + protocol_version: self.protocol_version, + application_context_name, + presentation_contexts: presentation_contexts_negotiated + .iter() + .map(|pc| PresentationContextResult { + id: pc.id, + reason: pc.reason.clone(), + transfer_syntax: pc.transfer_syntax.clone(), + }) + .collect(), + calling_ae_title: calling_ae_title.clone(), + called_ae_title, + user_variables: vec![ + UserVariableItem::MaxLength(self.max_pdu_length), + UserVariableItem::ImplementationClassUID( + IMPLEMENTATION_CLASS_UID.to_string(), + ), + UserVariableItem::ImplementationVersionName( + IMPLEMENTATION_VERSION_NAME.to_string(), + ), + ], + }); + Ok((pdu, NegotiatedOptions{ + peer_max_pdu_length: requestor_max_pdu_length, + user_variables, presentation_contexts: presentation_contexts_negotiated, - requestor_max_pdu_length, + }, calling_ae_title)) + }, + Pdu::ReleaseRQ => Err((Pdu::ReleaseRP, AbortedSnafu.build())), + pdu @ Pdu::AssociationAC { .. } + | pdu @ Pdu::AssociationRJ { .. } + | pdu @ Pdu::PData { .. } + | pdu @ Pdu::ReleaseRP + | pdu @ Pdu::AbortRQ { .. } => Err(( + Pdu::AbortRQ {source: AbortRQSource::ServiceProvider(AbortRQServiceProviderReason::UnexpectedPdu)}, + UnexpectedPduSnafu { pdu }.build() + )), + pdu @ Pdu::Unknown { .. } => Err(( + Pdu::AbortRQ {source: AbortRQSource::ServiceProvider(AbortRQServiceProviderReason::UnrecognizedPdu)}, + UnknownPduSnafu { pdu }.build() + )), + } + + } + + /// Negotiate an association with the given TCP stream. + pub fn establish(&self, mut socket: TcpStream) -> Result> { + ensure!( + !self.abstract_syntax_uids.is_empty() || self.promiscuous, + MissingAbstractSyntaxSnafu + ); + + let max_pdu_length = self.max_pdu_length; + socket + .set_read_timeout(self.read_timeout) + .context(SetReadTimeoutSnafu)?; + socket + .set_write_timeout(self.write_timeout) + .context(SetWriteTimeoutSnafu)?; + + let mut buf = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + let msg = read_pdu_from_wire(&mut socket, &mut buf, MAXIMUM_PDU_SIZE, self.strict)?; + let mut buffer: Vec = Vec::with_capacity(max_pdu_length as usize); + match self.process_a_association_rq(msg) { + Ok((pdu, NegotiatedOptions{ user_variables: _, presentation_contexts , peer_max_pdu_length}, calling_ae_title)) => { + write_pdu(&mut buffer, &pdu).context(SendPduSnafu)?; + socket.write_all(&buffer).context(WireSendSnafu)?; + Ok(ServerAssociation { + presentation_contexts, + requestor_max_pdu_length: peer_max_pdu_length, acceptor_max_pdu_length: max_pdu_length, socket, client_ae_title: calling_ae_title, buffer, strict: self.strict, - read_buffer: BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize), - timeout: self.timeout, + read_buffer: buf, + read_timeout: self.read_timeout, + write_timeout: self.write_timeout, }) - } - Pdu::ReleaseRQ => { - write_pdu(&mut buffer, &Pdu::ReleaseRP).context(SendResponseSnafu)?; + }, + Err((pdu, err)) => { + // send the rejection/abort PDU + write_pdu(&mut buffer, &pdu).context(SendPduSnafu)?; socket.write_all(&buffer).context(WireSendSnafu)?; - AbortedSnafu.fail() + Err(err) } - pdu @ Pdu::AssociationAC { .. } - | pdu @ Pdu::AssociationRJ { .. } - | pdu @ Pdu::PData { .. } - | pdu @ Pdu::ReleaseRP - | pdu @ Pdu::AbortRQ { .. } => UnexpectedRequestSnafu { pdu }.fail(), - pdu @ Pdu::Unknown { .. } => UnknownRequestSnafu { pdu }.fail(), } } @@ -711,8 +645,10 @@ pub struct ServerAssociation { strict: bool, /// Read buffer from the socket read_buffer: bytes::BytesMut, - /// Timeout for individual send/receive operations - timeout: Option, + /// Timeout for individual receive operations + read_timeout: Option, + /// Timeout for individual send operations + write_timeout: Option, } impl ServerAssociation { @@ -731,7 +667,7 @@ impl ServerAssociation { /// Send a PDU message to the other intervenient. pub fn send(&mut self, msg: &Pdu) -> Result<()> { self.buffer.clear(); - write_pdu(&mut self.buffer, msg).context(SendSnafu)?; + write_pdu(&mut self.buffer, msg).context(SendPduSnafu)?; if self.buffer.len() > self.requestor_max_pdu_length as usize { return SendTooLongPduSnafu { length: self.buffer.len(), @@ -743,34 +679,7 @@ impl ServerAssociation { /// Read a PDU message from the other intervenient. pub fn receive(&mut self) -> Result { - use std::io::{BufRead, BufReader, Cursor}; - - let mut reader = BufReader::new(&mut self.socket); - - loop { - let mut buf = Cursor::new(&self.read_buffer[..]); - match read_pdu(&mut buf, self.acceptor_max_pdu_length, self.strict) - .context(ReceiveRequestSnafu)? - { - Some(pdu) => { - self.read_buffer.advance(buf.position() as usize); - return Ok(pdu); - } - None => { - // Reset position - buf.set_position(0) - } - } - // Use BufReader to get similar behavior to AsyncRead read_buf - let recv = reader - .fill_buf() - .context(ReadPduSnafu) - .context(ReceiveSnafu)? - .to_vec(); - reader.consume(recv.len()); - self.read_buffer.extend_from_slice(&recv); - ensure!(!recv.is_empty(), ConnectionClosedSnafu); - } + read_pdu_from_wire(&mut self.socket, &mut self.read_buffer, self.acceptor_max_pdu_length, self.strict) } /// Send a provider initiated abort message @@ -886,36 +795,26 @@ where #[cfg(feature = "async")] pub mod non_blocking { - use std::{borrow::Cow, io::Cursor}; - use bytes::{Buf, BytesMut}; + use bytes::BytesMut; use snafu::{ensure, ResultExt}; - use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::TcpStream, - }; + use tokio::{io::AsyncWriteExt, net::TcpStream}; use super::{ - AccessControl, Result, SendSnafu, SendTooLongPduSnafu, ServerAssociation, + AccessControl, Result, SendTooLongPduSnafu, ServerAssociation, ServerAssociationOptions, WireSendSnafu, }; use crate::{ association::{ - server::{ - AbortedSnafu, ConnectionClosedSnafu, MissingAbstractSyntaxSnafu, - ReceiveRequestSnafu, ReceiveSnafu, RejectedSnafu, SendResponseSnafu, - UnexpectedRequestSnafu, UnknownRequestSnafu, WireReadSnafu, - }, - uid::trim_uid, + read_pdu_from_wire_async, server::{ + MissingAbstractSyntaxSnafu, + }, NegotiatedOptions, ReceivePduSnafu, SendPduSnafu, TimeoutSnafu }, pdu::{ - AbortRQServiceProviderReason, AbortRQSource, AssociationAC, AssociationRJ, - AssociationRJResult, AssociationRJServiceUserReason, AssociationRJSource, - AssociationRQ, PresentationContextResult, PresentationContextNegotiated, - PresentationContextResultReason, ReadPduSnafu, UserVariableItem, - DEFAULT_MAX_PDU, MAXIMUM_PDU_SIZE, + AbortRQServiceProviderReason, AbortRQSource, + ReadPduSnafu, MAXIMUM_PDU_SIZE, }, - read_pdu, write_pdu, Pdu, IMPLEMENTATION_CLASS_UID, IMPLEMENTATION_VERSION_NAME, + write_pdu, Pdu, }; impl ServerAssociationOptions<'_, A> @@ -931,212 +830,44 @@ pub mod non_blocking { !self.abstract_syntax_uids.is_empty() || self.promiscuous, MissingAbstractSyntaxSnafu ); - let timeout = self.timeout; + let read_timeout = self.read_timeout; let task = async { let max_pdu_length = self.max_pdu_length; - let mut read_buffer = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); - - let pdu = loop { - let mut buf = Cursor::new(&read_buffer[..]); - match read_pdu(&mut buf, MAXIMUM_PDU_SIZE, self.strict) - .context(ReceiveRequestSnafu)? - { - Some(pdu) => { - read_buffer.advance(buf.position() as usize); - break pdu; - } - None => { - // Reset position - buf.set_position(0) - } - } - let recv = socket - .read_buf(&mut read_buffer) - .await - .context(ReadPduSnafu) - .context(ReceiveSnafu)?; - ensure!(recv > 0, ConnectionClosedSnafu); - }; + let mut buf = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + let pdu = read_pdu_from_wire_async(&mut socket, &mut buf, MAXIMUM_PDU_SIZE, self.strict).await?; let mut buffer: Vec = Vec::with_capacity(max_pdu_length as usize); - match pdu { - Pdu::AssociationRQ(AssociationRQ { - protocol_version, - calling_ae_title, - called_ae_title, - application_context_name, - presentation_contexts, - user_variables, - }) => { - if protocol_version != self.protocol_version { - write_pdu( - &mut buffer, - &Pdu::AssociationRJ(AssociationRJ { - result: AssociationRJResult::Permanent, - source: AssociationRJSource::ServiceUser( - AssociationRJServiceUserReason::NoReasonGiven, - ), - }), - ) - .context(SendResponseSnafu)?; - socket.write_all(&buffer).await.context(WireSendSnafu)?; - return RejectedSnafu.fail(); - } - - if application_context_name != self.application_context_name { - write_pdu( - &mut buffer, - &Pdu::AssociationRJ(AssociationRJ { - result: AssociationRJResult::Permanent, - source: AssociationRJSource::ServiceUser( - AssociationRJServiceUserReason::ApplicationContextNameNotSupported, - ), - }), - ) - .context(SendResponseSnafu)?; - socket.write_all(&buffer).await.context(WireSendSnafu)?; - return RejectedSnafu.fail(); - } - - match self.ae_access_control.check_access( - &self.ae_title, - &calling_ae_title, - &called_ae_title, - user_variables - .iter() - .find_map(|user_variable| match user_variable { - UserVariableItem::UserIdentityItem(user_identity) => { - Some(user_identity) - } - _ => None, - }), - ) { - Ok(()) => {} - Err(reason) => { - write_pdu( - &mut buffer, - &Pdu::AssociationRJ(AssociationRJ { - result: AssociationRJResult::Permanent, - source: AssociationRJSource::ServiceUser(reason), - }), - ) - .context(SendResponseSnafu)?; - socket.write_all(&buffer).await.context(WireSendSnafu)?; - return Err(RejectedSnafu.build()); - } - } - - // fetch requested maximum PDU length - let requestor_max_pdu_length = user_variables - .iter() - .find_map(|item| match item { - UserVariableItem::MaxLength(len) => Some(*len), - _ => None, - }) - .unwrap_or(DEFAULT_MAX_PDU); - - // treat 0 as the maximum size admitted by the standard - let requestor_max_pdu_length = if requestor_max_pdu_length == 0 { - MAXIMUM_PDU_SIZE - } else { - requestor_max_pdu_length - }; - - let presentation_contexts_negotiated: Vec<_> = presentation_contexts - .into_iter() - .map(|pc| { - let abstract_syntax = trim_uid(Cow::from(pc.abstract_syntax)); - if !self - .abstract_syntax_uids - .contains(&abstract_syntax) - && !self.promiscuous - { - return PresentationContextNegotiated { - id: pc.id, - reason: PresentationContextResultReason::AbstractSyntaxNotSupported, - transfer_syntax: "1.2.840.10008.1.2".to_string(), - abstract_syntax: abstract_syntax.to_string(), - }; - } - - let (transfer_syntax, reason) = self - .choose_ts(pc.transfer_syntaxes) - .map(|ts| (ts, PresentationContextResultReason::Acceptance)) - .unwrap_or_else(|| { - ( - "1.2.840.10008.1.2".to_string(), - PresentationContextResultReason::TransferSyntaxesNotSupported, - ) - }); - - PresentationContextNegotiated { - id: pc.id, - reason, - transfer_syntax, - abstract_syntax: abstract_syntax.to_string(), - } - }) - .collect(); - - write_pdu( - &mut buffer, - &Pdu::AssociationAC(AssociationAC { - protocol_version: self.protocol_version, - application_context_name, - presentation_contexts: presentation_contexts_negotiated - .iter() - .map(|pc| PresentationContextResult { - id: pc.id, - reason: pc.reason.clone(), - transfer_syntax: pc.transfer_syntax.clone(), - }) - .collect(), - calling_ae_title: calling_ae_title.clone(), - called_ae_title, - user_variables: vec![ - UserVariableItem::MaxLength(max_pdu_length), - UserVariableItem::ImplementationClassUID( - IMPLEMENTATION_CLASS_UID.to_string(), - ), - UserVariableItem::ImplementationVersionName( - IMPLEMENTATION_VERSION_NAME.to_string(), - ), - ], - }), - ) - .context(SendResponseSnafu)?; + match self.process_a_association_rq(pdu) { + Ok((pdu, NegotiatedOptions{ user_variables: _, presentation_contexts , peer_max_pdu_length}, calling_ae_title)) => { + write_pdu(&mut buffer, &pdu).context(SendPduSnafu)?; socket.write_all(&buffer).await.context(WireSendSnafu)?; - - Ok(ServerAssociation { - presentation_contexts: presentation_contexts_negotiated, - requestor_max_pdu_length, + Ok(ServerAssociation { + presentation_contexts, + requestor_max_pdu_length: peer_max_pdu_length, acceptor_max_pdu_length: max_pdu_length, socket, client_ae_title: calling_ae_title, buffer, strict: self.strict, - read_buffer: BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize), - timeout, + read_buffer: buf, + read_timeout: self.read_timeout, + write_timeout: self.write_timeout, }) - } - Pdu::ReleaseRQ => { - write_pdu(&mut buffer, &Pdu::ReleaseRP).context(SendResponseSnafu)?; + }, + Err((pdu, err)) => { + // send the rejection/abort PDU + write_pdu(&mut buffer, &pdu).context(SendPduSnafu)?; socket.write_all(&buffer).await.context(WireSendSnafu)?; - AbortedSnafu.fail() + Err(err) } - pdu @ Pdu::AssociationAC { .. } - | pdu @ Pdu::AssociationRJ { .. } - | pdu @ Pdu::PData { .. } - | pdu @ Pdu::ReleaseRP - | pdu @ Pdu::AbortRQ { .. } => UnexpectedRequestSnafu { pdu }.fail(), - pdu @ Pdu::Unknown { .. } => UnknownRequestSnafu { pdu }.fail(), } + }; - if let Some(timeout) = timeout { + if let Some(timeout) = read_timeout { tokio::time::timeout(timeout, task) .await .map_err(|err| std::io::Error::new(std::io::ErrorKind::TimedOut, err)) - .context(WireReadSnafu)? + .context(TimeoutSnafu)? } else { task.await } @@ -1146,10 +877,10 @@ pub mod non_blocking { impl ServerAssociation { /// Send a PDU message to the other intervenient. pub async fn send(&mut self, msg: &Pdu) -> Result<()> { - let timeout = self.timeout; + let timeout = self.write_timeout; let task = async { self.buffer.clear(); - write_pdu(&mut self.buffer, msg).context(SendSnafu)?; + write_pdu(&mut self.buffer, msg).context(SendPduSnafu)?; if self.buffer.len() > self.requestor_max_pdu_length as usize { return SendTooLongPduSnafu { length: self.buffer.len(), @@ -1173,37 +904,21 @@ pub mod non_blocking { /// Read a PDU message from the other intervenient. pub async fn receive(&mut self) -> Result { - let timeout = self.timeout; + let timeout = self.read_timeout; let task = async { - loop { - let mut buf = Cursor::new(&self.read_buffer[..]); - match read_pdu(&mut buf, self.requestor_max_pdu_length, self.strict) - .context(ReceiveRequestSnafu)? - { - Some(pdu) => { - self.read_buffer.advance(buf.position() as usize); - return Ok(pdu); - } - None => { - // Reset position - buf.set_position(0) - } - } - let recv = self - .socket - .read_buf(&mut self.read_buffer) - .await - .context(ReadPduSnafu) - .context(ReceiveSnafu)?; - ensure!(recv > 0, ConnectionClosedSnafu); - } + read_pdu_from_wire_async( + &mut self.socket, + &mut self.read_buffer, + self.acceptor_max_pdu_length, + self.strict + ).await }; if let Some(timeout) = timeout { tokio::time::timeout(timeout, task) .await .map_err(|err| std::io::Error::new(std::io::ErrorKind::TimedOut, err)) .context(ReadPduSnafu) - .context(ReceiveSnafu)? + .context(ReceivePduSnafu)? } else { task.await } @@ -1213,7 +928,7 @@ pub mod non_blocking { /// and shut down the TCP connection, /// terminating the association. pub async fn abort(mut self) -> Result<()> { - let timeout = self.timeout; + let timeout = self.write_timeout; let task = async { let pdu = Pdu::AbortRQ { source: AbortRQSource::ServiceProvider( @@ -1242,7 +957,7 @@ pub mod non_blocking { #[cfg(test)] mod tests { - use super::choose_supported; + use super::*; #[test] fn test_choose_supported() { @@ -1263,4 +978,138 @@ mod tests { Some("1.2.840.10008.1.2.1".to_string()), ); } + + impl<'a, A> ServerAssociationOptions<'a, A> + where + A: AccessControl + { + // Broken implementation of server establish which sends an extra pdu during establish + pub(crate) fn establish_with_extra_pdus(&self, mut socket: std::net::TcpStream, extra_pdus: Vec) -> Result> { + let mut buf = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + let pdu = read_pdu_from_wire(&mut socket, &mut buf, MAXIMUM_PDU_SIZE, self.strict)?; + let ( + pdu, + NegotiatedOptions{ user_variables: _, presentation_contexts , peer_max_pdu_length}, + calling_ae_title + ) = self.process_a_association_rq(pdu) + .expect("Could not parse association req"); + + let mut buffer: Vec = Vec::with_capacity(self.max_pdu_length as usize); + write_pdu(&mut buffer, &pdu).context(SendPduSnafu)?; + for extra_pdu in extra_pdus { + write_pdu(&mut buffer, &extra_pdu).context(SendPduSnafu)?; + } + socket.write_all(&buffer).context(WireSendSnafu)?; + + Ok(ServerAssociation { + presentation_contexts, + requestor_max_pdu_length: peer_max_pdu_length, + acceptor_max_pdu_length: self.max_pdu_length, + socket, + client_ae_title: calling_ae_title, + buffer, + strict: self.strict, + read_buffer: BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize), + read_timeout: self.read_timeout, + write_timeout: self.write_timeout, + }) + } + + // Broken implementation of server establish which sends an extra pdu during establish + #[cfg(feature = "async")] + pub(crate) async fn establish_with_extra_pdus_async( + &self, mut socket: tokio::net::TcpStream, extra_pdus: Vec + ) -> Result> { + use tokio::io::AsyncWriteExt; + + use crate::association::read_pdu_from_wire_async; + + let mut buf = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + let pdu = read_pdu_from_wire_async(&mut socket, &mut buf, MAXIMUM_PDU_SIZE, self.strict).await?; + let ( + pdu, + NegotiatedOptions{ user_variables: _, presentation_contexts , peer_max_pdu_length}, + calling_ae_title + ) = self.process_a_association_rq(pdu) + .expect("Could not parse association req"); + + let mut buffer: Vec = Vec::with_capacity(self.max_pdu_length as usize); + write_pdu(&mut buffer, &pdu).context(SendPduSnafu)?; + for extra_pdu in extra_pdus { + write_pdu(&mut buffer, &extra_pdu).context(SendPduSnafu)?; + } + socket.write_all(&buffer).await.context(WireSendSnafu)?; + + Ok(ServerAssociation { + presentation_contexts, + requestor_max_pdu_length: peer_max_pdu_length, + acceptor_max_pdu_length: self.max_pdu_length, + socket, + client_ae_title: calling_ae_title, + buffer, + strict: self.strict, + read_buffer: BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize), + read_timeout: self.read_timeout, + write_timeout: self.write_timeout, + }) + } + + // Broken implementation of server establish which reproduces behavior that #589 introduced + pub fn broken_establish(&self, mut socket: TcpStream) -> Result> { + let mut buf = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + let msg = read_pdu_from_wire(&mut socket, &mut buf, MAXIMUM_PDU_SIZE, self.strict)?; + let mut buffer: Vec = Vec::with_capacity(self.max_pdu_length as usize); + let ( + pdu, + NegotiatedOptions{user_variables: _, presentation_contexts , peer_max_pdu_length}, + calling_ae_title + ) = self.process_a_association_rq(msg).expect("Could not parse association req"); + write_pdu(&mut buffer, &pdu).context(SendPduSnafu)?; + socket.write_all(&buffer).context(WireSendSnafu)?; + Ok(ServerAssociation { + presentation_contexts, + requestor_max_pdu_length: peer_max_pdu_length, + acceptor_max_pdu_length: self.max_pdu_length, + socket, + client_ae_title: calling_ae_title, + buffer, + strict: self.strict, + read_buffer: BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize), + read_timeout: self.read_timeout, + write_timeout: self.write_timeout, + }) + } + + // Broken implementation of server establish which reproduces behavior that #589 introduced + #[cfg(feature = "async")] + pub async fn broken_establish_async(&self, mut socket: tokio::net::TcpStream) -> Result> { + use tokio::io::AsyncWriteExt; + + use crate::association::read_pdu_from_wire_async; + + let mut buf = BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize); + let msg = read_pdu_from_wire_async(&mut socket, &mut buf, MAXIMUM_PDU_SIZE, self.strict).await?; + let mut buffer: Vec = Vec::with_capacity(self.max_pdu_length as usize); + let ( + pdu, + NegotiatedOptions{user_variables: _, presentation_contexts , peer_max_pdu_length}, + calling_ae_title + ) = self.process_a_association_rq(msg).expect("Could not parse association req"); + write_pdu(&mut buffer, &pdu).context(SendPduSnafu)?; + socket.write_all(&buffer).await.context(WireSendSnafu)?; + Ok(ServerAssociation { + presentation_contexts, + requestor_max_pdu_length: peer_max_pdu_length, + acceptor_max_pdu_length: self.max_pdu_length, + socket, + client_ae_title: calling_ae_title, + buffer, + strict: self.strict, + read_buffer: BytesMut::with_capacity(MAXIMUM_PDU_SIZE as usize), + read_timeout: self.read_timeout, + write_timeout: self.write_timeout, + }) + } + } + } diff --git a/ul/src/association/tests.rs b/ul/src/association/tests.rs new file mode 100644 index 00000000..73f770d5 --- /dev/null +++ b/ul/src/association/tests.rs @@ -0,0 +1,744 @@ + +use dicom_core::{dicom_value, DataElement, VR}; +use dicom_dictionary_std::{tags, uids::VERIFICATION}; +use dicom_object::InMemDicomObject; +use dicom_transfer_syntax_registry::entries::IMPLICIT_VR_LITTLE_ENDIAN; +// Helper funtion to create a C-ECHO command +fn create_c_echo_command(message_id: u16) -> Vec { + let obj = InMemDicomObject::command_from_element_iter([ + // Affected SOP Class UID - Verification SOP Class + DataElement::new(tags::AFFECTED_SOP_CLASS_UID, VR::UI, VERIFICATION), + // Command Field - C-ECHO-RQ + DataElement::new(tags::COMMAND_FIELD, VR::US, dicom_value!(U16, [0x0030])), + // Message ID + DataElement::new(tags::MESSAGE_ID, VR::US, dicom_value!(U16, [message_id])), + // Command Data Set Type - No data set present + DataElement::new( + tags::COMMAND_DATA_SET_TYPE, + VR::US, + dicom_value!(U16, [0x0101]), + ), + ]); + + let mut data = Vec::new(); + let ts = IMPLICIT_VR_LITTLE_ENDIAN.erased(); + obj.write_dataset_with_ts(&mut data, &ts) + .expect("Failed to serialize C-ECHO command"); + + data +} +mod successive_pdus_during_client_association { + use std::net::TcpListener; + use super::*; + use crate::{pdu::{PDataValue, PDataValueType}, ClientAssociationOptions, Pdu}; + + use crate::association::server::*; + + #[test] + fn test_baseline_sync() { + // Immediately _after_ association, the server sends a C-ECHO command + // This will be received by the client + + // Setup a mock server that will send multiple PDUs consecutively + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let server_addr = listener.local_addr().unwrap(); + + // Create a second PDU (C-ECHO command) to send immediately after + let echo_pdu = Pdu::PData { data: vec![ + PDataValue { + presentation_context_id: 1, + data: create_c_echo_command(1), + value_type: PDataValueType::Command, + is_last: true + } + ]}; + let server_pdu = echo_pdu.clone(); + + // Spawn server thread that sends multiple PDUs back-to-back + let server_handle = std::thread::spawn(move || { + let (stream, _) = listener.accept().unwrap(); + + // Use ServerAssociationOptions to establish the association + let server_options = ServerAssociationOptions::new() + .accept_any() + .with_abstract_syntax(VERIFICATION) + .ae_title("THIS-SCP"); + + let mut association = server_options.establish(stream).unwrap(); + + // Send the second PDU (C-ECHO command) immediately after establishment + association.send(&server_pdu).unwrap(); + }); + + // Give server time to start + std::thread::sleep(std::time::Duration::from_millis(10)); + + // Create client and attempt association + let scu_options = ClientAssociationOptions::new() + .with_abstract_syntax(VERIFICATION) + .calling_ae_title("RANDOM") + .called_ae_title("THIS-SCP") + .read_timeout(std::time::Duration::from_secs(5)); + + // This should succeed in establishing the association despite multiple PDUs + let mut association = scu_options.establish(server_addr).unwrap(); + + // Client should be able to receive the release request that was sent consecutively + let received_pdu = association.receive().unwrap(); + assert_eq!(received_pdu, echo_pdu); + + // Clean shutdown + drop(association); + server_handle.join().unwrap(); + } + + // Tests edge case where the server sends an extra PDU during association + // client should be able to handle this gracefully. + #[test] + fn test_association_sends_extra_pdu_fails() { + // During association, the server sends a C-ECHO command + // This will be received by the client + + // Setup a mock server that will send multiple PDUs consecutively + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let server_addr = listener.local_addr().unwrap(); + + // Create a second PDU (C-ECHO command) to send immediately after + let echo_pdu = Pdu::PData { data: vec![ + PDataValue { + presentation_context_id: 1, + data: create_c_echo_command(1), + value_type: PDataValueType::Command, + is_last: true + } + ]}; + let server_pdu = echo_pdu.clone(); + + // Spawn server thread that sends multiple PDUs back-to-back + let server_handle = std::thread::spawn(move || { + let (stream, _) = listener.accept().unwrap(); + + // Use ServerAssociationOptions to establish the association + let server_options = ServerAssociationOptions::new() + .accept_any() + .with_abstract_syntax(VERIFICATION) + .ae_title("THIS-SCP"); + + server_options.establish_with_extra_pdus(stream, vec![server_pdu]).unwrap(); + }); + + // Give server time to start + std::thread::sleep(std::time::Duration::from_millis(10)); + + // Create client and attempt association + let scu_options = ClientAssociationOptions::new() + .with_abstract_syntax(VERIFICATION) + .calling_ae_title("RANDOM") + .called_ae_title("THIS-SCP") + .read_timeout(std::time::Duration::from_secs(5)); + + // This should succeed in establishing the association despite multiple PDUs + let mut association = scu_options.establish(server_addr).unwrap(); + + // Client should be able to receive the release request that was sent consecutively + let received_pdu = association.receive().unwrap(); + assert_eq!(received_pdu, echo_pdu); + + // Clean shutdown + drop(association); + server_handle.join().unwrap(); + } + + #[cfg(feature = "async")] + #[tokio::test(flavor = "multi_thread")] + async fn test_baseline_async() { + // Immediately _after_ association, the server sends a C-ECHO command + // This will be received by the client + + // Setup a mock server that will send multiple PDUs consecutively + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let server_addr = listener.local_addr().unwrap(); + + // Create a second PDU (C-ECHO command) to send immediately after + let echo_pdu = Pdu::PData { data: vec![ + PDataValue { + presentation_context_id: 1, + data: create_c_echo_command(1), + value_type: PDataValueType::Command, + is_last: true + } + ]}; + let server_pdu = echo_pdu.clone(); + + // Spawn server task that sends multiple PDUs back-to-back + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + + // Use ServerAssociationOptions to establish the association + let server_options = ServerAssociationOptions::new() + .accept_any() + .with_abstract_syntax(VERIFICATION) + .ae_title("THIS-SCP"); + + let mut association = server_options.establish_async(stream).await.unwrap(); + + // Send the second PDU (C-ECHO command) immediately after establishment + association.send(&server_pdu).await.unwrap(); + }); + + // Give server time to start + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Create client and attempt association + let scu_options = ClientAssociationOptions::new() + .with_abstract_syntax(VERIFICATION) + .calling_ae_title("RANDOM") + .called_ae_title("THIS-SCP") + .read_timeout(std::time::Duration::from_secs(5)); + + // This should succeed in establishing the association despite multiple PDUs + let mut association = scu_options.establish_async(server_addr).await.unwrap(); + + // Client should be able to receive the release request that was sent consecutively + let received_pdu = association.receive().await.unwrap(); + assert_eq!(received_pdu, echo_pdu); + + // Clean shutdown + drop(association); + server_handle.await.unwrap(); + } + + // Tests edge case where the server sends an extra PDU during association + // client should be able to handle this gracefully. + #[cfg(feature = "async")] + #[tokio::test(flavor = "multi_thread")] + async fn test_association_sends_extra_pdu_fails_async() { + // During association, the server sends a C-ECHO command + // This will be received by the client + + // Setup a mock server that will send multiple PDUs consecutively + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let server_addr = listener.local_addr().unwrap(); + + // Create a second PDU (C-ECHO command) to send immediately after + let echo_pdu = Pdu::PData { data: vec![ + PDataValue { + presentation_context_id: 1, + data: create_c_echo_command(1), + value_type: PDataValueType::Command, + is_last: true + } + ]}; + let server_pdu = echo_pdu.clone(); + + // Spawn server task that sends multiple PDUs back-to-back + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + + // Use ServerAssociationOptions to establish the association + let server_options = ServerAssociationOptions::new() + .accept_any() + .with_abstract_syntax(VERIFICATION) + .ae_title("THIS-SCP"); + + server_options.establish_with_extra_pdus_async(stream, vec![server_pdu]).await.unwrap(); + }); + + // Give server time to start + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Create client and attempt association + let scu_options = ClientAssociationOptions::new() + .with_abstract_syntax(VERIFICATION) + .calling_ae_title("RANDOM") + .called_ae_title("THIS-SCP") + .read_timeout(std::time::Duration::from_secs(5)); + + // This should succeed in establishing the association despite multiple PDUs + let mut association = scu_options.establish_async(server_addr).await.unwrap(); + + // Client should be able to receive the release request that was sent consecutively + let received_pdu = association.receive().await.unwrap(); + assert_eq!(received_pdu, echo_pdu); + + // Clean shutdown + drop(association); + server_handle.await.unwrap(); + } + + // Tests edge case where the client sends an extra PDU during association + // using a broken client implementation that creates a new buffer instead of reusing it + #[test] + fn test_client_association_sends_extra_pdu_589_impl() { + // During association, the client's broken implementation drops extra PDUs + // This reproduces the behavior that #589 fixed + + // Setup a mock server + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let server_addr = listener.local_addr().unwrap(); + + // Create a PDU (C-ECHO command) that should be lost + let echo_pdu = Pdu::PData { data: vec![ + PDataValue { + presentation_context_id: 1, + data: create_c_echo_command(1), + value_type: PDataValueType::Command, + is_last: true + } + ]}; + + // Spawn server thread that sends extra PDUs during association + let server_handle = std::thread::spawn(move || { + let (stream, _) = listener.accept().unwrap(); + + // Use ServerAssociationOptions with extra PDUs during association + let server_options = ServerAssociationOptions::new() + .accept_any() + .with_abstract_syntax(VERIFICATION) + .ae_title("THIS-SCP"); + + server_options.establish_with_extra_pdus(stream, vec![echo_pdu]).unwrap(); + }); + // Give server time to start + std::thread::sleep(std::time::Duration::from_millis(10)); + + // Create client and attempt association + let scu_options = ClientAssociationOptions::new() + .with_abstract_syntax(VERIFICATION) + .calling_ae_title("RANDOM") + .called_ae_title("THIS-SCP") + .read_timeout(std::time::Duration::from_secs(5)); + + // This should succeed in establishing the association despite multiple PDUs + let mut association = scu_options.broken_establish(server_addr.into()).unwrap(); + + // Client should not have anything to receive + let received_pdu = association.receive(); + assert!(received_pdu.is_err()); + + // Client cannot receive the PDU that was sent during association + // Clean shutdown + drop(association); + server_handle.join().unwrap(); + } + + #[cfg(feature = "async")] + #[tokio::test(flavor = "multi_thread")] + async fn test_client_association_sends_extra_pdu_589_impl_async() { + // During association, the client's broken implementation drops extra PDUs + // This reproduces the behavior that #589 fixed + + // Setup a mock server + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let server_addr = listener.local_addr().unwrap(); + + // Create a PDU (C-ECHO command) that should be lost + let echo_pdu = Pdu::PData { data: vec![ + PDataValue { + presentation_context_id: 1, + data: create_c_echo_command(1), + value_type: PDataValueType::Command, + is_last: true + } + ]}; + + // Spawn server task that sends extra PDUs during association + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + + // Use ServerAssociationOptions with extra PDUs during association + let server_options = ServerAssociationOptions::new() + .accept_any() + .with_abstract_syntax(VERIFICATION) + .ae_title("THIS-SCP"); + + server_options.establish_with_extra_pdus_async(stream, vec![echo_pdu]).await.unwrap(); + }); + + // Give server time to start + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Create client using broken implementation (creates new buffer) + let scu_options = ClientAssociationOptions::new() + .with_abstract_syntax(VERIFICATION) + .calling_ae_title("RANDOM") + .called_ae_title("THIS-SCP") + .read_timeout(std::time::Duration::from_secs(5)); + + // Client's broken implementation will miss the extra PDU from server + let mut association = scu_options.broken_establish_async( + server_addr.into() + ).await.unwrap(); + + // Client should be able to receive the release request that was sent consecutively + let received_pdu = association.receive().await; + assert!(received_pdu.is_err()); + + // Client cannot receive the PDU that was sent during association + // Clean shutdown + drop(association); + server_handle.await.unwrap(); + } + +} + +mod successive_pdus_during_server_association { + use super::*; + use std::net::TcpListener; + use crate::{pdu::{PDataValue, PDataValueType}, ClientAssociationOptions, Pdu, AeAddr}; + use crate::association::server::*; + + + #[test] + fn test_server_baseline_sync() { + // Immediately _after_ association, the client sends a C-ECHO command + // This will be received by the server + + // Setup server listener + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let server_addr = listener.local_addr().unwrap(); + + // Create a PDU (C-ECHO command) to send immediately after association + let echo_pdu = Pdu::PData { data: vec![ + PDataValue { + presentation_context_id: 1, + data: create_c_echo_command(1), + value_type: PDataValueType::Command, + is_last: true + } + ]}; + let client_pdu = echo_pdu.clone(); + + // Spawn server thread + let server_handle = std::thread::spawn(move || { + let (stream, _) = listener.accept().unwrap(); + + // Use ServerAssociationOptions to establish the association + let server_options = ServerAssociationOptions::new() + .accept_any() + .with_abstract_syntax(VERIFICATION) + .ae_title("THIS-SCP"); + + let mut association = server_options.establish(stream).unwrap(); + + // Server should be able to receive the PDU sent by client after association + let received_pdu = association.receive().unwrap(); + assert_eq!(received_pdu, echo_pdu); + }); + + // Give server time to start + std::thread::sleep(std::time::Duration::from_millis(10)); + + // Create client and attempt association, then send PDU immediately after + let scu_options = ClientAssociationOptions::new() + .with_abstract_syntax(VERIFICATION) + .calling_ae_title("RANDOM") + .called_ae_title("THIS-SCP") + .read_timeout(std::time::Duration::from_secs(5)); + + // Establish association and send PDU immediately after + let mut association = scu_options.establish(server_addr).unwrap(); + + // Send the PDU immediately after establishment + association.send(&client_pdu).unwrap(); + + // Clean shutdown + drop(association); + server_handle.join().unwrap(); + } + + #[test] + fn test_server_association_receives_extra_pdu() { + // During association, the client sends an extra C-ECHO command + // Server should be able to handle this gracefully + + // Setup server listener + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let server_addr = listener.local_addr().unwrap(); + + // Create a PDU (C-ECHO command) to send during association + let echo_pdu = Pdu::PData { data: vec![ + PDataValue { + presentation_context_id: 1, + data: create_c_echo_command(1), + value_type: PDataValueType::Command, + is_last: true + } + ]}; + let client_pdu = echo_pdu.clone(); + + // Spawn server thread + let server_handle = std::thread::spawn(move || { + let (stream, _) = listener.accept().unwrap(); + + // Use ServerAssociationOptions to establish the association + let server_options = ServerAssociationOptions::new() + .accept_any() + .with_abstract_syntax(VERIFICATION) + .ae_title("THIS-SCP"); + + let mut association = server_options.establish(stream).unwrap(); + + // Server should be able to receive the extra PDU that was sent during association + let received_pdu = association.receive().unwrap(); + assert_eq!(received_pdu, echo_pdu); + }); + + // Give server time to start + std::thread::sleep(std::time::Duration::from_millis(10)); + + // Create client that sends extra PDU during association + let scu_options = ClientAssociationOptions::new() + .with_abstract_syntax(VERIFICATION) + .calling_ae_title("RANDOM") + .called_ae_title("THIS-SCP") + .read_timeout(std::time::Duration::from_secs(5)); + + // Use the test method that sends extra PDUs during association + let association = scu_options.establish_with_extra_pdus( + AeAddr::new_socket_addr(server_addr), + vec![client_pdu] + ).unwrap(); + + // Clean shutdown + drop(association); + server_handle.join().unwrap(); + } + + #[test] + fn test_server_association_receives_extra_pdu_589_impl() { + // Reproduce behavior that #589 introduced + // During association, the client sends an extra C-ECHO command + // Server should _not_ be able to handle this gracefully + + // Setup server listener + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let server_addr = listener.local_addr().unwrap(); + + // Create a PDU (C-ECHO command) to send during association + let echo_pdu = Pdu::PData { data: vec![ + PDataValue { + presentation_context_id: 1, + data: create_c_echo_command(1), + value_type: PDataValueType::Command, + is_last: true + } + ]}; + let client_pdu = echo_pdu.clone(); + + // Spawn server thread + let server_handle = std::thread::spawn(move || { + let (stream, _) = listener.accept().unwrap(); + + // Use ServerAssociationOptions to establish the association + let server_options = ServerAssociationOptions::new() + .accept_any() + .with_abstract_syntax(VERIFICATION) + .ae_title("THIS-SCP"); + + let mut association = server_options.broken_establish(stream).unwrap(); + + // Server misses the echo request entirely + let received_pdu = association.receive().unwrap(); + assert_eq!(received_pdu, Pdu::ReleaseRQ); + }); + + // Give server time to start + std::thread::sleep(std::time::Duration::from_millis(10)); + + // Create client that sends extra PDU during association + let scu_options = ClientAssociationOptions::new() + .with_abstract_syntax(VERIFICATION) + .calling_ae_title("RANDOM") + .called_ae_title("THIS-SCP") + .read_timeout(std::time::Duration::from_secs(5)); + + // Use the test method that sends extra PDUs during association + let association = scu_options.establish_with_extra_pdus( + AeAddr::new_socket_addr(server_addr), + vec![client_pdu] + ).unwrap(); + + // Clean shutdown + drop(association); + server_handle.join().unwrap(); + } + + #[cfg(feature = "async")] + #[tokio::test(flavor = "multi_thread")] + async fn test_server_baseline_async() { + // Immediately _after_ association, the client sends a C-ECHO command + // This will be received by the server + + // Setup server listener + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let server_addr = listener.local_addr().unwrap(); + + // Create a PDU (C-ECHO command) to send immediately after association + let echo_pdu = Pdu::PData { data: vec![ + PDataValue { + presentation_context_id: 1, + data: create_c_echo_command(1), + value_type: PDataValueType::Command, + is_last: true + } + ]}; + let client_pdu = echo_pdu.clone(); + + // Spawn server task + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + + // Use ServerAssociationOptions to establish the association + let server_options = ServerAssociationOptions::new() + .accept_any() + .with_abstract_syntax(VERIFICATION) + .ae_title("THIS-SCP"); + + let mut association = server_options.establish_async(stream).await.unwrap(); + + // Server should be able to receive the PDU sent by client after association + let received_pdu = association.receive().await.unwrap(); + assert_eq!(received_pdu, echo_pdu); + }); + + // Give server time to start + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Create client and attempt association, then send PDU immediately after + let scu_options = ClientAssociationOptions::new() + .with_abstract_syntax(VERIFICATION) + .calling_ae_title("RANDOM") + .called_ae_title("THIS-SCP") + .read_timeout(std::time::Duration::from_secs(5)); + + // Establish association and send PDU immediately after + let mut association = scu_options.establish_async(server_addr).await.unwrap(); + + // Send the PDU immediately after establishment + association.send(&client_pdu).await.unwrap(); + + // Clean shutdown + drop(association); + server_handle.await.unwrap(); + } + + #[cfg(feature = "async")] + #[tokio::test(flavor = "multi_thread")] + async fn test_server_association_receives_extra_pdu_async() { + // During association, the client sends an extra C-ECHO command + // Server should be able to handle this gracefully + + // Setup server listener + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let server_addr = listener.local_addr().unwrap(); + + // Create a PDU (C-ECHO command) to send during association + let echo_pdu = Pdu::PData { data: vec![ + PDataValue { + presentation_context_id: 1, + data: create_c_echo_command(1), + value_type: PDataValueType::Command, + is_last: true + } + ]}; + let client_pdu = echo_pdu.clone(); + + // Spawn server task + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + + // Use ServerAssociationOptions to establish the association + let server_options = ServerAssociationOptions::new() + .accept_any() + .with_abstract_syntax(VERIFICATION) + .ae_title("THIS-SCP"); + + let mut association = server_options.establish_async(stream).await.unwrap(); + + // Server should be able to receive the extra PDU that was sent during association + let received_pdu = association.receive().await.unwrap(); + assert_eq!(received_pdu, echo_pdu); + }); + + // Give server time to start + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Create client that sends extra PDU during association + let scu_options = ClientAssociationOptions::new() + .with_abstract_syntax(VERIFICATION) + .calling_ae_title("RANDOM") + .called_ae_title("THIS-SCP") + .read_timeout(std::time::Duration::from_secs(5)); + + // Use the test method that sends extra PDUs during association + let association = scu_options.establish_with_extra_pdus_async( + AeAddr::new_socket_addr(server_addr), + vec![client_pdu] + ).await.unwrap(); + + // Clean shutdown + drop(association); + server_handle.await.unwrap(); + } + + #[cfg(feature = "async")] + #[tokio::test(flavor = "multi_thread")] + async fn test_server_association_receives_extra_pdu_589_impl_async() { + // Reproduce behavior that #589 introduced + // During association, the client sends an extra C-ECHO command + // Server should _not_ be able to handle this gracefully + + // Setup server listener + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let server_addr = listener.local_addr().unwrap(); + + // Create a PDU (C-ECHO command) to send during association + let echo_pdu = Pdu::PData { data: vec![ + PDataValue { + presentation_context_id: 1, + data: create_c_echo_command(1), + value_type: PDataValueType::Command, + is_last: true + } + ]}; + let client_pdu = echo_pdu.clone(); + + // Spawn server task + let server_handle = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + + // Use ServerAssociationOptions to establish the association + let server_options = ServerAssociationOptions::new() + .accept_any() + .with_abstract_syntax(VERIFICATION) + .ae_title("THIS-SCP"); + + let mut association = server_options.broken_establish_async(stream).await.unwrap(); + + // Server misses the echo request entirely + let received_pdu = association.receive().await.unwrap(); + assert_eq!(received_pdu, Pdu::ReleaseRQ); + }); + + // Give server time to start + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Create client that sends extra PDU during association + let scu_options = ClientAssociationOptions::new() + .with_abstract_syntax(VERIFICATION) + .calling_ae_title("RANDOM") + .called_ae_title("THIS-SCP") + .read_timeout(std::time::Duration::from_secs(5)); + + // Use the test method that sends extra PDUs during association + let association = scu_options.establish_with_extra_pdus_async( + AeAddr::new_socket_addr(server_addr), + vec![client_pdu] + ).await.unwrap(); + + // Clean shutdown + drop(association); + server_handle.await.unwrap(); + } +} diff --git a/ul/tests/association.rs b/ul/tests/association.rs index e5cd11c5..cae9c2c4 100644 --- a/ul/tests/association.rs +++ b/ul/tests/association.rs @@ -5,6 +5,7 @@ use std::time::Instant; const TIMEOUT_TOLERANCE: u64 = 25; + #[rstest] #[case(100)] #[case(500)] @@ -52,4 +53,4 @@ async fn test_slow_association_async(#[case] timeout: u64) { elapsed.as_millis(), timeout ); -} +} \ No newline at end of file diff --git a/ul/tests/association_promiscuous.rs b/ul/tests/association_promiscuous.rs index a30a432c..6e6c9a1b 100644 --- a/ul/tests/association_promiscuous.rs +++ b/ul/tests/association_promiscuous.rs @@ -1,6 +1,6 @@ use std::net::SocketAddr; -use dicom_ul::association::client::Error::NoAcceptedPresentationContexts; +use dicom_ul::association::Error::NoAcceptedPresentationContexts; use dicom_ul::pdu::PresentationContextResultReason::Acceptance; use dicom_ul::pdu::{PresentationContextNegotiated, PresentationContextResultReason, UserVariableItem}; use dicom_ul::{