1 """
2 wchoice.py -- by bearophile, V.1.0 Oct 30 2006
3
4 Weighted choice: like the random.choice() when the probabilities of
5 the single elements aren't the same.
6 """
7 import psyco
8 from random import random
9 from bisect import bisect
10 from itertools import izip
11
12 psyco.profile()
13 -def wchoice(objects, frequences, filter=True, normalize=True):
14 """
15 wchoice(objects, frequences, filter=True, normalize=True): return
16 a function that return the given objects with the specified frequency
17 distribution. If no objects with frequency>0 are given, return a
18 constant function that return None.
19
20 Input:
21 objects: sequence of elements to choose.
22 frequences: sequence of their frequences.
23 filter=False disables the filtering, speeding up the object creation,
24 but less bad cases are controlled. Frequences must be float > 0.
25 normalize=False disables the probablitity normalization. The choice
26 becomes faster, but sum(frequences) must be 1
27 """
28 if filter:
29
30 if isinstance(frequences, (set, dict)):
31 raise "in wchoice: frequences: only ordered sequences."
32 if isinstance(objects, (set, dict)):
33 raise "in wchoice: objects: only ordered sequences."
34 if len(frequences) != len(objects):
35 raise "in wchoice: objects and frequences must have the same lenght."
36 frequences = map(float, frequences)
37 filteredFreq = []
38 filteredObj = []
39 for freq, obj in izip(frequences, objects):
40 if freq < 0:
41 raise "in wchoice: only positive frequences."
42 elif freq >1e-8:
43 filteredFreq.append(freq)
44 filteredObj.append(obj)
45
46 if len(filteredFreq) == 0:
47 return lambda: None
48 if len(filteredFreq) == 1:
49 return lambda: filteredObj[0]
50 frequences = filteredFreq
51 objects = filteredObj
52 else:
53 if len(objects) == 1:
54 return lambda: objects[0]
55
56
57 addedFreq = []
58 lastSum = 0
59 for freq in frequences:
60 lastSum += freq
61 addedFreq.append(lastSum)
62
63
64
65
66 if normalize:
67 return lambda rnd=random, bis=bisect: objects[bis(addedFreq, rnd()*lastSum)]
68 else:
69 return lambda rnd=random, bis=bisect: objects[bis(addedFreq, rnd())]
70
71
72 if __name__ == '__main__':
73 print "wchoice tests:"
74 objs = "ABCDE"
75 freqs = [1, 3, 1.1, 0, 5]
76 sumf = sum(freqs)
77 wc = wchoice(objs, freqs)
78 freq1 = dict.fromkeys(objs, 0)
79 nestractions = 100000
80 for i in xrange(nestractions):
81 freq1[wc()] += 1
82
83 freq2 = sorted(freq1.items())
84 freq3 = [sumf*float(v)/nestractions for (k,v) in freq2]
85
86 for (f1,f2) in zip(freq3, freqs):
87 print abs(f1-f2),
88 assert abs(f1-f2) < 0.05
89 print "\n"
90
91 wc = wchoice(["a"], [1])
92 assert set(wc() for i in xrange(20000)) == set(["a"])
93
94 wc = wchoice(["a"], [0])
95 assert set(wc() for i in xrange(20000)) == set([None])
96
97 wc = wchoice(["a","b"], [0,0])
98 assert set(wc() for i in xrange(20000)) == set([None])
99
100 objs = ["A"]
101 freqs = [1.5]
102 wc = wchoice(objs, freqs, filter=False)
103 assert [wc() for _ in xrange(10)] == ["A"] * 10
104
105 objs = "ABCDE"
106 freqs = [1, 3, 1.1, 0.1, 5]
107 wc = wchoice(objs, freqs, filter=False)
108 print [wc() for _ in xrange(50)]
109
110 print "Tests done."
111