Skip to content

Commit 6c34eb8

Browse files
committed
z3:
forking solver; Wrapped missing native throwable functions; Tracks of assertion clear on pop
1 parent e89ec8f commit 6c34eb8

File tree

10 files changed

+1133
-500
lines changed

10 files changed

+1133
-500
lines changed

ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt

Lines changed: 281 additions & 213 deletions
Large diffs are not rendered by default.

ksmt-z3/src/main/kotlin/com/microsoft/z3/UnsafeApi.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package com.microsoft.z3
22

3+
import it.unimi.dsi.fastutil.longs.LongSet
4+
35
fun incRefUnsafe(ctx: Long, ast: Long) {
46
// Invoke incRef directly without status check
57
Native.INTERNALincRef(ctx, ast)
@@ -9,3 +11,7 @@ fun decRefUnsafe(ctx: Long, ast: Long) {
911
// Invoke decRef directly without status check
1012
Native.INTERNALdecRef(ctx, ast)
1113
}
14+
15+
fun LongSet.decRefUnsafeAll(ctx: Long) = longIterator().forEachRemaining {
16+
decRefUnsafe(ctx, it)
17+
}

ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesTracker.kt

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,18 @@ import io.ksmt.sort.KUninterpretedSort
1919
* 2. Assert distinct constraints ([assertPendingUninterpretedValueConstraints])
2020
* that may be introduced during internalization.
2121
* */
22-
class ExpressionUninterpretedValuesTracker(val ctx: KContext, val z3Ctx: KZ3Context) {
22+
class ExpressionUninterpretedValuesTracker private constructor(
23+
val ctx: KContext,
24+
val z3Ctx: KZ3Context,
25+
private val registeredUninterpretedSortValues: HashMap<KUninterpretedSortValue, UninterpretedSortValueDescriptor>
26+
) {
27+
constructor(ctx: KContext, z3Ctx: KZ3Context) : this(ctx, z3Ctx, hashMapOf())
28+
constructor(ctx: KContext, z3Ctx: KZ3Context, forkingSolverManager: KZ3ForkingSolverManager) : this(
29+
ctx,
30+
z3Ctx,
31+
with(forkingSolverManager) { z3Ctx.findRegisteredUninterpretedSortValues() }
32+
)
33+
2334
private val expressionLevels = Object2IntOpenHashMap<KExpr<*>>().apply {
2435
defaultReturnValue(Int.MAX_VALUE) // Level which is greater than any possible level
2536
}
@@ -32,9 +43,6 @@ class ExpressionUninterpretedValuesTracker(val ctx: KContext, val z3Ctx: KZ3Cont
3243

3344
private val valueTrackerFrames = arrayListOf(currentFrame)
3445

35-
private val registeredUninterpretedSortValues =
36-
hashMapOf<KUninterpretedSortValue, UninterpretedSortValueDescriptor>()
37-
3846
/**
3947
* Skip any value tracking related actions until
4048
* we have uninterpreted values registered.
@@ -49,6 +57,11 @@ class ExpressionUninterpretedValuesTracker(val ctx: KContext, val z3Ctx: KZ3Cont
4957
body()
5058
}
5159

60+
fun fork(parent: ExpressionUninterpretedValuesTracker) = also {
61+
expressionLevels += parent.expressionLevels
62+
repeat(parent.valueTrackerFrames.size - 1) { pushAssertionLevel() }
63+
}
64+
5265
fun expressionUse(expr: KExpr<*>) = ifTrackingEnabled {
5366
currentFrame.analyzeUsedExpression(expr)
5467
}
@@ -121,7 +134,7 @@ class ExpressionUninterpretedValuesTracker(val ctx: KContext, val z3Ctx: KZ3Cont
121134
z3Ctx.releaseTemporaryAst(constraintLhs)
122135
}
123136

124-
private data class UninterpretedSortValueDescriptor(
137+
internal data class UninterpretedSortValueDescriptor(
125138
val value: KUninterpretedSortValue,
126139
val nativeUniqueValueDescriptor: Long,
127140
val nativeValueExpr: Long

ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt

Lines changed: 90 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,88 @@ package io.ksmt.solver.z3
22

33
import com.microsoft.z3.Context
44
import com.microsoft.z3.Solver
5+
import com.microsoft.z3.Z3Exception
56
import com.microsoft.z3.decRefUnsafe
7+
import com.microsoft.z3.decRefUnsafeAll
68
import com.microsoft.z3.incRefUnsafe
79
import io.ksmt.KContext
810
import io.ksmt.decl.KDecl
911
import io.ksmt.expr.KExpr
1012
import io.ksmt.expr.KUninterpretedSortValue
13+
import io.ksmt.solver.KSolverException
1114
import io.ksmt.solver.util.KExprLongInternalizerBase.Companion.NOT_INTERNALIZED
1215
import io.ksmt.sort.KSort
1316
import io.ksmt.sort.KUninterpretedSort
1417
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap
1518
import it.unimi.dsi.fastutil.longs.LongOpenHashSet
16-
import it.unimi.dsi.fastutil.longs.LongSet
1719
import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap
1820

1921
@Suppress("TooManyFunctions")
20-
class KZ3Context(
22+
class KZ3Context internal constructor(
2123
ksmtCtx: KContext,
22-
private val ctx: Context
24+
private val ctx: Context,
25+
forkingSolverManager: KZ3ForkingSolverManager?,
2326
) : AutoCloseable {
24-
constructor(ksmtCtx: KContext) : this(ksmtCtx, Context())
27+
constructor(ksmtCtx: KContext, ctx: Context) : this(ksmtCtx, ctx, null)
28+
constructor(ksmtCtx: KContext) : this(ksmtCtx, Context(), null)
2529

2630
private var isClosed = false
27-
28-
private val expressions = Object2LongOpenHashMap<KExpr<*>>().apply {
29-
defaultReturnValue(NOT_INTERNALIZED)
30-
}
31-
32-
private val sorts = Object2LongOpenHashMap<KSort>().apply {
33-
defaultReturnValue(NOT_INTERNALIZED)
34-
}
35-
36-
private val decls = Object2LongOpenHashMap<KDecl<*>>().apply {
37-
defaultReturnValue(NOT_INTERNALIZED)
31+
private val isForking = forkingSolverManager != null
32+
33+
// common for parent and child structures
34+
private val expressions: Object2LongOpenHashMap<KExpr<*>>
35+
private val sorts: Object2LongOpenHashMap<KSort>
36+
private val decls: Object2LongOpenHashMap<KDecl<*>>
37+
38+
private val z3Expressions: Long2ObjectOpenHashMap<KExpr<*>>
39+
private val z3Sorts: Long2ObjectOpenHashMap<KSort>
40+
private val z3Decls: Long2ObjectOpenHashMap<KDecl<*>>
41+
private val tmpNativeObjects: LongOpenHashSet
42+
private val converterNativeObjects: LongOpenHashSet
43+
44+
private val uninterpretedSortValueInterpreter: HashMap<KUninterpretedSort, Long>
45+
private val uninterpretedSortValueDecls: Long2ObjectOpenHashMap<KUninterpretedSortValue>
46+
private val uninterpretedSortValueInterpreters: LongOpenHashSet
47+
48+
49+
val uninterpretedValuesTracker: ExpressionUninterpretedValuesTracker
50+
51+
init {
52+
if (forkingSolverManager != null) {
53+
with(forkingSolverManager) {
54+
expressions = findExpressionsCache()
55+
sorts = findSortsCache()
56+
decls = findDeclsCache()
57+
58+
z3Expressions = findExpressionsReversedCache()
59+
z3Sorts = findSortsReversedCache()
60+
z3Decls = findDeclsReversedCache()
61+
tmpNativeObjects = findTmpNativeObjectsCache()
62+
converterNativeObjects = findConverterNativeObjectsCache()
63+
uninterpretedSortValueInterpreter = findUninterpretedSortValueInterpreter()
64+
uninterpretedSortValueDecls = findUninterpretedSortValueDecls()
65+
uninterpretedSortValueInterpreters = findUninterpretedSortValueInterpreters()
66+
}
67+
uninterpretedValuesTracker = ExpressionUninterpretedValuesTracker(ksmtCtx, this, forkingSolverManager)
68+
} else {
69+
expressions = Object2LongOpenHashMap<KExpr<*>>().apply { defaultReturnValue(NOT_INTERNALIZED) }
70+
sorts = Object2LongOpenHashMap<KSort>().apply { defaultReturnValue(NOT_INTERNALIZED) }
71+
decls = Object2LongOpenHashMap<KDecl<*>>().apply { defaultReturnValue(NOT_INTERNALIZED) }
72+
73+
z3Expressions = Long2ObjectOpenHashMap<KExpr<*>>()
74+
z3Sorts = Long2ObjectOpenHashMap<KSort>()
75+
z3Decls = Long2ObjectOpenHashMap<KDecl<*>>()
76+
tmpNativeObjects = LongOpenHashSet()
77+
converterNativeObjects = LongOpenHashSet()
78+
79+
uninterpretedSortValueInterpreter = hashMapOf()
80+
uninterpretedSortValueDecls = Long2ObjectOpenHashMap()
81+
uninterpretedSortValueInterpreters = LongOpenHashSet()
82+
83+
uninterpretedValuesTracker = ExpressionUninterpretedValuesTracker(ksmtCtx, this)
84+
}
3885
}
3986

40-
private val z3Expressions = Long2ObjectOpenHashMap<KExpr<*>>()
41-
private val z3Sorts = Long2ObjectOpenHashMap<KSort>()
42-
private val z3Decls = Long2ObjectOpenHashMap<KDecl<*>>()
43-
private val tmpNativeObjects = LongOpenHashSet()
44-
private val converterNativeObjects = LongOpenHashSet()
45-
46-
val uninterpretedValuesTracker = ExpressionUninterpretedValuesTracker(ksmtCtx, this)
4787

4888
@JvmField
4989
val nCtx: Long = ctx.nCtx()
@@ -54,17 +94,24 @@ class KZ3Context(
5494
val isActive: Boolean
5595
get() = !isClosed
5696

97+
internal fun fork(ksmtCtx: KContext, manager: KZ3ForkingSolverManager): KZ3Context {
98+
require(isForking) { "Can't fork non-forking context" }
99+
return KZ3Context(ksmtCtx, ctx, manager).also {
100+
it.uninterpretedValuesTracker.fork(uninterpretedValuesTracker)
101+
}
102+
}
103+
104+
internal fun findInternalizedExprWithoutAnalysis(expr: KExpr<*>): Long {
105+
val result = expressions.getLong(expr)
106+
return if (result == NOT_INTERNALIZED) NOT_INTERNALIZED else result
107+
}
108+
57109
/**
58110
* Find internalized expr.
59111
* Returns [NOT_INTERNALIZED] if expression was not found.
60112
* */
61-
fun findInternalizedExpr(expr: KExpr<*>): Long {
62-
val result = expressions.getLong(expr)
63-
if (result == NOT_INTERNALIZED) return NOT_INTERNALIZED
64-
65-
uninterpretedValuesTracker.expressionUse(expr)
66-
67-
return result
113+
fun findInternalizedExpr(expr: KExpr<*>): Long = findInternalizedExprWithoutAnalysis(expr).also {
114+
if (it != NOT_INTERNALIZED) uninterpretedValuesTracker.expressionUse(expr)
68115
}
69116

70117
fun saveInternalizedExpr(expr: KExpr<*>, internalized: Long) {
@@ -148,11 +195,6 @@ class KZ3Context(
148195
return ast
149196
}
150197

151-
private val uninterpretedSortValueInterpreter = hashMapOf<KUninterpretedSort, Long>()
152-
153-
private val uninterpretedSortValueDecls = Long2ObjectOpenHashMap<KUninterpretedSortValue>()
154-
private val uninterpretedSortValueInterpreters = LongOpenHashSet()
155-
156198
fun saveUninterpretedSortValueDecl(decl: Long, value: KUninterpretedSortValue): Long {
157199
if (uninterpretedSortValueDecls.putIfAbsent(decl, value) == null) {
158200
incRefUnsafe(nCtx, decl)
@@ -264,37 +306,38 @@ class KZ3Context(
264306
if (isClosed) return
265307
isClosed = true
266308

309+
if (isForking) return
310+
267311
uninterpretedSortValueInterpreter.clear()
268312

269-
uninterpretedSortValueDecls.keys.decRefAll()
313+
uninterpretedSortValueDecls.keys.decRefUnsafeAll(nCtx)
270314
uninterpretedSortValueDecls.clear()
271315

272-
uninterpretedSortValueInterpreters.decRefAll()
316+
uninterpretedSortValueInterpreters.decRefUnsafeAll(nCtx)
273317
uninterpretedSortValueInterpreters.clear()
274318

275-
converterNativeObjects.decRefAll()
319+
converterNativeObjects.decRefUnsafeAll(nCtx)
276320
converterNativeObjects.clear()
277321

278-
z3Expressions.keys.decRefAll()
322+
z3Expressions.keys.decRefUnsafeAll(nCtx)
279323
expressions.clear()
280324
z3Expressions.clear()
281325

282-
tmpNativeObjects.decRefAll()
326+
tmpNativeObjects.decRefUnsafeAll(nCtx)
283327
tmpNativeObjects.clear()
284328

285-
z3Decls.keys.decRefAll()
329+
z3Decls.keys.decRefUnsafeAll(nCtx)
286330
decls.clear()
287331
z3Decls.clear()
288332

289-
z3Sorts.keys.decRefAll()
333+
z3Sorts.keys.decRefUnsafeAll(nCtx)
290334
sorts.clear()
291335
z3Sorts.clear()
292336

293-
ctx.close()
294-
}
295-
296-
private fun LongSet.decRefAll() =
297-
longIterator().forEachRemaining {
298-
decRefUnsafe(nCtx, it)
337+
try {
338+
ctx.close()
339+
} catch (e: Z3Exception) {
340+
throw KSolverException(e)
299341
}
342+
}
300343
}

0 commit comments

Comments
 (0)