1 /**
2  * Associative Array implementation
3  *
4  * Compiler implementation of the
5  * $(LINK2 https://www.dlang.org, D programming language).
6  *
7  * Copyright:   Copyright (C) 2000-2023 by The D Language Foundation, All Rights Reserved
8  * Authors:     $(LINK2 https://www.digitalmars.com, Walter Bright), Dave Fladebo
9  * License:     Distributed under the Boost Software License, Version 1.0.
10  *              https://www.boost.org/LICENSE_1_0.txt
11  * Source:      https://github.com/dlang/dmd/blob/master/src/dmd/backend/aarray.d
12  */
13 
14 module dmd.backend.aarray;
15 
16 import core.stdc.stdio;
17 import core.stdc.stdlib;
18 import core.stdc.string;
19 
20 alias hash_t = size_t;
21 
22 version (MARS)
23 {
24     import dmd.root.hash;
25     import dmd.backend.global : err_nomem;
26 }
27 
28 nothrow:
29 @safe:
30 
31 /*********************
32  * This is the "bucket" used by the AArray.
33  */
34 private struct aaA
35 {
36     aaA *next;
37     hash_t hash;        // hash of the key
38     /* key   */         // key value goes here
39     /* value */         // value value goes here
40 }
41 
42 /**************************
43  * Associative Array type.
44  * Params:
45  *      TKey = type that has members Key, getHash(), and equals()
46  *      Value = value type
47  */
48 
49 struct AArray(TKey, Value)
50 {
51 nothrow:
52     alias Key = TKey.Key;       // key type
53 
54     ~this()
55     {
56         destroy();
57     }
58 
59     /****
60      * Frees all the data used by AArray
61      */
62     @trusted
63     void destroy()
64     {
65         if (buckets)
66         {
67             foreach (e; buckets)
68             {
69                 while (e)
70                 {
71                     auto en = e;
72                     e = e.next;
73                     free(en);
74                 }
75             }
76             free(buckets.ptr);
77             buckets = null;
78             nodes = 0;
79         }
80     }
81 
82     /********
83      * Returns:
84      *   Number of entries in the AArray
85      */
86     size_t length()
87     {
88         return nodes;
89     }
90 
91     /*************************************************
92      * Get pointer to value in associative array indexed by key.
93      * Add entry for key if it is not already there.
94      * Params:
95      *  pKey = pointer to key
96      * Returns:
97      *  pointer to Value
98      */
99     @trusted
100     Value* get(Key* pkey)
101     {
102         //printf("AArray::get()\n");
103         const aligned_keysize = aligntsize(Key.sizeof);
104 
105         if (!buckets.length)
106         {
107             alias aaAp = aaA*;
108             const len = prime_list[0];
109             auto p = cast(aaAp*)calloc(len, aaAp.sizeof);
110             if (!p)
111                 err_nomem();
112             buckets = p[0 .. len];
113         }
114 
115         hash_t key_hash = tkey.getHash(pkey);
116         const i = key_hash % buckets.length;
117         //printf("key_hash = %x, buckets.length = %d, i = %d\n", key_hash, buckets.length, i);
118         aaA* e;
119         auto pe = &buckets[i];
120         while ((e = *pe) != null)
121         {
122             if (key_hash == e.hash &&
123                 tkey.equals(pkey, cast(Key*)(e + 1)))
124             {
125                 goto Lret;
126             }
127             pe = &e.next;
128         }
129 
130         // Not found, create new elem
131         //printf("create new one\n");
132         e = cast(aaA *) malloc(aaA.sizeof + aligned_keysize + Value.sizeof);
133         if (!e)
134             err_nomem();
135         memcpy(e + 1, pkey, Key.sizeof);
136         memset(cast(void *)(e + 1) + aligned_keysize, 0, Value.sizeof);
137         e.hash = key_hash;
138         e.next = null;
139         *pe = e;
140 
141         ++nodes;
142         //printf("length = %d, nodes = %d\n", buckets_length, nodes);
143         if (nodes > buckets.length * 4)
144         {
145             //printf("rehash()\n");
146             rehash();
147         }
148 
149     Lret:
150         return cast(Value*)(cast(void*)(e + 1) + aligned_keysize);
151     }
152 
153     /*************************************************
154      * Determine if key is in aa.
155      * Params:
156      *  pKey = pointer to key
157      * Returns:
158      *  null    not in aa
159      *  !=null  in aa, return pointer to value
160      */
161 
162     @trusted
163     Value* isIn(Key* pkey)
164     {
165         //printf("AArray.isIn(), .length = %d, .ptr = %p\n", nodes, buckets.ptr);
166         if (!nodes)
167             return null;
168 
169         const key_hash = tkey.getHash(pkey);
170         //printf("hash = %d\n", key_hash);
171         const i = key_hash % buckets.length;
172         auto e = buckets[i];
173         while (e != null)
174         {
175             if (key_hash == e.hash &&
176                 tkey.equals(pkey, cast(Key*)(e + 1)))
177             {
178                 return cast(Value*)(cast(void*)(e + 1) + aligntsize(Key.sizeof));
179             }
180 
181             e = e.next;
182         }
183 
184         // Not found
185         return null;
186     }
187 
188 
189     /*************************************************
190      * Delete key entry in aa[].
191      * If key is not in aa[], do nothing.
192      * Params:
193      *  pKey = pointer to key
194      */
195 
196     @trusted
197     void del(Key *pkey)
198     {
199         if (!nodes)
200             return;
201 
202         const key_hash = tkey.getHash(pkey);
203         //printf("hash = %d\n", key_hash);
204         const i = key_hash % buckets.length;
205         auto pe = &buckets[i];
206         aaA* e;
207         while ((e = *pe) != null)       // null means not found
208         {
209             if (key_hash == e.hash &&
210                 tkey.equals(pkey, cast(Key*)(e + 1)))
211             {
212                 *pe = e.next;
213                 --nodes;
214                 free(e);
215                 break;
216             }
217             pe = &e.next;
218         }
219     }
220 
221 
222     /********************************************
223      * Produce array of keys from aa.
224      * Returns:
225      *  malloc'd array of keys
226      */
227 
228     @trusted
229     Key[] keys()
230     {
231         if (!nodes)
232             return null;
233 
234         if (nodes >= size_t.max / Key.sizeof)
235             err_nomem();
236         auto p = cast(Key *)malloc(nodes * Key.sizeof);
237         if (!p)
238             err_nomem();
239         auto q = p;
240         foreach (e; buckets)
241         {
242             while (e)
243             {
244                 memcpy(q, e + 1, Key.sizeof);
245                 ++q;
246                 e = e.next;
247             }
248         }
249         return p[0 .. nodes];
250     }
251 
252     /********************************************
253      * Produce array of values from aa.
254      * Returns:
255      *  malloc'd array of values
256      */
257 
258     @trusted
259     Value[] values()
260     {
261         if (!nodes)
262             return null;
263 
264         const aligned_keysize = aligntsize(Key.sizeof);
265         if (nodes >= size_t.max / Key.sizeof)
266             err_nomem();
267         auto p = cast(Value *)malloc(nodes * Value.sizeof);
268         if (!p)
269             err_nomem();
270         auto q = p;
271         foreach (e; buckets)
272         {
273             while (e)
274             {
275                 memcpy(q, cast(void*)(e + 1) + aligned_keysize, Value.sizeof);
276                 ++q;
277                 e = e.next;
278             }
279         }
280         return p[0 .. nodes];
281     }
282 
283     /********************************************
284      * Rehash an array.
285      */
286 
287     @trusted
288     void rehash()
289     {
290         //printf("Rehash\n");
291         if (!nodes)
292             return;
293 
294         size_t newbuckets_length = prime_list[$ - 1];
295 
296         foreach (prime; prime_list[0 .. $ - 1])
297         {
298             if (nodes <= prime)
299             {
300                 newbuckets_length = prime;
301                 break;
302             }
303         }
304         auto newbuckets = cast(aaA**)calloc(newbuckets_length, (aaA*).sizeof);
305         if (!newbuckets)
306             err_nomem();
307 
308         foreach (e; buckets)
309         {
310             while (e)
311             {
312                 auto en = e.next;
313                 auto b = &newbuckets[e.hash % newbuckets_length];
314                 e.next = *b;
315                 *b = e;
316                 e = en;
317             }
318         }
319 
320         free(buckets.ptr);
321         buckets = null;
322         buckets = newbuckets[0 .. newbuckets_length];
323     }
324 
325     alias applyDg = nothrow int delegate(Key*, Value*);
326     /*********************************************
327      * For each element in the AArray,
328      * call dg(Key* pkey, Value* pvalue)
329      * If dg returns !=0, stop and return that value.
330      * Params:
331      *  dg = delegate to call for each key/value pair
332      * Returns:
333      *  !=0 : value returned by first dg() call that returned non-zero
334      *  0   : no entries in aa, or all dg() calls returned 0
335      */
336 
337     @trusted
338     int apply(applyDg dg)
339     {
340         if (!nodes)
341             return 0;
342 
343         //printf("AArray.apply(aa = %p, keysize = %d, dg = %p)\n", &this, Key.sizeof, dg);
344 
345         const aligned_keysize = aligntsize(Key.sizeof);
346 
347         foreach (e; buckets)
348         {
349             while (e)
350             {
351                 auto result = dg(cast(Key*)(e + 1), cast(Value*)(cast(void*)(e + 1) + aligned_keysize));
352                 if (result)
353                     return result;
354                 e = e.next;
355             }
356         }
357 
358         return 0;
359     }
360 
361   private:
362 
363     aaA*[] buckets;
364     size_t nodes;               // number of nodes
365     TKey tkey;
366 }
367 
368 private:
369 
370 /**********************************
371  * Align to next pointer boundary, so value
372  * will be aligned.
373  * Params:
374  *      tsize = offset to be aligned
375  * Returns:
376  *      aligned offset
377  */
378 
379 size_t aligntsize(size_t tsize)
380 {
381     // Is pointer alignment on the x64 4 bytes or 8?
382     return (tsize + size_t.sizeof - 1) & ~(size_t.sizeof - 1);
383 }
384 
385 immutable uint[14] prime_list =
386 [
387                97,           389,
388              1543,          6151,
389            24_593,        98_317,
390           393_241,     1_572_869,
391         6_291_469,    25_165_843,
392       100_663_319,   402_653_189,
393     1_610_612_741, 4_294_967_291U,
394 ];
395 
396 /***************************************************************/
397 
398 /***
399  * A TKey for basic types
400  * Params:
401  *      K = a basic type
402  */
403 public struct Tinfo(K)
404 {
405 nothrow:
406     alias Key = K;
407 
408     static hash_t getHash(Key* pk)
409     {
410         return cast(hash_t)*pk;
411     }
412 
413     static bool equals(Key* pk1, Key* pk2)
414     {
415         return *pk1 == *pk2;
416     }
417 }
418 
419 /***************************************************************/
420 
421 /****
422  * A TKey that is a string
423  */
424 public struct TinfoChars
425 {
426 nothrow:
427     alias Key = const(char)[];
428 
429     static hash_t getHash(Key* pk)
430     {
431         version (MARS)
432         {
433             auto buf = *pk;
434             return calcHash(cast(const(ubyte[]))buf);
435         }
436         else
437         {
438             auto buf = *pk;
439             hash_t hash = 0;
440             foreach (v; buf)
441                 hash = hash * 11 + v;
442             return hash;
443         }
444     }
445 
446     @trusted
447     static bool equals(Key* pk1, Key* pk2)
448     {
449         auto buf1 = *pk1;
450         auto buf2 = *pk2;
451         return buf1.length == buf2.length &&
452                memcmp(buf1.ptr, buf2.ptr, buf1.length) == 0;
453     }
454 }
455 
456 // Interface for C++ code
457 public extern (C++) struct AAchars
458 {
459 nothrow:
460     alias AA = AArray!(TinfoChars, uint);
461     AA aa;
462 
463     @trusted
464     static AAchars* create()
465     {
466         auto a = cast(AAchars*)calloc(1, AAchars.sizeof);
467         if (!a)
468             err_nomem();
469         return a;
470     }
471 
472     @trusted
473     static void destroy(AAchars* aac)
474     {
475         aac.aa.destroy();
476         free(aac);
477     }
478 
479     @trusted
480     extern(D) uint* get(const(char)[] buf)
481     {
482         return aa.get(&buf);
483     }
484 
485     uint length()
486     {
487         return cast(uint)aa.length();
488     }
489 }
490 
491 /***************************************************************/
492 
493 // Key is the slice specified by (*TinfoPair.pbase)[Pair.start .. Pair.end]
494 
495 public struct Pair { uint start, end; }
496 
497 public struct TinfoPair
498 {
499 nothrow:
500     alias Key = Pair;
501 
502     ubyte** pbase;
503 
504     @trusted
505     hash_t getHash(Key* pk)
506     {
507         version (MARS)
508         {
509             auto buf = (*pbase)[pk.start .. pk.end];
510             return calcHash(buf);
511         }
512         else
513         {
514             auto buf = (*pbase)[pk.start .. pk.end];
515             hash_t hash = 0;
516             foreach (v; buf)
517                 hash = hash * 11 + v;
518             return hash;
519         }
520     }
521 
522     @trusted
523     bool equals(Key* pk1, Key* pk2)
524     {
525         const len1 = pk1.end - pk1.start;
526         const len2 = pk2.end - pk2.start;
527 
528         auto buf1 = *pk1;
529         auto buf2 = *pk2;
530         return len1 == len2 &&
531                memcmp(*pbase + pk1.start, *pbase + pk2.start, len1) == 0;
532     }
533 }
534 
535 // Interface for C++ code
536 public extern (C++) struct AApair
537 {
538 nothrow:
539     alias AA = AArray!(TinfoPair, uint);
540     AA aa;
541 
542     @trusted
543     static AApair* create(ubyte** pbase)
544     {
545         auto a = cast(AApair*)calloc(1, AApair.sizeof);
546         if (!a)
547             err_nomem();
548         a.aa.tkey.pbase = pbase;
549         return a;
550     }
551 
552     @trusted
553     static void destroy(AApair* aap)
554     {
555         aap.aa.destroy();
556         free(aap);
557     }
558 
559     @trusted
560     uint* get(uint start, uint end)
561     {
562         auto p = Pair(start, end);
563         return aa.get(&p);
564     }
565 
566     uint length()
567     {
568         return cast(uint)aa.length();
569     }
570 }
571 
572 // Interface for C++ code
573 public extern (C++) struct AApair2
574 {
575 nothrow:
576     alias AA = AArray!(TinfoPair, Pair);
577     AA aa;
578 
579     @trusted
580     static AApair2* create(ubyte** pbase)
581     {
582         auto a = cast(AApair2*)calloc(1, AApair2.sizeof);
583         if (!a)
584             err_nomem();
585         a.aa.tkey.pbase = pbase;
586         return a;
587     }
588 
589     @trusted
590     static void destroy(AApair2* aap)
591     {
592         aap.aa.destroy();
593         free(aap);
594     }
595 
596     @trusted
597     Pair* get(uint start, uint end)
598     {
599         auto p = Pair(start, end);
600         return aa.get(&p);
601     }
602 
603     uint length()
604     {
605         return cast(uint)aa.length();
606     }
607 }
608 
609 /*************************************************************/
610 
611 @system unittest
612 {
613     int dg(int* pk, bool* pv) { return 3; }
614     int dgz(int* pk, bool* pv) { return 0; }
615 
616     AArray!(Tinfo!int, bool) aa;
617     aa.rehash();
618     assert(aa.keys() == null);
619     assert(aa.values() == null);
620     assert(aa.apply(&dg) == 0);
621 
622     assert(aa.length == 0);
623     int k = 8;
624     aa.del(&k);
625     bool v = true;
626     assert(!aa.isIn(&k));
627     bool *pv = aa.get(&k);
628     *pv = true;
629     int j = 9;
630     pv = aa.get(&j);
631     *pv = false;
632     aa.rehash();
633 
634     assert(aa.length() == 2);
635     assert(*aa.get(&k) == true);
636     assert(*aa.get(&j) == false);
637 
638     assert(aa.apply(&dg) == 3);
639     assert(aa.apply(&dgz) == 0);
640 
641     aa.del(&k);
642     assert(aa.length() == 1);
643     assert(!aa.isIn(&k));
644     assert(*aa.isIn(&j) == false);
645 
646     auto keys = aa.keys();
647     assert(keys.length == 1);
648     assert(keys[0] == 9);
649 
650     auto values = aa.values();
651     assert(values.length == 1);
652     assert(values[0] == false);
653 
654     AArray!(Tinfo!int, bool) aa2;
655     int key = 10;
656     bool* getpv = aa2.get(&key);
657     aa2.apply(delegate(int* pk, bool* pv) @trusted {
658         assert(pv is getpv);
659         return 0;
660     });
661 }
662 
663 @system unittest
664 {
665     const(char)* buf = "abcb";
666     auto aap = AApair.create(cast(ubyte**)&buf);
667     auto pu = aap.get(1,2);
668     *pu = 10;
669     assert(aap.length == 1);
670     pu = aap.get(3,4);
671     assert(*pu == 10);
672     AApair.destroy(aap);
673 }