|
| 1 | +from contextlib import contextmanager |
| 2 | +import tracemalloc |
| 3 | + |
1 | 4 | import numpy as np |
2 | 5 | import pytest |
3 | 6 |
|
|
6 | 9 | import pandas._testing as tm |
7 | 10 |
|
8 | 11 |
|
| 12 | +@contextmanager |
| 13 | +def activated_tracemalloc(): |
| 14 | + tracemalloc.start() |
| 15 | + try: |
| 16 | + yield |
| 17 | + finally: |
| 18 | + tracemalloc.stop() |
| 19 | + |
| 20 | + |
| 21 | +def get_allocated_khash_memory(): |
| 22 | + snapshot = tracemalloc.take_snapshot() |
| 23 | + snapshot = snapshot.filter_traces( |
| 24 | + (tracemalloc.DomainFilter(True, ht.get_hashtable_trace_domain()),) |
| 25 | + ) |
| 26 | + return sum(map(lambda x: x.size, snapshot.traces)) |
| 27 | + |
| 28 | + |
9 | 29 | @pytest.mark.parametrize( |
10 | 30 | "table_type, dtype", |
11 | 31 | [ |
| 32 | + (ht.PyObjectHashTable, np.object_), |
12 | 33 | (ht.Int64HashTable, np.int64), |
13 | 34 | (ht.UInt64HashTable, np.uint64), |
14 | 35 | (ht.Float64HashTable, np.float64), |
@@ -53,13 +74,15 @@ def test_get_set_contains_len(self, table_type, dtype): |
53 | 74 | assert str(index + 2) in str(excinfo.value) |
54 | 75 |
|
55 | 76 | def test_map(self, table_type, dtype): |
56 | | - N = 77 |
57 | | - table = table_type() |
58 | | - keys = np.arange(N).astype(dtype) |
59 | | - vals = np.arange(N).astype(np.int64) + N |
60 | | - table.map(keys, vals) |
61 | | - for i in range(N): |
62 | | - assert table.get_item(keys[i]) == i + N |
| 77 | + # PyObjectHashTable has no map-method |
| 78 | + if table_type != ht.PyObjectHashTable: |
| 79 | + N = 77 |
| 80 | + table = table_type() |
| 81 | + keys = np.arange(N).astype(dtype) |
| 82 | + vals = np.arange(N).astype(np.int64) + N |
| 83 | + table.map(keys, vals) |
| 84 | + for i in range(N): |
| 85 | + assert table.get_item(keys[i]) == i + N |
63 | 86 |
|
64 | 87 | def test_map_locations(self, table_type, dtype): |
65 | 88 | N = 8 |
@@ -101,6 +124,53 @@ def test_unique(self, table_type, dtype): |
101 | 124 | unique = table.unique(keys) |
102 | 125 | tm.assert_numpy_array_equal(unique, expected) |
103 | 126 |
|
| 127 | + def test_tracemalloc_works(self, table_type, dtype): |
| 128 | + if dtype in (np.int8, np.uint8): |
| 129 | + N = 256 |
| 130 | + else: |
| 131 | + N = 30000 |
| 132 | + keys = np.arange(N).astype(dtype) |
| 133 | + with activated_tracemalloc(): |
| 134 | + table = table_type() |
| 135 | + table.map_locations(keys) |
| 136 | + used = get_allocated_khash_memory() |
| 137 | + my_size = table.sizeof() |
| 138 | + assert used == my_size |
| 139 | + del table |
| 140 | + assert get_allocated_khash_memory() == 0 |
| 141 | + |
| 142 | + def test_tracemalloc_for_empty(self, table_type, dtype): |
| 143 | + with activated_tracemalloc(): |
| 144 | + table = table_type() |
| 145 | + used = get_allocated_khash_memory() |
| 146 | + my_size = table.sizeof() |
| 147 | + assert used == my_size |
| 148 | + del table |
| 149 | + assert get_allocated_khash_memory() == 0 |
| 150 | + |
| 151 | + |
| 152 | +def test_tracemalloc_works_for_StringHashTable(): |
| 153 | + N = 1000 |
| 154 | + keys = np.arange(N).astype(np.compat.unicode).astype(np.object_) |
| 155 | + with activated_tracemalloc(): |
| 156 | + table = ht.StringHashTable() |
| 157 | + table.map_locations(keys) |
| 158 | + used = get_allocated_khash_memory() |
| 159 | + my_size = table.sizeof() |
| 160 | + assert used == my_size |
| 161 | + del table |
| 162 | + assert get_allocated_khash_memory() == 0 |
| 163 | + |
| 164 | + |
| 165 | +def test_tracemalloc_for_empty_StringHashTable(): |
| 166 | + with activated_tracemalloc(): |
| 167 | + table = ht.StringHashTable() |
| 168 | + used = get_allocated_khash_memory() |
| 169 | + my_size = table.sizeof() |
| 170 | + assert used == my_size |
| 171 | + del table |
| 172 | + assert get_allocated_khash_memory() == 0 |
| 173 | + |
104 | 174 |
|
105 | 175 | @pytest.mark.parametrize( |
106 | 176 | "table_type, dtype", |
@@ -157,6 +227,7 @@ def get_ht_function(fun_name, type_suffix): |
157 | 227 | @pytest.mark.parametrize( |
158 | 228 | "dtype, type_suffix", |
159 | 229 | [ |
| 230 | + (np.object_, "object"), |
160 | 231 | (np.int64, "int64"), |
161 | 232 | (np.uint64, "uint64"), |
162 | 233 | (np.float64, "float64"), |
|
0 commit comments