1 2 module markov.state; 3 4 import std.algorithm; 5 import std.array; 6 import std.exception; 7 import std.random; 8 import std.traits; 9 import std.typecons; 10 11 import markov.counter; 12 13 /++ 14 + Represents a table of token sequences bound to counter tables in a markov chain. 15 ++/ 16 struct State(T) 17 { 18 private: 19 size_t _size; 20 Counter!T[Key] _counters; 21 22 /++ 23 + Wraps a token, abstracting type qualifiers. 24 ++/ 25 struct Key 26 { 27 const T[] _key; 28 29 /++ 30 + Returns the natural value of the token sequence. 31 ++/ 32 @property 33 T[] value() 34 { 35 return cast(T[]) _key.dup; 36 } 37 38 /++ 39 + Compares two token sequences for equality. 40 ++/ 41 bool opEquals(ref const Key other) const 42 { 43 return _key == other._key; 44 } 45 } 46 47 public: 48 @disable 49 this(); 50 51 /++ 52 + Constructs a markov state with the given size. 53 + The size must be greater than 0. 54 ++/ 55 this(size_t size) 56 { 57 _size = enforce(size, "State size cannot be 0."); 58 } 59 60 /++ 61 + Checks if a counter table exists in the markov state. 62 ++/ 63 bool contains(T[] first) 64 { 65 if(first.length == size) 66 { 67 return !!(Key(first) in _counters); 68 } 69 else 70 { 71 return false; 72 } 73 } 74 75 /++ 76 + Checks if a token exists in a counter table in the markov state. 77 ++/ 78 bool contains(T[] first, T follow) 79 { 80 if(first.length == size) 81 { 82 auto ptr = Key(first) in _counters; 83 return ptr ? ptr.contains(follow) : false; 84 } 85 else 86 { 87 return false; 88 } 89 } 90 91 /++ 92 + Checks if the markov state is empty. 93 ++/ 94 @property 95 bool empty() 96 { 97 return length == 0; 98 } 99 100 /++ 101 + Returns the counter table that corresponds to the token sequence. 102 ++/ 103 Counter!T get(T[] first) 104 { 105 return _counters[Key(first)]; 106 } 107 108 /++ 109 + Returns a list of token sequences in the counter table. 110 ++/ 111 @property 112 T[][] keys() 113 { 114 return _counters.keys.map!"a.value".array; 115 } 116 117 /++ 118 + Returns the length of the markov state. 119 ++/ 120 @property 121 size_t length() 122 { 123 return _counters.length; 124 } 125 126 /++ 127 + Return the counter value of a token, from the counter table that 128 + corresponds to the given token sequence. 129 + If the token doesn't exist in the counter table, of the leading sequence 130 + doesn't exist in the markov state, or the length of the sequence doesn't 131 + match the size of the markov state, 0 is returned instead. 132 ++/ 133 uint peek(T[] first, T follow) 134 { 135 if(first.length == size) 136 { 137 auto ptr = Key(first) in _counters; 138 return ptr ? ptr.peek(follow) : 0; 139 } 140 else 141 { 142 return 0; 143 } 144 } 145 146 /++ 147 + Pokes a token in the counter table that corresponds to the given leading 148 + sequence of tokens, incrementing its counter value. 149 + If the token counter have an entry in the counter table, it is created 150 + and assigned a value of 1, and if no counter table exists for the leading 151 + token sequence, one is created as well. 152 + The length of the token sequence must match the size of the markov state. 153 ++/ 154 void poke(T[] first, T follow) 155 { 156 // Ensure that first length is equal to this state's size. 157 enforce(first.length == size, "Length of input doesn't match size."); 158 159 if(auto ptr = Key(first) in _counters) 160 { 161 ptr.poke(follow); 162 } 163 else 164 { 165 Counter!T counter; 166 counter.poke(follow); 167 168 _counters[Key(first)] = counter; 169 } 170 } 171 172 /++ 173 + Returns a random token from a random counter table. 174 + If either the markov state or the counter table is empty, 175 + null is returned instead. 176 ++/ 177 @property 178 T random()() 179 if(isAssignable!(T, typeof(null))) 180 { 181 if(!empty) 182 { 183 auto index = uniform(0, length); 184 return _counters.values[index].random; 185 } 186 else 187 { 188 return null; 189 } 190 } 191 192 /++ 193 + Ditto 194 ++/ 195 @property 196 Nullable!(Unqual!T) random()() 197 if(!isAssignable!(T, typeof(null))) 198 { 199 Nullable!(Unqual!T) result; 200 201 if(!empty) 202 { 203 auto index = uniform(0, length); 204 return _counters.values[index].random; 205 } 206 207 return result; 208 } 209 210 /++ 211 + Rebuilds the associative arrays used by the markov table. 212 + 213 + Params: 214 + deep = If true, all the counter tables are rebuilt as well. 215 ++/ 216 @property 217 void rehash(bool deep = false) 218 { 219 _counters.rehash; 220 221 if(deep) 222 { 223 foreach(ref counter; _counters) 224 { 225 counter.rehash; 226 } 227 } 228 } 229 230 /++ 231 + Returns a random token that might follow the given sequence of tokens 232 + based on the markov state and the counter table that corresponds to the 233 + token sequence. 234 + If either the markov state of the corresponding counter table is empty, 235 + or the token sequence doesn't have a counter table assigned to it, 236 + null is returned instead. 237 ++/ 238 @property 239 T select()(T[] first) 240 if(isAssignable!(T, typeof(null))) 241 { 242 if(!empty) 243 { 244 auto ptr = Key(first) in _counters; 245 return ptr ? ptr.select : null; 246 } 247 else 248 { 249 return null; 250 } 251 } 252 253 /++ 254 + Ditto 255 ++/ 256 @property 257 Nullable!(Unqual!T) select()(T[] first) 258 if(!isAssignable!(T, typeof(null))) 259 { 260 Nullable!(Unqual!T) result; 261 262 if(!empty) 263 { 264 if(auto ptr = Key(first) in _counters) 265 { 266 return ptr.select; 267 } 268 } 269 270 return result; 271 } 272 273 /++ 274 + Sets the counter table for a given sequence of tokens. 275 ++/ 276 void set(T[] first, Counter!T counter) 277 { 278 _counters[Key(first)] = counter; 279 } 280 281 /++ 282 + Returns the size of the markov state. 283 ++/ 284 @property 285 size_t size() 286 { 287 return _size; 288 } 289 } 290 291 unittest 292 { 293 try 294 { 295 auto state = State!string(0); 296 assert(0); 297 } 298 catch(Exception) 299 { 300 // Expected result. 301 } 302 } 303 304 unittest 305 { 306 try 307 { 308 auto state = State!string(1); 309 state.poke(["1", "2"], "3"); 310 assert(0); 311 } 312 catch(Exception) 313 { 314 // Expected result. 315 } 316 } 317 318 unittest 319 { 320 auto state = State!string(1); 321 322 assert(state.empty == true); 323 assert(state.length == 0); 324 assert(state.size == 1); 325 326 assert(state.random is null); 327 assert(state.select(["1"]) is null); 328 assert(state.peek(["1"], "2") == 0); 329 330 state.poke(["1"], "2"); 331 assert(state.empty == false); 332 assert(state.length == 1); 333 assert(state.size == 1); 334 335 assert(state.random == "2"); 336 assert(state.select(["1"]) == "2"); 337 assert(state.peek(["1"], "2") == 1); 338 339 state.poke(["1"], "2"); 340 assert(state.peek(["1"], "2") == 2); 341 assert(state.peek(["1"], "3") == 0); 342 343 state.poke(["1"], "3"); 344 assert(state.length == 1); 345 assert(state.peek(["1"], "2") == 2); 346 assert(state.peek(["1"], "3") == 1); 347 } 348 349 unittest 350 { 351 auto state = State!int(1); 352 353 assert(state.empty == true); 354 assert(state.length == 0); 355 assert(state.size == 1); 356 357 assert(state.random.isNull); 358 assert(state.select([1]).isNull); 359 assert(state.peek([1], 2) == 0); 360 361 state.poke([1], 2); 362 assert(state.empty == false); 363 assert(state.length == 1); 364 assert(state.size == 1); 365 366 assert(state.random == 2); 367 assert(state.select([1]) == 2); 368 assert(state.peek([1], 2) == 1); 369 370 state.poke([1], 2); 371 assert(state.peek([1], 2) == 2); 372 assert(state.peek([1], 3) == 0); 373 374 state.poke([1], 3); 375 assert(state.length == 1); 376 assert(state.peek([1], 2) == 2); 377 assert(state.peek([1], 3) == 1); 378 } 379 380 unittest 381 { 382 auto state = State!(int[])(1); 383 384 assert(state.empty == true); 385 assert(state.length == 0); 386 assert(state.size == 1); 387 388 assert(state.random is null); 389 assert(state.select([[1]]) is null); 390 assert(state.peek([[1]], [2]) == 0); 391 392 state.poke([[1]], [2]); 393 assert(state.empty == false); 394 assert(state.length == 1); 395 assert(state.size == 1); 396 397 assert(state.random == [2]); 398 assert(state.select([[1]]) == [2]); 399 assert(state.peek([[1]], [2]) == 1); 400 401 state.poke([[1]], [2]); 402 assert(state.peek([[1]], [2]) == 2); 403 assert(state.peek([[1]], [3]) == 0); 404 405 state.poke([[1]], [3]); 406 assert(state.length == 1); 407 assert(state.peek([[1]], [2]) == 2); 408 assert(state.peek([[1]], [3]) == 1); 409 } 410 411 unittest 412 { 413 auto state = State!(const(int[]))(1); 414 415 assert(state.empty == true); 416 assert(state.length == 0); 417 assert(state.size == 1); 418 419 assert(state.random.isNull); 420 assert(state.select([[1]]).isNull); 421 assert(state.peek([[1]], [2]) == 0); 422 423 state.poke([[1]], [2]); 424 assert(state.empty == false); 425 assert(state.length == 1); 426 assert(state.size == 1); 427 428 assert(state.random == [2]); 429 assert(state.select([[1]]) == [2]); 430 assert(state.peek([[1]], [2]) == 1); 431 432 state.poke([[1]], [2]); 433 assert(state.peek([[1]], [2]) == 2); 434 assert(state.peek([[1]], [3]) == 0); 435 436 state.poke([[1]], [3]); 437 assert(state.length == 1); 438 assert(state.peek([[1]], [2]) == 2); 439 assert(state.peek([[1]], [3]) == 1); 440 } 441 442 unittest 443 { 444 auto state = State!(immutable(int[]))(1); 445 446 assert(state.empty == true); 447 assert(state.length == 0); 448 assert(state.size == 1); 449 450 assert(state.random.isNull); 451 assert(state.select([[1]]).isNull); 452 assert(state.peek([[1]], [2]) == 0); 453 454 state.poke([[1]], [2]); 455 assert(state.empty == false); 456 assert(state.length == 1); 457 assert(state.size == 1); 458 459 assert(state.random == [2]); 460 assert(state.select([[1]]) == [2]); 461 assert(state.peek([[1]], [2]) == 1); 462 463 state.poke([[1]], [2]); 464 assert(state.peek([[1]], [2]) == 2); 465 assert(state.peek([[1]], [3]) == 0); 466 467 state.poke([[1]], [3]); 468 assert(state.length == 1); 469 assert(state.peek([[1]], [2]) == 2); 470 assert(state.peek([[1]], [3]) == 1); 471 }