]> begriffs open source - libderp/blob - src/treemap.c
Handy wrappers for std lib functions
[libderp] / src / treemap.c
1 #include "derp/list.h"
2 #include "derp/treemap.h"
3
4 #include <stdlib.h>
5
6 /* AA (Arne Andersson) Tree:
7  * Balanced Search Trees Made Simple
8  * https://user.it.uu.se/~arnea/ps/simp.pdf
9  *
10  * As fast as an RB-tree, but free of those
11  * ugly special cases. */
12
13 struct tm_node
14 {
15         int level;
16         struct map_pair *pair;
17         struct tm_node *left, *right;
18 };
19
20 struct treemap
21 {
22         struct tm_node *root, *bottom;
23         struct tm_node *deleted, *last;
24
25         dtor *key_dtor;
26         dtor *val_dtor;
27         comparator *cmp;
28         void *cmp_aux;
29         void *dtor_aux;
30 };
31
32 struct tm_iter
33 {
34         list *stack;
35         struct tm_node *n, *bottom;
36 };
37
38 treemap *
39 tm_new(comparator *cmp, void *cmp_aux)
40 {
41         treemap *t = malloc(sizeof *t);
42         struct tm_node *bottom = malloc(sizeof *bottom);
43         if (!t || !bottom)
44         {
45                 free(t);
46                 free(bottom);
47                 return NULL;
48         }
49         /* sentinel living below all leaves */
50         *bottom = (struct tm_node){
51                 .left = bottom, .right = bottom, .level = 0
52         };
53         *t = (treemap){
54                 .root = bottom,
55                 .bottom = bottom,
56                 .deleted = bottom,
57                 .last = bottom,
58                 .cmp = cmp,
59                 .cmp_aux = cmp_aux
60         };
61         return t;
62 }
63
64 void
65 tm_free(treemap *t)
66 {
67         if (!t)
68                 return;
69         tm_clear(t);
70         free(t->bottom);
71         free(t);
72 }
73
74 void
75 tm_dtor(treemap *t, dtor *key_dtor, dtor *val_dtor, void *dtor_aux)
76 {
77         if (!t)
78                 return;
79         t->key_dtor = key_dtor;
80         t->val_dtor = val_dtor;
81         t->dtor_aux = dtor_aux;
82 }
83
84 static size_t
85 internal_tm_length(const struct tm_node *n,
86                    const struct tm_node *bottom)
87 {
88         if (n == bottom)
89                 return 0;
90         return 1 +
91                 internal_tm_length(n->left, bottom) +
92                 internal_tm_length(n->right, bottom);
93 }
94
95 size_t
96 tm_length(const treemap *t)
97 {
98         return t ? internal_tm_length(t->root, t->bottom) : 0;
99 }
100
101 bool
102 tm_is_empty(const treemap *t)
103 {
104         return tm_length(t) == 0;
105 }
106
107 static void *
108 internal_tm_at(const treemap *t, const struct tm_node *n,
109                const void *key)
110 {
111         if (n == t->bottom)
112                 return NULL;
113         int x = t->cmp(key, n->pair->k, t->cmp_aux);
114         if (x == 0)
115                 return n->pair->v;
116         else if (x < 0)
117                 return internal_tm_at(t, n->left, key);
118         return internal_tm_at(t, n->right, key);
119 }
120
121 void *
122 tm_at(const treemap *t, const void *key)
123 {
124         return t ? internal_tm_at(t, t->root, key) : NULL;
125 }
126
127 static struct tm_node *
128 internal_tm_skew(struct tm_node *n) {
129         if (n->level != n->left->level)
130                 return n;
131         struct tm_node *left = n->left;
132         n->left = left->right;
133         left->right = n;
134         n = left;
135         return n;
136 }
137
138 static struct tm_node *
139 internal_tm_split(struct tm_node *n) {
140         if(n->right->right->level != n->level)
141                 return n;
142         struct tm_node *right = n->right;
143         n->right = right->left;
144         right->left = n;
145         n = right;
146         n->level++;
147         return n;
148 }
149
150 static struct tm_node *
151 internal_tm_insert(treemap *t, struct tm_node *n,
152                    struct tm_node *prealloc)
153 {
154         if (n == t->bottom)
155                 return prealloc;
156         int x = t->cmp(prealloc->pair->k, n->pair->k, t->cmp_aux);
157         if (x < 0)
158                 n->left = internal_tm_insert(t, n->left, prealloc);
159         else if (x > 0)
160                 n->right = internal_tm_insert(t, n->right, prealloc);
161         else
162         {
163                 /* prealloc was for naught, but we'll use its value */
164                 if (n->pair->v != prealloc->pair->v && t->val_dtor)
165                         t->val_dtor(n->pair->v, t->dtor_aux);
166                 if (n->pair->k != prealloc->pair->k && t->key_dtor)
167                         t->key_dtor(n->pair->k, t->dtor_aux);
168                 *n->pair = *prealloc->pair;
169                 free(prealloc->pair);
170                 free(prealloc);
171                 return n;
172         }
173         return internal_tm_split(internal_tm_skew(n));
174 }
175
176 bool
177 tm_insert(treemap *t, void *key, void *val)
178 {
179         if (!t)
180                 return false;
181         /* attempt the malloc before potentially splitting
182          * and skewing the tree, so the insertion can be a
183          * no-op on failure */
184         struct tm_node *prealloc = malloc(sizeof *prealloc);
185         struct map_pair *p = malloc(sizeof *p);
186         if (!prealloc || !p)
187         {
188                 free(prealloc);
189                 free(p);
190                 return false;
191         }
192         *p = (struct map_pair){.k = key, .v = val};
193         *prealloc = (struct tm_node){
194                 .level = 1, .pair = p, .left = t->bottom, .right = t->bottom
195         };
196
197         t->root = internal_tm_insert(t, t->root, prealloc);
198         return true;
199 }
200
201 static struct tm_node *
202 internal_tm_remove(treemap *t, struct tm_node *n, void *key)
203 {
204         if (n == t->bottom)
205                 return n;
206
207         /* 1: search down the tree and set pointers last and deleted */
208
209         t->last = n;
210         if (t->cmp(key, n->pair->k, t->cmp_aux) < 0)
211                 n->left = internal_tm_remove(t, n->left, key);
212         else
213         {
214                 t->deleted = n;
215                 n->right = internal_tm_remove(t, n->right, key);
216         }
217
218         /* 2: At the bottom of the tree, remove element if present */
219
220         if (n == t->last && t->deleted != t->bottom &&
221             t->cmp(key, t->deleted->pair->k, t->cmp_aux) == 0)
222         {
223                 if (t->key_dtor)
224                         t->key_dtor(t->deleted->pair->k, t->dtor_aux);
225                 if (t->val_dtor)
226                         t->val_dtor(t->deleted->pair->v, t->dtor_aux);
227
228                 *t->deleted->pair = *n->pair;
229                 t->deleted = t->bottom;
230                 n = n->right;
231
232                 free(t->last->pair);
233                 free(t->last);
234         } /* 3: on the way back up, rebalance */
235         else if (n->left->level  < n->level-1 ||
236                  n->right->level < n->level-1) {
237                 n->level--;
238                 if (n->right->level > n->level)
239                         n->right->level = n->level;
240                 n               = internal_tm_skew(n);
241                 n->right        = internal_tm_skew(n->right);
242                 n->right->right = internal_tm_skew(n->right->right);
243                 n               = internal_tm_split(n);
244                 n->right        = internal_tm_split(n->right);
245         }
246         return n;
247 }
248
249 bool
250 tm_remove(treemap *t, void *key)
251 {
252         if (!t)
253                 return false;
254         t->root = internal_tm_remove(t, t->root, key);
255         return true; // TODO: return false if key wasn't found
256 }
257
258 static void
259 internal_tm_clear(treemap *t, struct tm_node *n)
260 {
261         if (n == t->bottom)
262                 return;
263         internal_tm_clear(t, n->left);
264         internal_tm_clear(t, n->right);
265         if (t->key_dtor)
266                 t->key_dtor(n->pair->k, t->dtor_aux);
267         if (t->val_dtor)
268                 t->val_dtor(n->pair->v, t->dtor_aux);
269         free(n->pair);
270         free(n);
271 }
272
273 void
274 tm_clear(treemap *t)
275 {
276         if (!t)
277                 return;
278         internal_tm_clear(t, t->root);
279         t->root = t->deleted = t->last = t->bottom;
280 }
281
282 tm_iter *
283 tm_iter_begin(treemap *t)
284 {
285         if (!t)
286                 return NULL;
287         struct tm_iter *i = malloc(sizeof *i);
288         list *l = l_new();
289         if (!i || !l)
290         {
291                 free(i);
292                 l_free(l);
293                 return NULL;
294         }
295         *i = (struct tm_iter){
296                 .stack = l,
297                 .n = t->root,
298                 .bottom = t->bottom
299         };
300         return i;
301 }
302
303 struct map_pair *
304 tm_iter_next(tm_iter *i)
305 {
306         if (!i)
307                 return NULL;
308         if (l_is_empty(i->stack) && i->n == i->bottom)
309                 return NULL; /* done */
310         if (i->n != i->bottom)
311         {
312                 if (!l_append(i->stack, i->n))
313                         return NULL; /* OOM */
314                 i->n = i->n->left;
315                 return tm_iter_next(i);
316         }
317         struct tm_node *result = l_remove_last(i->stack);
318         i->n = result->right;
319         return result->pair;
320 }
321
322 void
323 tm_iter_free(tm_iter *i)
324 {
325         if (i)
326                 l_free(i->stack);
327         free(i);
328 }