diff --git a/src/libstd/collections/hash/set.rs b/src/libstd/collections/hash/set.rs index 92145907b95f6..c55dd049ec60f 100644 --- a/src/libstd/collections/hash/set.rs +++ b/src/libstd/collections/hash/set.rs @@ -410,9 +410,16 @@ impl HashSet /// ``` #[stable(feature = "rust1", since = "1.0.0")] pub fn intersection<'a>(&'a self, other: &'a HashSet) -> Intersection<'a, T, S> { - Intersection { - iter: self.iter(), - other, + if self.len() <= other.len() { + Intersection { + iter: self.iter(), + other, + } + } else { + Intersection { + iter: other.iter(), + other: self, + } } } @@ -436,7 +443,15 @@ impl HashSet /// ``` #[stable(feature = "rust1", since = "1.0.0")] pub fn union<'a>(&'a self, other: &'a HashSet) -> Union<'a, T, S> { - Union { iter: self.iter().chain(other.difference(self)) } + if self.len() <= other.len() { + Union { + iter: self.iter().chain(other.difference(self)), + } + } else { + Union { + iter: other.iter().chain(self.difference(other)), + } + } } /// Returns the number of elements in the set. @@ -584,7 +599,11 @@ impl HashSet /// ``` #[stable(feature = "rust1", since = "1.0.0")] pub fn is_disjoint(&self, other: &HashSet) -> bool { - self.iter().all(|v| !other.contains(v)) + if self.len() <= other.len() { + self.iter().all(|v| !other.contains(v)) + } else { + other.iter().all(|v| !self.contains(v)) + } } /// Returns `true` if the set is a subset of another, @@ -1494,6 +1513,7 @@ mod test_set { fn test_intersection() { let mut a = HashSet::new(); let mut b = HashSet::new(); + assert!(a.intersection(&b).next().is_none()); assert!(a.insert(11)); assert!(a.insert(1)); @@ -1518,6 +1538,22 @@ mod test_set { i += 1 } assert_eq!(i, expected.len()); + + assert!(a.insert(9)); // make a bigger than b + + i = 0; + for x in a.intersection(&b) { + assert!(expected.contains(x)); + i += 1 + } + assert_eq!(i, expected.len()); + + i = 0; + for x in b.intersection(&a) { + assert!(expected.contains(x)); + i += 1 + } + assert_eq!(i, expected.len()); } #[test] @@ -1573,11 +1609,11 @@ mod test_set { fn test_union() { let mut a = HashSet::new(); let mut b = HashSet::new(); + assert!(a.union(&b).next().is_none()); + assert!(b.union(&a).next().is_none()); assert!(a.insert(1)); assert!(a.insert(3)); - assert!(a.insert(5)); - assert!(a.insert(9)); assert!(a.insert(11)); assert!(a.insert(16)); assert!(a.insert(19)); @@ -1597,6 +1633,23 @@ mod test_set { i += 1 } assert_eq!(i, expected.len()); + + assert!(a.insert(9)); // make a bigger than b + assert!(a.insert(5)); + + i = 0; + for x in a.union(&b) { + assert!(expected.contains(x)); + i += 1 + } + assert_eq!(i, expected.len()); + + i = 0; + for x in b.union(&a) { + assert!(expected.contains(x)); + i += 1 + } + assert_eq!(i, expected.len()); } #[test]