Skip to content

Commit eed1aa4

Browse files
Make kernels subtypes of KernelFunctions.Kernel
1 parent 6cab505 commit eed1aa4

File tree

5 files changed

+20
-60
lines changed

5 files changed

+20
-60
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.1.0"
55

66
[deps]
77
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
8+
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
89
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
910
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -16,6 +17,7 @@ ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
1617

1718
[compat]
1819
DataStructures = "0.18.9"
20+
KernelFunctions = "0.10.5"
1921
LIBSVM = "0.6"
2022
LightGraphs = "1.3"
2123
SimpleValueGraphs = "0.3"

src/GraphKernels.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ using ThreadsX
1111

1212
import LIBSVM: svmtrain, svmpredict
1313

14+
using KernelFunctions: kernelmatrix, kernelmatrix_diag
15+
import KernelFunctions
16+
1417
export
1518
AbstractGraphKernel,
1619

@@ -19,17 +22,16 @@ export
1922
PyramidMatchGraphKernel,
2023
WeisfeilerLehmanGraphKernel,
2124

22-
NormalizeGraphKernel,
23-
2425
ConstVertexKernel,
2526
DiracVertexKernel,
2627
DotVertexKernel,
2728

29+
k_fold_cross_validation,
30+
31+
# reexport from KernelFunctions
2832
kernelmatrix,
2933
kernelmatrix_diag,
3034

31-
k_fold_cross_validation,
32-
3335
# overridden methods from LIBSVM
3436
svmtrain,
3537
svmpredict
@@ -39,7 +41,6 @@ include("replacedvertexvals.jl")
3941
include("vertex_kernels.jl")
4042
include("graph-kernels/abstract-graph-kernel.jl")
4143
include("graph-kernels/baseline-graph-kernel.jl")
42-
include("graph-kernels/normalize-graph-kernel.jl")
4344
include("graph-kernels/pyramid-match-graph-kernel.jl")
4445
include("graph-kernels/shortest-path-graph-kernel.jl")
4546
include("graph-kernels/weisfeiler-lehman-graph-kernel.jl")

src/graph-kernels/abstract-graph-kernel.jl

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# ================================================================
44

55
"""
6-
abstract type AbstractGraphKernel
6+
abstract type AbstractGraphKernel <: KernelFunctions.Kernel
77
88
A kernel function between two graphs.
99
@@ -17,10 +17,10 @@ that transforms a single graph into a suitable representation and `apply_preproc
1717
takes the representations for both graphs and calculates the kernel function.
1818
1919
### See also
20-
[`preprocessed_form`](@ref), [`apply_preprocessed`](@ref), [`kernel_matrix`](@ref), [`kernel_matrix_diag`](@ref)
20+
[`preprocessed_form`](@ref), [`apply_preprocessed`](@ref), [`KernelFunctions.Kernel`](@ref), [`kernel_matrix`](@ref), [`kernel_matrix_diag`](@ref)
2121
2222
"""
23-
abstract type AbstractGraphKernel end
23+
abstract type AbstractGraphKernel <: KernelFunctions.Kernel end
2424

2525
"""
2626
preprocessed_form(k::AbstractGraphKernel, g::AbstractGraph) = g
@@ -39,7 +39,7 @@ method.
3939
"""
4040
preprocessed_form(::AbstractGraphKernel, g::AbstractGraph) = g
4141

42-
function (kernel::AbstractGraphKernel)(g1, g2)
42+
function (kernel::AbstractGraphKernel)(g1::AbstractGraph, g2::AbstractGraph)
4343

4444
return apply_preprocessed(kernel, preprocessed_form(kernel, g1), preprocessed_form(kernel, g2))
4545
end
@@ -52,7 +52,7 @@ function _map_preprocessed_form(kernel::AbstractGraphKernel, graphs)
5252

5353
# TODO we should be able to avoid collecting the graphs
5454
# but currently ThreadX cannot split them otherwise,
55-
# maybe we can create some wrapper type that is splitable around graphs
55+
# maybe we can create some wrapper type that is splittable around graphs
5656
return ThreadsX.map(g -> preprocessed_form(kernel, g), collect(graphs))
5757
end
5858

@@ -63,7 +63,7 @@ Return a matrix of running the kernel on all pairs of graphs.
6363
### See also
6464
[`kernelmatrix_diag`](@ref)
6565
"""
66-
function kernelmatrix(kernel::AbstractGraphKernel, graphs)
66+
function KernelFunctions.kernelmatrix(kernel::AbstractGraphKernel, graphs::AbstractVector)
6767

6868
pre = _map_preprocessed_form(kernel, graphs)
6969

@@ -92,15 +92,7 @@ function _kernelmatrix_from_preprocessed(kernel, pre)
9292
return G
9393
end
9494

95-
"""
96-
kernelmatrix_diag(kernel::AbstractGraphKernel, graphs)
97-
98-
Calculate the diagonal of the kernelmatrix matrix of the graphs.
99-
100-
### See also
101-
[`kernelmatrix`](@ref)
102-
"""
103-
function kernelmatrix_diag(kernel::AbstractGraphKernel, graphs)
95+
function KernelFunctions.kernelmatrix_diag(kernel::AbstractGraphKernel, graphs::AbstractVector)
10496

10597
n = length(graphs)
10698
pre = _map_preprocessed_form(kernel, graphs)
@@ -112,13 +104,8 @@ function kernelmatrix_diag(kernel::AbstractGraphKernel, graphs)
112104
return D
113105
end
114106

115-
"""
116-
kernelmatrix(kernel::AbstractGraphKernel, graphs1, graphs2)
117107

118-
Calculate a matrix of invoking the kernel on all pairs.
119-
Entry `(i, j)` of the resulting matrix contains `kernel(graphs1[i], graphs2[j]`.
120-
"""
121-
function kernelmatrix(kernel::AbstractGraphKernel, graphs1, graphs2)
108+
function KernelFunctions.kernelmatrix(kernel::AbstractGraphKernel, graphs1::AbstractVector, graphs2::AbstractVector)
122109

123110
n_rows = length(graphs1)
124111
n_cols = length(graphs2)

src/graph-kernels/normalize-graph-kernel.jl

Lines changed: 0 additions & 32 deletions
This file was deleted.

src/vertex_kernels.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# AbstractVertexKernel
55
# ================================================================
66

7-
abstract type AbstractVertexKernel end
7+
abstract type AbstractVertexKernel <: KernelFunctions.Kernel end
88

9-
function (kernel::AbstractVertexKernel)((g1, v1), (g2, v2))
9+
function (kernel::AbstractVertexKernel)(gv1::Tuple{AbstractGraph, Integer}, gv2::Tuple{AbstractGraph, Integer})
1010

11+
(g1, v1) = gv1
12+
(g2, v2) = gv2
1113
return kernel(g1, v1, g2, v2)
1214
end
1315

0 commit comments

Comments
 (0)