1 module soa;
2 
3 import std.range : enumerate, isInputRange, take;
4 import std.traits : FieldNameTuple;
5 
6 /**
7  * Implement Struct Of Arrays from a struct type and array size.
8  *
9  * Inspired by code at https://github.com/economicmodeling/soa/blob/master/soa.d
10  *
11  * ---
12  * // Transforms a struct definition like this
13  * struct Vector2 {
14  *   float x = 0;
15  *   float y = 0;
16  * }
17  * Vector2[100] arrayOfStructs;
18  *
19  * // To a struct definition like this
20  * struct Vector2_SOA {
21  *   float[100] x = 0;
22  *   float[100] y = 0;
23  * }
24  * // alias Vector2_SOA = SOA!(Vector2, 100);
25  * Vector2_SOA structOfArrays;
26  * ---
27  *
28  * Provides a dispatching object for member access, comparison, assignment and
29  * others, and also provides a Random Access Finite Range of those, allowing
30  * seamless substitution of Array Of Structs and Struct Of Arrays types.
31  *
32  * ---
33  * SOA!(Vector2, 100) vectors;
34  * vectors[0].x = 10;
35  * assert(vectors[0].x == 10);
36  * assert(vectors[0].x == vectors.x[0]);
37  * vectors[1] = Vector2(2, 2);
38  * assert(vectors[1] == Vector2(2, 2));
39  *
40  * foreach(v; vectors[0 .. 2])
41  * {
42  *     import std.stdio;
43  *     writeln(v.x, " ", v.y);
44  * }
45  * ---
46  */
47 struct SOA(T, size_t N = 0)
48 if (is(T == struct))
49 {
50     alias ElementType = T;
51     alias Dispatcher = .Dispatcher!(T, N);
52     alias DispatcherRange = .DispatcherRange!(T, N);
53 
54     private enum usesStaticArrays = N > 0;
55 
56     // Generate one array for each field of `T` with the same name
57     static foreach (i, field; FieldNameTuple!T)
58     {
59         mixin("typeof(T." ~ field ~ ")[" ~ (usesStaticArrays ? N.stringof : "") ~ "] " ~ field ~ (usesStaticArrays ? " = T.init." ~ field : "") ~ ";\n");
60     }
61 
62     /// Construct SOA with initial elements copied from range.
63     this(R)(auto ref R range)
64     if (isInputRange!R)
65     {
66         this[] = range;
67     }
68 
69     @nogc @safe pure nothrow
70     {
71         /// Returns a Dispatcher object to the pseudo-indexed `T` instance.
72         inout(Dispatcher) opIndex(size_t index) inout
73         {
74             return typeof(return)(&this, index);
75         }
76 
77         /// Returns the full range of Dispatcher objects.
78         inout(DispatcherRange) opIndex() inout
79         {
80             return typeof(return)(&this, 0, length);
81         }
82 
83         /// Returns a range of Dispatcher objects.
84         inout(DispatcherRange) opSlice(size_t beginIndex, size_t pastTheEndIndex) inout
85         {
86             return typeof(return)(&this, beginIndex, pastTheEndIndex);
87         }
88     }
89 
90     static if (usesStaticArrays)
91     {
92         /// Length of the arrays.
93         enum length = N;
94     }
95     else
96     {
97         /// Length of the arrays, assumed to be the same between all of them.
98         @property size_t length() const @nogc @safe pure nothrow
99         {
100             return __traits(getMember, this, FieldNameTuple!T[0]).length;
101         }
102 
103         /// Concatenate a value.
104         void opOpAssign(string op : "~")(auto ref T value)
105         {
106             static foreach (i, field; FieldNameTuple!T)
107             {
108                 __traits(getMember, this, field) ~= __traits(getMember, value, field);
109             }
110         }
111 
112         /// Concatenate a value from Dispatcher.
113         void opOpAssign(string op : "~", size_t M)(auto ref Dispatcher!(T, M) dispatcher)
114         {
115             static foreach (i, field; FieldNameTuple!T)
116             {
117                 __traits(getMember, this, field) ~= __traits(getMember, dispatcher, field);
118             }
119         }
120 
121         /// Concatenate values from range.
122         void opOpAssign(string op : "~", R)(auto ref R range)
123         if (isInputRange!R)
124         {
125             foreach (v; range)
126             {
127                 this ~= v;
128             }
129         }
130 
131         invariant
132         {
133             size_t firstLength = __traits(getMember, this, FieldNameTuple!T[0]).length;
134             static foreach (i, field; FieldNameTuple!T[1 .. $])
135             {
136                 assert(__traits(getMember, this, field).length == firstLength);
137             }
138         }
139     }
140 
141     alias opDollar = length;
142 }
143 
144 /**
145  * Proxy object that makes it possible to use a SOA just as if it were an AOS.
146  */
147 private struct Dispatcher(T, size_t N)
148 {
149     SOA!(T, N)* soa;
150     size_t index;
151 
152     invariant
153     {
154         assert(index < soa.length, "Dispatcher index is out of bounds");
155     }
156 
157     /// Get a reference to a field by name
158     private auto ref getFieldRef(string field)() inout
159     {
160         return __traits(getMember, soa, field)[index];
161     }
162     /// Returns whether two instances of dispatcher are the same.
163     private bool isSame(size_t M)(auto ref Dispatcher!(T, M) other) const
164     {
165         static if (N == M)
166         {
167             return other is this;
168         }
169         else
170         {
171             return false;
172         }
173     }
174 
175     /// Get a reference to fields by name, dispatching to the right array at SOA instance.
176     auto ref opDispatch(string op)() inout
177     {
178         return getFieldRef!op;
179     }
180 
181     /// Assign values from a Dispatcher to another, copying each field by name to the right array.
182     void opAssign(size_t M)(auto ref Dispatcher!(T, M) other)
183     {
184         if (!isSame(other))
185         {
186             static foreach (i, field; FieldNameTuple!T)
187             {
188                 getFieldRef!field = other.getFieldRef!field;
189             }
190         }
191     }
192 
193     /// Assign values from an instance of `T`, copying each field by name to the right array.
194     void opAssign()(auto ref T value)
195     {
196         static foreach (i, field; FieldNameTuple!T)
197         {
198             getFieldRef!field = __traits(getMember, value, field);
199         }
200     }
201 
202     /// Compare for equality with another Dispatcher.
203     bool opEquals(size_t M)(auto ref Dispatcher!(T, M) other) const
204     {
205         if (isSame(other))
206         {
207             return true;
208         }
209         static foreach (i, field; FieldNameTuple!T)
210         {
211             if (other.getFieldRef!field != getFieldRef!field)
212             {
213                 return false;
214             }
215         }
216         return true;
217     }
218 
219     /// Compare for equality with another Dispatcher.
220     bool opEquals()(auto ref const T value) const
221     {
222         static foreach (i, field; FieldNameTuple!T)
223         {
224             if (__traits(getMember, value, field) != getFieldRef!field)
225             {
226                 return false;
227             }
228         }
229         return true;
230     }
231 
232     /// Pack `T` instance and cast to `U` if possible.
233     U opCast(U)() const
234     {
235         T value;
236         static foreach (i, field; FieldNameTuple!T)
237         {
238             __traits(getMember, value, field) = getFieldRef!field;
239         }
240         return cast(U) value;
241     }
242 }
243 
244 /**
245  * Random Access Finite Range of Dispatcher objects.
246  */
247 private struct DispatcherRange(T, size_t N)
248 {
249     SOA!(T, N)* soa;
250     size_t beginIndex;
251     size_t pastTheEndIndex;
252 
253     invariant
254     {
255         assert(beginIndex <= pastTheEndIndex);
256         assert(pastTheEndIndex <= soa.length, "DispatcherRange pastTheEndIndex is out of bounds");
257     }
258 
259     @nogc @safe pure nothrow
260     {
261         // Input Range
262         @property bool empty() const
263         {
264             return beginIndex >= pastTheEndIndex;
265         }
266 
267         auto front() inout
268         {
269             return this[0];
270         }
271 
272         void popFront()
273         {
274             beginIndex++;
275         }
276 
277         // Forward Range
278         inout(DispatcherRange) save() inout
279         {
280             return this;
281         }
282 
283         // Bidirectional Range
284         auto back() inout
285         {
286             return this[$ - 1];
287         }
288 
289         void popBack()
290         {
291             pastTheEndIndex--;
292         }
293 
294         // Random Access Finite Range
295         inout(Dispatcher!(T, N)) opIndex(size_t index) inout
296         {
297             return typeof(return)(soa, beginIndex + index);
298         }
299 
300         /// Returns a subrange.
301         inout(DispatcherRange) opSlice(size_t beginIndex, size_t pastTheEndIndex) inout
302         in { assert(beginIndex <= pastTheEndIndex); }
303         do
304         {
305             return typeof(return)(soa, this.beginIndex + beginIndex, this.beginIndex + pastTheEndIndex);
306         }
307 
308         /// Returns the range length.
309         @property size_t length() const
310         {
311             return pastTheEndIndex - beginIndex;
312         }
313 
314         alias opDollar = length;
315     }
316 
317     // Assignments
318     void opAssign()(auto ref T value)
319     {
320         foreach (i; 0 .. length)
321         {
322             this[i] = value;
323         }
324     }
325 
326     void opAssign(size_t M)(auto ref Dispatcher!(T, M) dispatcher)
327     {
328         foreach (i; 0 .. length)
329         {
330             this[i] = dispatcher;
331         }
332     }
333 
334     void opAssign(R)(auto ref R range)
335     if (isInputRange!R)
336     {
337         size_t i = 0;
338         foreach (v; range.take(length))
339         {
340             this[i] = v;
341             i++;
342         }
343     }
344 }
345 
346 
347 unittest
348 {
349     struct Color
350     {
351         float r = 1;
352         float g = 1;
353         float b = 1;
354         float a = 1;
355 
356         enum red = Color(1, 0, 0, 1);
357         enum green = Color(0, 1, 0, 1);
358         enum blue = Color(0, 0, 1, 1);
359         enum black = Color(0, 0, 0, 1);
360         enum white = Color(1, 1, 1, 1);
361     }
362     
363     alias Color16 = SOA!(Color, 16);
364     assert(Color16.sizeof == (Color[16]).sizeof);
365     assert(Color16.sizeof == 16 * Color.sizeof);
366     assert(is(typeof(Color16.r) == float[16]));
367     assert(is(typeof(Color16.g) == float[16]));
368     assert(is(typeof(Color16.b) == float[16]));
369     assert(is(typeof(Color16.a) == float[16]));
370     assert(Color16.init.r[0] is Color.init.r);
371     assert(Color16.init.g[0] is Color.init.g);
372     assert(Color16.init.b[0] is Color.init.b);
373     assert(Color16.init.a[0] is Color.init.a);
374 
375     Color16 colors;
376     assert(colors[0].r is Color.init.r);
377     colors[0].r = 5;
378     assert(colors[0].r is 5);
379     assert(colors[0] == colors[0]);
380     assert(colors[0] != colors[1]);
381     assert(colors[1] == colors[2]);
382     assert(colors[1] == Color.init);
383 
384     colors[0] = Color.white;
385     assert(colors[0] == Color.white);
386 
387     colors[3] = Color.red;
388     colors[2] = colors[3];
389     assert(colors[2] == Color.red);
390     assert(colors.r[2] == colors[2].r);
391     assert(colors.g[2] == colors[2].g);
392     assert(colors.b[2] == colors[2].b);
393     assert(colors.a[2] == colors[2].a);
394 
395     Color c2 = cast(Color) colors[2];
396     assert(c2 == Color.red);
397 
398     alias Color8 = SOA!(Color, 8);
399     Color8 otherColors;
400     otherColors[0] = colors[2];
401     assert(otherColors[0] == Color.red);
402     assert(otherColors[0] == colors[3]);
403 
404     // construction from range
405     import std.algorithm : map;
406     import std.range : enumerate, iota;
407     auto alphaGradient = Color8(iota(42).map!(i => Color(1, 1, 1, i / 8f)));
408     foreach (i, c; alphaGradient[].enumerate)
409     {
410         assert(c.a == i / 8f);
411     }
412 
413     alphaGradient = Color8(iota(5).map!(i => Color(1, 1, 1, i / 8f)));
414     foreach (i, c; alphaGradient[].enumerate)
415     {
416         if (i < 5)
417         {
418             assert(c.a == i / 8f);
419         }
420         else
421         {
422             assert(c == Color.init);
423         }
424     }
425 
426     // assignment from range
427     otherColors = Color8([Color.red, Color.blue, Color.black]);
428     otherColors[3 .. $] = alphaGradient[];
429     assert(otherColors[0] == Color.red);
430     assert(otherColors[1] == Color.blue);
431     assert(otherColors[2] == Color.black);
432     assert(otherColors[3] == alphaGradient[0]);
433     assert(otherColors[4] == alphaGradient[1]);
434     assert(otherColors[5] == alphaGradient[2]);
435     assert(otherColors[6] == alphaGradient[3]);
436     assert(otherColors[7] == alphaGradient[4]);
437 
438     // assignment from Color
439     otherColors[$-2 .. $] = Color.green;
440     assert(otherColors[$-2] == Color.green);
441     assert(otherColors[$-1] == Color.green);
442 
443     // assignment from Dispatcher
444     otherColors[$-2 .. $] = otherColors[0];
445     assert(otherColors[$-2] == Color.red);
446     assert(otherColors[$-1] == Color.red);
447 
448     // assignment from DispatcherRange
449     otherColors[$-2 .. $] = otherColors[1 .. 3];
450     assert(otherColors[$-2] == Color.blue);
451     assert(otherColors[$-1] == Color.black);
452 
453     alias SeveralColors = SOA!(Color);
454     SeveralColors several;
455     several.r = new float[5];
456     several.g = new float[5];
457     several.b = new float[5];
458     several.a = new float[5];
459 
460     several[0] = Color.red;
461     assert(several[0] == Color.red);
462     assert(several[1] != Color.init);  // arrays were not initialized to Color.init values
463 
464     import std.range : repeat;
465     several ~= repeat(Color.red, 3);
466     assert(several.length == 8);
467 
468     destroy(several.a);
469     destroy(several.b);
470     destroy(several.g);
471     destroy(several.r);
472 }
473 
474 unittest
475 {
476     import soa;
477 
478     // Transforms a struct definition like this
479     struct Vector2
480     {
481         float x = 0;
482         float y = 0;
483     }
484     Vector2[100] arrayOfStructs;
485     
486     // To a struct definition like this
487     struct Vector2_SOA
488     {
489         float[100] x = 0;
490         float[100] y = 0;
491     }
492     // alias Vector2_SOA = SOA!(Vector2, 100);
493     Vector2_SOA structOfArrays;
494 
495     SOA!(Vector2, 100) vectors;
496     // Assignment with object type
497     vectors[0] = Vector2(10, 0);
498     // Dispatcher object handles indexing the right arrays
499     assert(vectors[0].x == 10);
500     assert(vectors[0].y == 0);
501     assert(vectors[0].x == vectors.x[0]);
502     assert(vectors[0].y == vectors.y[0]);
503     // Slicing works, including assignment with single value or Range
504     vectors[1 .. 3] = Vector2(2, 2);
505     assert(vectors[1] == Vector2(2, 2));
506     assert(vectors[2] == Vector2(2, 2));
507 
508     // Also does other Range functionality
509     import std.stdio : writeln;
510     import std.range : retro;
511     foreach(v; vectors[0 .. 2].retro)
512     {
513         writeln("[", v.x, ", ", v.y, "]");
514     }
515 
516     // It is possible to also use dynamic arrays, but they must be provided or
517     // grown manually. All arrays must have the same length (SOA with dynamic
518     // arrays have an `invariant` block with this condition)
519     SOA!(Vector2) dynamicVectors;
520     dynamicVectors.x = new float[5];
521     dynamicVectors.y = new float[5];
522     scope (exit)
523     {
524         // In this case arrays were created with `new´, so destroy them afterwards
525         destroy(dynamicVectors.y);
526         destroy(dynamicVectors.x);
527     }
528     assert(dynamicVectors.length == 5);
529 
530     import std.algorithm : map;
531     import std.range : iota, enumerate;
532     dynamicVectors[] = iota(5).map!(x => Vector2(x, 0));
533     foreach (i, v; dynamicVectors[].enumerate)
534     {
535         assert(v == Vector2(i, 0));
536     }
537 
538     // In-place concatenate operator is available, although not available in betterC
539     dynamicVectors ~= Vector2(5, 0);
540     foreach (i, v; dynamicVectors[].enumerate)
541     {
542         assert(v == Vector2(i, 0));
543     }
544     assert(dynamicVectors.length == 6);
545 }