Skip to content

Commit 3002e01

Browse files
vitalybukakazutakahirataCopilot
authored andcommitted
[NFC][ADT] Add RadixTree (llvm#164524)
This commit introduces a RadixTree implementation to LLVM. RadixTree, as a Trie, is very efficient by searching for prefixes. A Radix Tree is more efficient implementation of Trie. The tree will be used to optimize Glob matching in SpecialCaseList: * llvm#164531 * llvm#164543 * llvm#164545 --------- Co-authored-by: Kazu Hirata <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 85c15c2 commit 3002e01

File tree

4 files changed

+740
-0
lines changed

4 files changed

+740
-0
lines changed

llvm/docs/ProgrammersManual.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2161,6 +2161,16 @@ that are not simple pointers (use :ref:`SmallPtrSet <dss_smallptrset>` for
21612161
pointers). Note that ``DenseSet`` has the same requirements for the value type that
21622162
:ref:`DenseMap <dss_densemap>` has.
21632163

2164+
.. _dss_radixtree:
2165+
2166+
llvm/ADT/RadixTree.h
2167+
^^^^^^^^^^^^^^^^^^^^
2168+
2169+
``RadixTree`` is a trie-based data structure that stores range-like keys and
2170+
their associated values. It is particularly efficient for storing keys that
2171+
share common prefixes, as it can compress these prefixes to save memory. It
2172+
supports efficient search of matching prefixes.
2173+
21642174
.. _dss_sparseset:
21652175

21662176
llvm/ADT/SparseSet.h

llvm/include/llvm/ADT/RadixTree.h

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
//===-- llvm/ADT/RadixTree.h - Radix Tree implementation --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//===----------------------------------------------------------------------===//
7+
//
8+
// This file implements a Radix Tree.
9+
//
10+
//===----------------------------------------------------------------------===//
11+
12+
#ifndef LLVM_ADT_RADIXTREE_H
13+
#define LLVM_ADT_RADIXTREE_H
14+
15+
#include "llvm/ADT/ADL.h"
16+
#include "llvm/ADT/STLExtras.h"
17+
#include "llvm/ADT/iterator.h"
18+
#include "llvm/ADT/iterator_range.h"
19+
#include <cassert>
20+
#include <cstddef>
21+
#include <iterator>
22+
#include <limits>
23+
#include <list>
24+
#include <utility>
25+
26+
namespace llvm {
27+
28+
/// \brief A Radix Tree implementation.
29+
///
30+
/// A Radix Tree (also known as a compact prefix tree or radix trie) is a
31+
/// data structure that stores a dynamic set or associative array where keys
32+
/// are strings and values are associated with these keys. Unlike a regular
33+
/// trie, the edges of a radix tree can be labeled with sequences of characters
34+
/// as well as single characters. This makes radix trees more efficient for
35+
/// storing sparse data sets, where many nodes in a regular trie would have
36+
/// only one child.
37+
///
38+
/// This implementation supports arbitrary key types that can be iterated over
39+
/// (e.g., `std::string`, `std::vector<char>`, `ArrayRef<char>`). The key type
40+
/// must provide `begin()` and `end()` for iteration.
41+
///
42+
/// The tree stores `std::pair<const KeyType, T>` as its value type.
43+
///
44+
/// Example usage:
45+
/// \code
46+
/// llvm::RadixTree<StringRef, int> Tree;
47+
/// Tree.emplace("apple", 1);
48+
/// Tree.emplace("grapefruit", 2);
49+
/// Tree.emplace("grape", 3);
50+
///
51+
/// // Find prefixes
52+
/// for (const auto &[Key, Value] : Tree.find_prefixes("grapefruit juice")) {
53+
/// // pair will be {"grape", 3}
54+
/// // pair will be {"grapefruit", 2}
55+
/// llvm::outs() << Key << ": " << Value << "\n";
56+
/// }
57+
///
58+
/// // Iterate over all elements
59+
/// for (const auto &[Key, Value] : Tree)
60+
/// llvm::outs() << Key << ": " << Value << "\n";
61+
/// \endcode
62+
///
63+
/// \note
64+
/// The `RadixTree` takes ownership of the `KeyType` and `T` objects
65+
/// inserted into it. When an element is removed or the tree is destroyed,
66+
/// these objects will be destructed.
67+
/// However, if `KeyType` is a reference-like type, e.g., StringRef or range,
68+
/// the user must guarantee that the referenced data has a lifetime longer than
69+
/// the tree.
70+
template <typename KeyType, typename T> class RadixTree {
71+
public:
72+
using key_type = KeyType;
73+
using mapped_type = T;
74+
using value_type = std::pair<const KeyType, mapped_type>;
75+
76+
private:
77+
using KeyConstIteratorType =
78+
decltype(adl_begin(std::declval<const key_type &>()));
79+
using KeyConstIteratorRangeType = iterator_range<KeyConstIteratorType>;
80+
using KeyValueType =
81+
remove_cvref_t<decltype(*adl_begin(std::declval<key_type &>()))>;
82+
using ContainerType = std::list<value_type>;
83+
84+
/// Represents an internal node in the Radix Tree.
85+
struct Node {
86+
KeyConstIteratorRangeType Key{KeyConstIteratorType{},
87+
KeyConstIteratorType{}};
88+
std::vector<Node> Children;
89+
90+
/// An iterator to the value associated with this node.
91+
///
92+
/// If this node does not have a value (i.e., it's an internal node that
93+
/// only serves as a path to other values), this iterator will be equal
94+
/// to default constructed `ContainerType::iterator()`.
95+
typename ContainerType::iterator Value;
96+
97+
/// The first character of the Key. Used for fast child lookup.
98+
KeyValueType KeyFront;
99+
100+
Node() = default;
101+
Node(const KeyConstIteratorRangeType &Key)
102+
: Key(Key), KeyFront(*Key.begin()) {
103+
assert(!Key.empty());
104+
}
105+
106+
Node(Node &&) = default;
107+
Node &operator=(Node &&) = default;
108+
109+
Node(const Node &) = delete;
110+
Node &operator=(const Node &) = delete;
111+
112+
const Node *findChild(const KeyConstIteratorRangeType &Key) const {
113+
if (Key.empty())
114+
return nullptr;
115+
for (const Node &Child : Children) {
116+
assert(!Child.Key.empty()); // Only root can be empty.
117+
if (Child.KeyFront == *Key.begin())
118+
return &Child;
119+
}
120+
return nullptr;
121+
}
122+
123+
Node *findChild(const KeyConstIteratorRangeType &Query) {
124+
const Node *This = this;
125+
return const_cast<Node *>(This->findChild(Query));
126+
}
127+
128+
size_t countNodes() const {
129+
size_t R = 1;
130+
for (const Node &C : Children)
131+
R += C.countNodes();
132+
return R;
133+
}
134+
135+
///
136+
/// Splits the current node into two.
137+
///
138+
/// This function is used when a new key needs to be inserted that shares
139+
/// a common prefix with the current node's key, but then diverges.
140+
/// The current `Key` is truncated to the common prefix, and a new child
141+
/// node is created for the remainder of the original node's `Key`.
142+
///
143+
/// \param SplitPoint An iterator pointing to the character in the current
144+
/// `Key` where the split should occur.
145+
void split(KeyConstIteratorType SplitPoint) {
146+
Node Child(make_range(SplitPoint, Key.end()));
147+
Key = make_range(Key.begin(), SplitPoint);
148+
149+
Children.swap(Child.Children);
150+
std::swap(Value, Child.Value);
151+
152+
Children.emplace_back(std::move(Child));
153+
}
154+
};
155+
156+
/// Root always corresponds to the empty key, which is the shortest possible
157+
/// prefix for everything.
158+
Node Root;
159+
ContainerType KeyValuePairs;
160+
161+
/// Finds or creates a new tail or leaf node corresponding to the `Key`.
162+
Node &findOrCreate(KeyConstIteratorRangeType Key) {
163+
Node *Curr = &Root;
164+
if (Key.empty())
165+
return *Curr;
166+
167+
for (;;) {
168+
auto [I1, I2] = llvm::mismatch(Key, Curr->Key);
169+
Key = make_range(I1, Key.end());
170+
171+
if (I2 != Curr->Key.end()) {
172+
// Match is partial. Either query is too short, or there is mismatching
173+
// character. Split either way, and put new node in between of the
174+
// current and its children.
175+
Curr->split(I2);
176+
177+
// Split was caused by mismatch, so `findChild` would fail.
178+
break;
179+
}
180+
181+
Node *Child = Curr->findChild(Key);
182+
if (!Child)
183+
break;
184+
185+
// Move to child with the same first character.
186+
Curr = Child;
187+
}
188+
189+
if (Key.empty()) {
190+
// The current node completely matches the key, return it.
191+
return *Curr;
192+
}
193+
194+
// `Key` is a suffix of original `Key` unmatched by path from the `Root` to
195+
// the `Curr`, and we have no candidate in the children to match more.
196+
// Create a new one.
197+
return Curr->Children.emplace_back(Key);
198+
}
199+
200+
///
201+
/// An iterator for traversing prefixes search results.
202+
///
203+
/// This iterator is used by `find_prefixes` to traverse the tree and find
204+
/// elements that are prefixes to the given key. It's a forward iterator.
205+
///
206+
/// \tparam MappedType The type of the value pointed to by the iterator.
207+
/// This will be `value_type` for non-const iterators
208+
/// and `const value_type` for const iterators.
209+
template <typename MappedType>
210+
class IteratorImpl
211+
: public iterator_facade_base<IteratorImpl<MappedType>,
212+
std::forward_iterator_tag, MappedType> {
213+
const Node *Curr = nullptr;
214+
KeyConstIteratorRangeType Query{KeyConstIteratorType{},
215+
KeyConstIteratorType{}};
216+
217+
void findNextValid() {
218+
while (Curr && Curr->Value == typename ContainerType::iterator())
219+
advance();
220+
}
221+
222+
void advance() {
223+
assert(Curr);
224+
if (Query.empty()) {
225+
Curr = nullptr;
226+
return;
227+
}
228+
229+
Curr = Curr->findChild(Query);
230+
if (!Curr) {
231+
Curr = nullptr;
232+
return;
233+
}
234+
235+
auto [I1, I2] = llvm::mismatch(Query, Curr->Key);
236+
if (I2 != Curr->Key.end()) {
237+
Curr = nullptr;
238+
return;
239+
}
240+
Query = make_range(I1, Query.end());
241+
}
242+
243+
friend class RadixTree;
244+
IteratorImpl(const Node *C, const KeyConstIteratorRangeType &Q)
245+
: Curr(C), Query(Q) {
246+
findNextValid();
247+
}
248+
249+
public:
250+
IteratorImpl() = default;
251+
252+
MappedType &operator*() const { return *Curr->Value; }
253+
254+
IteratorImpl &operator++() {
255+
advance();
256+
findNextValid();
257+
return *this;
258+
}
259+
260+
bool operator==(const IteratorImpl &Other) const {
261+
return Curr == Other.Curr;
262+
}
263+
};
264+
265+
public:
266+
RadixTree() = default;
267+
RadixTree(RadixTree &&) = default;
268+
RadixTree &operator=(RadixTree &&) = default;
269+
270+
using prefix_iterator = IteratorImpl<value_type>;
271+
using const_prefix_iterator = IteratorImpl<const value_type>;
272+
273+
using iterator = typename ContainerType::iterator;
274+
using const_iterator = typename ContainerType::const_iterator;
275+
276+
/// Returns true if the tree is empty.
277+
bool empty() const { return KeyValuePairs.empty(); }
278+
279+
/// Returns the number of elements in the tree.
280+
size_t size() const { return KeyValuePairs.size(); }
281+
282+
/// Returns the number of nodes in the tree.
283+
///
284+
/// This function counts all internal nodes in the tree. It can be useful for
285+
/// understanding the memory footprint or complexity of the tree structure.
286+
size_t countNodes() const { return Root.countNodes(); }
287+
288+
/// Returns an iterator to the first element.
289+
iterator begin() { return KeyValuePairs.begin(); }
290+
const_iterator begin() const { return KeyValuePairs.begin(); }
291+
292+
/// Returns an iterator to the end of the tree.
293+
iterator end() { return KeyValuePairs.end(); }
294+
const_iterator end() const { return KeyValuePairs.end(); }
295+
296+
/// Constructs and inserts a new element into the tree.
297+
///
298+
/// This function constructs an element in place within the tree. If an
299+
/// element with the same key already exists, the insertion fails and the
300+
/// function returns an iterator to the existing element along with `false`.
301+
/// Otherwise, the new element is inserted and the function returns an
302+
/// iterator to the new element along with `true`.
303+
///
304+
/// \param Key The key of the element to construct.
305+
/// \param Args Arguments to forward to the constructor of the mapped_type.
306+
/// \return A pair consisting of an iterator to the inserted element (or to
307+
/// the element that prevented insertion) and a boolean value
308+
/// indicating whether the insertion took place.
309+
template <typename... Ts>
310+
std::pair<iterator, bool> emplace(key_type &&Key, Ts &&...Args) {
311+
// We want to make new `Node` to refer key in the container, not the one
312+
// from the argument.
313+
// FIXME: Determine that we need a new node, before expanding
314+
// `KeyValuePairs`.
315+
const value_type &NewValue = KeyValuePairs.emplace_front(
316+
std::move(Key), T(std::forward<Ts>(Args)...));
317+
Node &Node = findOrCreate(NewValue.first);
318+
bool HasValue = Node.Value != typename ContainerType::iterator();
319+
if (!HasValue)
320+
Node.Value = KeyValuePairs.begin();
321+
else
322+
KeyValuePairs.pop_front();
323+
return {Node.Value, !HasValue};
324+
}
325+
326+
///
327+
/// Finds all elements whose keys are prefixes of the given `Key`.
328+
///
329+
/// This function returns an iterator range over all elements in the tree
330+
/// whose keys are prefixes of the provided `Key`. For example, if the tree
331+
/// contains "abcde", "abc", "abcdefgh", and `Key` is "abcde", this function
332+
/// would return iterators to "abcde" and "abc".
333+
///
334+
/// \param Key The key to search for prefixes of.
335+
/// \return An `iterator_range` of `const_prefix_iterator`s, allowing
336+
/// iteration over the found prefix elements.
337+
/// \note The returned iterators reference the `Key` provided by the caller.
338+
/// The caller must ensure that `Key` remains valid for the lifetime
339+
/// of the iterators.
340+
iterator_range<const_prefix_iterator>
341+
find_prefixes(const key_type &Key) const {
342+
return iterator_range<const_prefix_iterator>{
343+
const_prefix_iterator(&Root, KeyConstIteratorRangeType(Key)),
344+
const_prefix_iterator{}};
345+
}
346+
};
347+
348+
} // namespace llvm
349+
350+
#endif // LLVM_ADT_RADIXTREE_H

llvm/unittests/ADT/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ add_llvm_unittest(ADTTests
6363
PointerUnionTest.cpp
6464
PostOrderIteratorTest.cpp
6565
PriorityWorklistTest.cpp
66+
RadixTreeTest.cpp
6667
RangeAdapterTest.cpp
6768
RewriteBufferTest.cpp
6869
SCCIteratorTest.cpp

0 commit comments

Comments
 (0)