@@ -27,54 +27,86 @@ function preprocessed_form(kernel::PyramidMatchGraphKernel, g::AbstractGraph)
2727 # nv(g) x embedding_dim matrix
2828 embedding = _embedding (g, d)
2929
30- # hist[l, d][i] contains number of points in the i-th bucket
31- # for layer l and dimension d
32- hists = Matrix {SparseVector{Int, Int}} (undef, L + 1 , d)
33- for l in 0 : L
34- for dd in 1 : d
35- counts = zeros (Int, 2 ^ l)
36- for i in 1 : 2 ^ l
37- lo = - 1.0 + (i - 1 ) * 2 / 2 ^ l
38- hi = - 1.0 + i * 2 / 2 ^ l
39-
40- for x in embedding[:, dd]
41- if lo <= x < hi # this currently misses 1.0
42- counts[i] += 1
43- end
30+ # TODO consider inplace sort!
31+
32+ # the j-th column of this matrix contains the j-th coordinates
33+ # of the embedding vectors sorted in ascending order
34+ ordered_points = sort (embedding; dims= 1 )
35+
36+ return ordered_points
37+ end
38+
39+ function apply_preprocessed (kernel:: PyramidMatchGraphKernel , points1, points2)
40+
41+ # TODO the embedding dim might actually be smaller for some graph
42+ # we should maybe consider some kind of scaling in such a case
43+ d = min (size (points1, 2 ), size (points2, 2 ))
44+ L = kernel. histogram_levels
45+ n1 = size (points1, 1 )
46+ n2 = size (points2, 1 )
47+
48+ hist_intersect = zeros (Int64, L + 1 )
49+
50+ # TODO ensure no int overflow, same result on 32 bit platform
51+
52+ # TODO we don't need to store hist_intersect as a vector
53+ # TODO add @inbounds
54+ hist_intersect[1 ] = d * min (n1, n2)
55+ for l in 1 : L
56+ cell_boundaries = range (0.0 , 1.0 , length= 2 ^ l + 1 )
57+ for j in 1 : d
58+ i1 = 1
59+ i2 = 1
60+ while i1 <= n1
61+ # TODO is is possible that searchsortedlast is not implemented
62+ # efficiently on a StepRangeLen
63+ cell_num = searchsortedlast (cell_boundaries, points1[i1, j])
64+ # TODO maybe we should verify here, that cell_num is a valid index
65+ cell_lower = cell_boundaries[cell_num]
66+ cell_upper = cell_boundaries[cell_num + 1 ]
67+ # TODO maybe we need some correction for rounding errors here
68+
69+ i1 += 1
70+ num1 = 1
71+ # the first loop is just for safety we probably don't need it
72+ while i1 <= n1 && points1[i1, j] < cell_lower
73+ i1 += 1
74+ end
75+ # count number of vertices from g1 that fall into the bucket
76+ # specified [cell_lower, cell_upper) along the j-th dimension
77+ # of the hypercube
78+ while i1 <= n1 && points1[i1, j] < cell_upper
79+ num1 += 1
80+ i1 += 1
81+ end
82+
83+ # count number of vertices from g2 that fall into that bucket
84+ # We could also use binary search here
85+ while i2 <= n2 && points2[i2, j] < cell_lower
86+ i2 += 1
4487 end
88+ num2 = 0
89+ while i2 <= n2 && points2[i2, j] < cell_upper
90+ num2 += 1
91+ i2 += 1
92+ end
93+ hist_intersect[l + 1 ] += min (num1, num2)
4594 end
46- hists[l + 1 , dd] = sparsevec (counts)
4795 end
4896 end
49- return hists
50- end
5197
52- function apply_preprocessed (kernel:: PyramidMatchGraphKernel , hists1, hists2)
53-
54- d = kernel. embedding_dim
55- L = kernel. histogram_levels
56- return _I (hists1, hists2, d, L + 1 ) +
57- sum (l -> 1 / 2 ^ (L - l) * (_I (hists1, hists2, d, l + 1 ) -
58- _I (hists1, hists2, d, l + 2 )), 0 : L- 1 )
98+ return hist_intersect[L + 1 ] + sum (l -> (hist_intersect[l+ 1 ] - hist_intersect[l+ 2 ]) / 2 ^ (L - l), 0 : L- 1 )
5999end
60100
61101
62102function _embedding (g:: AbstractGraph , embedding_dim:: Int )
63103
64- n = nv (g)
65- A = zeros (max (embedding_dim, n), max (embedding_dim, n))
66- A[1 : n, 1 : n] = adjacency_matrix (g)
67- # TODO should we scale?
68- embedding = eigvecs (A; sortby= x -> - abs (x))[:, 1 : embedding_dim]
104+ evs = eigvecs (adjacency_matrix (g); sortby= x -> - abs (x))
105+ embedding = abs .(@view evs[:, 1 : min (size (evs, 2 ), embedding_dim)])
69106 # theoretically clamping is not necessary, this is just a precaution
70- # against rounding errors
71- clamp! (embedding, - 1 .0 , 1.0 )
107+ # against rounding errors. Clamping to prevfloat(-1.0) should make some calculations easier
108+ clamp! (embedding, 0 .0 , prevfloat ( 1.0 ) )
72109 return embedding
73110end
74111
75- function _I (hists1, hists2, embedding_dim, level)
76-
77- return sum (dd -> sum (min .(hists1[level, dd], hists2[level, dd])), 1 : embedding_dim)
78- end
79-
80112
0 commit comments