Skip to content

Commit ff14260

Browse files
committed
Optimize bits argument to add_rand_var
Achieves 10x speedup where simple constraints are used for the `bits` argument to `add_rand_var`. * Use `get_and_call` workaround for deepcopy only when no constraint optimization is required. * Use same constraint optimization for bits as other args.
1 parent db5b49a commit ff14260

File tree

2 files changed

+100
-53
lines changed

2 files changed

+100
-53
lines changed

constrainedrandom/internal/randvar.py

Lines changed: 71 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -138,61 +138,80 @@ def create_randomizer(self) -> Callable:
138138
return partial(self.fn, *self.args)
139139
else:
140140
return self.fn
141-
elif self.bits is not None:
141+
if self.bits is not None:
142+
# Convert this to a range-based domain.
142143
self.domain = range(0, 1 << self.bits)
143-
# This is still faster than doing self._get_random().randrange(self.bits << 1),
144-
# it seems that getrandbits is 10x faster than randrange.
144+
# If sufficiently small, let this fall through to the general case,
145+
# to optimize randomization w.r.t. constraints.
146+
# The maximum size of range that Python can handle (in CPython)
147+
# when using size_tis 62 bits, as it uses signed 64-bit integers
148+
# and the top of the range is expressed as 1 << bits, i.e.
149+
# requiring one extra bit to store.
150+
if self.bits >= 63:
151+
# Ideally here we would use:
152+
# return partial(self._get_random().getrandbits, self.bits)
153+
# as it seems that getrandbits is 10x faster than randrange.
154+
# However, there is a very strange interaction between deepcopy
155+
# and random that prevents this. See get_and_call for details.
156+
# This solution is still faster than a partial with randrange.
157+
return partial(get_and_call, self._get_random, 'getrandbits', self.bits)
158+
# Handle possible types of domain.
159+
is_range = isinstance(self.domain, range)
160+
is_list_or_tuple = isinstance(self.domain, list) or isinstance(self.domain, tuple)
161+
is_dict = isinstance(self.domain, dict)
162+
# Range, list and tuple are handled nicely by the constraint package.
163+
# Other Iterables may not be, e.g. enum.Enum isn't, despite being an Iterable.
164+
is_iterable = isinstance(self.domain, Iterable)
165+
if is_iterable and not (is_range or is_list_or_tuple or is_dict):
166+
# Convert non-dict iterables to a tuple as we don't expect them to need to be mutable,
167+
# and tuple ought to be slightly more performant than list.
168+
try:
169+
self.domain = tuple(self.domain)
170+
except TypeError:
171+
raise TypeError(f'RandVar was passed a domain of bad type - {self.domain}. '
172+
'This was an Iterable but could not be converted to tuple.')
173+
is_list_or_tuple = True
174+
if self.check_constraints and (is_range or is_list_or_tuple) and len(self.domain) < self.max_domain_size:
175+
# If we are provided a sufficiently small domain and we have constraints, simply construct a
176+
# constraint solution problem instead.
177+
problem = constraint.Problem()
178+
problem.addVariable(self.name, self.domain)
179+
for con in self.constraints:
180+
problem.addConstraint(con, (self.name,))
181+
# Produces a list of dictionaries - index it up front for very marginal
182+
# performance gains
183+
solutions = problem.getSolutions()
184+
if len(solutions) == 0:
185+
debug_fail = RandomizationFail([self.name],
186+
[(c, (self.name,)) for c in self.constraints])
187+
debug_info = RandomizationDebugInfo()
188+
debug_info.add_failure(debug_fail)
189+
raise utils.RandomizationError("Variable was unsolvable. Check constraints.", debug_info)
190+
solution_list = [s[self.name] for s in solutions]
191+
self.check_constraints = False
192+
return partial(self._get_random().choice, solution_list)
193+
elif self.bits is not None:
194+
# Ideally here we would use:
195+
# return partial(self._get_random().getrandbits, self.bits)
196+
# as it seems that getrandbits is 10x faster than randrange.
197+
# However, there is a very strange interaction between deepcopy
198+
# and random that prevents this. See get_and_call for details.
199+
# This solution is still faster than a partial with randrange.
145200
return partial(get_and_call, self._get_random, 'getrandbits', self.bits)
201+
elif is_range:
202+
return partial(self._get_random().randrange, self.domain.start, self.domain.stop)
203+
elif is_list_or_tuple:
204+
return partial(self._get_random().choice, self.domain)
205+
elif is_dict:
206+
rand = self._get_random()
207+
if rand is random:
208+
# Don't store a module in a partial as this can't be copied.
209+
# dist defaults to using the global random module.
210+
return partial(dist, self.domain)
211+
return partial(dist, self.domain, rand)
146212
else:
147-
# Handle possible types of domain.
148-
is_range = isinstance(self.domain, range)
149-
is_list_or_tuple = isinstance(self.domain, list) or isinstance(self.domain, tuple)
150-
is_dict = isinstance(self.domain, dict)
151-
# Range, list and tuple are handled nicely by the constraint package.
152-
# Other Iterables may not be, e.g. enum.Enum isn't, despite being an Iterable.
153-
is_iterable = isinstance(self.domain, Iterable)
154-
if is_iterable and not (is_range or is_list_or_tuple or is_dict):
155-
# Convert non-dict iterables to a tuple as we don't expect them to need to be mutable,
156-
# and tuple ought to be slightly more performant than list.
157-
try:
158-
self.domain = tuple(self.domain)
159-
except TypeError:
160-
raise TypeError(f'RandVar was passed a domain of bad type - {self.domain}. '
161-
'This was an Iterable but could not be converted to tuple.')
162-
is_list_or_tuple = True
163-
if self.check_constraints and (is_range or is_list_or_tuple) and len(self.domain) < self.max_domain_size:
164-
# If we are provided a sufficiently small domain and we have constraints, simply construct a
165-
# constraint solution problem instead.
166-
problem = constraint.Problem()
167-
problem.addVariable(self.name, self.domain)
168-
for con in self.constraints:
169-
problem.addConstraint(con, (self.name,))
170-
# Produces a list of dictionaries - index it up front for very marginal
171-
# performance gains
172-
solutions = problem.getSolutions()
173-
if len(solutions) == 0:
174-
debug_fail = RandomizationFail([self.name],
175-
[(c, (self.name,)) for c in self.constraints])
176-
debug_info = RandomizationDebugInfo()
177-
debug_info.add_failure(debug_fail)
178-
raise utils.RandomizationError("Variable was unsolvable. Check constraints.", debug_info)
179-
solution_list = [s[self.name] for s in solutions]
180-
self.check_constraints = False
181-
return partial(self._get_random().choice, solution_list)
182-
elif is_range:
183-
return partial(self._get_random().randrange, self.domain.start, self.domain.stop)
184-
elif is_list_or_tuple:
185-
return partial(self._get_random().choice, self.domain)
186-
elif is_dict:
187-
rand = self._get_random()
188-
if rand is random:
189-
# Don't store a module in a partial as this can't be copied.
190-
# dist defaults to using the global random module.
191-
return partial(dist, self.domain)
192-
return partial(dist, self.domain, rand)
193-
else:
194-
raise TypeError(f'RandVar was passed a domain of a bad type - {self.domain}. '
195-
'Domain should be a range, list, tuple, dictionary or other Iterable.')
213+
raise TypeError(f'RandVar was passed a domain of a bad type - {self.domain}. '
214+
'Domain should be a range, list, tuple, dictionary or other Iterable.')
196215

