Skip to content

Commit 6cab505

Browse files
Add vertex_labels keyword to PM kernel (#12)
1 parent f1d33af commit 6cab505

File tree

1 file changed

+118
-67
lines changed

1 file changed

+118
-67
lines changed

src/graph-kernels/pyramid-match-graph-kernel.jl

Lines changed: 118 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,152 @@
11

22
"""
33
PyramidMatchGraphKernel <: AbstractGraphKernel
4+
5+
A graph kernel that compares vertex embeddings of graph using pyramid match.
6+
7+
For each graph, the vertices are mapped into a hypercube of dimension `embedding_dim`.
8+
The hypercube is the split into buckets of certain granularity, controlled by `histogram_levels`,
9+
and the number of vertices in a bucket for both graphs are compared.
10+
If vertex labels are provided, only vertices with the same label in a bucket are compared.
11+
12+
# Keywords
13+
- `embedding_dim=6`: The number of dimensions of the space where the vertices are embedded.
14+
This is an upper bound - in some cases only an embedding of lower dimension is created.
15+
- `histogram_levels=4`: The number of levels for which histograms are compared. For
16+
each level `l ∈ 0:histogram_levels`, the embeddings are split into 2^l buckets and
17+
the number of vertices in that bucket are compared separately.
18+
- `vertex_labels`: By which vertex labels the vertices are grouped together. The embedding
19+
is created on all vertices, but then the histograms are separately compared for each
20+
group of vertex labels. Finally the results of each group are added together.
21+
Either `Auto`, `:` (all), or a tuple of vertex value keys.
22+
23+
# References
24+
<https://ojs.aaai.org/index.php/AAAI/article/view/10839>
425
"""
5-
struct PyramidMatchGraphKernel <: AbstractGraphKernel
26+
struct PyramidMatchGraphKernel{VL <: LabelsType} <: AbstractGraphKernel
627

728
embedding_dim::Int
829
histogram_levels::Int
30+
vertex_labels::VL
931

10-
function PyramidMatchGraphKernel(;embedding_dim::Integer=6, histogram_levels::Integer=4)
32+
function PyramidMatchGraphKernel(;embedding_dim::Integer=6, histogram_levels::Integer=4, vertex_labels=Auto())
1133

1234
embedding_dim >= 1 || throw(DomainError(embedding_dim, "embedding_dim must be >= 1"))
1335
histogram_levels >= 0 || throw(DomainError(embedding_dim, "histogram_levels must be >= 0"))
1436

15-
return new(embedding_dim, histogram_levels)
37+
return new{typeof(vertex_labels)}(Int(embedding_dim), Int(histogram_levels), vertex_labels)
1638
end
1739
end
1840

1941
function preprocessed_form(kernel::PyramidMatchGraphKernel, g::AbstractGraph)
2042

21-
d = kernel.embedding_dim
22-
L = kernel.histogram_levels
43+
embedding = _embedding(g, kernel.embedding_dim)
44+
d = size(embedding, 2) # d might sometimes be smaller than embedding_dim
2345

24-
# TODO consider using a 3 dimensional SparseArray instead
25-
# or just calculate the whole histogram on the fly
46+
vertex_labels = _labels(kernel.vertex_labels, vertexvals_type(g))
47+
48+
# TODO, we should deduce the correct type instead of using Any
49+
vertex_class_sizes = counter(Any)
50+
for v in vertices(g)
51+
class_key = tuple((get_vertexval(g, v, i) for i in vertex_labels)...)
52+
inc!(vertex_class_sizes, class_key)
53+
end
2654

27-
# nv(g) x embedding_dim matrix
28-
embedding = _embedding(g, d)
55+
# The rows of embedding correspond to vertices. We group vertices together
56+
# by their labels and associate a sub matrix of the corresponding rows
57+
# with each label class
58+
insert_row_index = Dict{Any, Int}()
59+
points_by_vertex_class = Dict{Any, Matrix{Float64}}()
60+
for (k, n_rows) in vertex_class_sizes
61+
insert_row_index[k] = 0
62+
points_by_vertex_class[k] = Matrix{Float64}(undef, n_rows, d)
63+
end
2964

30-
# TODO consider inplace sort!
65+
for (v, row) in enumerate(eachrow(embedding))
66+
k = tuple((get_vertexval(g, v, i) for i in vertex_labels)...)
67+
row_index = (insert_row_index[k] += 1)
68+
points_by_vertex_class[k][row_index, :] = row
69+
end
3170

32-
# the j-th column of this matrix contains the j-th coordinates
33-
# of the embedding vectors sorted in ascending order
34-
ordered_points = sort(embedding; dims=1)
71+
for points in values(points_by_vertex_class)
72+
# the j-th column of this matrix contains the j-th coordinates
73+
# of the embedding vectors sorted in ascending order
74+
sort!(points; dims=1)
75+
end
3576

36-
return ordered_points
77+
return points_by_vertex_class
3778
end
3879

39-
function apply_preprocessed(kernel::PyramidMatchGraphKernel, points1, points2)
80+
function apply_preprocessed(kernel::PyramidMatchGraphKernel, points_by_class_1, points_by_class_2)
4081

41-
# TODO the embedding dim might actually be smaller for some graph
42-
# we should maybe consider some kind of scaling in such a case
43-
d = min(size(points1, 2), size(points2, 2))
4482
L = kernel.histogram_levels
45-
n1 = size(points1, 1)
46-
n2 = size(points2, 1)
47-
48-
hist_intersect = zeros(Int64, L + 1)
49-
50-
# TODO ensure no int overflow, same result on 32 bit platform
51-
52-
# TODO we don't need to store hist_intersect as a vector
53-
# TODO add @inbounds
54-
hist_intersect[1] = d * min(n1, n2)
55-
for l in 1:L
56-
cell_boundaries = range(0.0, 1.0, length=2^l + 1)
57-
for j in 1:d
58-
i1 = 1
59-
i2 = 1
60-
while i1 <= n1
61-
# TODO is is possible that searchsortedlast is not implemented
62-
# efficiently on a StepRangeLen
63-
cell_num = searchsortedlast(cell_boundaries, points1[i1, j])
64-
# TODO maybe we should verify here, that cell_num is a valid index
65-
cell_lower = cell_boundaries[cell_num]
66-
cell_upper = cell_boundaries[cell_num + 1]
67-
# TODO maybe we need some correction for rounding errors here
68-
69-
i1 += 1
70-
num1 = 1
71-
# the first loop is just for safety we probably don't need it
72-
while i1 <= n1 && points1[i1, j] < cell_lower
73-
i1 += 1
74-
end
75-
# count number of vertices from g1 that fall into the bucket
76-
# specified [cell_lower, cell_upper) along the j-th dimension
77-
# of the hypercube
78-
while i1 <= n1 && points1[i1, j] < cell_upper
79-
num1 += 1
80-
i1 += 1
81-
end
83+
result = 0.0
84+
85+
for (class, points1) in points_by_class_1
86+
87+
points2 = get(points_by_class_2, class, nothing)
88+
points2 == nothing && continue
89+
90+
# TODO the embedding dim might actually be smaller for some graph
91+
# we should maybe consider some kind of scaling in such a case
92+
d = min(size(points1, 2), size(points2, 2))
93+
n1 = size(points1, 1)
94+
n2 = size(points2, 1)
95+
96+
hist_intersect = zeros(Int64, L + 1)
97+
98+
# TODO ensure no int overflow, same result on 32 bit platform
99+
100+
# TODO we don't need to store hist_intersect as a vector
101+
# TODO add @inbounds
102+
hist_intersect[1] = d * min(n1, n2)
103+
for l in 1:L
104+
cell_boundaries = range(0.0, 1.0, length=2^l + 1)
105+
for j in 1:d
106+
i1 = 1
107+
i2 = 1
108+
while i1 <= n1
109+
# TODO is is possible that searchsortedlast is not implemented
110+
# efficiently on a StepRangeLen
111+
cell_num = searchsortedlast(cell_boundaries, points1[i1, j])
112+
# TODO maybe we should verify here, that cell_num is a valid index
113+
cell_lower = cell_boundaries[cell_num]
114+
cell_upper = cell_boundaries[cell_num + 1]
115+
# TODO maybe we need some correction for rounding errors here
82116

83-
# count number of vertices from g2 that fall into that bucket
84-
# We could also use binary search here
85-
while i2 <= n2 && points2[i2, j] < cell_lower
86-
i2 += 1
87-
end
88-
num2 = 0
89-
while i2 <= n2 && points2[i2, j] < cell_upper
90-
num2 += 1
91-
i2 += 1
117+
i1 += 1
118+
num1 = 1
119+
# the first loop is just for safety we probably don't need it
120+
while i1 <= n1 && points1[i1, j] < cell_lower
121+
i1 += 1
122+
end
123+
# count number of vertices from g1 that fall into the bucket
124+
# specified [cell_lower, cell_upper) along the j-th dimension
125+
# of the hypercube
126+
while i1 <= n1 && points1[i1, j] < cell_upper
127+
num1 += 1
128+
i1 += 1
129+
end
130+
131+
# count number of vertices from g2 that fall into that bucket
132+
# We could also use binary search here
133+
while i2 <= n2 && points2[i2, j] < cell_lower
134+
i2 += 1
135+
end
136+
num2 = 0
137+
while i2 <= n2 && points2[i2, j] < cell_upper
138+
num2 += 1
139+
i2 += 1
140+
end
141+
hist_intersect[l + 1] += min(num1, num2)
92142
end
93-
hist_intersect[l + 1] += min(num1, num2)
94143
end
95144
end
96-
end
97145

98-
return hist_intersect[L + 1] + sum(l -> (hist_intersect[l+1] - hist_intersect[l+2]) / 2^(L - l), 0:L-1)
146+
result += (hist_intersect[L + 1] + sum(l -> (hist_intersect[l+1] - hist_intersect[l+2]) / 2^(L - l), 0:L-1))
147+
148+
end
149+
return result
99150
end
100151

101152

0 commit comments

Comments
 (0)