diff --git a/lib/srv/desktop/rdp/rdpclient/src/rdpdr.rs b/lib/srv/desktop/rdp/rdpclient/src/rdpdr.rs index a03a1368347cd..ca0efe4db83e8 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/rdpdr.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/rdpdr.rs @@ -49,6 +49,7 @@ use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::ffi::CString; use std::io::{Read, Seek, SeekFrom}; +use std::vec; /// Client implements a device redirection (RDPDR) client, as defined in /// https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-RDPEFS/%5bMS-RDPEFS%5d.pdf @@ -326,24 +327,29 @@ impl Client { return Err(Error::TryError("received a drive redirection major function when drive redirection was not allowed".to_string())); } - let output = if is_smart_card_op { + let device_control_responses = if is_smart_card_op { // Smart card control - if let Some(res) = self.scard.ioctl(ioctl.io_control_code, payload)? { - res - } else { - return Ok(vec![]); - } + self.scard.ioctl(&ioctl, payload)? } else { // Drive redirection, mimic FreeRDP's "no-op" // https://github.com/FreeRDP/FreeRDP/blob/511444a65e7aa2f537c5e531fa68157a50c1bd4d/channels/drive/client/drive_main.c#L677-L684 - Box::new(NoOp::new()) + vec![DeviceControlResponse::new( + &ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(NoOp::new()), + )] }; - let resp = DeviceControlResponse::new(&ioctl, NTSTATUS::STATUS_SUCCESS, output); - debug!("sending RDP: {:?}", resp); - let resp = self - .add_headers_and_chunkify(PacketId::PAKID_CORE_DEVICE_IOCOMPLETION, resp.encode()?)?; - Ok(resp) + let mut messages: Messages = vec![]; + for resp in device_control_responses { + debug!("sending RDP: {:?}", resp); + messages.extend(self.add_headers_and_chunkify( + PacketId::PAKID_CORE_DEVICE_IOCOMPLETION, + resp.encode()?, + )?); + } + + Ok(messages) } fn process_irp_create( @@ -2129,6 +2135,23 @@ pub struct DeviceIoRequest { } impl DeviceIoRequest { + #[cfg(test)] + fn new( + device_id: u32, + file_id: u32, + completion_id: u32, + major_function: MajorFunction, + minor_function: MinorFunction, + ) -> Self { + Self { + device_id, + file_id, + completion_id, + major_function, + minor_function, + } + } + fn decode(payload: &mut Payload) -> RdpResult { let device_id = payload.read_u32::()?; let file_id = payload.read_u32::()?; @@ -2192,6 +2215,21 @@ struct DeviceControlRequest { } impl DeviceControlRequest { + #[cfg(test)] + fn new( + header: DeviceIoRequest, + output_buffer_length: u32, + input_buffer_length: u32, + io_control_code: IoctlCode, + ) -> Self { + Self { + header, + output_buffer_length, + input_buffer_length, + io_control_code, + } + } + fn decode(header: DeviceIoRequest, payload: &mut Payload) -> RdpResult { let output_buffer_length = payload.read_u32::()?; let input_buffer_length = payload.read_u32::()?; diff --git a/lib/srv/desktop/rdp/rdpclient/src/rdpdr/scard.rs b/lib/srv/desktop/rdp/rdpclient/src/rdpdr/scard.rs index f1dadaca85e29..4a15480664029 100644 --- a/lib/srv/desktop/rdp/rdpclient/src/rdpdr/scard.rs +++ b/lib/srv/desktop/rdp/rdpclient/src/rdpdr/scard.rs @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::errors::invalid_data_error; +use crate::errors::{invalid_data_error, not_implemented_error}; +use crate::rdpdr::consts::NTSTATUS; use crate::{piv, Message}; use crate::{Encode, Payload}; use bitflags::bitflags; @@ -25,8 +26,11 @@ use std::char::{decode_utf16, REPLACEMENT_CHARACTER}; use std::collections::HashMap; use std::convert::TryInto; use std::io::{Read, Write}; +use std::vec; use uuid::Uuid; +use super::{DeviceControlRequest, DeviceControlResponse}; + // Client implements the smartcard emulator, forwarded over an RDP virtual channel. // Spec: https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-RDPESC/%5bMS-RDPESC%5d.pdf // @@ -57,10 +61,10 @@ impl Client { // ioctl handles messages coming from the RDP server over the RDPDR channel. pub(super) fn ioctl( &mut self, - code: IoctlCode, + ioctl: &DeviceControlRequest, input: &mut Payload, - ) -> RdpResult>> { - debug!("got IoctlCode {:?}", &code); + ) -> RdpResult> { + debug!("got IoctlCode {:?}", &ioctl.io_control_code); // Note: this is an incomplete implementation of the scard API. // It's the bare minimum needed to make RDP authentication using a smartcard work. // @@ -68,115 +72,215 @@ impl Client { // fail, but most modern Windows hosts shouldn't call those. If you're reading this because // some SCARD_IOCTL_*A call is failing, I was wrong and you'll have to implement the Ascii // calls. - match code { - IoctlCode::SCARD_IOCTL_ACCESSSTARTEDEVENT => self.handle_access_started_event(input), - IoctlCode::SCARD_IOCTL_ESTABLISHCONTEXT => self.handle_establish_context(input), - IoctlCode::SCARD_IOCTL_RELEASECONTEXT => self.handle_release_context(input), - IoctlCode::SCARD_IOCTL_CANCEL => self.handle_cancel(input), - IoctlCode::SCARD_IOCTL_ISVALIDCONTEXT => self.handle_is_valid_context(input), - IoctlCode::SCARD_IOCTL_LISTREADERSW => self.handle_list_readers(input), - IoctlCode::SCARD_IOCTL_GETSTATUSCHANGEW => self.handle_get_status_change(input), - IoctlCode::SCARD_IOCTL_CONNECTW => self.handle_connect(input), - IoctlCode::SCARD_IOCTL_DISCONNECT => self.handle_disconnect(input), - IoctlCode::SCARD_IOCTL_BEGINTRANSACTION => self.handle_begin_transaction(input), - IoctlCode::SCARD_IOCTL_ENDTRANSACTION => self.handle_end_transaction(input), - IoctlCode::SCARD_IOCTL_STATUSA => self.handle_status(input, StringEncoding::Ascii), - IoctlCode::SCARD_IOCTL_STATUSW => self.handle_status(input, StringEncoding::Unicode), + match ioctl.io_control_code { + IoctlCode::SCARD_IOCTL_ACCESSSTARTEDEVENT => { + self.handle_access_started_event(ioctl, input) + } + IoctlCode::SCARD_IOCTL_ESTABLISHCONTEXT => self.handle_establish_context(ioctl, input), + IoctlCode::SCARD_IOCTL_RELEASECONTEXT => self.handle_release_context(ioctl, input), + IoctlCode::SCARD_IOCTL_CANCEL => self.handle_cancel(ioctl, input), + IoctlCode::SCARD_IOCTL_ISVALIDCONTEXT => self.handle_is_valid_context(ioctl, input), + IoctlCode::SCARD_IOCTL_LISTREADERSW => self.handle_list_readers(ioctl, input), + IoctlCode::SCARD_IOCTL_GETSTATUSCHANGEW => self.handle_get_status_change(ioctl, input), + IoctlCode::SCARD_IOCTL_CONNECTW => self.handle_connect(ioctl, input), + IoctlCode::SCARD_IOCTL_DISCONNECT => self.handle_disconnect(ioctl, input), + IoctlCode::SCARD_IOCTL_BEGINTRANSACTION => self.handle_begin_transaction(ioctl, input), + IoctlCode::SCARD_IOCTL_ENDTRANSACTION => self.handle_end_transaction(ioctl, input), + IoctlCode::SCARD_IOCTL_STATUSA => { + self.handle_status(ioctl, input, StringEncoding::Ascii) + } + IoctlCode::SCARD_IOCTL_STATUSW => { + self.handle_status(ioctl, input, StringEncoding::Unicode) + } // Transmit is where communication with the actual smartcard (and the PIV application // on it) happens. All other messages are managing the smartcard reader and // establishing a connection to the smartcard. - IoctlCode::SCARD_IOCTL_TRANSMIT => self.handle_transmit(input), - IoctlCode::SCARD_IOCTL_GETDEVICETYPEID => self.handle_get_device_type_id(input), + IoctlCode::SCARD_IOCTL_TRANSMIT => self.handle_transmit(ioctl, input), + IoctlCode::SCARD_IOCTL_GETDEVICETYPEID => self.handle_get_device_type_id(ioctl, input), // Note: we keep an in-memory hashmap as a cache to implement these commands. Windows // doesn't seem to like a smartcard without a functioning cache. - IoctlCode::SCARD_IOCTL_READCACHEW => self.handle_read_cache(input), - IoctlCode::SCARD_IOCTL_WRITECACHEW => self.handle_write_cache(input), - IoctlCode::SCARD_IOCTL_GETREADERICON => self.handle_get_reader_icon(input), - _ => self.handle_unimplemented_ioctl(code), + IoctlCode::SCARD_IOCTL_READCACHEW => self.handle_read_cache(ioctl, input), + IoctlCode::SCARD_IOCTL_WRITECACHEW => self.handle_write_cache(ioctl, input), + IoctlCode::SCARD_IOCTL_GETREADERICON => self.handle_get_reader_icon(ioctl, input), + _ => self.handle_unimplemented_ioctl(ioctl, ioctl.io_control_code), } } fn handle_access_started_event( &self, + ioctl: &DeviceControlRequest, input: &mut Payload, - ) -> RdpResult>> { + ) -> RdpResult> { let req = ScardAccessStartedEvent_Call::decode(input)?; debug!("got {:?}", req); let resp = Long_Return::new(ReturnCode::SCARD_S_SUCCESS); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } fn handle_establish_context( &mut self, + ioctl: &DeviceControlRequest, input: &mut Payload, - ) -> RdpResult>> { + ) -> RdpResult> { let req = EstablishContext_Call::decode(input)?; debug!("got {:?}", req); let ctx = self.contexts.establish(); let resp = EstablishContext_Return::new(ReturnCode::SCARD_S_SUCCESS, ctx); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } fn handle_release_context( &mut self, + ioctl: &DeviceControlRequest, input: &mut Payload, - ) -> RdpResult>> { + ) -> RdpResult> { let req = Context_Call::decode(input)?; debug!("got {:?}", req); self.contexts.release(req.context.value); let resp = Long_Return::new(ReturnCode::SCARD_S_SUCCESS); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } - fn handle_cancel(&self, input: &mut Payload) -> RdpResult>> { + fn handle_cancel( + &mut self, + ioctl: &DeviceControlRequest, + input: &mut Payload, + ) -> RdpResult> { + let mut responses = vec![]; let req = Context_Call::decode(input)?; debug!("got {:?}", req); - let resp = Long_Return::new(ReturnCode::SCARD_S_SUCCESS); - debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + + // Fetch the pending SCARD_IOCTL_GETSTATUSCHANGEW response and add it to the responses. + if let Some(dcr) = self + .contexts + .get(req.context.value)? + .scard_cancel_response + .take() + { + responses.push(dcr); + } else { + warn!("Received SCARD_IOCTL_CANCEL for a context without a pending SCARD_IOCTL_GETSTATUSCHANGEW.") + } + + // Also add the response to the SCARD_IOCTL_CANCEL request. + responses.push(DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(Long_Return::new(ReturnCode::SCARD_S_SUCCESS)), + )); + debug!("sending {:?}", responses); + Ok(responses) } - fn handle_is_valid_context(&self, input: &mut Payload) -> RdpResult>> { + fn handle_is_valid_context( + &self, + ioctl: &DeviceControlRequest, + input: &mut Payload, + ) -> RdpResult> { let req = Context_Call::decode(input)?; debug!("got {:?}", req); let resp = Long_Return::new(ReturnCode::SCARD_S_SUCCESS); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } - fn handle_list_readers(&self, input: &mut Payload) -> RdpResult>> { + fn handle_list_readers( + &self, + ioctl: &DeviceControlRequest, + input: &mut Payload, + ) -> RdpResult> { let req = ListReaders_Call::decode(input)?; debug!("got {:?}", req); let resp = ListReaders_Return::new(ReturnCode::SCARD_S_SUCCESS, vec!["Teleport".to_string()]); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } - fn handle_get_status_change(&self, input: &mut Payload) -> RdpResult>> { + fn handle_get_status_change( + &mut self, + ioctl: &DeviceControlRequest, + input: &mut Payload, + ) -> RdpResult> { let req = GetStatusChange_Call::decode(input)?; + let timeout = req.timeout; + let context_value = req.context.value; debug!("got {:?}", req); - let resp = GetStatusChange_Return::new(ReturnCode::SCARD_S_SUCCESS, req); + + if timeout != TIMEOUT_INFINITE && timeout != TIMEOUT_IMMEDIATE { + // We've never seen one of these but we log a warning here in case we ever come + // across one and need to debug a related issue. + warn!( + "logic for a non-infinite/non-immediate timeout [{}] is not implemented", + timeout + ); + } + + let mut resp = GetStatusChange_Return::new(ReturnCode::SCARD_S_SUCCESS, req); if resp.no_change() { - debug!("blocking GetStatusChange call indefinitely, no response since our status will never change"); - Ok(None) + if timeout != TIMEOUT_INFINITE { + return Err(not_implemented_error(&format!( + "no change for non-infinite timeout [{}] is not implemented", + timeout + ))); + } + + // Received a GetStatusChange_Call with an infinite timeout, so we're adding + // a corresponding DeviceControlResponse request holding a GetStatusChange_Return + // with its return code set to SCARD_E_CANCELLED to this Context. This value will + // be returned when we get an SCARD_IOCTL_CANCEL call for this Context. + resp.set_return_code(ReturnCode::SCARD_E_CANCELLED); + self.contexts + .get(context_value)? + .set_scard_cancel_response(DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + ))?; + debug!("blocking GetStatusChange call indefinitely (since our status never changes) until we receive an SCARD_IOCTL_CANCEL"); + Ok(vec![]) } else { debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } } - fn handle_connect(&mut self, input: &mut Payload) -> RdpResult>> { + fn handle_connect( + &mut self, + ioctl: &DeviceControlRequest, + input: &mut Payload, + ) -> RdpResult> { let req = Connect_Call::decode(input)?; debug!("got {:?}", req); - let ctx = self - .contexts - .get(req.common.context.value) - .ok_or_else(|| invalid_data_error("unknown context ID"))?; + let ctx = self.contexts.get(req.common.context.value)?; let handle = ctx.connect( req.common.context, self.uuid, @@ -187,44 +291,72 @@ impl Client { let resp = Connect_Return::new(ReturnCode::SCARD_S_SUCCESS, handle); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } - fn handle_disconnect(&mut self, input: &mut Payload) -> RdpResult>> { + fn handle_disconnect( + &mut self, + ioctl: &DeviceControlRequest, + input: &mut Payload, + ) -> RdpResult> { let req = HCardAndDisposition_Call::decode(input)?; debug!("got {:?}", req); self.contexts - .get(req.handle.context.value) - .ok_or_else(|| invalid_data_error("unknown context ID"))? + .get(req.handle.context.value)? .disconnect(req.handle.value); let resp = Long_Return::new(ReturnCode::SCARD_S_SUCCESS); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } - fn handle_begin_transaction(&self, input: &mut Payload) -> RdpResult>> { + fn handle_begin_transaction( + &self, + ioctl: &DeviceControlRequest, + input: &mut Payload, + ) -> RdpResult> { let req = HCardAndDisposition_Call::decode(input)?; debug!("got {:?}", req); let resp = Long_Return::new(ReturnCode::SCARD_S_SUCCESS); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } - fn handle_end_transaction(&self, input: &mut Payload) -> RdpResult>> { + fn handle_end_transaction( + &self, + ioctl: &DeviceControlRequest, + input: &mut Payload, + ) -> RdpResult> { let req = HCardAndDisposition_Call::decode(input)?; debug!("got {:?}", req); let resp = Long_Return::new(ReturnCode::SCARD_S_SUCCESS); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } fn handle_status( &self, + ioctl: &DeviceControlRequest, input: &mut Payload, enc: StringEncoding, - ) -> RdpResult>> { + ) -> RdpResult> { let req = Status_Call::decode(input)?; debug!("got {:?}", req); let resp = Status_Return::new( @@ -233,10 +365,18 @@ impl Client { enc, ); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } - fn handle_transmit(&mut self, input: &mut Payload) -> RdpResult>> { + fn handle_transmit( + &mut self, + ioctl: &DeviceControlRequest, + input: &mut Payload, + ) -> RdpResult> { let req = Transmit_Call::decode(input)?; debug!("got {:?}", req); @@ -252,8 +392,7 @@ impl Client { let card = self .contexts - .get(req.handle.context.value) - .ok_or_else(|| invalid_data_error("unknown context ID"))? + .get(req.handle.context.value)? .get(req.handle.value) .ok_or_else(|| invalid_data_error("unknown handle ID"))?; @@ -261,77 +400,107 @@ impl Client { let resp = Transmit_Return::new(ReturnCode::SCARD_S_SUCCESS, resp.encode()); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } fn handle_get_device_type_id( &mut self, + ioctl: &DeviceControlRequest, input: &mut Payload, - ) -> RdpResult>> { + ) -> RdpResult> { let req = GetDeviceTypeId_Call::decode(input)?; debug!("got {:?}", req); - let _ctx = self - .contexts - .get(req.context.value) - .ok_or_else(|| invalid_data_error("unknown context ID"))?; + let _ctx = self.contexts.get(req.context.value)?; let resp = GetDeviceTypeId_Return::new(ReturnCode::SCARD_S_SUCCESS); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } - fn handle_read_cache(&mut self, input: &mut Payload) -> RdpResult>> { + fn handle_read_cache( + &mut self, + ioctl: &DeviceControlRequest, + input: &mut Payload, + ) -> RdpResult> { let req = ReadCache_Call::decode(input)?; debug!("got {:?}", req); let val = self .contexts - .get(req.common.context.value) - .ok_or_else(|| invalid_data_error("unknown context ID"))? + .get(req.common.context.value)? .cache_read(&req.lookup_name); let resp = ReadCache_Return::new(val); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } - fn handle_write_cache(&mut self, input: &mut Payload) -> RdpResult>> { + fn handle_write_cache( + &mut self, + ioctl: &DeviceControlRequest, + input: &mut Payload, + ) -> RdpResult> { let req = WriteCache_Call::decode(input)?; debug!("got {:?}", req); self.contexts - .get(req.common.context.value) - .ok_or_else(|| invalid_data_error("unknown context ID"))? + .get(req.common.context.value)? .cache_write(req.lookup_name, req.common.data); let resp = Long_Return::new(ReturnCode::SCARD_S_SUCCESS); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } fn handle_get_reader_icon( &mut self, + ioctl: &DeviceControlRequest, input: &mut Payload, - ) -> RdpResult>> { + ) -> RdpResult> { let req = GetReaderIcon_Call::decode(input)?; debug!("got {:?}", req); - let _ctx = self - .contexts - .get(req.context.value) - .ok_or_else(|| invalid_data_error("unknown context ID"))?; + let _ctx = self.contexts.get(req.context.value)?; let resp = GetReaderIcon_Return::new(ReturnCode::SCARD_E_UNSUPPORTED_FEATURE); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } - fn handle_unimplemented_ioctl(&self, code: IoctlCode) -> RdpResult>> { + fn handle_unimplemented_ioctl( + &self, + ioctl: &DeviceControlRequest, + code: IoctlCode, + ) -> RdpResult> { warn!("unimplemented IOCTL: {:?}", code); let resp = Long_Return::new(ReturnCode::SCARD_F_INTERNAL_ERROR); debug!("sending {:?}", resp); - Ok(Some(Box::new(resp))) + Ok(vec![DeviceControlResponse::new( + ioctl, + NTSTATUS::STATUS_SUCCESS, + Box::new(resp), + )]) } } @@ -548,6 +717,9 @@ impl Encode for ScardAccessStartedEvent_Call { } } +const TIMEOUT_INFINITE: u32 = 0xffffffff; +const TIMEOUT_IMMEDIATE: u32 = 0; + #[derive(Debug, FromPrimitive, ToPrimitive)] #[allow(non_camel_case_types)] #[repr(u32)] @@ -1244,6 +1416,10 @@ impl GetStatusChange_Return { } } + fn set_return_code(&mut self, return_code: ReturnCode) { + self.return_code = return_code; + } + fn no_change(&self) -> bool { for state in &self.reader_states { if state.current_state != state.event_state { @@ -2170,8 +2346,10 @@ impl Contexts { ctx } - fn get(&mut self, id: u32) -> Option<&mut ContextInternal> { - self.contexts.get_mut(&id) + fn get(&mut self, id: u32) -> RdpResult<&mut ContextInternal> { + self.contexts + .get_mut(&id) + .ok_or_else(|| invalid_data_error(&format!("unknown context id: {}", id))) } fn release(&mut self, id: u32) { @@ -2179,11 +2357,19 @@ impl Contexts { } } -#[derive(Debug, PartialEq)] +#[derive(Debug)] struct ContextInternal { handles: HashMap>, next_id: u32, cache: HashMap>, + // If we receive a SCARD_IOCTL_GETSTATUSCHANGEW with an infinite timeout, we need to + // return a GetStatusChange_Return (embedded in a DeviceControlResponse) with + // its return code set to SCARD_E_CANCELLED in the case that we receive a + // SCARD_IOCTL_CANCEL. + // + // This value will be set during the handling of the SCARD_IOCTL_GETSTATUSCHANGEW, so that + // it can be fetched and returned in response to a SCARD_IOCTL_CANCEL. + scard_cancel_response: Option, } impl ContextInternal { @@ -2192,7 +2378,16 @@ impl ContextInternal { next_id: 1, handles: HashMap::new(), cache: HashMap::new(), + scard_cancel_response: None, + } + } + + fn set_scard_cancel_response(&mut self, response: DeviceControlResponse) -> RdpResult<()> { + if self.scard_cancel_response.is_some() { + return Err(invalid_data_error("SCARD_IOCTL_CANCEL already received")); } + self.scard_cancel_response = Some(response); + Ok(()) } fn connect( @@ -2239,7 +2434,13 @@ fn debug_print_payload(payload: &mut Payload) { #[cfg(test)] mod tests { - use crate::Encode; + use crate::{ + rdpdr::{ + consts::{MajorFunction, MinorFunction, SCARD_DEVICE_ID}, + DeviceIoRequest, + }, + Encode, + }; use super::*; fn client() -> Client { @@ -2402,10 +2603,29 @@ mod tests { } let res = c - .ioctl(ctl_code, &mut to_payload(payload)) - .unwrap() + .ioctl( + &DeviceControlRequest::new( + DeviceIoRequest::new( + SCARD_DEVICE_ID, + 0, + 0, + MajorFunction::IRP_MJ_DEVICE_CONTROL, + MinorFunction::IRP_MN_NONE, + ), + 0, + 0, + ctl_code, + ), + &mut to_payload(payload), + ) .unwrap(); - assert_eq!(expected.encode().unwrap(), res.encode().unwrap()); + assert_eq!( + expected.encode().unwrap(), + // Only SCARD_IOCTL_CANCEL every returns more than a single + // result, and it's currently not tested, so res[0] works here + // for now. + res[0].output_buffer.encode().unwrap() + ); } /// Connects a piv::Card to the client's internal context cache