11
22"""
3- WeisfeilerLehmanGraphKernel(;base_kernel, num_iterations=5) <: AbstractGraphKernel
3+ WeisfeilerLehmanGraphKernel(;base_kernel, vertex_labels=Auto(), num_iterations=5) <: AbstractGraphKernel
4+
5+ A graph kernel based on the Weisfeiler-Lehman isomorphism test.
6+
7+ The kernel is run for multiple iterations on a graph, and adjacent vertex values for each vertex
8+ are collected and then hashed to a new vertex value. Each iteration generates a new graph with replaced
9+ vertex values. Afterwards `base_kernel` is run on all the new graphs and the results are summed up
10+ to generate a final value.
11+
12+ # Keywords
13+ - `base_kernel`: The kernel that is applied to the newly generated graphs.
14+ - `vertex_labels`: Which vertex labels to consider for generating the initial hash value of each vertex.
15+ Either `Auto`, `:` (all), or a tuple of vertex value keys.
16+ - `num_iterations`: For how many iterations we generate graphs with new vertex values.
17+
18+ # References
19+ https://www.jmlr.org/papers/volume12/shervashidze11a/shervashidze11a.pdf
420"""
5- struct WeisfeilerLehmanGraphKernel{BK<: AbstractGraphKernel } <: AbstractGraphKernel
21+ struct WeisfeilerLehmanGraphKernel{BK<: AbstractGraphKernel , VL <: LabelsType } <: AbstractGraphKernel
622
723 base_kernel:: BK
24+ vertex_labels:: VL
825 num_iterations:: Int
926
10- function WeisfeilerLehmanGraphKernel (base_kernel:: AbstractGraphKernel , num_iterations:: Integer )
27+ function WeisfeilerLehmanGraphKernel (; base_kernel:: AbstractGraphKernel , vertex_labels = Auto (), num_iterations:: Integer = 5 )
1128
1229 num_iterations >= 1 || throw (DomainError (num_iterations, " num_iterations must be >= 1" ))
13- return new {typeof(base_kernel)} (base_kernel, num_iterations)
30+
31+ return new {typeof(base_kernel), typeof(vertex_labels)} (base_kernel, vertex_labels, num_iterations)
1432 end
1533end
1634
17- function WeisfeilerLehmanGraphKernel (;base_kernel:: AbstractGraphKernel , num_iterations:: Integer = 5 )
18-
19- return WeisfeilerLehmanGraphKernel (base_kernel, num_iterations)
20- end
2135
36+ # TODO it might be useful to run the base_kernel (or a separate kernel) on the
37+ # initial graph without replaced vertex values. One might also consider keeping
38+ # a part of the initial vertex values for each graph.
2239function preprocessed_form (kernel:: WeisfeilerLehmanGraphKernel , g:: AbstractGraph )
2340
2441 nvg = nv (g)
2542 num_iterations = kernel. num_iterations
2643
2744 vertexvals = Matrix {UInt} (undef, nvg, num_iterations)
28-
29- for u in vertices (g)
30- vertexvals[u, 1 ] = hash (get_vertexval (g, u, :))
31- end
45+ _fill_initial_vertexvals! (vertexvals, g, kernel)
3246
3347 sort_buffer = Vector {UInt} (undef, Δ (g))
3448 for i in 2 : num_iterations
35- for u in vertices (g)
49+ @inbounds for u in vertices (g)
3650 deg_u = degree (g, u)
3751 for (j, v) in enumerate (neighbors (g, u))
3852 sort_buffer[j] = vertexvals[u, i- 1 ]
@@ -46,7 +60,16 @@ function preprocessed_form(kernel::WeisfeilerLehmanGraphKernel, g::AbstractGraph
4660 end
4761 end
4862
49- return [preprocessed_form (kernel. base_kernel, ReplacedVertexVals (g, vertexvals[:, i])) for i in 1 : num_iterations]
63+ return [preprocessed_form (kernel. base_kernel, ReplacedVertexVals (g, @inbounds vertexvals[:, i])) for i in 1 : num_iterations]
64+ end
65+
66+ function _fill_initial_vertexvals! (vertexvals, g, kernel)
67+
68+ vertex_labels = _labels (kernel. vertex_labels, vertexvals_type (g))
69+ for u in vertices (g)
70+ # TODO we can iteratively calculate the hash instead of creating a tuple
71+ @inbounds vertexvals[u, 1 ] = hash (tuple ((get_vertexval (g, u, i) for i in vertex_labels). .. ))
72+ end
5073end
5174
5275function apply_preprocessed (kernel:: WeisfeilerLehmanGraphKernel , pre1, pre2)
0 commit comments