Skip to content

Commit 352704b

Browse files
One file per graph kernel (#7)
1 parent 3d4509b commit 352704b

File tree

7 files changed

+366
-315
lines changed

7 files changed

+366
-315
lines changed

src/GraphKernels.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ using ThreadsX
1212
import LIBSVM: svmtrain, svmpredict
1313

1414
export
15+
AbstractGraphKernel,
16+
1517
BaselineGraphKernel,
1618
ShortestPathGraphKernel,
1719
PyramidMatchGraphKernel,
@@ -32,7 +34,11 @@ export
3234
svmpredict
3335

3436
include("vertex_kernels.jl")
35-
include("graph_kernels.jl")
37+
include("graph-kernels/abstract-graph-kernel.jl")
38+
include("graph-kernels/baseline-graph-kernel.jl")
39+
include("graph-kernels/normalize-graph-kernel.jl")
40+
include("graph-kernels/pyramid-match-graph-kernel.jl")
41+
include("graph-kernels/shortest-path-graph-kernel.jl")
3642
include("integrations/LIBSVM.jl")
3743

3844
# ================================================================
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# ================================================================
2+
# AbstractGraphKernel
3+
# ================================================================
4+
5+
"""
6+
abstract type AbstractGraphKernel
7+
8+
A kernel function between two graphs.
9+
10+
Subtypes of `AbstractGraphKernel` should implement `preprocessed_form` and `apply_preprocessed`.
11+
When `(k::AbstractGraphKernel)(g1, g2)` is invoked on two graphs, then
12+
```
13+
apply_preprocessed(k, preprocessed_form(k, g1), preprocessed_form(k, g2))
14+
```
15+
is called to calculate the kernel function. Therefore one should implement `preprocessed_form`
16+
that transforms a single graph into a suitable representation and `apply_preprocessed` that
17+
takes the representations for both graphs and calculates the kernel function.
18+
19+
### See also
20+
[`preprocessed_form`](@ref), [`apply_preprocessed`](@ref), [`kernel_matrix`](@ref), [`kernel_matrix_diag`](@ref)
21+
22+
"""
23+
abstract type AbstractGraphKernel end
24+
25+
"""
26+
preprocessed_form(k::AbstractGraphKernel, g::AbstractGraph) = g
27+
28+
Transform a graph `g` into a suitable form for a graph kernel `k`
29+
30+
When calculating a pairwise kernel matrix for multiple graphs, this preprocessed form
31+
allows us to calculate the transformation only once for each graph, so that we can cache
32+
the result. By default this simply returns `g` without any transformation.
33+
34+
When implementing a custom graph kernel, it might be a good idea to implement this
35+
method.
36+
37+
### See also
38+
[`AbstractGraphKernel`](@ref), [`apply_preprocessed`](@ref)
39+
"""
40+
preprocessed_form(::AbstractGraphKernel, g::AbstractGraph) = g
41+
42+
function (kernel::AbstractGraphKernel)(g1, g2)
43+
44+
return apply_preprocessed(kernel, preprocessed_form(kernel, g1), preprocessed_form(kernel, g2))
45+
end
46+
47+
## ---------------------------------------------------------------
48+
## kernelmatrix & kernelmatrix_diag
49+
## ---------------------------------------------------------------
50+
51+
function _map_preprocessed_form(kernel::AbstractGraphKernel, graphs)
52+
53+
# TODO we should be able to avoid collecting the graphs
54+
# but currently ThreadX cannot split them otherwise,
55+
# maybe we can create some wrapper type that is splitable around graphs
56+
return ThreadsX.map(g -> preprocessed_form(kernel, g), collect(graphs))
57+
end
58+
59+
"""
60+
kernelmatrix(kernel, graphs)
61+
Return a matrix of running the kernel on all pairs of graphs.
62+
63+
### See also
64+
[`kernelmatrix_diag`](@ref)
65+
"""
66+
function kernelmatrix(kernel::AbstractGraphKernel, graphs)
67+
68+
pre = _map_preprocessed_form(kernel, graphs)
69+
70+
# this simply a guard to make the code more type save, maybe we can get
71+
# rid of it at some point
72+
return _kernelmatrix_from_preprocessed(kernel, pre)
73+
end
74+
75+
function _kernelmatrix_from_preprocessed(kernel, pre)
76+
77+
n = length(pre)
78+
79+
# TODO maybe we should make the matrix only symmetric afterwards
80+
# so that we avoid false sharing when using multiple threads
81+
# TODO create some triangle generator instead of allocating a vector
82+
# TODO apparently ThreadsX can do load balancing so we should consider that here
83+
G = Matrix{Float64}(undef, n, n)
84+
indices = [(i, j) for i in 1:n for j in i:n]
85+
Threads.@threads for idx in indices
86+
i, j = idx
87+
@inbounds v = apply_preprocessed(kernel, pre[i], pre[j])
88+
@inbounds G[i, j] = v
89+
@inbounds G[j, i] = v
90+
end
91+
92+
return G
93+
end
94+
95+
"""
96+
kernelmatrix_diag(kernel::AbstractGraphKernel, graphs)
97+
98+
Calculate the diagonal of the kernelmatrix matrix of the graphs.
99+
100+
### See also
101+
[`kernelmatrix`](@ref)
102+
"""
103+
function kernelmatrix_diag(kernel::AbstractGraphKernel, graphs)
104+
105+
n = length(graphs)
106+
pre = _map_preprocessed_form(kernel, graphs)
107+
108+
D = Vector{Float64}(undef, n)
109+
Threads.@threads for i in 1:n
110+
@inbounds D[i] = apply_preprocessed(kernel, pre[i], pre[i])
111+
end
112+
return D
113+
end
114+
115+
"""
116+
kernelmatrix(kernel::AbstractGraphKernel, graphs1, graphs2)
117+
118+
Calculate a matrix of invoking the kernel on all pairs.
119+
Entry `(i, j)` of the resulting matrix contains `kernel(graphs1[i], graphs2[j]`.
120+
"""
121+
function kernelmatrix(kernel::AbstractGraphKernel, graphs1, graphs2)
122+
123+
n_rows = length(graphs1)
124+
n_cols = length(graphs2)
125+
126+
M = Matrix{Float64}(undef, n_rows, n_cols)
127+
128+
pre1 = _map_preprocessed_form(kernel, graphs1)
129+
pre2 = _map_preprocessed_form(kernel, graphs2)
130+
131+
Threads.@threads for i in 1:n_rows
132+
for j in 1:n_cols
133+
@inbounds M[i, j] = apply_preprocessed(kernel, pre1[i], pre2[j])
134+
end
135+
end
136+
137+
return M
138+
end
139+
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#
2+
# very similar to the No-Graph Baseline Kernel from https://mlai.cs.uni-bonn.de/publications/schulz2019gem.pdf
3+
# but even simpler in that we do not consider any metadata
4+
struct BaselineGraphKernel <: AbstractGraphKernel end
5+
6+
function apply_preprocessed(::BaselineGraphKernel, g1, g2)
7+
8+
return exp(-( (ne(g1) - ne(g2))^2 + (Float64(nv(g1)) - Float64(nv(g2)))^2) )
9+
end
10+
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
2+
3+
"""
4+
NormalizeGraphKernel(k::AbstractGraphKernel) <: AbstractGraphKernel
5+
6+
Graph kernel `k̃` that wraps around a kernel `k` and scales the output such
7+
`k̃(g1, g2) = k(g1, g2) / sqrt(k(g1, g1), k(g2, g2)).
8+
"""
9+
struct NormalizeGraphKernel{IK<:AbstractGraphKernel} <: AbstractGraphKernel
10+
11+
inner_kernel::IK
12+
end
13+
14+
function preprocessed_form(kernel::NormalizeGraphKernel, g::AbstractGraph)
15+
16+
inner = kernel.inner_kernel
17+
pre_inner = preprocessed_form(inner, g)
18+
k_ii = apply_preprocessed(inner, pre_inner, pre_inner)
19+
return (k_ii, pre_inner)
20+
end
21+
22+
function apply_preprocessed(kernel::NormalizeGraphKernel, pre1, pre2)
23+
24+
inner = kernel.inner_kernel
25+
k_11, pre_inner1 = pre1
26+
k_22, pre_inner2 = pre2
27+
28+
k_12 = apply_preprocessed(inner, pre_inner1, pre_inner2)
29+
30+
return k_12 / sqrt(k_11 * k_22)
31+
end
32+
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
2+
"""
3+
PyramidMatchGraphKernel <: AbstractGraphKernel
4+
"""
5+
struct PyramidMatchGraphKernel <: AbstractGraphKernel
6+
7+
embedding_dim::Int
8+
histogram_levels::Int
9+
10+
function PyramidMatchGraphKernel(;embedding_dim::Integer=6, histogram_levels::Integer=4)
11+
12+
embedding_dim >= 1 || throw(DomainError(embedding_dim, "embedding_dim must be >= 1"))
13+
histogram_levels >= 0 || throw(DomainError(embedding_dim, "histogram_levels must be >= 0"))
14+
15+
return new(embedding_dim, histogram_levels)
16+
end
17+
end
18+
19+
function preprocessed_form(kernel::PyramidMatchGraphKernel, g::AbstractGraph)
20+
21+
d = kernel.embedding_dim
22+
L = kernel.histogram_levels
23+
24+
# TODO consider using a 3 dimensional SparseArray instead
25+
# or just calculate the whole histogram on the fly
26+
27+
# nv(g) x embedding_dim matrix
28+
embedding = _embedding(g, d)
29+
30+
# hist[l, d][i] contains number of points in the i-th bucket
31+
# for layer l and dimension d
32+
hists = Matrix{SparseVector{Int, Int}}(undef, L + 1, d)
33+
for l in 0:L
34+
for dd in 1:d
35+
counts = zeros(Int, 2^l)
36+
for i in 1:2^l
37+
lo = -1.0 + (i - 1) * 2 / 2^l
38+
hi = -1.0 + i * 2 / 2^l
39+
40+
for x in embedding[:, dd]
41+
if lo <= x < hi # this currently misses 1.0
42+
counts[i] += 1
43+
end
44+
end
45+
end
46+
hists[l + 1, dd] = sparsevec(counts)
47+
end
48+
end
49+
return hists
50+
end
51+
52+
function apply_preprocessed(kernel::PyramidMatchGraphKernel, hists1, hists2)
53+
54+
d = kernel.embedding_dim
55+
L = kernel.histogram_levels
56+
return _I(hists1, hists2, d, L + 1) +
57+
sum(l -> 1 / 2^(L - l) * (_I(hists1, hists2, d, l + 1) -
58+
_I(hists1, hists2, d, l + 2)), 0:L-1)
59+
end
60+
61+
62+
function _embedding(g::AbstractGraph, embedding_dim::Int)
63+
64+
n = nv(g)
65+
A = zeros(max(embedding_dim, n), max(embedding_dim, n))
66+
A[1:n, 1:n] = adjacency_matrix(g)
67+
# TODO should we scale?
68+
embedding = eigvecs(A; sortby=x -> -abs(x))[:, 1:embedding_dim]
69+
# theoretically clamping is not necessary, this is just a precaution
70+
# against rounding errors
71+
clamp!(embedding, -1.0, 1.0)
72+
return embedding
73+
end
74+
75+
function _I(hists1, hists2, embedding_dim, level)
76+
77+
return sum(dd -> sum(min.(hists1[level, dd], hists2[level, dd])), 1:embedding_dim)
78+
end
79+
80+
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
2+
using LinearAlgebra: eigvecs
3+
using SparseArrays: SparseVector, sparsevec
4+
5+
6+
"""
7+
ShortestPathGraphKernel <: AbstractGraphKernel
8+
9+
A graph kernel that compares two graphs `g` and `g'` by comparing all pairs
10+
of vertices `(u, v)` of the `g` and `(u', v')` of `g'` if their shortest distance
11+
is smaller than `tol`. In that case, the vertices `u` and `u'`, as well as `v`, `v'` are
12+
compared with `vertex_kernel`.
13+
14+
# Keywords
15+
- `tol=0.0`: Only pairs of vertices where the shortest distance is at most `tol` are
16+
compared.
17+
- `dist_key=:`: The key for the edge values to compute the shortest distance with. Can be either
18+
an `Integer` or a `Symbol` for a key to a specific edge value, `nothing` to use a default distance
19+
of `1` for each edge, or `:` in which case the default edge weight for that graph type
20+
is used.
21+
- `vertex_kernel=ConstVertexKernel(1.0)`: The kernel used to compare two vertices.
22+
23+
# References
24+
[Borgwardt, K. M., & Kriegel, H. P.: Shortest-path kernels on graphs](https://www.dbs.ifi.lmu.de/~borgward/papers/BorKri05.pdf)
25+
"""
26+
struct ShortestPathGraphKernel{VK <: AbstractVertexKernel} <: AbstractGraphKernel
27+
28+
tol::Float64
29+
vertex_kernel::VK
30+
dist_key::Union{Int, Symbol, Colon, Nothing}
31+
end
32+
33+
function ShortestPathGraphKernel(;tol=0.0, vertex_kernel=ConstVertexKernel(1.0), dist_key=Colon())
34+
35+
return ShortestPathGraphKernel(tol, vertex_kernel, dist_key)
36+
end
37+
38+
function preprocessed_form(kernel::ShortestPathGraphKernel, g::AbstractGraph)
39+
40+
dists = _make_dists(g, kernel.dist_key)
41+
42+
ds = map(t -> t.dist, dists)
43+
us = map(t -> t.u, dists)
44+
vs = map(t -> t.v, dists)
45+
46+
return (g=g, ds=ds, us=us, vs=vs)
47+
end
48+
49+
function apply_preprocessed(kernel::ShortestPathGraphKernel, pre1, pre2)
50+
51+
g1, ds1, us1, vs1 = pre1
52+
g2, ds2, us2, vs2 = pre2
53+
54+
# TODO there might be some issues with unsigned types here
55+
ε = kernel.tol
56+
vertex_kernel = kernel.vertex_kernel
57+
58+
result = 0.0
59+
60+
len1 = length(ds1)
61+
len2 = length(ds2)
62+
63+
i2 = 1
64+
@inbounds for i1 in Base.OneTo(length(ds1))
65+
d1 = ds1[i1]
66+
while i2 <= len2 && d1 > (ds2[i2] + ε)
67+
i2 += 1
68+
end
69+
j2 = i2
70+
while j2 <= len2 && ds2[j2] <= (d1 + ε)
71+
result += vertex_kernel(g1, us1[i1], g2, us2[j2])
72+
result += vertex_kernel(g1, vs1[i1], g2, vs2[j2])
73+
j2 += 1
74+
end
75+
end
76+
77+
return result
78+
79+
end
80+
81+
function _make_dists(g, dist_key)
82+
83+
dists = if dist_key === Colon()
84+
floyd_warshall_shortest_paths(g).dists
85+
elseif dist_key === nothing
86+
floyd_warshall_shortest_paths(g, LightGraphs.DefaultDistance(nv(g))).dists
87+
else
88+
floyd_warshall_shortest_paths(g, weights(g, dist_key)).dists
89+
end
90+
91+
verts = vertices(g)
92+
tm = typemax(eltype(dists))
93+
dists_list = [(dist=dists[u, v], u=u, v=v) for u in verts for v in verts if u != v && dists[u, v] != tm]
94+
sort!(dists_list, by=t->t.dist)
95+
return dists_list
96+
end
97+
98+

0 commit comments

Comments
 (0)