Skip to content

Commit 259a8b1

Browse files
Update Readme
* Add an example to README * Add wikipedia link to README
1 parent 5440e96 commit 259a8b1

File tree

1 file changed

+41
-1
lines changed

1 file changed

+41
-1
lines changed

README.md

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,44 @@
22

33
![](https://img.shields.io/badge/lifecycle-experimental-orange.svg)
44

5-
A Julia package for kernel functions on graphs.
5+
A Julia package for calculating [graph kernels](https://en.wikipedia.org/wiki/Graph_kernel) -
6+
kernel function where the inputs are graphs.
7+
8+
### Example
9+
10+
```julia
11+
julia> using GraphKernels: ShortestPathGraphKernel, svmtrain, svmpredict
12+
13+
julia> using GraphDatasets: loadgraphs, TUDatasets
14+
julia> using SimpleValueGraphs: get_graphval
15+
julia> using Random: shuffle
16+
julia> using Statistics: mean
17+
18+
# load the MUTAG dataset - it contains 188 graphs of two different classes
19+
julia> graphs = loadgraphs(TUDatasets.MUTAGDataset(); resolve_categories=true)
20+
188-element ValGraphCollection of graphs with
21+
eltype: Int8
22+
vertex value types: (chem = String,)
23+
edge value types: (bond_type = String,)
24+
graph value types: (class = Int8,)
25+
26+
# shuffle the graphs and split into train and test data
27+
julia> graphs = shuffle(graphs);
28+
julia> X_train, X_test = graphs[begin:120], graphs[121:end];
29+
julia> y_train, y_test = get_graphval.(X_train, :class), get_graphval.(X_test, :class);
30+
31+
# instantiate a ShortestPathGraphKernel
32+
# dist_key is set to nothing so that we use unit distances for all edges
33+
julia> kernel = ShortestPathGraphKernel(;dist_key=nothing)
34+
ShortestPathGraphKernel{ConstVertexKernel}(0.0, ConstVertexKernel(1.0), nothing)
35+
36+
# train a support vector machine with that kernel
37+
julia> model = svmtrain(X_train, y_train, kernel);
38+
39+
# predict classed on the test data
40+
julia> y_test_pred = svmpredict(model, X_test);
41+
42+
# compare with the actual classes and calculate the accuracy
43+
julia> accuracy = mean(y_test .== y_test_pred)
44+
0.8529411764705882
45+
```

0 commit comments

Comments
 (0)