1 
2 module markov.chain;
3 
4 import std.algorithm;
5 import std.exception;
6 import std.range;
7 import std.traits;
8 import std.typecons;
9 
10 import markov.state;
11 
12 struct MarkovChain(T)
13 {
14 private:
15     T[] _history;
16     State!T[size_t] _states;
17 
18 public:
19     @disable
20     this();
21 
22     /++
23      + Constructs a markov chain with empty states of the given sizes.
24      ++/
25     this(size_t[] sizes...)
26     {
27         _history.length = sizes.reduce!max;
28 
29         foreach(size; sizes)
30         {
31             _states[size] = State!T(size);
32         }
33     }
34 
35     /++
36      + Constructs a markov chain using a list of existing states.
37      ++/
38     this(State!T[] states...)
39     {
40         enforce(states.length, "Cannot construct markov chain with 0 states.");
41 
42         foreach(state; states)
43         {
44             _states[state.size] = state;
45         }
46 
47         _history.length = _states.values.map!"a.size".reduce!max;
48     }
49 
50     /++
51      + Checks if all of the markov chain's states are empty.
52      ++/
53     @property
54     bool empty()
55     {
56         return _states.values.all!"a.empty";
57     }
58 
59     /++
60      + Trains the markov chain with a specific token sequence.
61      ++/
62     void feed(T[] first, T follow)
63     {
64         if(auto ptr = first.length in _states)
65         {
66             ptr.poke(first, follow);
67         }
68     }
69 
70     /++
71      + Returns a token generated from the internal set of states, based on the
72      + tokens previously generated. If no token can be produced, a random one is returned.
73      ++/
74     T generate()()
75     if(isAssignable!(T, typeof(null)))
76     {
77         T result = select;
78         return result ? result : random;
79     }
80 
81     /++
82      + Ditto
83      ++/
84     Nullable!(Unqual!T) generate()()
85     if(!isAssignable!(T, typeof(null)))
86     {
87         Nullable!(Unqual!T) result = select;
88         return !result.isNull ? result : random;
89     }
90 
91     /++
92      + Ditto, but produces an array of token with the given length.
93      ++/
94     Unqual!T[] generate()(size_t length)
95     {
96         Unqual!T[] output;
97 
98         if(generate(length, output) == length)
99         {
100             return output;
101         }
102         else
103         {
104             return null;
105         }
106     }
107 
108     /++
109      + Ditto, but the array is given as an out-parameter.
110      +
111      + Returns:
112      +   The number of tokens that were generated.
113      ++/
114     size_t generate()(size_t length, out Unqual!T[] output)
115     {
116         output = new Unqual!T[length];
117 
118         foreach(i; 0 .. length)
119         {
120             auto result = generate;
121 
122             static if(isAssignable!(T, typeof(null)))
123             {
124                 if(result is null)
125                 {
126                     return i;
127                 }
128                 else
129                 {
130                     output[i] = result;
131                 }
132             }
133             else
134             {
135                 if(result.isNull)
136                 {
137                     return i;
138                 }
139                 else
140                 {
141                     output[i] = result.get;
142                 }
143             }
144         }
145 
146         return length;
147     }
148 
149     /++
150      + Returns the number of states used by the markov chain.
151      ++/
152     @property
153     size_t length()
154     {
155         return _states.length;
156     }
157 
158     /++
159      + Returns the lengths of the markov chain's states in an unknown order.
160      ++/
161     @property
162     size_t[] lengths()
163     {
164         return _states.values.map!"a.length".array;
165     }
166 
167     /++
168      + Pushes a token to the markov chain's history buffer.
169      ++/
170     void push(T follow)
171     {
172         static if(isMutable!T)
173         {
174             copy(_history[1 .. $], _history[0 .. $ - 1]);
175             _history[$ - 1] = follow;
176         }
177         else
178         {
179             _history = _history[1 .. $] ~ [ follow ];
180         }
181     }
182 
183     /++
184      + Returns a randomly selected token from a randomly selected state.
185      ++/
186     @property
187     T random()()
188     if(isAssignable!(T, typeof(null)))
189     {
190         foreach(ref state; _states)
191         {
192             T current = state.random;
193             if(current) {
194                 push(current);
195                 return current;
196             }
197         }
198 
199         return null;
200     }
201 
202     /++
203      + Ditto.
204      ++/
205     @property
206     Nullable!(Unqual!T) random()()
207     if(!isAssignable!(T, typeof(null)))
208     {
209         Nullable!(Unqual!T) result;
210 
211         foreach(ref state; _states)
212         {
213             result = state.random;
214 
215             if(!result.isNull)
216             {
217                 push(result.get);
218                 return result;
219             }
220         }
221 
222         return result;
223     }
224 
225     /++
226      + Resets the markov chain's history buffer to an empty state.
227      ++/
228     @property
229     void reset()
230     {
231         _history = T[].init;
232         _history.length = sizes.reduce!max;
233     }
234 
235     /++
236      + Rehashes the associative arrays used in the markov chain's states.
237      ++/
238     @property
239     void rehash()
240     {
241         foreach(ref state; _states)
242         {
243             state.rehash;
244         }
245     }
246 
247     /++
248      + Pushes tokens to the markov chain's history buffer, seeding it for
249      + subsequent calls to `select()` or `generate()`.
250      +
251      + Note that any tokens that would exceed the space of the history buffer
252      + (which is equal to the size of the largest state) are discarded.
253      ++/
254     void seed(T[] seed...)
255     {
256         seed.retro.take(_history.length).retro.each!(f => push(f));
257     }
258 
259     /++
260      + Returns a token generated from the internal set of states, based on the
261      + tokens previously generated. If no token can be produced, null is returned.
262      ++/
263     T select()()
264     if(isAssignable!(T, typeof(null)))
265     {
266         if(!empty)
267         {
268             foreach(ref state; _states.values.sort!"a.size > b.size")
269             {
270                 T current = state.select(_history[$ - state.size .. $]);
271                 if(current) {
272                     push(current);
273                     return current;
274                 }
275             }
276         }
277 
278         return null;
279     }
280 
281     /++
282      + Ditto
283      ++/
284     Nullable!(Unqual!T) select()()
285     if(!isAssignable!(T, typeof(null)))
286     {
287         Nullable!(Unqual!T) result;
288 
289         if(!empty)
290         {
291             foreach(ref state; _states.values.sort!"a.size > b.size")
292             {
293                 result = state.select(_history[$ - state.size .. $]);
294 
295                 if(!result.isNull)
296                 {
297                     push(result.get);
298                     return result;
299                 }
300             }
301         }
302 
303         return result;
304     }
305 
306     /++
307      + Returns the sizes of the markov chain's states in an unknown order.
308      ++/
309     @property
310     size_t[] sizes()
311     {
312         return _states.values.map!"a.size".array;
313     }
314 
315     /++
316      + Returns an array representing the markov chain's internal set of states.
317      ++/
318     @property
319     State!T[] states()
320     {
321         return _states.values;
322     }
323 
324     /++
325      + Trains the markov chain from a sequence of input tokens.
326      ++/
327     void train(T[] input...)
328     {
329         foreach(index, follow; input)
330         {
331             foreach(size, ref state; _states)
332             {
333                 if(size <= index)
334                 {
335                     T[] first = input[index - size .. index];
336                     state.poke(first, follow);
337                 }
338             }
339         }
340     }
341 }
342 
343 unittest
344 {
345     auto chain = MarkovChain!(int[])(1);
346 
347     chain.train([1, 2, 3], [4, 5, 6], [7, 8, 9]);
348 }
349 
350 unittest
351 {
352     auto chain = MarkovChain!(immutable(int[]))(1);
353 
354     chain.train([1, 2, 3], [4, 5, 6], [7, 8, 9]);
355 }