Skip to content

Commit f543eb1

Browse files
authored
Merge pull request #55 from bgauzere/master
patch for computation of treelet kernel: intersection does not induce…
2 parents e735e9a + 32e8ccf commit f543eb1

File tree

2 files changed

+7
-15
lines changed

2 files changed

+7
-15
lines changed

gklearn/kernels/treelet.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -452,20 +452,12 @@ def _kernel_do(self, canonkey1, canonkey2):
452452
kernel : float
453453
Treelet kernel between 2 graphs.
454454
"""
455-
keys = set(canonkey1.keys()) & set(
456-
canonkey2.keys()
457-
) # find same canonical keys in both graphs
455+
keys = set(canonkey1.keys()) | set(canonkey2.keys()) # find same canonical keys in both graphs
458456
if len(keys) == 0: # There is nothing in common...
459-
return 0
457+
return 0
460458

461-
vector1 = np.array(
462-
[(canonkey1[key] if (key in canonkey1.keys()) else 0) for key in
463-
keys]
464-
)
465-
vector2 = np.array(
466-
[(canonkey2[key] if (key in canonkey2.keys()) else 0) for key in
467-
keys]
468-
)
459+
vector1 = np.array([canonkey1.get(key,0) for key in keys])
460+
vector2 = np.array([canonkey2.get(key,0)for key in keys])
469461

470462
# vector1, vector2 = [], []
471463
# keys1, keys2 = canonkey1, canonkey2

gklearn/kernels/treeletKernel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ def _treeletkernel_do(canonkey1, canonkey2, sub_kernel):
160160
kernel : float
161161
Treelet Kernel between 2 graphs.
162162
"""
163-
keys = set(canonkey1.keys()) & set(canonkey2.keys()) # find same canonical keys in both graphs
164-
vector1 = np.array([(canonkey1[key] if (key in canonkey1.keys()) else 0) for key in keys])
165-
vector2 = np.array([(canonkey2[key] if (key in canonkey2.keys()) else 0) for key in keys])
163+
keys = set(canonkey1.keys()) | set(canonkey2.keys()) # find union of canonical keys in both graphs
164+
vector1 = np.array([canonkey1.get(key,0) for key in keys])
165+
vector2 = np.array([canonkey2.get(key,0) for key in keys])
166166
kernel = sub_kernel(vector1, vector2)
167167
return kernel
168168

0 commit comments

Comments
 (0)