|
1 | | -# |
2 | | -# very similar to the No-Graph Baseline Kernel from https://mlai.cs.uni-bonn.de/publications/schulz2019gem.pdf |
3 | | -# but even simpler in that we do not consider any metadata |
4 | | -struct BaselineGraphKernel <: AbstractGraphKernel end |
5 | 1 |
|
6 | | -function apply_preprocessed(::BaselineGraphKernel, g1, g2) |
| 2 | +using DataStructures: counter, inc! |
7 | 3 |
|
8 | | - return exp(-( (ne(g1) - ne(g2))^2 + (Float64(nv(g1)) - Float64(nv(g2)))^2) ) |
| 4 | +const LabelsType = Union{Auto, Colon, Tuple{Vararg{Int}}, Tuple{Vararg{Symbol}}} |
| 5 | + |
| 6 | +""" |
| 7 | + NoGraphBaselineGraphKernel(;vertex_labels=Auto(), edge_labels=Auto(), gamma=0.5) |
| 8 | +
|
| 9 | +A graph kernel that does not use structural information of the graph. Instead vertices and |
| 10 | +edges are considered as sets and features are created by splitting these sets into classes |
| 11 | +according to their labels. Then the kernel function is calculated by applying an RBF kernel |
| 12 | +on these features. |
| 13 | +
|
| 14 | +This kernel provides a good baseline, as to verify if a graph kernel, that uses structural |
| 15 | +information, is meaningful on some data. |
| 16 | +
|
| 17 | +# Keywords |
| 18 | +- `vertex_labels`: Which vertex labels to consider. Either `Auto`, `:` (all), or a tuple of vertex value keys. |
| 19 | +- `edge_labels`: Which edge labels to consider. Either `Auto`, `:` (all), or a tuple of edge value keys. |
| 20 | +- `gamma`: Scaling factor for the RBF kernel. |
| 21 | +
|
| 22 | +# Examples |
| 23 | +```jldoctest |
| 24 | +julia> k1 = BaselineGraphKernel(gamma=2.0); # custom gamma value |
| 25 | +
|
| 26 | +# no vertex and edge labels. In that case, only the number of vertices and edges are considered. |
| 27 | +julia> k2 = BaselineGraphKernel(vertex_labels=(), edge_labels=()); |
| 28 | +
|
| 29 | +# all vertex labels, and edge labels for the keys :a and :b are considered. |
| 30 | +julia> k3 = BaselineGraphKernel(vertex_labels=:, edge_labels=(:a, :b)); |
| 31 | +
|
| 32 | +# Suitable vertex and edge label keys are inferred from the types of these labels. |
| 33 | +julia> k4 = BaselineGraphKernel(vertex_labels=Auto(), edge_labels=Auto())); |
| 34 | +``` |
| 35 | +
|
| 36 | +# References |
| 37 | +<https://mlai.cs.uni-bonn.de/publications/schulz2019gem.pdf> |
| 38 | +""" |
| 39 | +struct NoGraphBaselineGraphKernel{VL <: LabelsType, EL <: LabelsType}<: AbstractGraphKernel |
| 40 | + |
| 41 | + vertex_labels::VL |
| 42 | + edge_labels::EL |
| 43 | + gamma::Float64 |
| 44 | + |
| 45 | + function NoGraphBaselineGraphKernel(;vertex_labels=Auto(), edge_labels=Auto(), gamma=0.5) |
| 46 | + |
| 47 | + (gamma > 0 && isfinite(gamma)) || throw(DomainError(gamma, "gamma must be 0 < gamma < ∞")) |
| 48 | + |
| 49 | + return new{typeof(vertex_labels), typeof(edge_labels)}(vertex_labels, edge_labels, Float64(gamma)) |
| 50 | + end |
| 51 | +end |
| 52 | + |
| 53 | +function preprocessed_form(kernel::NoGraphBaselineGraphKernel, g::AbstractValGraph) |
| 54 | + |
| 55 | + vertex_labels = _labels(kernel.vertex_labels, vertexvals_type(g)) |
| 56 | + edge_labels = _labels(kernel.edge_labels, edgevals_type(g)) |
| 57 | + |
| 58 | + # TODO, we should deduce the correct type instead of using Any |
| 59 | + vertex_class_sizes = counter(Any) |
| 60 | + for v in vertices(g) |
| 61 | + class_key = tuple((get_vertexval(g, v, i) for i in vertex_labels)...) |
| 62 | + inc!(vertex_class_sizes, class_key) |
| 63 | + end |
| 64 | + |
| 65 | + edge_class_sizes = counter(Any) |
| 66 | + for e in edges(g, :) |
| 67 | + class_key = tuple((get_edgeval(e, i) for i in edge_labels)...) |
| 68 | + inc!(edge_class_sizes, class_key) |
| 69 | + end |
| 70 | + |
| 71 | + return (vertex_class_sizes=vertex_class_sizes, edge_class_sizes=edge_class_sizes) |
| 72 | +end |
| 73 | + |
| 74 | +_labels(::Colon, types) = fieldnames(types) |
| 75 | + |
| 76 | +function _is_suitable_label_type(T) |
| 77 | + |
| 78 | + return T <: Union{Integer, AbstractString, AbstractChar, Symbol} |
| 79 | +end |
| 80 | + |
| 81 | +function _labels(::Auto, types) |
| 82 | + |
| 83 | + return filter(i -> _is_suitable_label_type(fieldtype(types, i)), fieldnames(types)) |
| 84 | +end |
| 85 | + |
| 86 | +# TODO we could actually verify that the keys are valid |
| 87 | +_labels(keys::Tuple, types) = keys |
| 88 | + |
| 89 | +function apply_preprocessed(kernel::NoGraphBaselineGraphKernel, sizes1, sizes2) |
| 90 | + |
| 91 | + squared_sum = 0 |
| 92 | + # Some labels might occur in only one of the graphs. Therefore we need to loop |
| 93 | + # over the labels of both graphs and ensure that the squared difference of the count |
| 94 | + # of some label is used exactly once. |
| 95 | + for (class1, count1) in sizes1.vertex_class_sizes |
| 96 | + count2 = sizes2.vertex_class_sizes[class1] |
| 97 | + squared_sum += (count1 - count2)^2 |
| 98 | + end |
| 99 | + for (class2, count2) in sizes2.vertex_class_sizes |
| 100 | + count1 = sizes1.vertex_class_sizes[class2] |
| 101 | + if count1 == 0 |
| 102 | + squared_sum += count2^2 |
| 103 | + end |
| 104 | + end |
| 105 | + |
| 106 | + for (class1, count1) in sizes1.edge_class_sizes |
| 107 | + count2 = sizes2.edge_class_sizes[class1] |
| 108 | + squared_sum += (count1 - count2)^2 |
| 109 | + end |
| 110 | + for (class2, count2) in sizes2.edge_class_sizes |
| 111 | + count1 = sizes1.edge_class_sizes[class2] |
| 112 | + if count1 == 0 |
| 113 | + squared_sum += count2^2 |
| 114 | + end |
| 115 | + end |
| 116 | + |
| 117 | + return exp(-kernel.gamma * squared_sum) |
9 | 118 | end |
10 | 119 |
|
0 commit comments