|
1 | 1 |
|
2 | 2 | """ |
3 | 3 | PyramidMatchGraphKernel <: AbstractGraphKernel |
| 4 | +
|
| 5 | +A graph kernel that compares vertex embeddings of graph using pyramid match. |
| 6 | +
|
| 7 | +For each graph, the vertices are mapped into a hypercube of dimension `embedding_dim`. |
| 8 | +The hypercube is the split into buckets of certain granularity, controlled by `histogram_levels`, |
| 9 | +and the number of vertices in a bucket for both graphs are compared. |
| 10 | +If vertex labels are provided, only vertices with the same label in a bucket are compared. |
| 11 | +
|
| 12 | +# Keywords |
| 13 | +- `embedding_dim=6`: The number of dimensions of the space where the vertices are embedded. |
| 14 | + This is an upper bound - in some cases only an embedding of lower dimension is created. |
| 15 | +- `histogram_levels=4`: The number of levels for which histograms are compared. For |
| 16 | + each level `l ∈ 0:histogram_levels`, the embeddings are split into 2^l buckets and |
| 17 | + the number of vertices in that bucket are compared separately. |
| 18 | +- `vertex_labels`: By which vertex labels the vertices are grouped together. The embedding |
| 19 | + is created on all vertices, but then the histograms are separately compared for each |
| 20 | + group of vertex labels. Finally the results of each group are added together. |
| 21 | + Either `Auto`, `:` (all), or a tuple of vertex value keys. |
| 22 | +
|
| 23 | +# References |
| 24 | +<https://ojs.aaai.org/index.php/AAAI/article/view/10839> |
4 | 25 | """ |
5 | | -struct PyramidMatchGraphKernel <: AbstractGraphKernel |
| 26 | +struct PyramidMatchGraphKernel{VL <: LabelsType} <: AbstractGraphKernel |
6 | 27 |
|
7 | 28 | embedding_dim::Int |
8 | 29 | histogram_levels::Int |
| 30 | + vertex_labels::VL |
9 | 31 |
|
10 | | - function PyramidMatchGraphKernel(;embedding_dim::Integer=6, histogram_levels::Integer=4) |
| 32 | + function PyramidMatchGraphKernel(;embedding_dim::Integer=6, histogram_levels::Integer=4, vertex_labels=Auto()) |
11 | 33 |
|
12 | 34 | embedding_dim >= 1 || throw(DomainError(embedding_dim, "embedding_dim must be >= 1")) |
13 | 35 | histogram_levels >= 0 || throw(DomainError(embedding_dim, "histogram_levels must be >= 0")) |
14 | 36 |
|
15 | | - return new(embedding_dim, histogram_levels) |
| 37 | + return new{typeof(vertex_labels)}(Int(embedding_dim), Int(histogram_levels), vertex_labels) |
16 | 38 | end |
17 | 39 | end |
18 | 40 |
|
19 | 41 | function preprocessed_form(kernel::PyramidMatchGraphKernel, g::AbstractGraph) |
20 | 42 |
|
21 | | - d = kernel.embedding_dim |
22 | | - L = kernel.histogram_levels |
| 43 | + embedding = _embedding(g, kernel.embedding_dim) |
| 44 | + d = size(embedding, 2) # d might sometimes be smaller than embedding_dim |
23 | 45 |
|
24 | | - # TODO consider using a 3 dimensional SparseArray instead |
25 | | - # or just calculate the whole histogram on the fly |
| 46 | + vertex_labels = _labels(kernel.vertex_labels, vertexvals_type(g)) |
| 47 | + |
| 48 | + # TODO, we should deduce the correct type instead of using Any |
| 49 | + vertex_class_sizes = counter(Any) |
| 50 | + for v in vertices(g) |
| 51 | + class_key = tuple((get_vertexval(g, v, i) for i in vertex_labels)...) |
| 52 | + inc!(vertex_class_sizes, class_key) |
| 53 | + end |
26 | 54 |
|
27 | | - # nv(g) x embedding_dim matrix |
28 | | - embedding = _embedding(g, d) |
| 55 | + # The rows of embedding correspond to vertices. We group vertices together |
| 56 | + # by their labels and associate a sub matrix of the corresponding rows |
| 57 | + # with each label class |
| 58 | + insert_row_index = Dict{Any, Int}() |
| 59 | + points_by_vertex_class = Dict{Any, Matrix{Float64}}() |
| 60 | + for (k, n_rows) in vertex_class_sizes |
| 61 | + insert_row_index[k] = 0 |
| 62 | + points_by_vertex_class[k] = Matrix{Float64}(undef, n_rows, d) |
| 63 | + end |
29 | 64 |
|
30 | | - # TODO consider inplace sort! |
| 65 | + for (v, row) in enumerate(eachrow(embedding)) |
| 66 | + k = tuple((get_vertexval(g, v, i) for i in vertex_labels)...) |
| 67 | + row_index = (insert_row_index[k] += 1) |
| 68 | + points_by_vertex_class[k][row_index, :] = row |
| 69 | + end |
31 | 70 |
|
32 | | - # the j-th column of this matrix contains the j-th coordinates |
33 | | - # of the embedding vectors sorted in ascending order |
34 | | - ordered_points = sort(embedding; dims=1) |
| 71 | + for points in values(points_by_vertex_class) |
| 72 | + # the j-th column of this matrix contains the j-th coordinates |
| 73 | + # of the embedding vectors sorted in ascending order |
| 74 | + sort!(points; dims=1) |
| 75 | + end |
35 | 76 |
|
36 | | - return ordered_points |
| 77 | + return points_by_vertex_class |
37 | 78 | end |
38 | 79 |
|
39 | | -function apply_preprocessed(kernel::PyramidMatchGraphKernel, points1, points2) |
| 80 | +function apply_preprocessed(kernel::PyramidMatchGraphKernel, points_by_class_1, points_by_class_2) |
40 | 81 |
|
41 | | - # TODO the embedding dim might actually be smaller for some graph |
42 | | - # we should maybe consider some kind of scaling in such a case |
43 | | - d = min(size(points1, 2), size(points2, 2)) |
44 | 82 | L = kernel.histogram_levels |
45 | | - n1 = size(points1, 1) |
46 | | - n2 = size(points2, 1) |
47 | | - |
48 | | - hist_intersect = zeros(Int64, L + 1) |
49 | | - |
50 | | - # TODO ensure no int overflow, same result on 32 bit platform |
51 | | - |
52 | | - # TODO we don't need to store hist_intersect as a vector |
53 | | - # TODO add @inbounds |
54 | | - hist_intersect[1] = d * min(n1, n2) |
55 | | - for l in 1:L |
56 | | - cell_boundaries = range(0.0, 1.0, length=2^l + 1) |
57 | | - for j in 1:d |
58 | | - i1 = 1 |
59 | | - i2 = 1 |
60 | | - while i1 <= n1 |
61 | | - # TODO is is possible that searchsortedlast is not implemented |
62 | | - # efficiently on a StepRangeLen |
63 | | - cell_num = searchsortedlast(cell_boundaries, points1[i1, j]) |
64 | | - # TODO maybe we should verify here, that cell_num is a valid index |
65 | | - cell_lower = cell_boundaries[cell_num] |
66 | | - cell_upper = cell_boundaries[cell_num + 1] |
67 | | - # TODO maybe we need some correction for rounding errors here |
68 | | - |
69 | | - i1 += 1 |
70 | | - num1 = 1 |
71 | | - # the first loop is just for safety we probably don't need it |
72 | | - while i1 <= n1 && points1[i1, j] < cell_lower |
73 | | - i1 += 1 |
74 | | - end |
75 | | - # count number of vertices from g1 that fall into the bucket |
76 | | - # specified [cell_lower, cell_upper) along the j-th dimension |
77 | | - # of the hypercube |
78 | | - while i1 <= n1 && points1[i1, j] < cell_upper |
79 | | - num1 += 1 |
80 | | - i1 += 1 |
81 | | - end |
| 83 | + result = 0.0 |
| 84 | + |
| 85 | + for (class, points1) in points_by_class_1 |
| 86 | + |
| 87 | + points2 = get(points_by_class_2, class, nothing) |
| 88 | + points2 == nothing && continue |
| 89 | + |
| 90 | + # TODO the embedding dim might actually be smaller for some graph |
| 91 | + # we should maybe consider some kind of scaling in such a case |
| 92 | + d = min(size(points1, 2), size(points2, 2)) |
| 93 | + n1 = size(points1, 1) |
| 94 | + n2 = size(points2, 1) |
| 95 | + |
| 96 | + hist_intersect = zeros(Int64, L + 1) |
| 97 | + |
| 98 | + # TODO ensure no int overflow, same result on 32 bit platform |
| 99 | + |
| 100 | + # TODO we don't need to store hist_intersect as a vector |
| 101 | + # TODO add @inbounds |
| 102 | + hist_intersect[1] = d * min(n1, n2) |
| 103 | + for l in 1:L |
| 104 | + cell_boundaries = range(0.0, 1.0, length=2^l + 1) |
| 105 | + for j in 1:d |
| 106 | + i1 = 1 |
| 107 | + i2 = 1 |
| 108 | + while i1 <= n1 |
| 109 | + # TODO is is possible that searchsortedlast is not implemented |
| 110 | + # efficiently on a StepRangeLen |
| 111 | + cell_num = searchsortedlast(cell_boundaries, points1[i1, j]) |
| 112 | + # TODO maybe we should verify here, that cell_num is a valid index |
| 113 | + cell_lower = cell_boundaries[cell_num] |
| 114 | + cell_upper = cell_boundaries[cell_num + 1] |
| 115 | + # TODO maybe we need some correction for rounding errors here |
82 | 116 |
|
83 | | - # count number of vertices from g2 that fall into that bucket |
84 | | - # We could also use binary search here |
85 | | - while i2 <= n2 && points2[i2, j] < cell_lower |
86 | | - i2 += 1 |
87 | | - end |
88 | | - num2 = 0 |
89 | | - while i2 <= n2 && points2[i2, j] < cell_upper |
90 | | - num2 += 1 |
91 | | - i2 += 1 |
| 117 | + i1 += 1 |
| 118 | + num1 = 1 |
| 119 | + # the first loop is just for safety we probably don't need it |
| 120 | + while i1 <= n1 && points1[i1, j] < cell_lower |
| 121 | + i1 += 1 |
| 122 | + end |
| 123 | + # count number of vertices from g1 that fall into the bucket |
| 124 | + # specified [cell_lower, cell_upper) along the j-th dimension |
| 125 | + # of the hypercube |
| 126 | + while i1 <= n1 && points1[i1, j] < cell_upper |
| 127 | + num1 += 1 |
| 128 | + i1 += 1 |
| 129 | + end |
| 130 | + |
| 131 | + # count number of vertices from g2 that fall into that bucket |
| 132 | + # We could also use binary search here |
| 133 | + while i2 <= n2 && points2[i2, j] < cell_lower |
| 134 | + i2 += 1 |
| 135 | + end |
| 136 | + num2 = 0 |
| 137 | + while i2 <= n2 && points2[i2, j] < cell_upper |
| 138 | + num2 += 1 |
| 139 | + i2 += 1 |
| 140 | + end |
| 141 | + hist_intersect[l + 1] += min(num1, num2) |
92 | 142 | end |
93 | | - hist_intersect[l + 1] += min(num1, num2) |
94 | 143 | end |
95 | 144 | end |
96 | | - end |
97 | 145 |
|
98 | | - return hist_intersect[L + 1] + sum(l -> (hist_intersect[l+1] - hist_intersect[l+2]) / 2^(L - l), 0:L-1) |
| 146 | + result += (hist_intersect[L + 1] + sum(l -> (hist_intersect[l+1] - hist_intersect[l+2]) / 2^(L - l), 0:L-1)) |
| 147 | + |
| 148 | + end |
| 149 | + return result |
99 | 150 | end |
100 | 151 |
|
101 | 152 |
|
|
0 commit comments