1 2 module markov.counter; 3 4 import std.algorithm; 5 import std.array; 6 import std.random; 7 import std.traits; 8 import std.typecons; 9 10 /++ 11 + Represents a set of counters for trailing (following) tokens in a markov state. 12 ++/ 13 struct Counter(T) 14 { 15 private: 16 uint[Key] _counts; 17 uint _total; 18 19 /++ 20 + Wraps a token, providing normalized hashing and abstracting type qualifiers. 21 ++/ 22 struct Key 23 { 24 T _key; 25 26 /++ 27 + Returns the natural value of the token. 28 ++/ 29 @property 30 T value() 31 { 32 return _key; 33 } 34 35 /++ 36 + Returns the hash for a token. 37 ++/ 38 hash_t toHash() const nothrow @safe 39 { 40 static if(__traits(compiles, { 41 T key = void; 42 hash_t hash = key.toHash; 43 })) 44 { 45 return _key.toHash; 46 } 47 else 48 { 49 return typeid(T).getHash(&_key); 50 } 51 } 52 53 /++ 54 + Compares two tokens for equality. 55 ++/ 56 bool opEquals(ref const Key other) const 57 { 58 return _key == other._key; 59 } 60 } 61 62 public: 63 /++ 64 + Constructs a counter table with an initial token. 65 ++/ 66 this(T follow) 67 { 68 poke(follow); 69 } 70 71 /++ 72 + Checks if the counter table contains a given token. 73 ++/ 74 bool contains(T follow) 75 { 76 return !!(Key(follow) in _counts); 77 } 78 79 /++ 80 + Checks if the counter table is empty. 81 ++/ 82 @property 83 bool empty() 84 { 85 return length == 0; 86 } 87 88 /++ 89 + Returns the counter value for a token. 90 ++/ 91 uint get(T follow) 92 { 93 return _counts[Key(follow)]; 94 } 95 96 /++ 97 + Returns a list of tokens in the counter table. 98 ++/ 99 @property 100 T[] keys() 101 { 102 return _counts.keys.map!"a.value".array; 103 } 104 105 /++ 106 + Returns the length of the counter table. 107 ++/ 108 @property 109 size_t length() 110 { 111 return _counts.length; 112 } 113 114 /++ 115 + Returns the counter value for a token. 116 + If the token doesn't exist, 0 is returned. 117 ++/ 118 uint peek(T follow) 119 { 120 auto ptr = Key(follow) in _counts; 121 return ptr ? *ptr : 0; 122 } 123 124 /++ 125 + Pokes a token in the counter table, incrementing its counter value. 126 + If the token doesn't exist, it's created and assigned a counter of 1. 127 ++/ 128 void poke(T follow) 129 { 130 scope(exit) _total = 0; 131 132 if(auto ptr = Key(follow) in _counts) 133 { 134 *ptr = *ptr + 1; 135 } 136 else 137 { 138 _counts[Key(follow)] = 1; 139 } 140 } 141 142 /++ 143 + Returns a random token with equal distribution. 144 + If the counter table is emtpy, null is returned instead. 145 ++/ 146 @property 147 T random()() 148 if(isAssignable!(T, typeof(null))) 149 { 150 if(!empty) 151 { 152 auto index = uniform(0, length); 153 return _counts.keys[index].value; 154 } 155 else 156 { 157 return null; 158 } 159 } 160 161 /++ 162 + Ditto 163 ++/ 164 @property 165 Nullable!(Unqual!T) random()() 166 if(!isAssignable!(T, typeof(null))) 167 { 168 Nullable!(Unqual!T) result; 169 170 if(!empty) 171 { 172 auto index = uniform(0, length); 173 result = _counts.keys[index].value; 174 } 175 176 return result; 177 } 178 179 /++ 180 + Rebuilds the associative arrays used by the counter table. 181 ++/ 182 @property 183 void rehash() 184 { 185 _counts.rehash; 186 } 187 188 /++ 189 + Returns a random token, distributed based on the counter values. 190 + Specifically, a token with a higher counter is more likely to be chosen 191 + than a token with a counter lower than it. 192 + If the counter table is empty, null is returned instead. 193 ++/ 194 @property 195 T select()() 196 if(isAssignable!(T, typeof(null))) 197 { 198 if(!empty) 199 { 200 auto result = uniform(0, total); 201 202 foreach(key, count; _counts) 203 { 204 if(result < count) 205 { 206 return key.value; 207 } 208 else 209 { 210 result -= count; 211 } 212 } 213 214 // No return. 215 assert(0); 216 } 217 else 218 { 219 return null; 220 } 221 } 222 223 /++ 224 + Ditto 225 ++/ 226 @property 227 Nullable!(Unqual!T) select()() 228 if(!isAssignable!(T, typeof(null))) 229 { 230 Nullable!(Unqual!T) result; 231 232 if(!empty) 233 { 234 auto needle = uniform(0, total); 235 236 foreach(key, count; _counts) 237 { 238 if(needle < count) 239 { 240 result = key.value; 241 return result; 242 } 243 else 244 { 245 needle -= count; 246 } 247 } 248 249 // No return. 250 assert(0); 251 } 252 253 return result; 254 } 255 256 /++ 257 + Sets the counter value for a given token. 258 ++/ 259 void set(T follow, uint count) 260 { 261 scope(exit) _total = 0; 262 _counts[Key(follow)] = count; 263 } 264 265 /++ 266 + Returns the sum of all counters on all tokens. The value is cached once 267 + it's been computed. 268 ++/ 269 @property 270 uint total() 271 { 272 if(_total == 0) 273 { 274 _total = _counts.values.sum; 275 } 276 277 return _total; 278 } 279 } 280 281 unittest 282 { 283 auto counter = Counter!string("1"); 284 285 assert(counter.empty == false); 286 assert(counter.length == 1); 287 assert(counter.total == 1); 288 289 assert(counter.contains("1") == true); 290 assert(counter.random == "1"); 291 assert(counter.select == "1"); 292 293 assert(counter.peek("1") == 1); 294 295 counter.poke("1"); 296 assert(counter.peek("1") == 2); 297 assert(counter.length == 1); 298 assert(counter.total == 2); 299 300 counter.poke("2"); 301 assert(counter.peek("1") == 2); 302 assert(counter.peek("2") == 1); 303 assert(counter.length == 2); 304 assert(counter.total == 3); 305 } 306 307 unittest 308 { 309 auto counter = Counter!int(1); 310 311 assert(counter.empty == false); 312 assert(counter.length == 1); 313 assert(counter.total == 1); 314 315 assert(counter.contains(1) == true); 316 assert(counter.random == 1); 317 assert(counter.select == 1); 318 319 assert(counter.peek(1) == 1); 320 321 counter.poke(1); 322 assert(counter.peek(1) == 2); 323 assert(counter.length == 1); 324 assert(counter.total == 2); 325 326 counter.poke(2); 327 assert(counter.peek(1) == 2); 328 assert(counter.peek(2) == 1); 329 assert(counter.length == 2); 330 assert(counter.total == 3); 331 } 332 333 unittest 334 { 335 auto counter = Counter!(int[])([1]); 336 337 assert(counter.empty == false); 338 assert(counter.length == 1); 339 assert(counter.total == 1); 340 341 assert(counter.contains([1]) == true); 342 assert(counter.random == [1]); 343 assert(counter.select == [1]); 344 345 assert(counter.peek([1]) == 1); 346 347 counter.poke([1]); 348 assert(counter.peek([1]) == 2); 349 assert(counter.length == 1); 350 assert(counter.total == 2); 351 352 counter.poke([2]); 353 assert(counter.peek([1]) == 2); 354 assert(counter.peek([2]) == 1); 355 assert(counter.length == 2); 356 assert(counter.total == 3); 357 } 358 359 unittest 360 { 361 auto counter = Counter!(const(int[]))([1]); 362 363 assert(counter.empty == false); 364 assert(counter.length == 1); 365 assert(counter.total == 1); 366 367 assert(counter.contains([1]) == true); 368 assert(counter.random == [1]); 369 assert(counter.select == [1]); 370 371 assert(counter.peek([1]) == 1); 372 373 counter.poke([1]); 374 assert(counter.peek([1]) == 2); 375 assert(counter.length == 1); 376 assert(counter.total == 2); 377 378 counter.poke([2]); 379 assert(counter.peek([1]) == 2); 380 assert(counter.peek([2]) == 1); 381 assert(counter.length == 2); 382 assert(counter.total == 3); 383 } 384 385 unittest 386 { 387 auto counter = Counter!(immutable(int[]))([1]); 388 389 assert(counter.empty == false); 390 assert(counter.length == 1); 391 assert(counter.total == 1); 392 393 assert(counter.contains([1]) == true); 394 assert(counter.random == [1]); 395 assert(counter.select == [1]); 396 397 assert(counter.peek([1]) == 1); 398 399 counter.poke([1]); 400 assert(counter.peek([1]) == 2); 401 assert(counter.length == 1); 402 assert(counter.total == 2); 403 404 counter.poke([2]); 405 assert(counter.peek([1]) == 2); 406 assert(counter.peek([2]) == 1); 407 assert(counter.length == 2); 408 assert(counter.total == 3); 409 }