197216
def add_constraint(self, constr: utils.Constraint) -> None:
198217
'''

tests/features/basic.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
'''
77

88
from enum import Enum, IntEnum
9-
from random import Random
109

1110
from constrainedrandom import RandObj
1211
from examples.ldinstr import ldInstr
@@ -82,6 +81,35 @@ def get_randobj(self, *args):
8281
return randobj
8382

8483

84+
class BitsLargeWidth(testutils.RandObjTestBase):
85+
'''
86+
Test that large values of bits work.
87+
'''
88+
89+
ITERATIONS = 1000
90+
91+
def get_randobj(self, *args):
92+
r = RandObj(*args)
93+
def not_1(x):
94+
return x != 1
95+
r.add_rand_var('uint64_t', bits=64, constraints=[not_1])
96+
r.add_rand_var('uint128_t', bits=128, constraints=[not_1])
97+
r.add_rand_var('uint256_t', bits=256, constraints=[not_1])
98+
return r
99+
100+
def check(self, results):
101+
for result in results:
102+
self.assertGreaterEqual(result['uint64_t'], 0)
103+
self.assertLess(result['uint64_t'], 1 << 64)
104+
self.assertNotEqual(result['uint64_t'], 1)
105+
self.assertGreaterEqual(result['uint128_t'], 0)
106+
self.assertLess(result['uint128_t'], 1 << 128)
107+
self.assertNotEqual(result['uint128_t'], 1)
108+
self.assertGreaterEqual(result['uint256_t'], 0)
109+
self.assertLess(result['uint256_t'], 1 << 256)
110+
self.assertNotEqual(result['uint256_t'], 1)
111+
112+
85113
class MultiBasic(testutils.RandObjTestBase):
86114
'''
87115
Test a basic multi-variable constraint (easy to randomly fulfill the constraint).

0 commit comments

Comments
 (0)