Skip to content

Commit 46e44c9

Browse files
committed
Pass random length list tests with sparse/thorough solver
Naive solver worked already, this commit makes the sparse solver work for the basic test cases already added. Refactor vargroup to make random length list work cleanly. Fully delegate control of CSP to VarGroup, not half in one class and half in another as it was before.
1 parent d1178c8 commit 46e44c9

File tree

4 files changed

+233
-122
lines changed

4 files changed

+233
-122
lines changed

constrainedrandom/internal/multivar.py

Lines changed: 39 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
# SPDX-License-Identifier: MIT
22
# Copyright (c) 2023 Imagination Technologies Ltd. All Rights Reserved
33

4-
import constraint
54
from collections import defaultdict
6-
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Union
5+
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union
76

87
from .vargroup import VarGroup
98

109
from .. import utils
11-
from ..debug import RandomizationDebugInfo, RandomizationFail
10+
from ..debug import RandomizationDebugInfo
1211

1312
if TYPE_CHECKING:
1413
from ..randobj import RandObj
@@ -138,74 +137,65 @@ def solve_groups(
138137
'''
139138
constraints = self.constraints
140139
sparse_solver = solutions_per_group is not None
141-
solutions = []
142-
solved_vars = []
140+
solutions : List[Dict[str, Any]] = []
141+
solved_vars : List[str] = []
143142

144-
# Respect assigned temporary values
143+
# Respect assigned temporary values.
145144
if len(with_values) > 0:
146145
for var_name in with_values.keys():
147146
solved_vars.append(var_name)
148147
solutions.append(with_values)
149148

150-
# If solving sparsely, we'll create a new problem for each group.
151-
# If not solving sparsely, just create one big problem that we add to
152-
# as we go along.
153-
if not sparse_solver:
154-
problem = constraint.Problem()
155-
for var_name, value in with_values.items():
156-
problem.addVariable(var_name, (value,))
157-
149+
# For each group, construct a problem and solve it.
158150
for group in groups:
159-
if sparse_solver:
160-
# Construct one problem per group, add solved variables from previous groups.
161-
problem = constraint.Problem()
162-
# Construct the appropriate group variable problem
163-
group_problem = VarGroup(
164-
group,
165-
solved_vars,
166-
problem,
167-
constraints,
168-
self.max_domain_size,
169-
self.debug,
170-
)
171-
172151
group_solutions = None
152+
group_problem = None
173153
attempts = 0
174154
while group_solutions is None or len(group_solutions) == 0:
155+
# Early loop exit cases
175156
if attempts >= max_iterations:
176157
# We have failed, give up
177158
return None
178159
if attempts > 0 and not group_problem.can_retry():
179160
# Not worth retrying - the same result will be obtained.
180161
return None
181-
if sparse_solver:
182-
if len(solutions) > 0:
183-
# Respect a proportion of the solution space, determined
184-
# by the sparsity/solutions_per_group.
162+
163+
# Determine what the starting state space for this group
164+
# should be.
165+
if sparse_solver and len(solutions) > 0:
166+
# Respect a proportion of the solution space, determined
167+
# by the sparsity/solutions_per_group.
168+
# Start by choosing a subset of the possible solutions.
185169
if solutions_per_group >= len(solutions):
186-
solution_subset = solutions
170+
solution_subset = list(solutions)
187171
else:
188172
solution_subset = self.parent._get_random().choices(
189173
solutions,
190174
k=solutions_per_group
191175
)
192-
if solutions_per_group == 1:
193-
for var_name, value in solution_subset[0].items():
194-
if var_name in problem._variables:
195-
del problem._variables[var_name]
196-
problem.addVariable(var_name, (value,))
197-
else:
198-
solution_space = defaultdict(list)
199-
for soln in solution_subset:
200-
for var_name, value in soln.items():
201-
# List is ~2x slower than set for 'in',
202-
# but variables might be non-hashable.
203-
if value not in solution_space[var_name]:
204-
solution_space[var_name].append(value)
205-
for var_name, values in solution_space.items():
206-
if var_name in problem._variables:
207-
del problem._variables[var_name]
208-
problem.addVariable(var_name, values)
176+
else:
177+
# If not sparse, maintain the entire list of possible solutions.
178+
solution_subset = list(solutions)
179+
180+
# Translate this subset into a dictionary of the
181+
# possible values for each variable.
182+
solution_space = defaultdict(list)
183+
for soln in solution_subset:
184+
for var_name, value in soln.items():
185+
# List is ~2x slower than set for 'in',
186+
# but variables might be non-hashable.
187+
if value not in solution_space[var_name]:
188+
solution_space[var_name].append(value)
189+
190+
# Construct the appropriate group variable problem.
191+
# Must be done after selecting the solution space.
192+
group_problem = VarGroup(
193+
group,
194+
solution_space,
195+
constraints,
196+
self.max_domain_size,
197+
self.debug,
198+
)
209199

210200
# Attempt to solve the group
211201
group_solutions = group_problem.solve(

constrainedrandom/internal/randvar.py

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ def get_length(self) -> int:
253253
" but none was given when get_length was called.")
254254
return self.rand_length_val
255255

256+
def is_list(self) -> bool:
257+
'''
258+
Returns ``True`` if this is a list variable.
259+
260+
:return: ``True`` if this is a list variable, otherwise ``False``.
261+
'''
262+
return self.length is not None or self.rand_length is not None
263+
256264
def set_rand_length(self, length: int) -> None:
257265
'''
258266
Function to set the random length.
@@ -282,31 +290,48 @@ def _get_random(self) -> random.Random:
282290
return random
283291
return self._random
284292

285-
def get_domain_size(self) -> int:
293+
def get_domain_size(self, possible_lengths: Optional[List[int]]=None) -> int:
286294
'''
287295
Return total domain size, accounting for length of this random variable.
288296
297+
:param possible_lengths: Optional, when there is more than one possiblity
298+
for the value of the random length, specifies a list of the
299+
possibilities.
289300
:return: domain size, integer.
290301
'''
291302
if self.domain is None:
292303
# If there's no domain, it means we can't estimate the complexity
293304
# of this variable. Return 1.
294305
return 1
295306
else:
296-
length = self.get_length()
297-
if length is None:
298-
# length is None implies a scalar variable.
299-
return len(self.domain)
300-
elif length == 0:
301-
# This is a zero-length list, adding no complexity.
302-
return 1
303-
elif length == 1:
304-
return len(self.domain)
307+
# possible_lengths is used when the variable has a random
308+
# length and that length is not yet fully determined.
309+
if possible_lengths is None:
310+
# Normal, fixed length of some description.
311+
length = self.get_length()
312+
if length is None:
313+
# length is None implies a scalar variable.
314+
return len(self.domain)
315+
elif length == 0:
316+
# This is a zero-length list, adding no complexity.
317+
return 1
318+
elif length == 1:
319+
return len(self.domain)
320+
else:
321+
# In this case it is effectively cartesian product, i.e.
322+
# n ** k, where n is the size of the domain and k is the length
323+
# of the list.
324+
return len(self.domain) ** length
305325
else:
306-
# In this case it is effectively cartesian product, i.e.
307-
# n ** k, where n is the size of the domain and k is the length
308-
# of the list.
309-
return len(self.domain) ** length
326+
# Random length which could be one of a number of values.
327+
assert self.rand_length is not None, "Cannot use possible_lengths " \
328+
"for a variable with non-random length."
329+
# For each possible length, the domain is the cartesian
330+
# product as above, but added together.
331+
total = 0
332+
for poss_len in possible_lengths:
333+
total += len(self.domain) ** poss_len
334+
return total
310335

311336
def can_use_with_constraint(self) -> bool:
312337
'''
@@ -321,30 +346,42 @@ def can_use_with_constraint(self) -> bool:
321346
# and the domain isn't a dictionary.
322347
return self.domain is not None and not isinstance(self.domain, dict)
323348

324-
def get_constraint_domain(self) -> utils.Domain:
349+
def get_constraint_domain(self, possible_lengths: Optional[List[int]]=None) -> utils.Domain:
325350
'''
326351
Get a ``constraint`` package friendly version of the domain
327352
of this random variable.
328353
354+
:param possible_lengths: Optional, when there is more than one possiblity
355+
for the value of the random length, specifies a list of the
356+
possibilities.
329357
:return: the variable's domain in a format that will work
330358
with the ``constraint`` package.
331359
'''
332-
length = self.get_length()
333-
if length is None:
334-
# Straightforward, scalar
335-
return self.domain
336-
elif length == 0:
337-
# List of length zero - an empty list is only correct choice.
338-
return [[]]
339-
elif length == 1:
340-
# List of length one
341-
return [[x] for x in self.domain]
360+
if possible_lengths is None:
361+
length = self.get_length()
362+
if length is None:
363+
# Straightforward, scalar
364+
return self.domain
365+
elif length == 0:
366+
# List of length zero - an empty list is only correct choice.
367+
return [[]]
368+
elif length == 1:
369+
# List of length one
370+
return [[x] for x in self.domain]
371+
else:
372+
# List of greater length, cartesian product.
373+
# Beware that this may be an extremely large domain.
374+
# Ensure each element is of type list, which is what
375+
# we want to return.
376+
return [list(x) for x in product(self.domain, repeat=length)]
342377
else:
343-
# List of greater length, cartesian product.
344-
# Beware that this may be an extremely large domain.
345-
# Ensure each element is of type list, which is what
346-
# we want to return.
347-
return [list(x) for x in product(self.domain, repeat=length)]
378+
# For each possible length, return the possible domains.
379+
# This can get extremely large, even more so than
380+
# the regular product.
381+
result = []
382+
for poss_len in possible_lengths:
383+
result += [list(x) for x in product(self.domain, repeat=poss_len)]
384+
return result
348385

349386
def randomize_once(self, constraints: Iterable[utils.Constraint], check_constraints: bool, debug: bool) -> Any:
350387
'''

0 commit comments

Comments
 (0)