Skip to content

Commit 3aa1e6d

Browse files
Bump LIBSVM version to v0.7
The svmtrain and svmpredict functions slightly changed in 0.7 for custom kernels, therefore some code had to be changed.
1 parent eed1aa4 commit 3aa1e6d

File tree

3 files changed

+7
-15
lines changed

3 files changed

+7
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
1818
[compat]
1919
DataStructures = "0.18.9"
2020
KernelFunctions = "0.10.5"
21-
LIBSVM = "0.6"
21+
LIBSVM = "0.7"
2222
LightGraphs = "1.3"
2323
SimpleValueGraphs = "0.3"
2424
ThreadsX = "0.1"

src/GraphKernels.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ include("integrations/LIBSVM.jl")
5555
5656
Simple k-fold cross validation implementation for quick testing during development.
5757
"""
58-
function k_fold_cross_validation(kernel::AbstractGraphKernel, graphs; k_folds=5, class_key=1, kwargs...)
58+
function k_fold_cross_validation(kernel::KernelFunctions.Kernel, graphs::AbstractVector{<:AbstractGraph}; k_folds=5, class_key=1, kwargs...)
5959

6060
n = length(graphs)
6161
indices = randperm(MersenneTwister(123), n)

src/integrations/LIBSVM.jl

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11

22
struct GraphSVMModel
33
svm::LIBSVM.SVM
4-
kernel::AbstractGraphKernel
4+
kernel::KernelFunctions.Kernel
55
graphs::AbstractVector
66
end
77

88

9-
function svmtrain(graphs::AbstractVector{<:AbstractGraph}, labels, kernel::AbstractGraphKernel; kwargs...)
10-
11-
n = length(graphs)
12-
13-
X = vcat(transpose(1:n), kernelmatrix(kernel, graphs))
9+
function svmtrain(graphs::AbstractVector{<:AbstractGraph}, labels, kernel::KernelFunctions.Kernel; kwargs...)
1410

11+
X = kernelmatrix(kernel, graphs)
1512
svm = svmtrain(X, labels, kernel=Kernel.Precomputed; kwargs...)
1613

1714
return GraphSVMModel(svm, kernel, graphs)
@@ -22,12 +19,7 @@ function svmpredict(model::GraphSVMModel, unpredicted_graphs::AbstractVector{<:A
2219
graphs = model.graphs
2320
kernel = model.kernel
2421

25-
m = length(graphs)
26-
n = length(unpredicted_graphs)
27-
28-
X = Matrix{Float64}(undef, m + 1, n)
29-
X[1, :] = 1:n
30-
X[2:end, :] = kernelmatrix(kernel, graphs, unpredicted_graphs)
31-
22+
# TODO might only be necessary to do the calculations for support vectors
23+
X = kernelmatrix(kernel, graphs, unpredicted_graphs)
3224
return svmpredict(model.svm, X)[1] # for simplicity return only the labels for now
3325
end

0 commit comments

Comments
 (0)