diff --git a/Cargo.toml b/Cargo.toml index ee44521..b5b734f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ default-run = "imdl" [features] default = [] -bench = ["rand"] +bench = [] [dependencies] ansi_term = "0.12.0" @@ -29,6 +29,7 @@ lexiclean = "0.0.1" libc = "0.2.0" log = "0.4.8" md5 = "0.7.0" +rand = "0.7.3" open = "1.4.0" pretty_assertions = "0.6.0" pretty_env_logger = "0.4.0" @@ -65,10 +66,6 @@ features = ["default", "wrap_help"] version = "2.1.1" features = ["serde"] -[dependencies.rand] -version = "0.7.3" -optional = true - [dev-dependencies] criterion = "0.3.0" temptree = "0.0.0" diff --git a/bin/gen/config.yaml b/bin/gen/config.yaml index 9df907a..2d958b0 100644 --- a/bin/gen/config.yaml +++ b/bin/gen/config.yaml @@ -11,6 +11,10 @@ examples: text: "BitTorrent metainfo related functionality is under the `torrent` subcommand:" code: "imdl torrent --help" +- command: imdl torrent announce + text: "Announce the infohash to all trackers in the supplied `.torrent` file, and print the peer lists that come back:" + code: "imdl torrent announce --input foo.torrent" + - command: imdl torrent create text: "Intermodal can be used to create `.torrent` files:" code: "imdl torrent create --input foo" diff --git a/src/common.rs b/src/common.rs index 6726f7e..891e868 100644 --- a/src/common.rs +++ b/src/common.rs @@ -12,6 +12,7 @@ pub(crate) use std::{ hash::Hash, io::{self, BufRead, BufReader, Cursor, Read, Write}, iter::{self, Sum}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs, UdpSocket}, num::{ParseFloatError, ParseIntError, TryFromIntError}, ops::{AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}, path::{self, Path, PathBuf}, @@ -19,7 +20,7 @@ pub(crate) use std::{ str::{self, FromStr}, string::FromUtf8Error, sync::Once, - time::{SystemTime, SystemTimeError}, + time::{Duration, SystemTime, SystemTimeError}, usize, }; @@ -31,6 +32,7 @@ pub(crate) use ignore::WalkBuilder; pub(crate) use indicatif::{ProgressBar, ProgressStyle}; pub(crate) use lexiclean::Lexiclean; pub(crate) use libc::EXIT_FAILURE; +pub(crate) use rand::Rng; pub(crate) use regex::{Regex, RegexSet}; pub(crate) use serde::{de::Error as _, Deserialize, Deserializer, Serialize, Serializer}; pub(crate) use serde_hex::SerHex; @@ -52,7 +54,7 @@ pub(crate) use url::{Host, Url}; pub(crate) use log::trace; // modules -pub(crate) use crate::{consts, error, host_port_parse_error, magnet_link_parse_error}; +pub(crate) use crate::{consts, error, host_port_parse_error, magnet_link_parse_error, tracker}; // functions pub(crate) use crate::xor_args::xor_args; diff --git a/src/error.rs b/src/error.rs index 9993e7e..e59539d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -54,6 +54,8 @@ pub(crate) enum Error { source: bendy::serde::Error, input: InputTarget, }, + #[snafu(display("Torrent metainfo does not specify any usable trackers"))] + MetainfoMissingTrackers, #[snafu(display("Failed to serialize torrent metainfo: {}", source))] MetainfoSerialize { source: bendy::serde::Error }, #[snafu(display("Failed to decode metainfo bencode from {}: {}", input, error))] @@ -136,6 +138,43 @@ pub(crate) enum Error { SymlinkRoot { root: PathBuf }, #[snafu(display("Failed to retrieve system time: {}", source))] SystemTime { source: SystemTimeError }, + #[snafu(display("Compact peer list is not the expected length"))] + TrackerCompactPeerList, + #[snafu(display("Tracker exchange to `udp://{}` timed out.", tracker_addr))] + TrackerExchange { tracker_addr: SocketAddr }, + #[snafu(display( + "Cannot connect to tracker `{}`: URL does not specify a valid host port", + tracker_url + ))] + TrackerHostPort { + source: HostPortParseError, + tracker_url: Url, + }, + #[snafu(display("Tracker client cannot announce without a connection id"))] + TrackerNoConnectionId, + #[snafu(display("Tracker resolved to no useable addresses"))] + TrackerNoHosts, + #[snafu(display("Malformed response from tracker"))] + TrackerResponse, + #[snafu(display("Response from tracker has wrong length: got {}; want {}", got, want))] + TrackerResponseLength { want: usize, got: usize }, + #[snafu(display("Tracker failed to send datagram: {}", source))] + TrackerSend { source: io::Error }, + #[snafu(display("Failed to resolve socket addrs: {}", source))] + TrackerSocketAddrs { source: io::Error }, + #[snafu(display( + "Cannot connect to tracker `{}`: only UDP trackers are supported", + tracker_url + ))] + TrackerUdpOnly { tracker_url: Url }, + #[snafu(display("Failed to bind to UDP socket: {}", source))] + UdpSocketBind { source: io::Error }, + #[snafu(display("Failed to connect to `udp://{}`: {}", addr, source))] + UdpSocketConnect { addr: SocketAddr, source: io::Error }, + #[snafu(display("Failed to get local UDP socket address: {}", source))] + UdpSocketLocalAddress { source: io::Error }, + #[snafu(display("Failed to set read timeout: {}", source))] + UdpSocketReadTimeout { source: io::Error }, #[snafu(display( "Feature `{}` cannot be used without passing the `--unstable` flag", feature diff --git a/src/host_port.rs b/src/host_port.rs index 6e2f55f..b53e249 100644 --- a/src/host_port.rs +++ b/src/host_port.rs @@ -57,6 +57,42 @@ impl Display for HostPort { } } +impl TryFrom<&Url> for HostPort { + type Error = HostPortParseError; + + fn try_from(url: &Url) -> Result { + match (url.host(), url.port()) { + (Some(host), Some(port)) => Ok(HostPort { + host: host.to_owned(), + port, + }), + (Some(_), None) => Err(HostPortParseError::PortMissing { + text: url.as_str().to_owned(), + }), + (None, Some(_)) => Err(HostPortParseError::HostMissing { + text: url.as_str().to_owned(), + }), + (None, None) => Err(HostPortParseError::HostPortMissing { + text: url.as_str().to_owned(), + }), + } + } +} + +impl ToSocketAddrs for HostPort { + type Iter = std::vec::IntoIter; + + fn to_socket_addrs(&self) -> io::Result { + let address = match &self.host { + Host::Domain(domain) => return (domain.clone(), self.port).to_socket_addrs(), + Host::Ipv4(address) => IpAddr::V4(*address), + Host::Ipv6(address) => IpAddr::V6(*address), + }; + + Ok(vec![SocketAddr::new(address, self.port)].into_iter()) + } +} + #[derive(Serialize, Deserialize)] struct Tuple(String, u16); @@ -156,4 +192,18 @@ mod tests { "l39:1234:5678:9abc:def0:1234:5678:9abc:def0i65000ee", ); } + + #[test] + fn test_from_url() { + let url = Url::parse("udp://imdl.io:12345").unwrap(); + let host_port = HostPort::try_from(&url).unwrap(); + assert_eq!(host_port.host, Host::Domain("imdl.io".into())); + assert_eq!(host_port.port, 12345); + } + + #[test] + fn test_from_url_no_port() { + let url = Url::parse("udp://imdl.io").unwrap(); + assert!(HostPort::try_from(&url).is_err()); + } } diff --git a/src/host_port_parse_error.rs b/src/host_port_parse_error.rs index 1902a9d..1de430b 100644 --- a/src/host_port_parse_error.rs +++ b/src/host_port_parse_error.rs @@ -12,4 +12,8 @@ pub(crate) enum HostPortParseError { Port { text: String, source: ParseIntError }, #[snafu(display("Port missing: `{}`", text))] PortMissing { text: String }, + #[snafu(display("Host missing: `{}`", text))] + HostMissing { text: String }, + #[snafu(display("Host and port missing: `{}`", text))] + HostPortMissing { text: String }, } diff --git a/src/infohash.rs b/src/infohash.rs index c2d5c73..9425787 100644 --- a/src/infohash.rs +++ b/src/infohash.rs @@ -62,6 +62,12 @@ impl From for Infohash { } } +impl From for [u8; 20] { + fn from(infohash: Infohash) -> Self { + infohash.inner.bytes() + } +} + impl From for Sha1Digest { fn from(infohash: Infohash) -> Sha1Digest { infohash.inner diff --git a/src/lib.rs b/src/lib.rs index d50d375..a812a26 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -89,6 +89,7 @@ mod style; mod subcommand; mod table; mod torrent_summary; +mod tracker; mod use_color; mod verifier; mod walker; diff --git a/src/subcommand/torrent.rs b/src/subcommand/torrent.rs index fa895a1..4977540 100644 --- a/src/subcommand/torrent.rs +++ b/src/subcommand/torrent.rs @@ -1,5 +1,6 @@ use crate::common::*; +mod announce; mod create; mod link; mod piece_length; @@ -14,6 +15,7 @@ mod verify; about("Subcommands related to the BitTorrent protocol.") )] pub(crate) enum Torrent { + Announce(announce::Announce), Create(create::Create), Link(link::Link), #[structopt(alias = "piece-size")] @@ -26,6 +28,7 @@ pub(crate) enum Torrent { impl Torrent { pub(crate) fn run(self, env: &mut Env, options: &Options) -> Result<(), Error> { match self { + Self::Announce(announce) => announce.run(env), Self::Create(create) => create.run(env, options), Self::Link(link) => link.run(env), Self::PieceLength(piece_length) => piece_length.run(env), diff --git a/src/subcommand/torrent/announce.rs b/src/subcommand/torrent/announce.rs new file mode 100644 index 0000000..0ef1404 --- /dev/null +++ b/src/subcommand/torrent/announce.rs @@ -0,0 +1,272 @@ +use crate::common::*; + +const INPUT_HELP: &str = + "Read torrent metainfo from `INPUT`. If `INPUT` is `-`, read metainfo from standard input."; + +const INPUT_FLAG: &str = "input-flag"; + +const INPUT_POSITIONAL: &str = ""; + +#[derive(StructOpt)] +#[structopt( + help_message(consts::HELP_MESSAGE), + version_message(consts::VERSION_MESSAGE), + about("Announce a .torrent file.") +)] +pub(crate) struct Announce { + #[structopt( + name = INPUT_FLAG, + long = "input", + short = "i", + value_name = "INPUT", + empty_values(false), + parse(try_from_os_str = InputTarget::try_from_os_str), + help = INPUT_HELP, + )] + input_flag: Option, + #[structopt( + name = INPUT_POSITIONAL, + value_name = "INPUT", + empty_values(false), + parse(try_from_os_str = InputTarget::try_from_os_str), + required_unless = INPUT_FLAG, + conflicts_with = INPUT_FLAG, + help = INPUT_HELP, + )] + input_positional: Option, +} + +impl Announce { + pub(crate) fn run(self, env: &mut Env) -> Result<(), Error> { + let target = xor_args( + "input_flag", + &self.input_flag, + "input_positional", + &self.input_positional, + )?; + + let input = env.read(target)?; + let infohash = Infohash::from_input(&input)?; + let metainfo = Metainfo::from_input(&input)?; + let mut peers = HashSet::new(); + let mut usable_trackers = 0; + + for tracker_url in metainfo.trackers() { + let tracker_url = match tracker_url { + Ok(tracker_url) => tracker_url, + Err(err) => { + errln!(env, "Skipping tracker: {}", err)?; + continue; + } + }; + + let client = match tracker::Client::from_url(tracker_url) { + Ok(client) => client, + Err(err) => { + errln!(env, "Couldn't build tracker client. {}", err)?; + continue; + } + }; + + usable_trackers += 1; + match client.announce_exchange(infohash) { + Ok(peer_list) => peers.extend(peer_list), + Err(err) => errln!(env, "Announce failed: {}", err)?, + } + } + + if usable_trackers == 0 { + return Err(Error::MetainfoMissingTrackers); + } + + for peer in &peers { + outln!(env, "{}", peer)?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(test)] + pub(crate) fn new_dummy_metainfo() -> Metainfo { + Metainfo { + announce: None, + announce_list: None, + nodes: None, + comment: None, + created_by: None, + creation_date: None, + encoding: None, + info: Info { + private: None, + piece_length: Bytes(16 * 1024), + source: None, + name: "testing".into(), + pieces: PieceList::from_pieces(["test", "data"]), + mode: Mode::Single { + length: Bytes(2 * 16 * 1024), + md5sum: None, + }, + update_url: None, + }, + } + } + + #[test] + fn input_required() { + test_env! { + args: [ + "torrent", + "announce", + ], + tree: { + }, + matches: Err(Error::Clap { .. }), + }; + } + + #[test] + fn input_arguments_positional() { + let mut env = test_env! { + args: [ + "torrent", + "announce", + "foo", + ], + tree: {}, + }; + assert_matches!(env.run(), Err(error::Error::Filesystem { .. })); + } + + #[test] + fn input_arguments_flag() { + let mut env = test_env! { + args: [ + "torrent", + "announce", + "--input", + "foo", + ], + tree: {}, + }; + assert_matches!(env.run(), Err(error::Error::Filesystem { .. })); + } + + #[test] + fn input_arguments_conflict() { + let mut env = test_env! { + args: [ + "torrent", + "announce", + "--input", + "foo", + "bar", + ], + tree: {}, + }; + assert_matches!(env.run(), Err(Error::Clap { .. })); + } + + #[test] + fn metainfo_missing_trackers() { + let mut env = test_env! { + args: [ + "torrent", + "announce", + "--input", + "test.torrent", + ], + tree: {}, + }; + let metainfo = new_dummy_metainfo(); + + env.write("test.torrent", metainfo.serialize().unwrap()); + assert_matches!(env.run(), Err(Error::MetainfoMissingTrackers)); + } + + #[test] + fn metainfo_no_udp_trackers() { + let mut env = test_env! { + args: [ + "torrent", + "announce", + "--input", + "test.torrent", + ], + tree: {}, + }; + let https_tracker_url = "utp://intermodal.io:443/tracker/announce"; + let metainfo = Metainfo { + announce: None, + announce_list: Some(vec![vec![https_tracker_url.into()]]), + nodes: None, + comment: None, + created_by: None, + creation_date: None, + encoding: None, + info: Info { + private: None, + piece_length: Bytes(16 * 1024), + source: None, + name: "testing".into(), + pieces: PieceList::from_pieces(["test", "data"]), + mode: Mode::Single { + length: Bytes(2 * 16 * 1024), + md5sum: None, + }, + update_url: None, + }, + }; + + env.write("test.torrent", metainfo.serialize().unwrap()); + assert_matches!(env.run(), Err(Error::MetainfoMissingTrackers)); + assert_eq!( + env.err(), + format!( + "Couldn't build tracker client. Cannot connect to tracker `{}`: only UDP trackers are supported\n", + https_tracker_url + ) + ); + } + + #[test] + fn tracker_host_port_not_well_formed() { + let mut env = test_env! { + args: [ + "torrent", + "announce", + "--input", + "test.torrent", + ], + tree: {}, + }; + let tracker_url = "udp://1.2.3.4:1333337/announce"; + let metainfo = Metainfo { + announce: None, + announce_list: Some(vec![vec![tracker_url.into()]]), + nodes: None, + comment: None, + created_by: None, + creation_date: None, + encoding: None, + info: Info { + private: None, + piece_length: Bytes(16 * 1024), + source: None, + name: "testing".into(), + pieces: PieceList::from_pieces(["test", "data"]), + mode: Mode::Single { + length: Bytes(2 * 16 * 1024), + md5sum: None, + }, + update_url: None, + }, + }; + env.write("test.torrent", metainfo.serialize().unwrap()); + assert_matches!(env.run(), Err(Error::MetainfoMissingTrackers)); + } +} diff --git a/src/tracker.rs b/src/tracker.rs new file mode 100644 index 0000000..c000cb1 --- /dev/null +++ b/src/tracker.rs @@ -0,0 +1,13 @@ +use request::Request; +use response::Response; + +pub(crate) use action::Action; +pub(crate) use client::Client; + +mod client; +mod request; +mod response; + +mod action; +mod announce; +mod connect; diff --git a/src/tracker/action.rs b/src/tracker/action.rs new file mode 100644 index 0000000..a09d4a4 --- /dev/null +++ b/src/tracker/action.rs @@ -0,0 +1,29 @@ +#[derive(Debug)] +pub enum Action { + Connect, + Announce, + Scrape, + Unsupported, +} + +impl From for u32 { + fn from(a: Action) -> Self { + match a { + Action::Connect => 0, + Action::Announce => 1, + Action::Scrape => 2, + Action::Unsupported => 0xffff, + } + } +} + +impl From for Action { + fn from(x: u32) -> Self { + match x { + 0 => Action::Connect, + 1 => Action::Announce, + 2 => Action::Scrape, + _ => Action::Unsupported, + } + } +} diff --git a/src/tracker/announce.rs b/src/tracker/announce.rs new file mode 100644 index 0000000..a0bd7a3 --- /dev/null +++ b/src/tracker/announce.rs @@ -0,0 +1,294 @@ +use crate::common::*; + +#[derive(Debug, PartialEq)] +pub(crate) struct Request { + pub(crate) connection_id: u64, // 8 bytes + pub(crate) action: u32, // 12 + pub(crate) transaction_id: u32, // 16 + pub(crate) infohash: [u8; 20], // 36 + pub(crate) peer_id: [u8; 20], // 56 + pub(crate) downloaded: u64, // 64 + pub(crate) left: u64, // 72 + pub(crate) uploaded: u64, // 80 + pub(crate) event: u64, // 88 + pub(crate) ip_address: u32, // 92 + pub(crate) num_want: u32, // 96 + pub(crate) port: u16, // 98 +} + +impl Request { + pub(crate) const LENGTH: usize = 98; + + pub(crate) fn new(connection_id: u64, btinh: Infohash, peer_id: [u8; 20], port: u16) -> Self { + let mut rng = rand::thread_rng(); + Self { + connection_id, + action: tracker::Action::Announce.into(), + transaction_id: rng.gen(), + infohash: btinh.into(), + peer_id, + downloaded: 0x0000, + left: u64::MAX, + uploaded: 0x0000, + event: 0x0000, + ip_address: 0x0000, + num_want: u32::MAX, + port, + } + } +} + +#[derive(Debug, PartialEq)] +pub(crate) struct Response { + pub(crate) action: u32, // 4 bytes + pub(crate) transaction_id: u32, // 8 + pub(crate) interval: u32, // 12 + pub(crate) leechers: u32, // 16 + pub(crate) seeders: u32, // 20 +} + +impl Response { + pub(crate) const LENGTH: usize = 20; +} + +impl super::Request for Request { + type Response = Response; + + fn serialize(&self) -> Vec { + let mut msg = Vec::new(); + + msg.extend_from_slice(&self.connection_id.to_be_bytes()); + msg.extend_from_slice(&self.action.to_be_bytes()); + msg.extend_from_slice(&self.transaction_id.to_be_bytes()); + msg.extend_from_slice(&self.infohash); + msg.extend_from_slice(&self.peer_id); + msg.extend_from_slice(&self.downloaded.to_be_bytes()); + msg.extend_from_slice(&self.left.to_be_bytes()); + msg.extend_from_slice(&self.uploaded.to_be_bytes()); + msg.extend_from_slice(&self.event.to_be_bytes()); + msg.extend_from_slice(&self.ip_address.to_be_bytes()); + msg.extend_from_slice(&self.num_want.to_be_bytes()); + msg.extend_from_slice(&self.port.to_be_bytes()); + + msg + } + + fn transaction_id(&self) -> u32 { + self.transaction_id + } + + fn action(&self) -> u32 { + self.action + } +} + +impl super::Response for Request { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8])> { + if buf.len() != Request::LENGTH { + return Err(Error::TrackerResponseLength { + got: buf.len(), + want: Request::LENGTH, + }); + } + + Ok(( + Request { + connection_id: u64::from_be_bytes( + buf[0..8] + .try_into() + .invariant_unwrap("buf size is at least Request::LENGTH"), + ), + action: u32::from_be_bytes( + buf[8..12] + .try_into() + .invariant_unwrap("buf size is at least Request::LENGTH"), + ), + transaction_id: u32::from_be_bytes( + buf[12..16] + .try_into() + .invariant_unwrap("buf size is at least Request::LENGTH"), + ), + infohash: buf[16..36] + .try_into() + .invariant_unwrap("buf size is at least Request::LENGTH"), + peer_id: buf[36..56] + .try_into() + .invariant_unwrap("buf size is at least Request::LENGTH"), + downloaded: u64::from_be_bytes( + buf[56..64] + .try_into() + .invariant_unwrap("buf size is at least Request::LENGTH"), + ), + left: u64::from_be_bytes( + buf[64..72] + .try_into() + .invariant_unwrap("buf size is at least Request::LENGTH"), + ), + uploaded: u64::from_be_bytes( + buf[72..80] + .try_into() + .invariant_unwrap("buf size is at least Request::LENGTH"), + ), + event: u64::from_be_bytes( + buf[80..88] + .try_into() + .invariant_unwrap("buf size is at least Request::LENGTH"), + ), + ip_address: u32::from_be_bytes( + buf[88..92] + .try_into() + .invariant_unwrap("buf size is at least Request::LENGTH"), + ), + num_want: u32::from_be_bytes( + buf[92..96] + .try_into() + .invariant_unwrap("buf size is at least Request::LENGTH"), + ), + port: u16::from_be_bytes( + buf[96..98] + .try_into() + .invariant_unwrap("buf size is at least Request::LENGTH"), + ), + }, + &buf[Self::LENGTH..], + )) + } + + fn transaction_id(&self) -> u32 { + self.transaction_id + } + + fn action(&self) -> u32 { + self.action + } +} + +impl super::Response for Response { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8])> { + if buf.len() < Response::LENGTH { + return Err(Error::TrackerResponseLength { + want: Response::LENGTH, + got: buf.len(), + }); + } + + Ok(( + Response { + action: u32::from_be_bytes( + buf[0..4] + .try_into() + .invariant_unwrap("buf size is at least Response::LENGTH"), + ), + transaction_id: u32::from_be_bytes( + buf[4..8] + .try_into() + .invariant_unwrap("buf size is at least Response::LENGTH"), + ), + interval: u32::from_be_bytes( + buf[8..12] + .try_into() + .invariant_unwrap("buf size is at least Response::LENGTH"), + ), + leechers: u32::from_be_bytes( + buf[12..16] + .try_into() + .invariant_unwrap("buf size is at least Response::LENGTH"), + ), + seeders: u32::from_be_bytes( + buf[16..20] + .try_into() + .invariant_unwrap("buf size is at least Response::LENGTH"), + ), + }, + &buf[Self::LENGTH..], + )) + } + + fn transaction_id(&self) -> u32 { + self.transaction_id + } + + fn action(&self) -> u32 { + self.action + } +} + +impl super::Request for Response { + type Response = Request; + + #[allow(dead_code)] + fn serialize(&self) -> Vec { + let mut msg = Vec::new(); + + msg.extend_from_slice(&self.action.to_be_bytes()); + msg.extend_from_slice(&self.transaction_id.to_be_bytes()); + msg.extend_from_slice(&self.interval.to_be_bytes()); + msg.extend_from_slice(&self.leechers.to_be_bytes()); + msg.extend_from_slice(&self.seeders.to_be_bytes()); + + msg + } + + fn transaction_id(&self) -> u32 { + self.transaction_id + } + + fn action(&self) -> u32 { + self.action + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tracker::{announce, request::Request, response::Response}; + + #[test] + pub(crate) fn announce_request_roundtrip() { + let req = announce::Request { + connection_id: 0x01, + action: 0x02, + transaction_id: 0x03, + infohash: [0x04; 20], + peer_id: [0x05; 20], + downloaded: 0x06, + left: 0x07, + uploaded: 0x08, + event: 0x09, + ip_address: 0x0a, + num_want: 0x0b, + port: 0x0c, + }; + let buf = req.serialize(); + let (req2, _) = announce::Request::deserialize(&buf).unwrap(); + assert_eq!(req, req2); + } + + #[test] + pub(crate) fn announce_response_roundtrip() { + let resp = announce::Response { + action: 0x01, + transaction_id: 0x02, + interval: 0x03, + leechers: 0x04, + seeders: 0x05, + }; + let buf = resp.serialize(); + let (resp2, _) = announce::Response::deserialize(&buf).unwrap(); + assert_eq!(resp, resp2); + } + + #[test] + pub(crate) fn announce_request_bad_deserialize() { + let buf = [0x01, 0x02, 0x03, 0x04, 0x05]; + let err = announce::Request::deserialize(&buf); + assert_matches!(err, Err(Error::TrackerResponseLength { .. })); + } + + #[test] + pub(crate) fn announce_response_bad_deserialize() { + let buf = [0x01, 0x02, 0x03, 0x04, 0x05]; + let err = announce::Response::deserialize(&buf); + assert_matches!(err, Err(Error::TrackerResponseLength { .. })); + } +} diff --git a/src/tracker/client.rs b/src/tracker/client.rs new file mode 100644 index 0000000..6d518ce --- /dev/null +++ b/src/tracker/client.rs @@ -0,0 +1,350 @@ +use super::*; +use crate::common::*; + +#[derive(Debug)] +pub(crate) struct Client { + peer_id: [u8; 20], + tracker_addr: SocketAddr, + sock: UdpSocket, + connection_id: Option, +} + +impl Client { + const RX_BUF_LEN: usize = 8192; + const UDP_SOCKET_READ_TIMEOUT_S: u64 = 3; + const UDP_SOCKET_READ_TIMEOUT_NS: u32 = 0; + + pub fn connect(address: A) -> Result { + let addrs = address + .to_socket_addrs() // XXX: this may cause DNS look-ups! + .context(error::TrackerSocketAddrs)?; + + for tracker_addr in addrs { + let sock = match Self::new_udp_socket(tracker_addr) { + Ok(sock) => sock, + Err(_) => continue, // TODO: log these as warnings + }; + let mut client = Client { + peer_id: rand::thread_rng().gen(), + tracker_addr, + sock, + connection_id: None, + }; + if let Ok(()) = client.connect_exchange() { + return Ok(client); + } + } + Err(Error::TrackerNoHosts) + } + + fn new_udp_socket(addr: SocketAddr) -> Result { + let sock = match addr { + SocketAddr::V4(_) => UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0)), + SocketAddr::V6(_) => UdpSocket::bind((Ipv6Addr::UNSPECIFIED, 0)), + } + .context(error::UdpSocketBind)?; + sock + .connect(addr) + .context(error::UdpSocketConnect { addr })?; + sock + .set_read_timeout(Some(Duration::new( + Self::UDP_SOCKET_READ_TIMEOUT_S, + Self::UDP_SOCKET_READ_TIMEOUT_NS, + ))) + .context(error::UdpSocketReadTimeout)?; + Ok(sock) + } + + pub fn from_url(tracker_url: Url) -> Result { + if tracker_url.scheme() != "udp" { + return Err(Error::TrackerUdpOnly { tracker_url }); + } + Self::connect(HostPort::try_from(&tracker_url).context(error::TrackerHostPort { tracker_url })?) + } + + fn connect_exchange(&mut self) -> Result<()> { + let req = connect::Request::new(); + let mut buf = [0u8; connect::Response::LENGTH]; + let (resp, _) = self.exchange(&req, &mut buf)?; + self.connection_id.replace(resp.connection_id); + Ok(()) + } + + pub fn announce_exchange(&self, btinh: Infohash) -> Result> { + let connection_id = match self.connection_id { + Some(id) => id, + None => return Err(Error::TrackerNoConnectionId), + }; + + let local_addr = self + .sock + .local_addr() + .context(error::UdpSocketLocalAddress)?; + let req = announce::Request::new(connection_id, btinh, self.peer_id, local_addr.port()); + let mut buf = [0u8; Self::RX_BUF_LEN]; + let (_, payload) = self.exchange(&req, &mut buf)?; + + Client::parse_compact_peer_list(payload, local_addr.is_ipv6()) + } + + fn exchange<'a, T: Request>( + &self, + req: &T, + buf: &'a mut [u8], + ) -> Result<(T::Response, &'a [u8])> { + let msg = req.serialize(); + let mut len_read: usize = 0; + + for _ in 0..3 { + self.sock.send(&msg).context(error::TrackerSend)?; + if let Ok(len) = self.sock.recv(buf) { + len_read = len; + break; + } + } + + if len_read == 0 { + return Err(Error::TrackerExchange { + tracker_addr: self.tracker_addr, + }); + } + + let (resp, payload) = T::Response::deserialize(&buf[..len_read])?; + if resp.transaction_id() != req.transaction_id() || resp.action() != req.action() { + return Err(Error::TrackerResponse); + } + + Ok((resp, payload)) + } + + fn parse_compact_peer_list(buf: &[u8], is_ipv6: bool) -> Result> { + let mut peer_list = Vec::::new(); + let stride = if is_ipv6 { 18 } else { 6 }; + + let chunks = buf.chunks_exact(stride); + if !chunks.remainder().is_empty() { + return Err(Error::TrackerCompactPeerList); + } + + for hostpost in chunks { + let (ip, port) = hostpost.split_at(stride - 2); + let ip = if is_ipv6 { + let octets: [u8; 16] = ip[0..16] + .try_into() + .invariant_unwrap("iterator guarantees bounds are OK"); + IpAddr::from(std::net::Ipv6Addr::from(octets)) + } else { + IpAddr::from(std::net::Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3])) + }; + + let port = u16::from_be_bytes( + port + .try_into() + .invariant_unwrap("iterator guarantees bounds are OK"), + ); + + peer_list.push((ip, port).into()); + } + + Ok(peer_list) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::thread; + + struct TestServer { + sock: UdpSocket, + peer_list: Vec, + } + + impl TestServer { + fn new_ipv4() -> (Self, SocketAddr, Vec) { + TestServer::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED)) + } + + fn new_ipv6() -> (Self, SocketAddr, Vec) { + TestServer::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED)) + } + + fn new(addr: IpAddr) -> (Self, SocketAddr, Vec) { + let sock = UdpSocket::bind((addr, 0)).unwrap(); + sock.set_read_timeout(None).unwrap(); + + let server_addr = sock.local_addr().unwrap(); + let stride = if server_addr.is_ipv6() { 18 } else { 6 }; + let peer_list: Vec = (0..10 * stride) + .map(|_| rand::thread_rng().gen::()) + .collect::>(); + + let local_addr = if server_addr.is_ipv6() { + (Ipv6Addr::LOCALHOST, server_addr.port()).into() + } else { + (Ipv4Addr::LOCALHOST, server_addr.port()).into() + }; + + ( + TestServer { + sock, + peer_list: peer_list.clone(), + }, + local_addr, + peer_list, + ) + } + + fn connect_exchange(&self) { + let mut buf = [0u8; 8192]; + let mut rng = rand::thread_rng(); + + let (n, peer) = self.sock.recv_from(&mut buf).unwrap(); + let (req, _) = connect::Request::deserialize(buf[..n].try_into().unwrap()).unwrap(); + let req = connect::Response { + action: Action::Connect.into(), + transaction_id: req.transaction_id, + connection_id: rng.gen(), + } + .serialize(); + self.sock.send_to(&req, peer).unwrap(); + } + + fn announce_exchange(&self) { + let mut buf = [0u8; 8192]; + + let (n, peer) = self.sock.recv_from(&mut buf).unwrap(); + let (req, _) = announce::Request::deserialize(&buf[..n]).unwrap(); + let mut req: Vec = announce::Response { + action: Action::Announce.into(), + transaction_id: req.transaction_id, + interval: 0x1337_1337, + leechers: 0xcafe_babe, + seeders: 0xdead_beef, + } + .serialize(); + req.extend_from_slice(&self.peer_list); + self.sock.send_to(&req, peer).unwrap(); + } + } + + #[test] + fn client_from_url_no_port() { + let tracker_url = Url::parse("udp://intermodal.io/announce").unwrap(); + assert_matches!( + Client::from_url(tracker_url), + Err(Error::TrackerHostPort { .. }) + ); + } + + #[test] + fn client_from_url_no_host() { + let tracker_url = Url::parse("udp://magnet:?announce=no_host").unwrap(); + assert_matches!( + Client::from_url(tracker_url), + Err(Error::TrackerHostPort { .. }) + ); + } + + #[test] + fn client_from_url_not_udp() { + let tracker_url = Url::parse("https://intermodal.io:100/announce").unwrap(); + assert_matches!( + Client::from_url(tracker_url), + Err(Error::TrackerUdpOnly { .. }) + ); + } + + #[test] + fn client_connect_v4() { + let (server, addr, _) = TestServer::new_ipv4(); + thread::spawn(move || { + server.connect_exchange(); + }); + Client::connect(addr).unwrap(); + } + + #[test] + fn client_connect_v6() { + let (server, addr, _) = TestServer::new_ipv6(); + thread::spawn(move || { + server.connect_exchange(); + }); + Client::connect(addr).unwrap(); + } + + #[test] + fn client_connect_timeout_ipv4() { + let (_, addr, _) = TestServer::new_ipv4(); + assert_matches!(Client::connect(addr), Err(Error::TrackerNoHosts { .. })); + } + + #[test] + fn client_connect_timeout_ipv6() { + let (_, addr, _) = TestServer::new_ipv6(); + assert_matches!(Client::connect(addr), Err(Error::TrackerNoHosts { .. })); + } + + #[test] + fn client_announce_without_connection_id() {} + + #[test] + fn client_announce_timeout_ipv4() { + let (server, addr, _) = TestServer::new_ipv4(); + thread::spawn(move || { + server.connect_exchange(); + }); + + let c = Client::connect(addr).unwrap(); + let addrs = c.announce_exchange(Sha1Digest::from_bytes([0u8; 20]).into()); + assert_matches!(addrs, Err(Error::TrackerExchange { .. })); + } + + #[test] + fn client_announce_timeout_ipv6() { + let (server, addr, _) = TestServer::new_ipv4(); + thread::spawn(move || { + server.connect_exchange(); + }); + + let c = Client::connect(addr).unwrap(); + let addrs = c.announce_exchange(Sha1Digest::from_bytes([0u8; 20]).into()); + assert_matches!(addrs, Err(Error::TrackerExchange { .. })); + } + + #[test] + fn client_announce_ipv4() { + let (server, addr, expected_targets) = TestServer::new_ipv4(); + thread::spawn(move || { + server.connect_exchange(); + server.announce_exchange(); + }); + + let c = Client::connect(addr).unwrap(); + let addrs = c + .announce_exchange(Sha1Digest::from_bytes([0u8; 20]).into()) + .unwrap(); + assert_eq!( + addrs, + Client::parse_compact_peer_list(&expected_targets, addr.is_ipv6()).unwrap() + ); + } + + #[test] + fn client_announce_ipv6() { + let (server, addr, expected_targets) = TestServer::new_ipv6(); + thread::spawn(move || { + server.connect_exchange(); + server.announce_exchange(); + }); + + let c = Client::connect(addr).unwrap(); + let addrs = c + .announce_exchange(Sha1Digest::from_bytes([0u8; 20]).into()) + .unwrap(); + assert_eq!( + addrs, + Client::parse_compact_peer_list(&expected_targets, addr.is_ipv6()).unwrap() + ); + } +} diff --git a/src/tracker/connect.rs b/src/tracker/connect.rs new file mode 100644 index 0000000..9c0651e --- /dev/null +++ b/src/tracker/connect.rs @@ -0,0 +1,197 @@ +use crate::common::*; + +#[derive(Clone, Copy, Debug, PartialEq)] +pub(crate) struct Request { + pub(crate) protocol_id: u64, + pub(crate) action: u32, + pub(crate) transaction_id: u32, +} + +impl Request { + pub(crate) const LENGTH: usize = 16; + + const UDP_TRACKER_MAGIC: u64 = 0x0000_0417_2710_1980; + + pub(crate) fn new() -> Self { + Self { + protocol_id: Self::UDP_TRACKER_MAGIC, + action: tracker::Action::Connect.into(), + transaction_id: rand::thread_rng().gen(), + } + } +} + +#[derive(Debug, PartialEq)] +pub(crate) struct Response { + pub(crate) action: u32, + pub(crate) transaction_id: u32, + pub(crate) connection_id: u64, +} + +impl Response { + pub(crate) const LENGTH: usize = 16; +} + +impl super::Request for Request { + type Response = Response; + + fn serialize(&self) -> Vec { + let mut msg = Vec::new(); + + msg.extend_from_slice(&self.protocol_id.to_be_bytes()); + msg.extend_from_slice(&self.action.to_be_bytes()); + msg.extend_from_slice(&self.transaction_id.to_be_bytes()); + + msg + } + + fn transaction_id(&self) -> u32 { + self.transaction_id + } + + fn action(&self) -> u32 { + self.action + } +} + +impl super::Response for Request { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8])> { + if buf.len() != Self::LENGTH { + return Err(Error::TrackerResponse); + } + + Ok(( + Request { + protocol_id: u64::from_be_bytes( + buf[0..8] + .try_into() + .invariant_unwrap("incoming type guarantees bounds are OK"), + ), + action: u32::from_be_bytes( + buf[8..12] + .try_into() + .invariant_unwrap("incoming type guarantees bounds are OK"), + ), + transaction_id: u32::from_be_bytes( + buf[12..16] + .try_into() + .invariant_unwrap("incoming type guarantees bounds are OK"), + ), + }, + &buf[Self::LENGTH..], + )) + } + + fn transaction_id(&self) -> u32 { + self.transaction_id + } + + fn action(&self) -> u32 { + self.action + } +} + +impl super::Request for Response { + type Response = Request; + + fn serialize(&self) -> Vec { + let mut msg = Vec::new(); + + msg.extend_from_slice(&self.action.to_be_bytes()); + msg.extend_from_slice(&self.transaction_id.to_be_bytes()); + msg.extend_from_slice(&self.connection_id.to_be_bytes()); + + msg + } + + fn transaction_id(&self) -> u32 { + self.transaction_id + } + + fn action(&self) -> u32 { + self.action + } +} + +impl super::Response for Response { + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8])> { + if buf.len() < Self::LENGTH { + return Err(Error::TrackerResponse); + } + + Ok(( + Self { + action: u32::from_be_bytes( + buf[0..4] + .try_into() + .invariant_unwrap("bounds are checked manually above"), + ), + transaction_id: u32::from_be_bytes( + buf[4..8] + .try_into() + .invariant_unwrap("bounds are checked manually above"), + ), + connection_id: u64::from_be_bytes( + buf[8..16] + .try_into() + .invariant_unwrap("bounds are checked manually above"), + ), + }, + &buf[Self::LENGTH..], + )) + } + + fn transaction_id(&self) -> u32 { + self.transaction_id + } + + fn action(&self) -> u32 { + self.action + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tracker::{connect, request::Request, response::Response}; + + #[test] + pub(crate) fn connect_request_roundtrip() { + let req = connect::Request { + protocol_id: 0x1337_beef_babe_cafe, + action: 50, + transaction_id: 1234, + }; + + let buf = req.serialize(); + let (req2, _) = connect::Request::deserialize(&buf).unwrap(); + assert_eq!(req, req2); + } + + #[test] + pub(crate) fn connect_response_roundtrip() { + let resp = connect::Response { + action: 50, + transaction_id: 1234, + connection_id: 0x1337_beef_babe_cafe, + }; + + let buf = resp.serialize(); + let (resp2, _) = connect::Response::deserialize(&buf).unwrap(); + assert_eq!(resp, resp2); + } + + #[test] + pub(crate) fn connect_request_datagram_size() { + let buf = [0x01, 0x02, 0x03]; + let err = connect::Request::deserialize(&buf); + assert_matches!(err, Err(Error::TrackerResponse)); + } + + #[test] + pub(crate) fn connect_response_datagram_size() { + let buf = [0x01, 0x02, 0x03]; + let err = connect::Response::deserialize(&buf); + assert_matches!(err, Err(Error::TrackerResponse)); + } +} diff --git a/src/tracker/request.rs b/src/tracker/request.rs new file mode 100644 index 0000000..615889b --- /dev/null +++ b/src/tracker/request.rs @@ -0,0 +1,9 @@ +use super::Response; + +pub(crate) trait Request { + type Response: Response; + fn serialize(&self) -> Vec; + + fn transaction_id(&self) -> u32; + fn action(&self) -> u32; +} diff --git a/src/tracker/response.rs b/src/tracker/response.rs new file mode 100644 index 0000000..366ddcf --- /dev/null +++ b/src/tracker/response.rs @@ -0,0 +1,11 @@ +use crate::common::*; + +pub(crate) trait Response { + // Deserialize the response into a Response object and payload. + fn deserialize(buf: &[u8]) -> Result<(Self, &[u8])> + where + Self: std::marker::Sized; + + fn transaction_id(&self) -> u32; + fn action(&self) -> u32; +}