Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions src/multistream_select/dialer_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ use crate::{
error::{self, Error, ParseError},
multistream_select::{
protocol::{
encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError,
webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol,
ProtocolError,
},
Negotiated, NegotiationError, Version,
},
Expand Down Expand Up @@ -305,7 +306,7 @@ pub enum HandshakeResult {
/// Handshake state.
#[derive(Debug)]
enum HandshakeState {
/// Wainting to receive any response from remote peer.
/// Waiting to receive any response from remote peer.
WaitingResponse,

/// Waiting to receive the actual application protocol from remote peer.
Expand All @@ -314,7 +315,7 @@ enum HandshakeState {

/// `multistream-select` dialer handshake state.
#[derive(Debug)]
pub struct DialerState {
pub struct WebRtcDialerState {
/// Proposed main protocol.
protocol: ProtocolName,

Expand All @@ -325,16 +326,16 @@ pub struct DialerState {
state: HandshakeState,
}

impl DialerState {
impl WebRtcDialerState {
/// Propose protocol to remote peer.
///
/// Return [`DialerState`] which is used to drive forward the negotiation and an encoded
/// Return [`WebRtcDialerState`] which is used to drive forward the negotiation and an encoded
/// `multistream-select` message that contains the protocol proposal for the substream.
pub fn propose(
protocol: ProtocolName,
fallback_names: Vec<ProtocolName>,
) -> crate::Result<(Self, Vec<u8>)> {
let message = encode_multistream_message(
let message = webrtc_encode_multistream_message(
std::iter::once(protocol.clone())
.chain(fallback_names.clone())
.filter_map(|protocol| Protocol::try_from(protocol.as_ref()).ok())
Expand All @@ -353,7 +354,7 @@ impl DialerState {
))
}

/// Register response to [`DialerState`].
/// Register response to [`WebRtcDialerState`].
pub fn register_response(
&mut self,
payload: Vec<u8>,
Expand Down Expand Up @@ -755,7 +756,7 @@ mod tests {
#[test]
fn propose() {
let (mut dialer_state, message) =
DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
let message = bytes::BytesMut::from(&message[..]).freeze();

let Message::Protocols(protocols) = Message::decode(message).unwrap() else {
Expand All @@ -777,7 +778,7 @@ mod tests {

#[test]
fn propose_with_fallback() {
let (mut dialer_state, message) = DialerState::propose(
let (mut dialer_state, message) = WebRtcDialerState::propose(
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
Expand Down Expand Up @@ -813,7 +814,7 @@ mod tests {
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();

let (mut dialer_state, _message) =
DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();

match dialer_state.register_response(bytes.freeze().to_vec()) {
Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {}
Expand All @@ -832,7 +833,7 @@ mod tests {
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();

let (mut dialer_state, _message) =
DialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();

match dialer_state.register_response(bytes.freeze().to_vec()) {
Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {}
Expand All @@ -842,7 +843,7 @@ mod tests {

#[test]
fn negotiate_main_protocol() {
let message = encode_multistream_message(
let message = webrtc_encode_multistream_message(
vec![Message::Protocol(
Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
)]
Expand All @@ -851,7 +852,7 @@ mod tests {
.unwrap()
.freeze();

let (mut dialer_state, _message) = DialerState::propose(
let (mut dialer_state, _message) = WebRtcDialerState::propose(
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
Expand All @@ -860,13 +861,13 @@ mod tests {
match dialer_state.register_response(message.to_vec()) {
Ok(HandshakeResult::Succeeded(negotiated)) =>
assert_eq!(negotiated, ProtocolName::from("/13371338/proto/1")),
_ => panic!("invalid event"),
event => panic!("invalid event {event:?}"),
}
}

#[test]
fn negotiate_fallback_protocol() {
let message = encode_multistream_message(
let message = webrtc_encode_multistream_message(
vec![Message::Protocol(
Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
)]
Expand All @@ -875,7 +876,7 @@ mod tests {
.unwrap()
.freeze();

let (mut dialer_state, _message) = DialerState::propose(
let (mut dialer_state, _message) = WebRtcDialerState::propose(
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
Expand Down
36 changes: 19 additions & 17 deletions src/multistream_select/listener_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ use crate::{
error::{self, Error},
multistream_select::{
protocol::{
encode_multistream_message, HeaderLine, Message, MessageIO, Protocol, ProtocolError,
webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol,
ProtocolError,
},
Negotiated, NegotiationError,
},
Expand Down Expand Up @@ -324,7 +325,7 @@ where
}
}

/// Result of [`listener_negotiate()`].
/// Result of [`webrtc_listener_negotiate()`].
#[derive(Debug)]
pub enum ListenerSelectResult {
/// Requested protocol is available and substream can be accepted.
Expand All @@ -348,7 +349,7 @@ pub enum ListenerSelectResult {
/// Parse protocols offered by the remote peer and check if any of the offered protocols match
/// locally available protocols. If a match is found, return an encoded multistream-select
/// response and the negotiated protocol. If parsing fails or no match is found, return an error.
pub fn listener_negotiate<'a>(
pub fn webrtc_listener_negotiate<'a>(
supported_protocols: &'a mut impl Iterator<Item = &'a ProtocolName>,
payload: Bytes,
) -> crate::Result<ListenerSelectResult> {
Expand Down Expand Up @@ -382,9 +383,9 @@ pub fn listener_negotiate<'a>(
if protocol.as_ref() == supported.as_bytes() {
return Ok(ListenerSelectResult::Accepted {
protocol: supported.clone(),
message: encode_multistream_message(std::iter::once(Message::Protocol(
protocol,
)))?,
message: webrtc_encode_multistream_message(std::iter::once(
Message::Protocol(protocol),
))?,
});
}
}
Expand All @@ -396,7 +397,7 @@ pub fn listener_negotiate<'a>(
);

Ok(ListenerSelectResult::Rejected {
message: encode_multistream_message(std::iter::once(Message::NotAvailable))?,
message: webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable))?,
})
}

Expand All @@ -405,15 +406,15 @@ mod tests {
use super::*;

#[test]
fn listener_negotiate_works() {
fn webrtc_listener_negotiate_works() {
let mut local_protocols = vec![
ProtocolName::from("/13371338/proto/1"),
ProtocolName::from("/sup/proto/1"),
ProtocolName::from("/13371338/proto/2"),
ProtocolName::from("/13371338/proto/3"),
ProtocolName::from("/13371338/proto/4"),
];
let message = encode_multistream_message(
let message = webrtc_encode_multistream_message(
vec![
Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()),
Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()),
Expand All @@ -423,7 +424,7 @@ mod tests {
.unwrap()
.freeze();

match listener_negotiate(&mut local_protocols.iter(), message) {
match webrtc_listener_negotiate(&mut local_protocols.iter(), message) {
Err(error) => panic!("error received: {error:?}"),
Ok(ListenerSelectResult::Rejected { .. }) => panic!("message rejected"),
Ok(ListenerSelectResult::Accepted { protocol, message }) => {
Expand All @@ -441,14 +442,14 @@ mod tests {
ProtocolName::from("/13371338/proto/3"),
ProtocolName::from("/13371338/proto/4"),
];
let message = encode_multistream_message(std::iter::once(Message::Protocols(vec![
let message = webrtc_encode_multistream_message(std::iter::once(Message::Protocols(vec![
Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
])))
.unwrap()
.freeze();

match listener_negotiate(&mut local_protocols.iter(), message) {
match webrtc_listener_negotiate(&mut local_protocols.iter(), message) {
Err(error) => assert!(std::matches!(error, Error::InvalidData)),
_ => panic!("invalid event"),
}
Expand All @@ -469,7 +470,7 @@ mod tests {
let message = Message::Header(HeaderLine::V1);
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();

match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
match webrtc_listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
Err(error) => assert!(std::matches!(
error,
Error::NegotiationError(error::NegotiationError::MultistreamSelectError(
Expand Down Expand Up @@ -498,7 +499,7 @@ mod tests {
]);
let _ = message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();

match listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
match webrtc_listener_negotiate(&mut local_protocols.iter(), bytes.freeze()) {
Err(error) => assert!(std::matches!(
error,
Error::NegotiationError(error::NegotiationError::MultistreamSelectError(
Expand All @@ -518,7 +519,7 @@ mod tests {
ProtocolName::from("/13371338/proto/3"),
ProtocolName::from("/13371338/proto/4"),
];
let message = encode_multistream_message(
let message = webrtc_encode_multistream_message(
vec![Message::Protocol(
Protocol::try_from(&b"/13371339/proto/1"[..]).unwrap(),
)]
Expand All @@ -527,12 +528,13 @@ mod tests {
.unwrap()
.freeze();

match listener_negotiate(&mut local_protocols.iter(), message) {
match webrtc_listener_negotiate(&mut local_protocols.iter(), message) {
Err(error) => panic!("error received: {error:?}"),
Ok(ListenerSelectResult::Rejected { message }) => {
assert_eq!(
message,
encode_multistream_message(std::iter::once(Message::NotAvailable)).unwrap()
webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable))
.unwrap()
);
}
Ok(ListenerSelectResult::Accepted { protocol, message }) => panic!("message accepted"),
Expand Down
5 changes: 3 additions & 2 deletions src/multistream_select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ mod negotiated;
mod protocol;

pub use crate::multistream_select::{
dialer_select::{dialer_select_proto, DialerSelectFuture, DialerState, HandshakeResult},
dialer_select::{dialer_select_proto, DialerSelectFuture, HandshakeResult, WebRtcDialerState},
listener_select::{
listener_negotiate, listener_select_proto, ListenerSelectFuture, ListenerSelectResult,
listener_select_proto, webrtc_listener_negotiate, ListenerSelectFuture,
ListenerSelectResult,
},
negotiated::{Negotiated, NegotiatedComplete, NegotiationError},
protocol::{HeaderLine, Message, Protocol, ProtocolError},
Expand Down
81 changes: 78 additions & 3 deletions src/multistream_select/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,7 @@ impl Message {
let mut remaining: &[u8] = &msg;
loop {
// A well-formed message must be terminated with a newline.
// TODO: don't do this
if remaining == [b'\n'] || remaining.is_empty() {
if remaining == [b'\n'] {
break;
} else if protocols.len() == MAX_PROTOCOLS {
return Err(ProtocolError::TooManyProtocols);
Expand All @@ -228,7 +227,12 @@ impl Message {
}

/// Create `multistream-select` message from an iterator of `Message`s.
pub fn encode_multistream_message(
///
/// # Note
///
/// This is implementation is not compliant with the multistream-select protocol spec.
/// The only purpose of this was to get the `multistream-select` protocol working with smoldot.
pub fn webrtc_encode_multistream_message(
messages: impl IntoIterator<Item = Message>,
) -> crate::Result<BytesMut> {
// encode `/multistream-select/1.0.0` header
Expand All @@ -245,6 +249,9 @@ pub fn encode_multistream_message(
header.append(&mut proto_bytes);
}

// For the `Message::Protocols` to be interpreted correctly, it must be followed by a newline.
header.push(b'\n');

Ok(BytesMut::from(&header[..]))
}

Expand Down Expand Up @@ -468,3 +475,71 @@ impl From<uvi::decode::Error> for ProtocolError {
Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_decode_main_messages() {
// Decode main messages.
let bytes = Bytes::from_static(MSG_MULTISTREAM_1_0);
assert_eq!(
Message::decode(bytes).unwrap(),
Message::Header(HeaderLine::V1)
);

let bytes = Bytes::from_static(MSG_PROTOCOL_NA);
assert_eq!(Message::decode(bytes).unwrap(), Message::NotAvailable);

let bytes = Bytes::from_static(MSG_LS);
assert_eq!(Message::decode(bytes).unwrap(), Message::ListProtocols);
}

#[test]
fn test_decode_empty_message() {
// Empty message should decode to an IoError, not Header::Protocols.
let bytes = Bytes::from_static(b"");
match Message::decode(bytes).unwrap_err() {
ProtocolError::IoError(io) => assert_eq!(io.kind(), io::ErrorKind::InvalidData),
err => panic!("Unexpected error: {:?}", err),
};
}

#[test]
fn test_decode_protocols() {
// Single protocol.
let bytes = Bytes::from_static(b"/protocol-v1\n");
assert_eq!(
Message::decode(bytes).unwrap(),
Message::Protocol(Protocol::try_from(Bytes::from_static(b"/protocol-v1")).unwrap())
);

// Multiple protocols.
let expected = Message::Protocols(vec![
Protocol::try_from(Bytes::from_static(b"/protocol-v1")).unwrap(),
Protocol::try_from(Bytes::from_static(b"/protocol-v2")).unwrap(),
]);
let mut encoded = BytesMut::new();
expected.encode(&mut encoded).unwrap();

// `\r` is the length of the protocol names.
let bytes = Bytes::from_static(b"\r/protocol-v1\n\r/protocol-v2\n\n");
assert_eq!(encoded, bytes);

assert_eq!(
Message::decode(bytes).unwrap(),
Message::Protocols(vec![
Protocol::try_from(Bytes::from_static(b"/protocol-v1")).unwrap(),
Protocol::try_from(Bytes::from_static(b"/protocol-v2")).unwrap(),
])
);

// Check invalid length.
let bytes = Bytes::from_static(b"\r/v1\n\n");
assert_eq!(
Message::decode(bytes).unwrap_err(),
ProtocolError::InvalidMessage
);
}
}
Loading
Loading