Skip to content

Commit b188629

Browse files
Refactor BaselineGraphKernel to NoGraphBaselineGraphKernel (#9)
1 parent 585f297 commit b188629

File tree

4 files changed

+127
-7
lines changed

4 files changed

+127
-7
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Simon Schoelly <sischoel@gmail.com> and contributors"]
44
version = "0.1.0"
55

66
[deps]
7+
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
78
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
89
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -14,6 +15,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1415
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
1516

1617
[compat]
18+
DataStructures = "0.18.9"
1719
LIBSVM = "0.6"
1820
LightGraphs = "1.3"
1921
SimpleValueGraphs = "0.3"

src/GraphKernels.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import LIBSVM: svmtrain, svmpredict
1414
export
1515
AbstractGraphKernel,
1616

17-
BaselineGraphKernel,
17+
NoGraphBaselineGraphKernel,
1818
ShortestPathGraphKernel,
1919
PyramidMatchGraphKernel,
2020
WeisfeilerLehmanGraphKernel,
@@ -34,6 +34,7 @@ export
3434
svmtrain,
3535
svmpredict
3636

37+
include("utils.jl")
3738
include("replacedvertexvals.jl")
3839
include("vertex_kernels.jl")
3940
include("graph-kernels/abstract-graph-kernel.jl")
Lines changed: 115 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,119 @@
1-
#
2-
# very similar to the No-Graph Baseline Kernel from https://mlai.cs.uni-bonn.de/publications/schulz2019gem.pdf
3-
# but even simpler in that we do not consider any metadata
4-
struct BaselineGraphKernel <: AbstractGraphKernel end
51

6-
function apply_preprocessed(::BaselineGraphKernel, g1, g2)
2+
using DataStructures: counter, inc!
73

8-
return exp(-( (ne(g1) - ne(g2))^2 + (Float64(nv(g1)) - Float64(nv(g2)))^2) )
4+
const LabelsType = Union{Auto, Colon, Tuple{Vararg{Int}}, Tuple{Vararg{Symbol}}}
5+
6+
"""
7+
NoGraphBaselineGraphKernel(;vertex_labels=Auto(), edge_labels=Auto(), gamma=0.5)
8+
9+
A graph kernel that does not use structural information of the graph. Instead vertices and
10+
edges are considered as sets and features are created by splitting these sets into classes
11+
according to their labels. Then the kernel function is calculated by applying an RBF kernel
12+
on these features.
13+
14+
This kernel provides a good baseline, as to verify if a graph kernel, that uses structural
15+
information, is meaningful on some data.
16+
17+
# Keywords
18+
- `vertex_labels`: Which vertex labels to consider. Either `Auto`, `:` (all), or a tuple of vertex value keys.
19+
- `edge_labels`: Which edge labels to consider. Either `Auto`, `:` (all), or a tuple of edge value keys.
20+
- `gamma`: Scaling factor for the RBF kernel.
21+
22+
# Examples
23+
```jldoctest
24+
julia> k1 = BaselineGraphKernel(gamma=2.0); # custom gamma value
25+
26+
# no vertex and edge labels. In that case, only the number of vertices and edges are considered.
27+
julia> k2 = BaselineGraphKernel(vertex_labels=(), edge_labels=());
28+
29+
# all vertex labels, and edge labels for the keys :a and :b are considered.
30+
julia> k3 = BaselineGraphKernel(vertex_labels=:, edge_labels=(:a, :b));
31+
32+
# Suitable vertex and edge label keys are inferred from the types of these labels.
33+
julia> k4 = BaselineGraphKernel(vertex_labels=Auto(), edge_labels=Auto()));
34+
```
35+
36+
# References
37+
<https://mlai.cs.uni-bonn.de/publications/schulz2019gem.pdf>
38+
"""
39+
struct NoGraphBaselineGraphKernel{VL <: LabelsType, EL <: LabelsType}<: AbstractGraphKernel
40+
41+
vertex_labels::VL
42+
edge_labels::EL
43+
gamma::Float64
44+
45+
function NoGraphBaselineGraphKernel(;vertex_labels=Auto(), edge_labels=Auto(), gamma=0.5)
46+
47+
(gamma > 0 && isfinite(gamma)) || throw(DomainError(gamma, "gamma must be 0 < gamma < ∞"))
48+
49+
return new{typeof(vertex_labels), typeof(edge_labels)}(vertex_labels, edge_labels, Float64(gamma))
50+
end
51+
end
52+
53+
function preprocessed_form(kernel::NoGraphBaselineGraphKernel, g::AbstractValGraph)
54+
55+
vertex_labels = _labels(kernel.vertex_labels, vertexvals_type(g))
56+
edge_labels = _labels(kernel.edge_labels, edgevals_type(g))
57+
58+
# TODO, we should deduce the correct type instead of using Any
59+
vertex_class_sizes = counter(Any)
60+
for v in vertices(g)
61+
class_key = tuple((get_vertexval(g, v, i) for i in vertex_labels)...)
62+
inc!(vertex_class_sizes, class_key)
63+
end
64+
65+
edge_class_sizes = counter(Any)
66+
for e in edges(g, :)
67+
class_key = tuple((get_edgeval(e, i) for i in edge_labels)...)
68+
inc!(edge_class_sizes, class_key)
69+
end
70+
71+
return (vertex_class_sizes=vertex_class_sizes, edge_class_sizes=edge_class_sizes)
72+
end
73+
74+
_labels(::Colon, types) = fieldnames(types)
75+
76+
function _is_suitable_label_type(T)
77+
78+
return T <: Union{Integer, AbstractString, AbstractChar, Symbol}
79+
end
80+
81+
function _labels(::Auto, types)
82+
83+
return filter(i -> _is_suitable_label_type(fieldtype(types, i)), fieldnames(types))
84+
end
85+
86+
# TODO we could actually verify that the keys are valid
87+
_labels(keys::Tuple, types) = keys
88+
89+
function apply_preprocessed(kernel::NoGraphBaselineGraphKernel, sizes1, sizes2)
90+
91+
squared_sum = 0
92+
# Some labels might occur in only one of the graphs. Therefore we need to loop
93+
# over the labels of both graphs and ensure that the squared difference of the count
94+
# of some label is used exactly once.
95+
for (class1, count1) in sizes1.vertex_class_sizes
96+
count2 = sizes2.vertex_class_sizes[class1]
97+
squared_sum += (count1 - count2)^2
98+
end
99+
for (class2, count2) in sizes2.vertex_class_sizes
100+
count1 = sizes1.vertex_class_sizes[class2]
101+
if count1 == 0
102+
squared_sum += count2^2
103+
end
104+
end
105+
106+
for (class1, count1) in sizes1.edge_class_sizes
107+
count2 = sizes2.edge_class_sizes[class1]
108+
squared_sum += (count1 - count2)^2
109+
end
110+
for (class2, count2) in sizes2.edge_class_sizes
111+
count1 = sizes1.edge_class_sizes[class2]
112+
if count1 == 0
113+
squared_sum += count2^2
114+
end
115+
end
116+
117+
return exp(-kernel.gamma * squared_sum)
9118
end
10119

src/utils.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
"""
3+
Auto
4+
5+
A singleton type to denote that some argument to a function or constructor should
6+
be automatically determined.
7+
"""
8+
struct Auto end

0 commit comments

Comments
 (0)