|  | 
|  | 1 | +//===-- llvm/ADT/RadixTree.h - Radix Tree implementation --------*- C++ -*-===// | 
|  | 2 | +// | 
|  | 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|  | 4 | +// See https://llvm.org/LICENSE.txt for license information. | 
|  | 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  | 6 | +//===----------------------------------------------------------------------===// | 
|  | 7 | +// | 
|  | 8 | +// This file implements a Radix Tree. | 
|  | 9 | +// | 
|  | 10 | +//===----------------------------------------------------------------------===// | 
|  | 11 | + | 
|  | 12 | +#ifndef LLVM_ADT_RADIXTREE_H | 
|  | 13 | +#define LLVM_ADT_RADIXTREE_H | 
|  | 14 | + | 
|  | 15 | +#include "llvm/ADT/ADL.h" | 
|  | 16 | +#include "llvm/ADT/STLExtras.h" | 
|  | 17 | +#include "llvm/ADT/iterator.h" | 
|  | 18 | +#include "llvm/ADT/iterator_range.h" | 
|  | 19 | +#include <cassert> | 
|  | 20 | +#include <cstddef> | 
|  | 21 | +#include <iterator> | 
|  | 22 | +#include <limits> | 
|  | 23 | +#include <list> | 
|  | 24 | +#include <utility> | 
|  | 25 | + | 
|  | 26 | +namespace llvm { | 
|  | 27 | + | 
|  | 28 | +/// \brief A Radix Tree implementation. | 
|  | 29 | +/// | 
|  | 30 | +/// A Radix Tree (also known as a compact prefix tree or radix trie) is a | 
|  | 31 | +/// data structure that stores a dynamic set or associative array where keys | 
|  | 32 | +/// are strings and values are associated with these keys. Unlike a regular | 
|  | 33 | +/// trie, the edges of a radix tree can be labeled with sequences of characters | 
|  | 34 | +/// as well as single characters. This makes radix trees more efficient for | 
|  | 35 | +/// storing sparse data sets, where many nodes in a regular trie would have | 
|  | 36 | +/// only one child. | 
|  | 37 | +/// | 
|  | 38 | +/// This implementation supports arbitrary key types that can be iterated over | 
|  | 39 | +/// (e.g., `std::string`, `std::vector<char>`, `ArrayRef<char>`). The key type | 
|  | 40 | +/// must provide `begin()` and `end()` for iteration. | 
|  | 41 | +/// | 
|  | 42 | +/// The tree stores `std::pair<const KeyType, T>` as its value type. | 
|  | 43 | +/// | 
|  | 44 | +/// Example usage: | 
|  | 45 | +/// \code | 
|  | 46 | +///   llvm::RadixTree<StringRef, int> Tree; | 
|  | 47 | +///   Tree.emplace("apple", 1); | 
|  | 48 | +///   Tree.emplace("grapefruit", 2); | 
|  | 49 | +///   Tree.emplace("grape", 3); | 
|  | 50 | +/// | 
|  | 51 | +///   // Find prefixes | 
|  | 52 | +///   for (const auto &[Key, Value] : Tree.find_prefixes("grapefruit juice")) { | 
|  | 53 | +///     // pair will be {"grape", 3} | 
|  | 54 | +///     // pair will be {"grapefruit", 2} | 
|  | 55 | +///     llvm::outs() << Key << ": " << Value << "\n"; | 
|  | 56 | +///   } | 
|  | 57 | +/// | 
|  | 58 | +///   // Iterate over all elements | 
|  | 59 | +///   for (const auto &[Key, Value] : Tree) | 
|  | 60 | +///     llvm::outs() << Key << ": " << Value << "\n"; | 
|  | 61 | +/// \endcode | 
|  | 62 | +/// | 
|  | 63 | +/// \note | 
|  | 64 | +/// The `RadixTree` takes ownership of the `KeyType` and `T` objects | 
|  | 65 | +/// inserted into it. When an element is removed or the tree is destroyed, | 
|  | 66 | +/// these objects will be destructed. | 
|  | 67 | +/// However, if `KeyType` is a reference-like type, e.g., StringRef or range, | 
|  | 68 | +/// the user must guarantee that the referenced data has a lifetime longer than | 
|  | 69 | +/// the tree. | 
|  | 70 | +template <typename KeyType, typename T> class RadixTree { | 
|  | 71 | +public: | 
|  | 72 | +  using key_type = KeyType; | 
|  | 73 | +  using mapped_type = T; | 
|  | 74 | +  using value_type = std::pair<const KeyType, mapped_type>; | 
|  | 75 | + | 
|  | 76 | +private: | 
|  | 77 | +  using KeyConstIteratorType = | 
|  | 78 | +      decltype(adl_begin(std::declval<const key_type &>())); | 
|  | 79 | +  using KeyConstIteratorRangeType = iterator_range<KeyConstIteratorType>; | 
|  | 80 | +  using KeyValueType = | 
|  | 81 | +      remove_cvref_t<decltype(*adl_begin(std::declval<key_type &>()))>; | 
|  | 82 | +  using ContainerType = std::list<value_type>; | 
|  | 83 | + | 
|  | 84 | +  /// Represents an internal node in the Radix Tree. | 
|  | 85 | +  struct Node { | 
|  | 86 | +    KeyConstIteratorRangeType Key{KeyConstIteratorType{}, | 
|  | 87 | +                                  KeyConstIteratorType{}}; | 
|  | 88 | +    std::vector<Node> Children; | 
|  | 89 | + | 
|  | 90 | +    /// An iterator to the value associated with this node. | 
|  | 91 | +    /// | 
|  | 92 | +    /// If this node does not have a value (i.e., it's an internal node that | 
|  | 93 | +    /// only serves as a path to other values), this iterator will be equal | 
|  | 94 | +    /// to default constructed `ContainerType::iterator()`. | 
|  | 95 | +    typename ContainerType::iterator Value; | 
|  | 96 | + | 
|  | 97 | +    /// The first character of the Key. Used for fast child lookup. | 
|  | 98 | +    KeyValueType KeyFront; | 
|  | 99 | + | 
|  | 100 | +    Node() = default; | 
|  | 101 | +    Node(const KeyConstIteratorRangeType &Key) | 
|  | 102 | +        : Key(Key), KeyFront(*Key.begin()) { | 
|  | 103 | +      assert(!Key.empty()); | 
|  | 104 | +    } | 
|  | 105 | + | 
|  | 106 | +    Node(Node &&) = default; | 
|  | 107 | +    Node &operator=(Node &&) = default; | 
|  | 108 | + | 
|  | 109 | +    Node(const Node &) = delete; | 
|  | 110 | +    Node &operator=(const Node &) = delete; | 
|  | 111 | + | 
|  | 112 | +    const Node *findChild(const KeyConstIteratorRangeType &Key) const { | 
|  | 113 | +      if (Key.empty()) | 
|  | 114 | +        return nullptr; | 
|  | 115 | +      for (const Node &Child : Children) { | 
|  | 116 | +        assert(!Child.Key.empty()); // Only root can be empty. | 
|  | 117 | +        if (Child.KeyFront == *Key.begin()) | 
|  | 118 | +          return &Child; | 
|  | 119 | +      } | 
|  | 120 | +      return nullptr; | 
|  | 121 | +    } | 
|  | 122 | + | 
|  | 123 | +    Node *findChild(const KeyConstIteratorRangeType &Query) { | 
|  | 124 | +      const Node *This = this; | 
|  | 125 | +      return const_cast<Node *>(This->findChild(Query)); | 
|  | 126 | +    } | 
|  | 127 | + | 
|  | 128 | +    size_t countNodes() const { | 
|  | 129 | +      size_t R = 1; | 
|  | 130 | +      for (const Node &C : Children) | 
|  | 131 | +        R += C.countNodes(); | 
|  | 132 | +      return R; | 
|  | 133 | +    } | 
|  | 134 | + | 
|  | 135 | +    /// | 
|  | 136 | +    /// Splits the current node into two. | 
|  | 137 | +    /// | 
|  | 138 | +    /// This function is used when a new key needs to be inserted that shares | 
|  | 139 | +    /// a common prefix with the current node's key, but then diverges. | 
|  | 140 | +    /// The current `Key` is truncated to the common prefix, and a new child | 
|  | 141 | +    /// node is created for the remainder of the original node's `Key`. | 
|  | 142 | +    /// | 
|  | 143 | +    /// \param SplitPoint An iterator pointing to the character in the current | 
|  | 144 | +    ///                   `Key` where the split should occur. | 
|  | 145 | +    void split(KeyConstIteratorType SplitPoint) { | 
|  | 146 | +      Node Child(make_range(SplitPoint, Key.end())); | 
|  | 147 | +      Key = make_range(Key.begin(), SplitPoint); | 
|  | 148 | + | 
|  | 149 | +      Children.swap(Child.Children); | 
|  | 150 | +      std::swap(Value, Child.Value); | 
|  | 151 | + | 
|  | 152 | +      Children.emplace_back(std::move(Child)); | 
|  | 153 | +    } | 
|  | 154 | +  }; | 
|  | 155 | + | 
|  | 156 | +  /// Root always corresponds to the empty key, which is the shortest possible | 
|  | 157 | +  /// prefix for everything. | 
|  | 158 | +  Node Root; | 
|  | 159 | +  ContainerType KeyValuePairs; | 
|  | 160 | + | 
|  | 161 | +  /// Finds or creates a new tail or leaf node corresponding to the `Key`. | 
|  | 162 | +  Node &findOrCreate(KeyConstIteratorRangeType Key) { | 
|  | 163 | +    Node *Curr = &Root; | 
|  | 164 | +    if (Key.empty()) | 
|  | 165 | +      return *Curr; | 
|  | 166 | + | 
|  | 167 | +    for (;;) { | 
|  | 168 | +      auto [I1, I2] = llvm::mismatch(Key, Curr->Key); | 
|  | 169 | +      Key = make_range(I1, Key.end()); | 
|  | 170 | + | 
|  | 171 | +      if (I2 != Curr->Key.end()) { | 
|  | 172 | +        // Match is partial. Either query is too short, or there is mismatching | 
|  | 173 | +        // character. Split either way, and put new node in between of the | 
|  | 174 | +        // current and its children. | 
|  | 175 | +        Curr->split(I2); | 
|  | 176 | + | 
|  | 177 | +        // Split was caused by mismatch, so `findChild` would fail. | 
|  | 178 | +        break; | 
|  | 179 | +      } | 
|  | 180 | + | 
|  | 181 | +      Node *Child = Curr->findChild(Key); | 
|  | 182 | +      if (!Child) | 
|  | 183 | +        break; | 
|  | 184 | + | 
|  | 185 | +      // Move to child with the same first character. | 
|  | 186 | +      Curr = Child; | 
|  | 187 | +    } | 
|  | 188 | + | 
|  | 189 | +    if (Key.empty()) { | 
|  | 190 | +      // The current node completely matches the key, return it. | 
|  | 191 | +      return *Curr; | 
|  | 192 | +    } | 
|  | 193 | + | 
|  | 194 | +    // `Key` is a suffix of original `Key` unmatched by path from the `Root` to | 
|  | 195 | +    // the `Curr`, and we have no candidate in the children to match more. | 
|  | 196 | +    // Create a new one. | 
|  | 197 | +    return Curr->Children.emplace_back(Key); | 
|  | 198 | +  } | 
|  | 199 | + | 
|  | 200 | +  /// | 
|  | 201 | +  /// An iterator for traversing prefixes search results. | 
|  | 202 | +  /// | 
|  | 203 | +  /// This iterator is used by `find_prefixes` to traverse the tree and find | 
|  | 204 | +  /// elements that are prefixes to the given key. It's a forward iterator. | 
|  | 205 | +  /// | 
|  | 206 | +  /// \tparam MappedType The type of the value pointed to by the iterator. | 
|  | 207 | +  ///                    This will be `value_type` for non-const iterators | 
|  | 208 | +  ///                    and `const value_type` for const iterators. | 
|  | 209 | +  template <typename MappedType> | 
|  | 210 | +  class IteratorImpl | 
|  | 211 | +      : public iterator_facade_base<IteratorImpl<MappedType>, | 
|  | 212 | +                                    std::forward_iterator_tag, MappedType> { | 
|  | 213 | +    const Node *Curr = nullptr; | 
|  | 214 | +    KeyConstIteratorRangeType Query{KeyConstIteratorType{}, | 
|  | 215 | +                                    KeyConstIteratorType{}}; | 
|  | 216 | + | 
|  | 217 | +    void findNextValid() { | 
|  | 218 | +      while (Curr && Curr->Value == typename ContainerType::iterator()) | 
|  | 219 | +        advance(); | 
|  | 220 | +    } | 
|  | 221 | + | 
|  | 222 | +    void advance() { | 
|  | 223 | +      assert(Curr); | 
|  | 224 | +      if (Query.empty()) { | 
|  | 225 | +        Curr = nullptr; | 
|  | 226 | +        return; | 
|  | 227 | +      } | 
|  | 228 | + | 
|  | 229 | +      Curr = Curr->findChild(Query); | 
|  | 230 | +      if (!Curr) { | 
|  | 231 | +        Curr = nullptr; | 
|  | 232 | +        return; | 
|  | 233 | +      } | 
|  | 234 | + | 
|  | 235 | +      auto [I1, I2] = llvm::mismatch(Query, Curr->Key); | 
|  | 236 | +      if (I2 != Curr->Key.end()) { | 
|  | 237 | +        Curr = nullptr; | 
|  | 238 | +        return; | 
|  | 239 | +      } | 
|  | 240 | +      Query = make_range(I1, Query.end()); | 
|  | 241 | +    } | 
|  | 242 | + | 
|  | 243 | +    friend class RadixTree; | 
|  | 244 | +    IteratorImpl(const Node *C, const KeyConstIteratorRangeType &Q) | 
|  | 245 | +        : Curr(C), Query(Q) { | 
|  | 246 | +      findNextValid(); | 
|  | 247 | +    } | 
|  | 248 | + | 
|  | 249 | +  public: | 
|  | 250 | +    IteratorImpl() = default; | 
|  | 251 | + | 
|  | 252 | +    MappedType &operator*() const { return *Curr->Value; } | 
|  | 253 | + | 
|  | 254 | +    IteratorImpl &operator++() { | 
|  | 255 | +      advance(); | 
|  | 256 | +      findNextValid(); | 
|  | 257 | +      return *this; | 
|  | 258 | +    } | 
|  | 259 | + | 
|  | 260 | +    bool operator==(const IteratorImpl &Other) const { | 
|  | 261 | +      return Curr == Other.Curr; | 
|  | 262 | +    } | 
|  | 263 | +  }; | 
|  | 264 | + | 
|  | 265 | +public: | 
|  | 266 | +  RadixTree() = default; | 
|  | 267 | +  RadixTree(RadixTree &&) = default; | 
|  | 268 | +  RadixTree &operator=(RadixTree &&) = default; | 
|  | 269 | + | 
|  | 270 | +  using prefix_iterator = IteratorImpl<value_type>; | 
|  | 271 | +  using const_prefix_iterator = IteratorImpl<const value_type>; | 
|  | 272 | + | 
|  | 273 | +  using iterator = typename ContainerType::iterator; | 
|  | 274 | +  using const_iterator = typename ContainerType::const_iterator; | 
|  | 275 | + | 
|  | 276 | +  /// Returns true if the tree is empty. | 
|  | 277 | +  bool empty() const { return KeyValuePairs.empty(); } | 
|  | 278 | + | 
|  | 279 | +  /// Returns the number of elements in the tree. | 
|  | 280 | +  size_t size() const { return KeyValuePairs.size(); } | 
|  | 281 | + | 
|  | 282 | +  /// Returns the number of nodes in the tree. | 
|  | 283 | +  /// | 
|  | 284 | +  /// This function counts all internal nodes in the tree. It can be useful for | 
|  | 285 | +  /// understanding the memory footprint or complexity of the tree structure. | 
|  | 286 | +  size_t countNodes() const { return Root.countNodes(); } | 
|  | 287 | + | 
|  | 288 | +  /// Returns an iterator to the first element. | 
|  | 289 | +  iterator begin() { return KeyValuePairs.begin(); } | 
|  | 290 | +  const_iterator begin() const { return KeyValuePairs.begin(); } | 
|  | 291 | + | 
|  | 292 | +  /// Returns an iterator to the end of the tree. | 
|  | 293 | +  iterator end() { return KeyValuePairs.end(); } | 
|  | 294 | +  const_iterator end() const { return KeyValuePairs.end(); } | 
|  | 295 | + | 
|  | 296 | +  /// Constructs and inserts a new element into the tree. | 
|  | 297 | +  /// | 
|  | 298 | +  /// This function constructs an element in place within the tree. If an | 
|  | 299 | +  /// element with the same key already exists, the insertion fails and the | 
|  | 300 | +  /// function returns an iterator to the existing element along with `false`. | 
|  | 301 | +  /// Otherwise, the new element is inserted and the function returns an | 
|  | 302 | +  /// iterator to the new element along with `true`. | 
|  | 303 | +  /// | 
|  | 304 | +  /// \param Key The key of the element to construct. | 
|  | 305 | +  /// \param Args Arguments to forward to the constructor of the mapped_type. | 
|  | 306 | +  /// \return A pair consisting of an iterator to the inserted element (or to | 
|  | 307 | +  ///         the element that prevented insertion) and a boolean value | 
|  | 308 | +  ///         indicating whether the insertion took place. | 
|  | 309 | +  template <typename... Ts> | 
|  | 310 | +  std::pair<iterator, bool> emplace(key_type &&Key, Ts &&...Args) { | 
|  | 311 | +    // We want to make new `Node` to refer key in the container, not the one | 
|  | 312 | +    // from the argument. | 
|  | 313 | +    // FIXME: Determine that we need a new node, before expanding | 
|  | 314 | +    // `KeyValuePairs`. | 
|  | 315 | +    const value_type &NewValue = KeyValuePairs.emplace_front( | 
|  | 316 | +        std::move(Key), T(std::forward<Ts>(Args)...)); | 
|  | 317 | +    Node &Node = findOrCreate(NewValue.first); | 
|  | 318 | +    bool HasValue = Node.Value != typename ContainerType::iterator(); | 
|  | 319 | +    if (!HasValue) | 
|  | 320 | +      Node.Value = KeyValuePairs.begin(); | 
|  | 321 | +    else | 
|  | 322 | +      KeyValuePairs.pop_front(); | 
|  | 323 | +    return {Node.Value, !HasValue}; | 
|  | 324 | +  } | 
|  | 325 | + | 
|  | 326 | +  /// | 
|  | 327 | +  /// Finds all elements whose keys are prefixes of the given `Key`. | 
|  | 328 | +  /// | 
|  | 329 | +  /// This function returns an iterator range over all elements in the tree | 
|  | 330 | +  /// whose keys are prefixes of the provided `Key`. For example, if the tree | 
|  | 331 | +  /// contains "abcde", "abc", "abcdefgh", and `Key` is "abcde", this function | 
|  | 332 | +  /// would return iterators to "abcde" and "abc". | 
|  | 333 | +  /// | 
|  | 334 | +  /// \param Key The key to search for prefixes of. | 
|  | 335 | +  /// \return An `iterator_range` of `const_prefix_iterator`s, allowing | 
|  | 336 | +  ///         iteration over the found prefix elements. | 
|  | 337 | +  /// \note The returned iterators reference the `Key` provided by the caller. | 
|  | 338 | +  ///       The caller must ensure that `Key` remains valid for the lifetime | 
|  | 339 | +  ///       of the iterators. | 
|  | 340 | +  iterator_range<const_prefix_iterator> | 
|  | 341 | +  find_prefixes(const key_type &Key) const { | 
|  | 342 | +    return iterator_range<const_prefix_iterator>{ | 
|  | 343 | +        const_prefix_iterator(&Root, KeyConstIteratorRangeType(Key)), | 
|  | 344 | +        const_prefix_iterator{}}; | 
|  | 345 | +  } | 
|  | 346 | +}; | 
|  | 347 | + | 
|  | 348 | +} // namespace llvm | 
|  | 349 | + | 
|  | 350 | +#endif // LLVM_ADT_RADIXTREE_H | 
0 commit comments