1 
2 module markov.counter;
3 
4 import std.algorithm;
5 import std.array;
6 import std.random;
7 import std.traits;
8 import std.typecons;
9 
10 /++
11  + Represents a set of counters for trailing (following) tokens in a markov state.
12  ++/
13 struct Counter(T)
14 {
15 private:
16     uint[Key] _counts;
17     uint _total;
18 
19     /++
20      + Wraps a token, providing normalized hashing and abstracting type qualifiers.
21      ++/
22     struct Key
23     {
24         T _key;
25 
26         /++
27          + Returns the natural value of the token.
28          ++/
29         @property
30         T value()
31         {
32             return _key;
33         }
34 
35         /++
36          + Returns the hash for a token.
37          ++/
38         hash_t toHash() const nothrow @safe
39         {
40             static if(__traits(compiles, {
41                 T key = void;
42                 hash_t hash = key.toHash;
43             }))
44             {
45                 return _key.toHash;
46             }
47             else
48             {
49                 return typeid(T).getHash(&_key);
50             }
51         }
52 
53         /++
54          + Compares two tokens for equality.
55          ++/
56         bool opEquals(ref const Key other) const
57         {
58             return _key == other._key;
59         }
60     }
61 
62 public:
63     /++
64      + Constructs a counter table with an initial token.
65      ++/
66     this(T follow)
67     {
68         poke(follow);
69     }
70 
71     /++
72      + Checks if the counter table contains a given token.
73      ++/
74     bool contains(T follow)
75     {
76         return !!(Key(follow) in _counts);
77     }
78 
79     /++
80      + Checks if the counter table is empty.
81      ++/
82     @property
83     bool empty()
84     {
85         return length == 0;
86     }
87 
88     /++
89      + Returns the counter value for a token.
90      ++/
91     uint get(T follow)
92     {
93         return _counts[Key(follow)];
94     }
95 
96     /++
97      + Returns a list of tokens in the counter table.
98      ++/
99     @property
100     T[] keys()
101     {
102         return _counts.keys.map!"a.value".array;
103     }
104 
105     /++
106      + Returns the length of the counter table.
107      ++/
108     @property
109     size_t length()
110     {
111         return _counts.length;
112     }
113 
114     /++
115      + Returns the counter value for a token.
116      + If the token doesn't exist, 0 is returned.
117      ++/
118     uint peek(T follow)
119     {
120         auto ptr = Key(follow) in _counts;
121         return ptr ? *ptr : 0;
122     }
123 
124     /++
125      + Pokes a token in the counter table, incrementing its counter value.
126      + If the token doesn't exist, it's created and assigned a counter of 1.
127      ++/
128     void poke(T follow)
129     {
130         scope(exit) _total = 0;
131 
132         if(auto ptr = Key(follow) in _counts)
133         {
134             *ptr = *ptr + 1;
135         }
136         else
137         {
138             _counts[Key(follow)] = 1;
139         }
140     }
141 
142     /++
143      + Returns a random token with equal distribution.
144      + If the counter table is emtpy, null is returned instead.
145      ++/
146     @property
147     T random()()
148     if(isAssignable!(T, typeof(null)))
149     {
150         if(!empty)
151         {
152             auto index = uniform(0, length);
153             return _counts.keys[index].value;
154         }
155         else
156         {
157             return null;
158         }
159     }
160 
161     /++
162      + Ditto
163      ++/
164     @property
165     Nullable!(Unqual!T) random()()
166     if(!isAssignable!(T, typeof(null)))
167     {
168         Nullable!(Unqual!T) result;
169 
170         if(!empty)
171         {
172             auto index = uniform(0, length);
173             result = _counts.keys[index].value;
174         }
175 
176         return result;
177     }
178 
179     /++
180      + Rebuilds the associative arrays used by the counter table.
181      ++/
182     @property
183     void rehash()
184     {
185         _counts.rehash;
186     }
187 
188     /++
189      + Returns a random token, distributed based on the counter values.
190      + Specifically, a token with a higher counter is more likely to be chosen
191      + than a token with a counter lower than it.
192      + If the counter table is empty, null is returned instead.
193      ++/
194     @property
195     T select()()
196     if(isAssignable!(T, typeof(null)))
197     {
198         if(!empty)
199         {
200             auto result = uniform(0, total);
201 
202             foreach(key, count; _counts)
203             {
204                 if(result < count)
205                 {
206                     return key.value;
207                 }
208                 else
209                 {
210                     result -= count;
211                 }
212             }
213 
214             // No return.
215             assert(0);
216         }
217         else
218         {
219             return null;
220         }
221     }
222 
223     /++
224      + Ditto
225      ++/
226     @property
227     Nullable!(Unqual!T) select()()
228     if(!isAssignable!(T, typeof(null)))
229     {
230         Nullable!(Unqual!T) result;
231 
232         if(!empty)
233         {
234             auto needle = uniform(0, total);
235 
236             foreach(key, count; _counts)
237             {
238                 if(needle < count)
239                 {
240                     result = key.value;
241                     return result;
242                 }
243                 else
244                 {
245                     needle -= count;
246                 }
247             }
248 
249             // No return.
250             assert(0);
251         }
252 
253         return result;
254     }
255 
256     /++
257      + Sets the counter value for a given token.
258      ++/
259     void set(T follow, uint count)
260     {
261         scope(exit) _total = 0;
262         _counts[Key(follow)] = count;
263     }
264 
265     /++
266      + Returns the sum of all counters on all tokens. The value is cached once
267      + it's been computed.
268      ++/
269     @property
270     uint total()
271     {
272         if(_total == 0)
273         {
274             _total = _counts.values.sum;
275         }
276 
277         return _total;
278     }
279 }
280 
281 unittest
282 {
283     auto counter = Counter!string("1");
284 
285     assert(counter.empty == false);
286     assert(counter.length == 1);
287     assert(counter.total == 1);
288 
289     assert(counter.contains("1") == true);
290     assert(counter.random == "1");
291     assert(counter.select == "1");
292 
293     assert(counter.peek("1") == 1);
294 
295     counter.poke("1");
296     assert(counter.peek("1") == 2);
297     assert(counter.length == 1);
298     assert(counter.total == 2);
299 
300     counter.poke("2");
301     assert(counter.peek("1") == 2);
302     assert(counter.peek("2") == 1);
303     assert(counter.length == 2);
304     assert(counter.total == 3);
305 }
306 
307 unittest
308 {
309     auto counter = Counter!int(1);
310 
311     assert(counter.empty == false);
312     assert(counter.length == 1);
313     assert(counter.total == 1);
314 
315     assert(counter.contains(1) == true);
316     assert(counter.random == 1);
317     assert(counter.select == 1);
318 
319     assert(counter.peek(1) == 1);
320 
321     counter.poke(1);
322     assert(counter.peek(1) == 2);
323     assert(counter.length == 1);
324     assert(counter.total == 2);
325 
326     counter.poke(2);
327     assert(counter.peek(1) == 2);
328     assert(counter.peek(2) == 1);
329     assert(counter.length == 2);
330     assert(counter.total == 3);
331 }
332 
333 unittest
334 {
335     auto counter = Counter!(int[])([1]);
336 
337     assert(counter.empty == false);
338     assert(counter.length == 1);
339     assert(counter.total == 1);
340 
341     assert(counter.contains([1]) == true);
342     assert(counter.random == [1]);
343     assert(counter.select == [1]);
344 
345     assert(counter.peek([1]) == 1);
346 
347     counter.poke([1]);
348     assert(counter.peek([1]) == 2);
349     assert(counter.length == 1);
350     assert(counter.total == 2);
351 
352     counter.poke([2]);
353     assert(counter.peek([1]) == 2);
354     assert(counter.peek([2]) == 1);
355     assert(counter.length == 2);
356     assert(counter.total == 3);
357 }
358 
359 unittest
360 {
361     auto counter = Counter!(const(int[]))([1]);
362 
363     assert(counter.empty == false);
364     assert(counter.length == 1);
365     assert(counter.total == 1);
366 
367     assert(counter.contains([1]) == true);
368     assert(counter.random == [1]);
369     assert(counter.select == [1]);
370 
371     assert(counter.peek([1]) == 1);
372 
373     counter.poke([1]);
374     assert(counter.peek([1]) == 2);
375     assert(counter.length == 1);
376     assert(counter.total == 2);
377 
378     counter.poke([2]);
379     assert(counter.peek([1]) == 2);
380     assert(counter.peek([2]) == 1);
381     assert(counter.length == 2);
382     assert(counter.total == 3);
383 }
384 
385 unittest
386 {
387     auto counter = Counter!(immutable(int[]))([1]);
388 
389     assert(counter.empty == false);
390     assert(counter.length == 1);
391     assert(counter.total == 1);
392 
393     assert(counter.contains([1]) == true);
394     assert(counter.random == [1]);
395     assert(counter.select == [1]);
396 
397     assert(counter.peek([1]) == 1);
398 
399     counter.poke([1]);
400     assert(counter.peek([1]) == 2);
401     assert(counter.length == 1);
402     assert(counter.total == 2);
403 
404     counter.poke([2]);
405     assert(counter.peek([1]) == 2);
406     assert(counter.peek([2]) == 1);
407     assert(counter.length == 2);
408     assert(counter.total == 3);
409 }