Skip to content

Commit 218f54a

Browse files
Add vertex_labels keyword to WL kernel (#10)
1 parent c758a26 commit 218f54a

File tree

3 files changed

+55
-29
lines changed

3 files changed

+55
-29
lines changed

src/graph-kernels/baseline-graph-kernel.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11

22
using DataStructures: counter, inc!
33

4-
const LabelsType = Union{Auto, Colon, Tuple{Vararg{Int}}, Tuple{Vararg{Symbol}}}
54

65
"""
76
NoGraphBaselineGraphKernel(;vertex_labels=Auto(), edge_labels=Auto(), gamma=0.5)
@@ -71,20 +70,6 @@ function preprocessed_form(kernel::NoGraphBaselineGraphKernel, g::AbstractValGra
7170
return (vertex_class_sizes=vertex_class_sizes, edge_class_sizes=edge_class_sizes)
7271
end
7372

74-
_labels(::Colon, types) = fieldnames(types)
75-
76-
function _is_suitable_label_type(T)
77-
78-
return T <: Union{Integer, AbstractString, AbstractChar, Symbol}
79-
end
80-
81-
function _labels(::Auto, types)
82-
83-
return filter(i -> _is_suitable_label_type(fieldtype(types, i)), fieldnames(types))
84-
end
85-
86-
# TODO we could actually verify that the keys are valid
87-
_labels(keys::Tuple, types) = keys
8873

8974
function apply_preprocessed(kernel::NoGraphBaselineGraphKernel, sizes1, sizes2)
9075

src/graph-kernels/weisfeiler-lehman-graph-kernel.jl

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,52 @@
11

22
"""
3-
WeisfeilerLehmanGraphKernel(;base_kernel, num_iterations=5) <: AbstractGraphKernel
3+
WeisfeilerLehmanGraphKernel(;base_kernel, vertex_labels=Auto(), num_iterations=5) <: AbstractGraphKernel
4+
5+
A graph kernel based on the Weisfeiler-Lehman isomorphism test.
6+
7+
The kernel is run for multiple iterations on a graph, and adjacent vertex values for each vertex
8+
are collected and then hashed to a new vertex value. Each iteration generates a new graph with replaced
9+
vertex values. Afterwards `base_kernel` is run on all the new graphs and the results are summed up
10+
to generate a final value.
11+
12+
# Keywords
13+
- `base_kernel`: The kernel that is applied to the newly generated graphs.
14+
- `vertex_labels`: Which vertex labels to consider for generating the initial hash value of each vertex.
15+
Either `Auto`, `:` (all), or a tuple of vertex value keys.
16+
- `num_iterations`: For how many iterations we generate graphs with new vertex values.
17+
18+
# References
19+
https://www.jmlr.org/papers/volume12/shervashidze11a/shervashidze11a.pdf
420
"""
5-
struct WeisfeilerLehmanGraphKernel{BK<:AbstractGraphKernel} <: AbstractGraphKernel
21+
struct WeisfeilerLehmanGraphKernel{BK<:AbstractGraphKernel, VL <: LabelsType} <: AbstractGraphKernel
622

723
base_kernel::BK
24+
vertex_labels::VL
825
num_iterations::Int
926

10-
function WeisfeilerLehmanGraphKernel(base_kernel::AbstractGraphKernel, num_iterations::Integer)
27+
function WeisfeilerLehmanGraphKernel(;base_kernel::AbstractGraphKernel, vertex_labels=Auto(), num_iterations::Integer=5)
1128

1229
num_iterations >= 1 || throw(DomainError(num_iterations, "num_iterations must be >= 1"))
13-
return new{typeof(base_kernel)}(base_kernel, num_iterations)
30+
31+
return new{typeof(base_kernel), typeof(vertex_labels)}(base_kernel, vertex_labels, num_iterations)
1432
end
1533
end
1634

17-
function WeisfeilerLehmanGraphKernel(;base_kernel::AbstractGraphKernel, num_iterations::Integer=5)
18-
19-
return WeisfeilerLehmanGraphKernel(base_kernel, num_iterations)
20-
end
2135

36+
# TODO it might be useful to run the base_kernel (or a separate kernel) on the
37+
# initial graph without replaced vertex values. One might also consider keeping
38+
# a part of the initial vertex values for each graph.
2239
function preprocessed_form(kernel::WeisfeilerLehmanGraphKernel, g::AbstractGraph)
2340

2441
nvg = nv(g)
2542
num_iterations = kernel.num_iterations
2643

2744
vertexvals = Matrix{UInt}(undef, nvg, num_iterations)
28-
29-
for u in vertices(g)
30-
vertexvals[u, 1] = hash(get_vertexval(g, u, :))
31-
end
45+
_fill_initial_vertexvals!(vertexvals, g, kernel)
3246

3347
sort_buffer = Vector{UInt}(undef, Δ(g))
3448
for i in 2:num_iterations
35-
for u in vertices(g)
49+
@inbounds for u in vertices(g)
3650
deg_u = degree(g, u)
3751
for (j, v) in enumerate(neighbors(g, u))
3852
sort_buffer[j] = vertexvals[u, i-1]
@@ -46,7 +60,16 @@ function preprocessed_form(kernel::WeisfeilerLehmanGraphKernel, g::AbstractGraph
4660
end
4761
end
4862

49-
return [preprocessed_form(kernel.base_kernel, ReplacedVertexVals(g, vertexvals[:, i])) for i in 1:num_iterations]
63+
return [preprocessed_form(kernel.base_kernel, ReplacedVertexVals(g, @inbounds vertexvals[:, i])) for i in 1:num_iterations]
64+
end
65+
66+
function _fill_initial_vertexvals!(vertexvals, g, kernel)
67+
68+
vertex_labels = _labels(kernel.vertex_labels, vertexvals_type(g))
69+
for u in vertices(g)
70+
# TODO we can iteratively calculate the hash instead of creating a tuple
71+
@inbounds vertexvals[u, 1] = hash(tuple((get_vertexval(g, u, i) for i in vertex_labels)...))
72+
end
5073
end
5174

5275
function apply_preprocessed(kernel::WeisfeilerLehmanGraphKernel, pre1, pre2)

src/utils.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,21 @@ A singleton type to denote that some argument to a function or constructor shoul
66
be automatically determined.
77
"""
88
struct Auto end
9+
10+
const LabelsType = Union{Auto, Colon, Tuple{Vararg{Int}}, Tuple{Vararg{Symbol}}}
11+
12+
_labels(::Colon, types) = fieldnames(types)
13+
14+
function _is_suitable_label_type(T)
15+
16+
return T <: Union{Integer, AbstractString, AbstractChar, Symbol}
17+
end
18+
19+
function _labels(::Auto, types)
20+
21+
return filter(i -> _is_suitable_label_type(fieldtype(types, i)), fieldnames(types))
22+
end
23+
24+
# TODO we could actually verify that the keys are valid
25+
_labels(keys::Tuple, types) = keys
26+

0 commit comments

Comments
 (0)