Skip to content

Commit 585f297

Browse files
Add Weisfeiler-Lehman graph kernel (#8)
1 parent 352704b commit 585f297

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

src/GraphKernels.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ export
1717
BaselineGraphKernel,
1818
ShortestPathGraphKernel,
1919
PyramidMatchGraphKernel,
20+
WeisfeilerLehmanGraphKernel,
2021

2122
NormalizeGraphKernel,
2223

@@ -33,12 +34,14 @@ export
3334
svmtrain,
3435
svmpredict
3536

37+
include("replacedvertexvals.jl")
3638
include("vertex_kernels.jl")
3739
include("graph-kernels/abstract-graph-kernel.jl")
3840
include("graph-kernels/baseline-graph-kernel.jl")
3941
include("graph-kernels/normalize-graph-kernel.jl")
4042
include("graph-kernels/pyramid-match-graph-kernel.jl")
4143
include("graph-kernels/shortest-path-graph-kernel.jl")
44+
include("graph-kernels/weisfeiler-lehman-graph-kernel.jl")
4245
include("integrations/LIBSVM.jl")
4346

4447
# ================================================================
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
2+
"""
3+
WeisfeilerLehmanGraphKernel(;base_kernel, num_iterations=5) <: AbstractGraphKernel
4+
"""
5+
struct WeisfeilerLehmanGraphKernel{BK<:AbstractGraphKernel} <: AbstractGraphKernel
6+
7+
base_kernel::BK
8+
num_iterations::Int
9+
10+
function WeisfeilerLehmanGraphKernel(base_kernel::AbstractGraphKernel, num_iterations::Integer)
11+
12+
num_iterations >= 1 || throw(DomainError(num_iterations, "num_iterations must be >= 1"))
13+
return new{typeof(base_kernel)}(base_kernel, num_iterations)
14+
end
15+
end
16+
17+
function WeisfeilerLehmanGraphKernel(;base_kernel::AbstractGraphKernel, num_iterations::Integer=5)
18+
19+
return WeisfeilerLehmanGraphKernel(base_kernel, num_iterations)
20+
end
21+
22+
function preprocessed_form(kernel::WeisfeilerLehmanGraphKernel, g::AbstractGraph)
23+
24+
nvg = nv(g)
25+
num_iterations = kernel.num_iterations
26+
27+
vertexvals = Matrix{UInt}(undef, nvg, num_iterations)
28+
29+
for u in vertices(g)
30+
vertexvals[u, 1] = hash(get_vertexval(g, u, :))
31+
end
32+
33+
sort_buffer = Vector{UInt}(undef, Δ(g))
34+
for i in 2:num_iterations
35+
for u in vertices(g)
36+
deg_u = degree(g, u)
37+
for (j, v) in enumerate(neighbors(g, u))
38+
sort_buffer[j] = vertexvals[u, i-1]
39+
end
40+
sort!(@view sort_buffer[1:deg_u])
41+
h = vertexvals[u, i-1]
42+
for j in 1:deg_u
43+
h = hash(sort_buffer[j], h)
44+
end
45+
vertexvals[u, i] = h
46+
end
47+
end
48+
49+
return [preprocessed_form(kernel.base_kernel, ReplacedVertexVals(g, vertexvals[:, i])) for i in 1:num_iterations]
50+
end
51+
52+
function apply_preprocessed(kernel::WeisfeilerLehmanGraphKernel, pre1, pre2)
53+
54+
return sum(t -> apply_preprocessed(kernel.base_kernel, t[1], t[2]), zip(pre1, pre2))
55+
end
56+

src/replacedvertexvals.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
2+
struct ReplacedVertexVals{V, V_VAL, E_VALS, G_VALS, G <: AbstractValGraph} <: AbstractValGraph{V, Tuple{V_VAL}, E_VALS, G_VALS}
3+
4+
graph::G
5+
vertexvals::Vector{V_VAL}
6+
7+
function ReplacedVertexVals(graph::AbstractValGraph, vertexvals::Vector)
8+
9+
# TODO throw proper error
10+
@assert nv(graph) == length(vertexvals)
11+
12+
V = eltype(graph)
13+
V_VAL = eltype(vertexvals)
14+
E_VALS = edgevals_type(graph)
15+
G_VALS = graphvals_type(graph)
16+
G = typeof(graph)
17+
return new{V, V_VAL, E_VALS, G_VALS, G}(graph, vertexvals)
18+
end
19+
end
20+
21+
22+
SimpleValueGraphs.nv(g::ReplacedVertexVals) = nv(g.graph)
23+
SimpleValueGraphs.has_edge(g::ReplacedVertexVals, s::Integer, d::Integer) = has_edge(g.graph, s, d)
24+
SimpleValueGraphs.is_directed(::Type{<:ReplacedVertexVals{V, V_VAL, E_VALS, G_VALS, G}}) where {V, V_VAL, E_VALS, G_VALS, G} = is_directed(G)
25+
26+
SimpleValueGraphs.ne(g::ReplacedVertexVals) = ne(g.graph)
27+
28+
SimpleValueGraphs.get_edgeval(g::ReplacedVertexVals, s::Integer, d::Integer, i::Integer) = get_edgeval(g.graph, s, d, i)
29+
30+
function SimpleValueGraphs.get_vertexval(g::ReplacedVertexVals, v::Integer, i::Integer)
31+
32+
# TODO might verify that this is a correct vertex
33+
# TODO might verify that i is 1
34+
return g.vertexvals[v]
35+
end
36+
37+

0 commit comments

Comments
 (0)