]> begriffs open source - libderp/blob - src/treemap.c
Remove gcda during tests, not as part of make
[libderp] / src / treemap.c
1 #include "list.h"
2 #include "treemap.h"
3
4 #include <stdlib.h>
5
6 /* AA (Arne Andersson) Tree:
7  * Balanced Search Trees Made Simple
8  * https://user.it.uu.se/~arnea/ps/simp.pdf
9  *
10  * As fast as an RB-tree, but free of those
11  * ugly special cases. */
12
13 struct tm_node
14 {
15         int level;
16         struct map_pair *pair;
17         struct tm_node *left, *right;
18 };
19
20 struct treemap
21 {
22         struct tm_node *root, *bottom;
23         struct tm_node *deleted, *last;
24
25         dtor *key_dtor;
26         dtor *val_dtor;
27         comparator *cmp;
28         void *cmp_aux;
29         void *dtor_aux;
30 };
31
32 struct tm_iter
33 {
34         list *stack;
35         struct tm_node *n, *bottom;
36 };
37
38 treemap *
39 tm_new(comparator *cmp, void *cmp_aux)
40 {
41         treemap *t = malloc(sizeof *t);
42         struct tm_node *bottom = malloc(sizeof *bottom);
43         if (!t || !bottom)
44         {
45                 free(t);
46                 free(bottom);
47                 return NULL;
48         }
49         /* sentinel living below all leaves */
50         *bottom = (struct tm_node){
51                 .left = bottom, .right = bottom, .level = 0
52         };
53         *t = (treemap){
54                 .root = bottom,
55                 .bottom = bottom,
56                 .deleted = bottom,
57                 .last = bottom,
58                 .cmp = cmp,
59                 .cmp_aux = cmp_aux
60         };
61         return t;
62 }
63
64 void
65 tm_free(treemap *t)
66 {
67         if (!t)
68                 return;
69         tm_clear(t);
70         free(t->bottom);
71         free(t);
72 }
73
74 void
75 tm_dtor(treemap *t, dtor *key_dtor, dtor *val_dtor, void *dtor_aux)
76 {
77         if (!t)
78                 return;
79         t->key_dtor = key_dtor;
80         t->val_dtor = val_dtor;
81         t->dtor_aux = dtor_aux;
82 }
83
84 static size_t
85 _tm_length(const struct tm_node *n, const struct tm_node *bottom)
86 {
87         if (n == bottom)
88                 return 0;
89         return 1 +
90                 _tm_length(n->left, bottom) + _tm_length(n->right, bottom);
91 }
92
93 size_t
94 tm_length(const treemap *t)
95 {
96         return t ? _tm_length(t->root, t->bottom) : 0;
97 }
98
99 bool
100 tm_is_empty(const treemap *t)
101 {
102         return tm_length(t) == 0;
103 }
104
105 static void *
106 _tm_at(const treemap *t, const struct tm_node *n, const void *key)
107 {
108         if (n == t->bottom)
109                 return NULL;
110         int x = t->cmp(key, n->pair->k, t->cmp_aux);
111         if (x == 0)
112                 return n->pair->v;
113         else if (x < 0)
114                 return _tm_at(t, n->left, key);
115         return _tm_at(t, n->right, key);
116 }
117
118 void *
119 tm_at(const treemap *t, const void *key)
120 {
121         return t ? _tm_at(t, t->root, key) : NULL;
122 }
123
124 static struct tm_node *
125 _tm_skew(struct tm_node *n) {
126    if (n->level != n->left->level)
127            return n;
128    struct tm_node *left = n->left;
129    n->left = left->right;
130    left->right = n;
131    n = left;
132    return n;
133 }
134
135 static struct tm_node *
136 _tm_split(struct tm_node *n) {
137    if(n->right->right->level != n->level)
138            return n;
139    struct tm_node *right = n->right;
140    n->right = right->left;
141    right->left = n;
142    n = right;
143    n->level++;
144    return n;
145 }
146
147 static struct tm_node *
148 _tm_insert(treemap *t, struct tm_node *n, struct tm_node *prealloc)
149 {
150         if (n == t->bottom)
151                 return prealloc;
152         int x = t->cmp(prealloc->pair->k, n->pair->k, t->cmp_aux);
153         if (x < 0)
154                 n->left = _tm_insert(t, n->left, prealloc);
155         else if (x > 0)
156                 n->right = _tm_insert(t, n->right, prealloc);
157         else
158         {
159                 /* prealloc was for naught, but we'll use its value */
160                 if (n->pair->v != prealloc->pair->v && t->val_dtor)
161                         t->val_dtor(n->pair->v, t->dtor_aux);
162                 if (n->pair->k != prealloc->pair->k && t->key_dtor)
163                         t->key_dtor(n->pair->k, t->dtor_aux);
164                 *n->pair = *prealloc->pair;
165                 free(prealloc->pair);
166                 free(prealloc);
167                 return n;
168         }
169         return _tm_split(_tm_skew(n));
170 }
171
172 bool
173 tm_insert(treemap *t, void *key, void *val)
174 {
175         if (!t)
176                 return false;
177         /* attempt the malloc before potentially splitting
178          * and skewing the tree, so the insertion can be a
179          * no-op on failure */
180         struct tm_node *prealloc = malloc(sizeof *prealloc);
181         struct map_pair *p = malloc(sizeof *p);
182         if (!prealloc || !p)
183         {
184                 free(prealloc);
185                 free(p);
186                 return false;
187         }
188         *p = (struct map_pair){.k = key, .v = val};
189         *prealloc = (struct tm_node){
190                 .level = 1, .pair = p, .left = t->bottom, .right = t->bottom
191         };
192
193         t->root = _tm_insert(t, t->root, prealloc);
194         return true;
195 }
196
197 static struct tm_node *
198 _tm_remove(treemap *t, struct tm_node *n, void *key)
199 {
200         if (n == t->bottom)
201                 return n;
202
203         /* 1: search down the tree and set pointers last and deleted */
204
205         t->last = n;
206         if (t->cmp(key, n->pair->k, t->cmp_aux) < 0)
207                 n->left = _tm_remove(t, n->left, key);
208         else
209         {
210                 t->deleted = n;
211                 n->right = _tm_remove(t, n->right, key);
212         }
213
214         /* 2: At the bottom of the tree, remove element if present */
215
216         if (n == t->last && t->deleted != t->bottom &&
217             t->cmp(key, t->deleted->pair->k, t->cmp_aux) == 0)
218         {
219                 if (t->key_dtor)
220                         t->key_dtor(t->deleted->pair->k, t->dtor_aux);
221                 if (t->val_dtor)
222                         t->val_dtor(t->deleted->pair->v, t->dtor_aux);
223
224                 *t->deleted->pair = *n->pair;
225                 t->deleted = t->bottom;
226                 n = n->right;
227
228                 free(t->last->pair);
229                 free(t->last);
230         } /* 3: on the way back up, rebalance */
231         else if (n->left->level  < n->level-1 ||
232                  n->right->level < n->level-1) {
233                 n->level--;
234                 if (n->right->level > n->level)
235                         n->right->level = n->level;
236                 n               = _tm_skew(n);
237                 n->right        = _tm_skew(n->right);
238                 n->right->right = _tm_skew(n->right->right);
239                 n               = _tm_split(n);
240                 n->right        = _tm_split(n->right);
241         }
242         return n;
243 }
244
245 bool
246 tm_remove(treemap *t, void *key)
247 {
248         if (!t)
249                 return false;
250         t->root = _tm_remove(t, t->root, key);
251         return true; // TODO: return false if key wasn't found
252 }
253
254 static void
255 _tm_clear(treemap *t, struct tm_node *n)
256 {
257         if (n == t->bottom)
258                 return;
259         _tm_clear(t, n->left);
260         _tm_clear(t, n->right);
261         if (t->key_dtor)
262                 t->key_dtor(n->pair->k, t->dtor_aux);
263         if (t->val_dtor)
264                 t->val_dtor(n->pair->v, t->dtor_aux);
265         free(n->pair);
266         free(n);
267 }
268
269 void
270 tm_clear(treemap *t)
271 {
272         if (!t)
273                 return;
274         _tm_clear(t, t->root);
275         t->root = t->deleted = t->last = t->bottom;
276 }
277
278 tm_iter *
279 tm_iter_begin(treemap *t)
280 {
281         if (!t)
282                 return NULL;
283         struct tm_iter *i = malloc(sizeof *i);
284         list *l = l_new();
285         if (!i || !l)
286         {
287                 free(i);
288                 l_free(l);
289                 return NULL;
290         }
291         *i = (struct tm_iter){
292                 .stack = l,
293                 .n = t->root,
294                 .bottom = t->bottom
295         };
296         return i;
297 }
298
299 struct map_pair *
300 tm_iter_next(tm_iter *i)
301 {
302         if (!i)
303                 return NULL;
304         if (l_is_empty(i->stack) && i->n == i->bottom)
305                 return NULL; /* done */
306         if (i->n != i->bottom)
307         {
308                 if (!l_append(i->stack, i->n))
309                         return NULL; /* OOM */
310                 i->n = i->n->left;
311                 return tm_iter_next(i);
312         }
313         struct tm_node *result = l_remove_last(i->stack);
314         i->n = result->right;
315         return result->pair;
316 }
317
318 void
319 tm_iter_free(tm_iter *i)
320 {
321         if (i)
322                 l_free(i->stack);
323         free(i);
324 }