diff --git a/aoc21/src/main.rs b/aoc21/src/main.rs index 3dfce01..d837361 100644 --- a/aoc21/src/main.rs +++ b/aoc21/src/main.rs @@ -1,4 +1,4 @@ -use std::collections::{HashSet, VecDeque}; +use std::collections::{HashMap, HashSet, VecDeque}; use std::error::Error; use std::io::{self, Read, Write}; use std::time::Instant; @@ -12,8 +12,9 @@ type Result = ::std::result::Result>; type Coord = (isize, isize); -fn parse_input>(input: T) -> (Coord, HashSet) { +fn parse_input>(input: T) -> (Coord, Coord, HashSet) { let mut start = (-1, -1); + let mut bound = (-1, -1); let mut map = HashSet::new(); for (i, l) in input.as_ref().split_whitespace().enumerate() { for (j, c) in l.chars().enumerate() { @@ -23,49 +24,85 @@ fn parse_input>(input: T) -> (Coord, HashSet) { if c == 'S' { start = (i as isize, j as isize); } + bound.1 = bound.1.max(j as isize + 1); } + bound.0 = bound.0.max(i as isize + 1); } - (start, map) + (start, bound, map) } -fn bfs(start: Coord, mut step: usize, map: &HashSet) -> usize { - let mut queue = VecDeque::new(); - queue.push_back(start); +fn normalize_coord(pos: Coord, bound: Coord) -> (Coord, Coord) { + let origin = (pos.0.rem_euclid(bound.0), pos.1.rem_euclid(bound.1)); + let dis = (pos.0 - origin.0, pos.1 - origin.1); + (origin, dis) +} - while !queue.is_empty() && step > 0 { - step -= 1; - let l = queue.len(); - let mut visited = HashSet::new(); - for _ in 0..l { - let curr = queue.pop_front().unwrap(); +fn bfs(start: Coord, mut step: usize, bound: Coord, map: &HashSet) -> usize { + let mut cache: HashMap> = HashMap::new(); + + for i in 0..bound.0 { + for j in 0..bound.1 { + if !map.contains(&(i, j)) { + continue; + } + let e = cache.entry((i, j)).or_default(); for (dx, dy) in [(1, 0), (-1, 0), (0, 1), (0, -1)] { - let (nx, ny) = (curr.0 + dx, curr.1 + dy); - if map.contains(&(nx, ny)) && visited.insert((nx, ny)) { - queue.push_back((nx, ny)) + let (nx, ny) = (i + dx, j + dy); + let (origin, _) = normalize_coord((nx, ny), bound); + if map.contains(&origin) { + e.push((nx, ny)); } } } } + + let mut queue = HashSet::new(); + queue.insert(start); + while step > 0 { + step -= 1; + + queue = queue + .iter() + .flat_map(|&curr| { + let (origin, dis) = normalize_coord(curr, bound); + cache + .get(&origin) + .unwrap() + .iter() + .map(move |(x, y)| (x + dis.0, y + dis.1)) + }) + .collect(); + } queue.len() } -fn part1(start: Coord, map: &HashSet) -> Result { +fn part1(start: Coord, bound: Coord, map: &HashSet) -> Result { let _start = Instant::now(); - let result = bfs(start, 64, map); + let result = bfs(start, 64, bound, map); writeln!(io::stdout(), "Part 1: {result}")?; writeln!(io::stdout(), "> Time elapsed is: {:?}", _start.elapsed())?; Ok(result) } +fn part2(start: Coord, bound: Coord, map: &HashSet) -> Result { + let _start = Instant::now(); + + let result = bfs(start, 26501365, bound, map); + + writeln!(io::stdout(), "Part 2: {result}")?; + writeln!(io::stdout(), "> Time elapsed is: {:?}", _start.elapsed())?; + Ok(result) +} + fn main() -> Result<()> { let mut input = String::new(); io::stdin().read_to_string(&mut input)?; - let (start, map) = parse_input(input); - part1(start, &map)?; - // part2()?; + let (start, bound, map) = parse_input(input); + part1(start, bound, &map)?; + part2(start, bound, &map)?; Ok(()) } @@ -82,13 +119,20 @@ fn example_input() { .##.#.####. .##..##.##. ..........."; - let (start, map) = parse_input(input); - assert_eq!(bfs(start, 6, &map), 16); + let (start, bound, map) = parse_input(input); + assert_eq!(bfs(start, 6, bound, &map), 16); + assert_eq!(bfs(start, 10, bound, &map), 50); + assert_eq!(bfs(start, 50, bound, &map), 1594); + assert_eq!(bfs(start, 100, bound, &map), 6536); + assert_eq!(bfs(start, 500, bound, &map), 167004); + assert_eq!(bfs(start, 1000, bound, &map), 668697); + assert_eq!(bfs(start, 5000, bound, &map), 16733044); } #[test] fn real_input() { let input = std::fs::read_to_string("input/input.txt").unwrap(); - let (start, map) = parse_input(input); - assert_eq!(part1(start, &map).unwrap(), 3600); + let (start, bound, map) = parse_input(input); + assert_eq!(part1(start, bound, &map).unwrap(), 3600); + assert_eq!(part2(start, bound, &map).unwrap(), 3600); }