diff --git a/libs/util/include/psemek/util/find_path.hpp b/libs/util/include/psemek/util/find_path.hpp new file mode 100644 index 00000000..4455b85d --- /dev/null +++ b/libs/util/include/psemek/util/find_path.hpp @@ -0,0 +1,142 @@ +#pragma once + +#include +#include +#include +#include + +namespace psemek::util +{ + + template + struct pathfinder + { + struct node_compare + { + std::unordered_map const & priority; + + bool operator()(Node const & n1, Node const & n2) const + { + auto const p1 = priority.at(n1); + auto const p2 = priority.at(n2); + + if (p1 < p2) + return true; + if (p1 > p2) + return false; + return n1 < n2; + } + }; + + NeighboursFn node_neighbours; + HeuristicFn heuristic; + + std::unordered_map cost; + std::unordered_map priority; + std::unordered_map previous; + std::set queue; + + pathfinder(NeighboursFn && node_neighbours, HeuristicFn && heuristic) + : node_neighbours(std::forward(node_neighbours)) + , heuristic(std::forward(heuristic)) + , queue(node_compare{priority}) + {} + + void reset() + { + queue.clear(); + previous.clear(); + priority.clear(); + cost.clear(); + } + + void init(Node const & start) + { + cost[start] = Cost{}; + priority[start] = cost[start] + heuristic(start); + queue.insert(start); + } + + bool finished() const + { + return queue.empty(); + } + + Node step() + { + Node const node = *queue.begin(); + queue.erase(queue.begin()); + + auto const node_cost = cost[node]; + + node_neighbours(node, [&](Node const & neighbour, Cost const & edge_cost){ + Cost const new_cost = node_cost + edge_cost; + + auto it = cost.find(neighbour); + + if (it == cost.end() || new_cost < it->second) + { + if (it != cost.end()) + queue.erase(neighbour); + + cost[neighbour] = new_cost; + priority[neighbour] = new_cost + heuristic(neighbour); + previous[neighbour] = node; + queue.insert(neighbour); + } + }); + + return node; + } + + bool found(Node const & end) const + { + return cost.contains(end); + } + + // Doesn't return the starting node! + template + Iterator path(Node const & start, Node const & end, Iterator it) const + { + Node node = end; + while (true) + { + *it++ = node; + node = previous.at(node); + if (node == start) + break; + } + return it; + } + + std::deque path(Node const & start, Node const & end) const + { + std::deque result; + path(start, end, std::front_inserter(result)); + return result; + } + }; + + template + std::optional> find_path(Node const & start, Node const & end, NeighboursFn && node_neighbours, HeuristicFn && heuristic) + { + pathfinder helper(std::forward(node_neighbours), std::forward(heuristic)); + + helper.init(start); + bool found = false; + while (!helper.finished()) + { + if (helper.step() == end) + { + found = true; + break; + } + } + + if (found) + return helper.path(start, end); + + return std::nullopt; + } + +}