]> begriffs open source - libderp/blob - src/treemap.c
Adequate treemap test coverage
[libderp] / src / treemap.c
1 #include "treemap.h"
2
3 #include <stdlib.h>
4
5 /* AA (Arne Andersson) Tree:
6  * Balanced Search Trees Made Simple
7  * https://user.it.uu.se/~arnea/ps/simp.pdf
8  *
9  * As fast as an RB-tree, but free of those
10  * ugly special cases. */
11
12 struct tm_node
13 {
14         int level;
15         struct map_pair *pair;
16         struct tm_node *left, *right;
17 };
18
19 struct treemap
20 {
21         struct tm_node *root, *bottom;
22         struct tm_node *deleted, *last;
23
24         dtor *key_dtor;
25         dtor *val_dtor;
26         comparator *cmp;
27         void *cmp_aux;
28         void *dtor_aux;
29 };
30
31 treemap *
32 tm_new(comparator *cmp, void *cmp_aux)
33 {
34         treemap *t = malloc(sizeof *t);
35         struct tm_node *bottom = malloc(sizeof *bottom);
36         if (!t || !bottom)
37         {
38                 free(t);
39                 free(bottom);
40                 return NULL;
41         }
42         /* sentinel living below all leaves */
43         *bottom = (struct tm_node){
44                 .left = bottom, .right = bottom, .level = 0
45         };
46         *t = (treemap){
47                 .root = bottom,
48                 .bottom = bottom,
49                 .deleted = bottom,
50                 .last = bottom,
51                 .cmp = cmp,
52                 .cmp_aux = cmp_aux
53         };
54         return t;
55 }
56
57 void
58 tm_free(treemap *t)
59 {
60         if (!t)
61                 return;
62         tm_clear(t);
63         free(t->bottom);
64         free(t);
65 }
66
67 void
68 tm_dtor(treemap *t, dtor *key_dtor, dtor *val_dtor, void *dtor_aux)
69 {
70         if (!t)
71                 return;
72         t->key_dtor = key_dtor;
73         t->val_dtor = val_dtor;
74         t->dtor_aux = dtor_aux;
75 }
76
77 static size_t
78 _tm_length(const struct tm_node *n, const struct tm_node *bottom)
79 {
80         if (n == bottom)
81                 return 0;
82         return 1 +
83                 _tm_length(n->left, bottom) + _tm_length(n->right, bottom);
84 }
85
86 size_t
87 tm_length(const treemap *t)
88 {
89         return t ? _tm_length(t->root, t->bottom) : 0;
90 }
91
92 bool
93 tm_is_empty(const treemap *t)
94 {
95         return tm_length(t) == 0;
96 }
97
98 static void *
99 _tm_at(const treemap *t, const struct tm_node *n, const void *key)
100 {
101         if (n == t->bottom)
102                 return NULL;
103         int x = t->cmp(key, n->pair->k, t->cmp_aux);
104         if (x == 0)
105                 return n->pair->v;
106         else if (x < 0)
107                 return _tm_at(t, n->left, key);
108         return _tm_at(t, n->right, key);
109 }
110
111 void *
112 tm_at(const treemap *t, const void *key)
113 {
114         return t ? _tm_at(t, t->root, key) : NULL;
115 }
116
117 static struct tm_node *
118 _tm_skew(struct tm_node *n) {
119    if (n->level != n->left->level)
120            return n;
121    struct tm_node *left = n->left;
122    n->left = left->right;
123    left->right = n;
124    n = left;
125    return n;
126 }
127
128 static struct tm_node *
129 _tm_split(struct tm_node *n) {
130    if(n->right->right->level != n->level)
131            return n;
132    struct tm_node *right = n->right;
133    n->right = right->left;
134    right->left = n;
135    n = right;
136    n->level++;
137    return n;
138 }
139
140 static struct tm_node *
141 _tm_insert(treemap *t, struct tm_node *n, struct tm_node *prealloc)
142 {
143         if (n == t->bottom)
144                 return prealloc;
145         int x = t->cmp(prealloc->pair->k, n->pair->k, t->cmp_aux);
146         if (x < 0)
147                 n->left = _tm_insert(t, n->left, prealloc);
148         else if (x > 0)
149                 n->right = _tm_insert(t, n->right, prealloc);
150         else
151         {
152                 /* prealloc was for naught, but we'll use its value */
153                 if (n->pair->v != prealloc->pair->v && t->val_dtor)
154                         t->val_dtor(n->pair->v, t->dtor_aux);
155                 if (n->pair->k != prealloc->pair->k && t->key_dtor)
156                         t->key_dtor(n->pair->k, t->dtor_aux);
157                 *n->pair = *prealloc->pair;
158                 free(prealloc->pair);
159                 free(prealloc);
160                 return n;
161         }
162         return _tm_split(_tm_skew(n));
163 }
164
165 bool
166 tm_insert(treemap *t, void *key, void *val)
167 {
168         if (!t)
169                 return false;
170         /* attempt the malloc before potentially splitting
171          * and skewing the tree, so the insertion can be a
172          * no-op on failure */
173         struct tm_node *prealloc = malloc(sizeof *prealloc);
174         struct map_pair *p = malloc(sizeof *p);
175         if (!prealloc || !p)
176         {
177                 free(prealloc);
178                 free(p);
179                 return false;
180         }
181         *p = (struct map_pair){.k = key, .v = val};
182         *prealloc = (struct tm_node){
183                 .level = 1, .pair = p, .left = t->bottom, .right = t->bottom
184         };
185
186         t->root = _tm_insert(t, t->root, prealloc);
187         return true;
188 }
189
190 static struct tm_node *
191 _tm_remove(treemap *t, struct tm_node *n, void *key)
192 {
193         if (n == t->bottom)
194                 return n;
195
196         /* 1: search down the tree and set pointers last and deleted */
197
198         t->last = n;
199         if (t->cmp(key, n->pair->k, t->cmp_aux) < 0)
200                 n->left = _tm_remove(t, n->left, key);
201         else
202         {
203                 t->deleted = n;
204                 n->right = _tm_remove(t, n->right, key);
205         }
206
207         /* 2: At the bottom of the tree, remove element if present */
208
209         if (n == t->last && t->deleted != t->bottom &&
210             t->cmp(key, t->deleted->pair->k, t->cmp_aux) == 0)
211         {
212                 if (t->key_dtor)
213                         t->key_dtor(t->deleted->pair->k, t->dtor_aux);
214                 if (t->val_dtor)
215                         t->val_dtor(t->deleted->pair->v, t->dtor_aux);
216
217                 *t->deleted->pair = *n->pair;
218                 t->deleted = t->bottom;
219                 n = n->right;
220
221                 free(t->last->pair);
222                 free(t->last);
223         } /* 3: on the way back up, rebalance */
224         else if (n->left->level  < n->level-1 ||
225                  n->right->level < n->level-1) {
226                 n->level--;
227                 if (n->right->level > n->level)
228                         n->right->level = n->level;
229                 n               = _tm_skew(n);
230                 n->right        = _tm_skew(n->right);
231                 n->right->right = _tm_skew(n->right->right);
232                 n               = _tm_split(n);
233                 n->right        = _tm_split(n->right);
234         }
235         return n;
236 }
237
238 bool
239 tm_remove(treemap *t, void *key)
240 {
241         if (!t)
242                 return false;
243         t->root = _tm_remove(t, t->root, key);
244         return true; // TODO: return false if key wasn't found
245 }
246
247 static void
248 _tm_clear(treemap *t, struct tm_node *n)
249 {
250         if (n == t->bottom)
251                 return;
252         _tm_clear(t, n->left);
253         _tm_clear(t, n->right);
254         if (t->key_dtor)
255                 t->key_dtor(n->pair->k, t->dtor_aux);
256         if (t->val_dtor)
257                 t->val_dtor(n->pair->v, t->dtor_aux);
258         free(n->pair);
259         free(n);
260 }
261
262 void
263 tm_clear(treemap *t)
264 {
265         if (!t)
266                 return;
267         _tm_clear(t, t->root);
268         t->root = t->deleted = t->last = t->bottom;
269 }