Skip to content

Commit f1d33af

Browse files
Make PyramidMatchingGraphKernel faster (#11)
1 parent 218f54a commit f1d33af

File tree

1 file changed

+68
-36
lines changed

1 file changed

+68
-36
lines changed

src/graph-kernels/pyramid-match-graph-kernel.jl

Lines changed: 68 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -27,54 +27,86 @@ function preprocessed_form(kernel::PyramidMatchGraphKernel, g::AbstractGraph)
2727
# nv(g) x embedding_dim matrix
2828
embedding = _embedding(g, d)
2929

30-
# hist[l, d][i] contains number of points in the i-th bucket
31-
# for layer l and dimension d
32-
hists = Matrix{SparseVector{Int, Int}}(undef, L + 1, d)
33-
for l in 0:L
34-
for dd in 1:d
35-
counts = zeros(Int, 2^l)
36-
for i in 1:2^l
37-
lo = -1.0 + (i - 1) * 2 / 2^l
38-
hi = -1.0 + i * 2 / 2^l
39-
40-
for x in embedding[:, dd]
41-
if lo <= x < hi # this currently misses 1.0
42-
counts[i] += 1
43-
end
30+
# TODO consider inplace sort!
31+
32+
# the j-th column of this matrix contains the j-th coordinates
33+
# of the embedding vectors sorted in ascending order
34+
ordered_points = sort(embedding; dims=1)
35+
36+
return ordered_points
37+
end
38+
39+
function apply_preprocessed(kernel::PyramidMatchGraphKernel, points1, points2)
40+
41+
# TODO the embedding dim might actually be smaller for some graph
42+
# we should maybe consider some kind of scaling in such a case
43+
d = min(size(points1, 2), size(points2, 2))
44+
L = kernel.histogram_levels
45+
n1 = size(points1, 1)
46+
n2 = size(points2, 1)
47+
48+
hist_intersect = zeros(Int64, L + 1)
49+
50+
# TODO ensure no int overflow, same result on 32 bit platform
51+
52+
# TODO we don't need to store hist_intersect as a vector
53+
# TODO add @inbounds
54+
hist_intersect[1] = d * min(n1, n2)
55+
for l in 1:L
56+
cell_boundaries = range(0.0, 1.0, length=2^l + 1)
57+
for j in 1:d
58+
i1 = 1
59+
i2 = 1
60+
while i1 <= n1
61+
# TODO is is possible that searchsortedlast is not implemented
62+
# efficiently on a StepRangeLen
63+
cell_num = searchsortedlast(cell_boundaries, points1[i1, j])
64+
# TODO maybe we should verify here, that cell_num is a valid index
65+
cell_lower = cell_boundaries[cell_num]
66+
cell_upper = cell_boundaries[cell_num + 1]
67+
# TODO maybe we need some correction for rounding errors here
68+
69+
i1 += 1
70+
num1 = 1
71+
# the first loop is just for safety we probably don't need it
72+
while i1 <= n1 && points1[i1, j] < cell_lower
73+
i1 += 1
74+
end
75+
# count number of vertices from g1 that fall into the bucket
76+
# specified [cell_lower, cell_upper) along the j-th dimension
77+
# of the hypercube
78+
while i1 <= n1 && points1[i1, j] < cell_upper
79+
num1 += 1
80+
i1 += 1
81+
end
82+
83+
# count number of vertices from g2 that fall into that bucket
84+
# We could also use binary search here
85+
while i2 <= n2 && points2[i2, j] < cell_lower
86+
i2 += 1
4487
end
88+
num2 = 0
89+
while i2 <= n2 && points2[i2, j] < cell_upper
90+
num2 += 1
91+
i2 += 1
92+
end
93+
hist_intersect[l + 1] += min(num1, num2)
4594
end
46-
hists[l + 1, dd] = sparsevec(counts)
4795
end
4896
end
49-
return hists
50-
end
5197

52-
function apply_preprocessed(kernel::PyramidMatchGraphKernel, hists1, hists2)
53-
54-
d = kernel.embedding_dim
55-
L = kernel.histogram_levels
56-
return _I(hists1, hists2, d, L + 1) +
57-
sum(l -> 1 / 2^(L - l) * (_I(hists1, hists2, d, l + 1) -
58-
_I(hists1, hists2, d, l + 2)), 0:L-1)
98+
return hist_intersect[L + 1] + sum(l -> (hist_intersect[l+1] - hist_intersect[l+2]) / 2^(L - l), 0:L-1)
5999
end
60100

61101

62102
function _embedding(g::AbstractGraph, embedding_dim::Int)
63103

64-
n = nv(g)
65-
A = zeros(max(embedding_dim, n), max(embedding_dim, n))
66-
A[1:n, 1:n] = adjacency_matrix(g)
67-
# TODO should we scale?
68-
embedding = eigvecs(A; sortby=x -> -abs(x))[:, 1:embedding_dim]
104+
evs = eigvecs(adjacency_matrix(g); sortby=x -> -abs(x))
105+
embedding = abs.(@view evs[:, 1:min(size(evs, 2), embedding_dim)])
69106
# theoretically clamping is not necessary, this is just a precaution
70-
# against rounding errors
71-
clamp!(embedding, -1.0, 1.0)
107+
# against rounding errors. Clamping to prevfloat(-1.0) should make some calculations easier
108+
clamp!(embedding, 0.0, prevfloat(1.0))
72109
return embedding
73110
end
74111

75-
function _I(hists1, hists2, embedding_dim, level)
76-
77-
return sum(dd -> sum(min.(hists1[level, dd], hists2[level, dd])), 1:embedding_dim)
78-
end
79-
80112

0 commit comments

Comments
 (0)