1 
2 module markov.state;
3 
4 import std.algorithm;
5 import std.array;
6 import std.exception;
7 import std.random;
8 import std.traits;
9 import std.typecons;
10 
11 import markov.counter;
12 
13 /++
14  + Represents a table of token sequences bound to counter tables in a markov chain.
15  ++/
16 struct State(T)
17 {
18 private:
19     size_t _size;
20     Counter!T[Key] _counters;
21 
22     /++
23      + Wraps a token, abstracting type qualifiers.
24      ++/
25     struct Key
26     {
27         const T[] _key;
28 
29         /++
30          + Returns the natural value of the token sequence.
31          ++/
32         @property
33         T[] value()
34         {
35             return cast(T[]) _key.dup;
36         }
37 
38         /++
39          + Compares two token sequences for equality.
40          ++/
41         bool opEquals(ref const Key other) const
42         {
43             return _key == other._key;
44         }
45     }
46 
47 public:
48     @disable
49     this();
50 
51     /++
52      + Constructs a markov state with the given size.
53      + The size must be greater than 0.
54      ++/
55     this(size_t size)
56     {
57         _size = enforce(size, "State size cannot be 0.");
58     }
59 
60     /++
61      + Checks if a counter table exists in the markov state.
62      ++/
63     bool contains(T[] first)
64     {
65         if(first.length == size)
66         {
67             return !!(Key(first) in _counters);
68         }
69         else
70         {
71             return false;
72         }
73     }
74 
75     /++
76      + Checks if a token exists in a counter table in the markov state.
77      ++/
78     bool contains(T[] first, T follow)
79     {
80         if(first.length == size)
81         {
82             auto ptr = Key(first) in _counters;
83             return ptr ? ptr.contains(follow) : false;
84         }
85         else
86         {
87             return false;
88         }
89     }
90 
91     /++
92      + Checks if the markov state is empty.
93      ++/
94     @property
95     bool empty()
96     {
97         return length == 0;
98     }
99 
100     /++
101      + Returns the counter table that corresponds to the token sequence.
102      ++/
103     Counter!T get(T[] first)
104     {
105         return _counters[Key(first)];
106     }
107 
108     /++
109      + Returns a list of token sequences in the counter table.
110      ++/
111     @property
112     T[][] keys()
113     {
114         return _counters.keys.map!"a.value".array;
115     }
116 
117     /++
118      + Returns the length of the markov state.
119      ++/
120     @property
121     size_t length()
122     {
123         return _counters.length;
124     }
125 
126     /++
127      + Return the counter value of a token, from the counter table that
128      + corresponds to the given token sequence.
129      + If the token doesn't exist in the counter table, of the leading sequence
130      + doesn't exist in the markov state, or the length of the sequence doesn't
131      + match the size of the markov state, 0 is returned instead.
132      ++/
133     uint peek(T[] first, T follow)
134     {
135         if(first.length == size)
136         {
137             auto ptr = Key(first) in _counters;
138             return ptr ? ptr.peek(follow) : 0;
139         }
140         else
141         {
142             return 0;
143         }
144     }
145 
146     /++
147      + Pokes a token in the counter table that corresponds to the given leading
148      + sequence of tokens, incrementing its counter value.
149      + If the token counter have an entry in the counter table, it is created
150      + and assigned a value of 1, and if no counter table exists for the leading
151      + token sequence, one is created as well.
152      + The length of the token sequence must match the size of the markov state.
153      ++/
154     void poke(T[] first, T follow)
155     {
156         // Ensure that first length is equal to this state's size.
157         enforce(first.length == size, "Length of input doesn't match size.");
158 
159         if(auto ptr = Key(first) in _counters)
160         {
161             ptr.poke(follow);
162         }
163         else
164         {
165             Counter!T counter;
166             counter.poke(follow);
167 
168             _counters[Key(first)] = counter;
169         }
170     }
171 
172     /++
173      + Returns a random token from a random counter table.
174      + If either the markov state or the counter table is empty,
175      + null is returned instead.
176      ++/
177     @property
178     T random()()
179     if(isAssignable!(T, typeof(null)))
180     {
181         if(!empty)
182         {
183             auto index = uniform(0, length);
184             return _counters.values[index].random;
185         }
186         else
187         {
188             return null;
189         }
190     }
191 
192     /++
193      + Ditto
194      ++/
195     @property
196     Nullable!(Unqual!T) random()()
197     if(!isAssignable!(T, typeof(null)))
198     {
199         Nullable!(Unqual!T) result;
200 
201         if(!empty)
202         {
203             auto index = uniform(0, length);
204             return _counters.values[index].random;
205         }
206 
207         return result;
208     }
209 
210     /++
211      + Rebuilds the associative arrays used by the markov table.
212      +
213      + Params:
214      +   deep = If true, all the counter tables are rebuilt as well.
215      ++/
216     @property
217     void rehash(bool deep = false)
218     {
219         _counters.rehash;
220 
221         if(deep)
222         {
223             foreach(ref counter; _counters)
224             {
225                 counter.rehash;
226             }
227         }
228     }
229 
230     /++
231      + Returns a random token that might follow the given sequence of tokens
232      + based on the markov state and the counter table that corresponds to the
233      + token sequence.
234      + If either the markov state of the corresponding counter table is empty,
235      + or the token sequence doesn't have a counter table assigned to it,
236      + null is returned instead.
237      ++/
238     @property
239     T select()(T[] first)
240     if(isAssignable!(T, typeof(null)))
241     {
242         if(!empty)
243         {
244             auto ptr = Key(first) in _counters;
245             return ptr ? ptr.select : null;
246         }
247         else
248         {
249             return null;
250         }
251     }
252 
253     /++
254      + Ditto
255      ++/
256     @property
257     Nullable!(Unqual!T) select()(T[] first)
258     if(!isAssignable!(T, typeof(null)))
259     {
260         Nullable!(Unqual!T) result;
261 
262         if(!empty)
263         {
264             if(auto ptr = Key(first) in _counters)
265             {
266                 return ptr.select;
267             }
268         }
269 
270         return result;
271     }
272 
273     /++
274      + Sets the counter table for a given sequence of tokens.
275      ++/
276     void set(T[] first, Counter!T counter)
277     {
278         _counters[Key(first)] = counter;
279     }
280 
281     /++
282      + Returns the size of the markov state.
283      ++/
284     @property
285     size_t size()
286     {
287         return _size;
288     }
289 }
290 
291 unittest
292 {
293     try
294     {
295         auto state = State!string(0);
296         assert(0);
297     }
298     catch(Exception)
299     {
300         // Expected result.
301     }
302 }
303 
304 unittest
305 {
306     try
307     {
308         auto state = State!string(1);
309         state.poke(["1", "2"], "3");
310         assert(0);
311     }
312     catch(Exception)
313     {
314         // Expected result.
315     }
316 }
317 
318 unittest
319 {
320     auto state = State!string(1);
321 
322     assert(state.empty == true);
323     assert(state.length == 0);
324     assert(state.size == 1);
325 
326     assert(state.random is null);
327     assert(state.select(["1"]) is null);
328     assert(state.peek(["1"], "2") == 0);
329 
330     state.poke(["1"], "2");
331     assert(state.empty == false);
332     assert(state.length == 1);
333     assert(state.size == 1);
334 
335     assert(state.random == "2");
336     assert(state.select(["1"]) == "2");
337     assert(state.peek(["1"], "2") == 1);
338 
339     state.poke(["1"], "2");
340     assert(state.peek(["1"], "2") == 2);
341     assert(state.peek(["1"], "3") == 0);
342 
343     state.poke(["1"], "3");
344     assert(state.length == 1);
345     assert(state.peek(["1"], "2") == 2);
346     assert(state.peek(["1"], "3") == 1);
347 }
348 
349 unittest
350 {
351     auto state = State!int(1);
352 
353     assert(state.empty == true);
354     assert(state.length == 0);
355     assert(state.size == 1);
356 
357     assert(state.random.isNull);
358     assert(state.select([1]).isNull);
359     assert(state.peek([1], 2) == 0);
360 
361     state.poke([1], 2);
362     assert(state.empty == false);
363     assert(state.length == 1);
364     assert(state.size == 1);
365 
366     assert(state.random == 2);
367     assert(state.select([1]) == 2);
368     assert(state.peek([1], 2) == 1);
369 
370     state.poke([1], 2);
371     assert(state.peek([1], 2) == 2);
372     assert(state.peek([1], 3) == 0);
373 
374     state.poke([1], 3);
375     assert(state.length == 1);
376     assert(state.peek([1], 2) == 2);
377     assert(state.peek([1], 3) == 1);
378 }
379 
380 unittest
381 {
382     auto state = State!(int[])(1);
383 
384     assert(state.empty == true);
385     assert(state.length == 0);
386     assert(state.size == 1);
387 
388     assert(state.random is null);
389     assert(state.select([[1]]) is null);
390     assert(state.peek([[1]], [2]) == 0);
391 
392     state.poke([[1]], [2]);
393     assert(state.empty == false);
394     assert(state.length == 1);
395     assert(state.size == 1);
396 
397     assert(state.random == [2]);
398     assert(state.select([[1]]) == [2]);
399     assert(state.peek([[1]], [2]) == 1);
400 
401     state.poke([[1]], [2]);
402     assert(state.peek([[1]], [2]) == 2);
403     assert(state.peek([[1]], [3]) == 0);
404 
405     state.poke([[1]], [3]);
406     assert(state.length == 1);
407     assert(state.peek([[1]], [2]) == 2);
408     assert(state.peek([[1]], [3]) == 1);
409 }
410 
411 unittest
412 {
413     auto state = State!(const(int[]))(1);
414 
415     assert(state.empty == true);
416     assert(state.length == 0);
417     assert(state.size == 1);
418 
419     assert(state.random.isNull);
420     assert(state.select([[1]]).isNull);
421     assert(state.peek([[1]], [2]) == 0);
422 
423     state.poke([[1]], [2]);
424     assert(state.empty == false);
425     assert(state.length == 1);
426     assert(state.size == 1);
427 
428     assert(state.random == [2]);
429     assert(state.select([[1]]) == [2]);
430     assert(state.peek([[1]], [2]) == 1);
431 
432     state.poke([[1]], [2]);
433     assert(state.peek([[1]], [2]) == 2);
434     assert(state.peek([[1]], [3]) == 0);
435 
436     state.poke([[1]], [3]);
437     assert(state.length == 1);
438     assert(state.peek([[1]], [2]) == 2);
439     assert(state.peek([[1]], [3]) == 1);
440 }
441 
442 unittest
443 {
444     auto state = State!(immutable(int[]))(1);
445 
446     assert(state.empty == true);
447     assert(state.length == 0);
448     assert(state.size == 1);
449 
450     assert(state.random.isNull);
451     assert(state.select([[1]]).isNull);
452     assert(state.peek([[1]], [2]) == 0);
453 
454     state.poke([[1]], [2]);
455     assert(state.empty == false);
456     assert(state.length == 1);
457     assert(state.size == 1);
458 
459     assert(state.random == [2]);
460     assert(state.select([[1]]) == [2]);
461     assert(state.peek([[1]], [2]) == 1);
462 
463     state.poke([[1]], [2]);
464     assert(state.peek([[1]], [2]) == 2);
465     assert(state.peek([[1]], [3]) == 0);
466 
467     state.poke([[1]], [3]);
468     assert(state.length == 1);
469     assert(state.peek([[1]], [2]) == 2);
470     assert(state.peek([[1]], [3]) == 1);
471 }