Skip to content

Commit 6ee69af

Browse files
Santos SafraoSantos Safrao
authored andcommitted
added Kernel Mutual Subspace Method
1 parent fddb9ae commit 6ee69af

File tree

5 files changed

+107
-71
lines changed

5 files changed

+107
-71
lines changed

examples/kmsm.asv

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
%% Load data
2+
testDataUsage = TestDataUsage.Subsets;
3+
n = 2;
4+
[trainData, testData, testLabels] = prepareData(testDataUsage, n);
5+
6+
%% Train model
7+
numDimReferenceSubspace = 10;
8+
numDimInputSubspace = 4;
9+
sigma = 1;
10+
11+
model = KMSM(trainData,...
12+
numDimReferenceSubspace,...
13+
numDimInputSubspace,...
14+
sigma,...
15+
testLabels);
16+
17+
%% Evaluate model
18+
modelEvaluation = model.evaluate(testData);
19+
modelEvaluation.printResults();

examples/kmsm.m

Lines changed: 15 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,19 @@
1-
%% Load Data
2-
clear;
3-
clc;
4-
load('TsukubaHandDigitsDataset24x24.mat')
5-
% Check if the variable 'trainData' and 'testData' do not exist in the workspace
6-
% if ~(exist('trainData', 'var') == 1 && exist('testData', 'var') == 1)
7-
% % If they don't exist, load the data from the .mat file
8-
% load('TsukubaHandDigitsDataset24x24.mat');
9-
% end
10-
% you can use the following code to convert the test data format from 3d to 4d
11-
testData = subsetTestData(testData, 2);
12-
% specific_class = 5;
13-
%% to accomodate for MATLAB indexing
14-
% specific_class = specific_class + 1;
15-
training_data = cvlNormalize(trainData);
16-
testing_data = cvlNormalize(testData);
17-
% testing_data = testData(:, :, specific_class);
18-
size_of_test_data = size(testing_data);
1+
%% Load data
2+
testDataUsage = TestDataUsage.Subsets;
3+
n = 2;
4+
[trainData, testData, testLabels] = prepareData(testDataUsage, n);
195

20-
% get number of elements of size_of_test_data
21-
array_size = numel(size_of_test_data);
22-
23-
if array_size == 4
24-
% do nothing
25-
num_sets = size_of_test_data(3);
26-
num_classes = size_of_test_data(4);
27-
elseif array_size == 3
28-
num_sets = 1;
29-
num_classes = size_of_test_data(3);
30-
else
31-
num_classes = 1;
32-
num_sets = 1;
33-
end
34-
35-
%% Train Model
36-
num_dim_reference_subspaces = 10;
37-
num_dim_input_subpaces = 5;
6+
%% Train model
7+
numDimReferenceSubspace = 10;
8+
numDimInputSubspace = 4;
389
sigma = 1;
3910

40-
reference_subspaces = cvlKernelBasisVector(training_data, num_dim_reference_subspaces, sigma);
41-
input_subspaces = cvlKernelBasisVector(testing_data, num_dim_input_subpaces, sigma);
42-
% save('reference_subspaces.mat', 'reference_subspaces');
43-
% reference_subspaces = reference_subspaces(:, :, 1);
44-
tic;
45-
%% Recognition Phase
46-
similarities = cvlKernelCanonicalAngles(training_data,reference_subspaces,...
47-
testing_data, input_subspaces, sigma);
48-
similarities = similarities(:, :, end, end);
49-
% End timing and display the elapsed time
50-
elapsedTime = toc;
51-
fprintf('The code block executed in %.5f seconds.\n', elapsedTime);
52-
model_evaluation = ModelEvaluation(similarities, generateLabels(num_classes, num_sets));
53-
54-
displayModelResults('Kernel Mutual Subspace Methods', model_evaluation);
11+
model = KMSM(trainData,...
12+
numDimReferenceSubspace,...
13+
numDimInputSubspace,...
14+
sigma,...
15+
testLabels);
5516

56-
%% Print preditions
57-
% disp(model_evaluation.predicted_labels);
58-
% disp(model_evaluation.true_labels);
59-
% disp(similarities)
60-
% plotSimilarities(similarities)
17+
%% Evaluate model
18+
modelEvaluation = model.evaluate(testData);
19+
modelEvaluation.printResults();

src/classes/@CMSM/CMSM.m

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
classdef CMSM
22
properties
33
name = 'Constrained Mutual Subspace Method';
4-
trainData;
54
generalizedDifferenceSubspace; % GDS
65
referenceSubspaces;
76
numDimInputSubspace;
@@ -13,11 +12,11 @@
1312
function obj = CMSM(trainData, numDimReferenceSubspace, numDimInputSubspace, indexOfEigsToKeep, labels)
1413
numDim = size(trainData, 1);
1514
numClasses = size(trainData, 3);
16-
obj.trainData = trainData;
15+
trainData = cvlNormalize(trainData);
1716
obj.numDimReferenceSubspace = numDimReferenceSubspace;
1817
obj.numDimInputSubspace = numDimInputSubspace;
1918
obj.trueTestLabels = labels;
20-
subspaces = cvlBasisVector(obj.trainData,...
19+
subspaces = cvlBasisVector(trainData,...
2120
obj.numDimReferenceSubspace);
2221
P = zeros(numDim, numDim);
2322
for I=1:numClasses
@@ -61,6 +60,7 @@
6160
function scores = getSimilarityScores(obj, testData)
6261
testDatasize = size(testData);
6362
testDatasizeNumElements = numel(testDatasize);
63+
testData = cvlNormalize(testData);
6464

6565
if testDatasizeNumElements == 4
6666
numSets = testDatasize(3);

src/classes/@KMSM/KMSM.m

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
classdef KMSM
2+
properties
3+
name = 'Kernel Mutual Subspace Method';
4+
trainData;
5+
referenceSubspaces;
6+
numDimInputSubspace;
7+
numDimReferenceSubspace;
8+
sigma;
9+
trueTestLabels;
10+
end
11+
12+
methods
13+
function obj = KMSM(trainData, numDimReferenceSubspace, numDimInputSubspace, sigma, labels)
14+
obj.trainData = cvlNormalize(trainData);
15+
obj.sigma = sigma;
16+
obj.numDimReferenceSubspace = numDimReferenceSubspace;
17+
obj.numDimInputSubspace = numDimInputSubspace;
18+
obj.trueTestLabels = labels;
19+
obj.referenceSubspaces = cvlKernelBasisVector(obj.trainData,...
20+
obj.numDimReferenceSubspace,...
21+
obj.sigma);
22+
end
23+
24+
25+
% Returns the predicted labels for the test data
26+
function prediction = predict(obj, testData)
27+
similarityScores = obj.getSimilarityScores(testData);
28+
eval = ModelEvaluation(similarityScores, obj.trueTestLabels, obj.name);
29+
prediction = eval.predicted_labels;
30+
end
31+
32+
% Returns the similarity scores for the test data (same as probabilities)
33+
function probabilities = predictProb(obj, testData)
34+
probabilities = obj.getSimilarityScores(testData);
35+
end
36+
37+
% Returns the evaluation object for the test data
38+
function eval = evaluate(obj, testData)
39+
similarityScores = obj.getSimilarityScores(testData);
40+
eval = ModelEvaluation(similarityScores, obj.trueTestLabels, obj.name);
41+
end
42+
end
43+
44+
methods (Access = private)
45+
function scores = getSimilarityScores(obj, testData)
46+
testData = cvlNormalize(testData);
47+
inputSubspace = cvlKernelBasisVector(testData,...
48+
obj.numDimInputSubspace,...
49+
obj.sigma);
50+
similarities = cvlKernelCanonicalAngles(obj.trainData,...
51+
obj.referenceSubspaces,...
52+
testData,...
53+
inputSubspace,...
54+
obj.sigma);
55+
scores = similarities(:, :, end, end);
56+
end
57+
end
58+
end

src/classes/@MSM/MSM.m

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,50 @@
11
classdef MSM
22
properties
33
name = 'Mutual Subspace Method';
4-
trainData;
54
referenceSubspaces;
65
numDimInputSubspace;
76
numDimReferenceSubspace;
87
trueTestLabels;
98
end
10-
9+
1110
methods
1211
function obj = MSM(trainData, numDimReferenceSubspace, numDimInputSubspace, labels)
13-
obj.trainData = trainData;
12+
trainData = cvlNormalize(trainData);
1413
obj.numDimReferenceSubspace = numDimReferenceSubspace;
1514
obj.numDimInputSubspace = numDimInputSubspace;
1615
obj.trueTestLabels = labels;
17-
subspaces = cvlBasisVector(obj.trainData,...
18-
obj.numDimReferenceSubspace);
16+
subspaces = cvlBasisVector(trainData,...
17+
obj.numDimReferenceSubspace);
1918
obj.referenceSubspaces = subspaces;
2019
end
21-
22-
20+
21+
2322
% Returns the predicted labels for the test data
2423
function prediction = predict(obj, testData)
2524
similarityScores = obj.getSimilarityScores(testData);
2625
eval = ModelEvaluation(similarityScores, obj.trueTestLabels, obj.name);
2726
prediction = eval.predicted_labels;
2827
end
29-
28+
3029
% Returns the similarity scores for the test data (same as probabilities)
3130
function probabilities = predictProb(obj, testData)
3231
probabilities = obj.getSimilarityScores(testData);
3332
end
34-
33+
3534
% Returns the evaluation object for the test data
3635
function eval = evaluate(obj, testData)
3736
similarityScores = obj.getSimilarityScores(testData);
3837
eval = ModelEvaluation(similarityScores, obj.trueTestLabels, obj.name);
3938
end
4039
end
41-
40+
4241
methods (Access = private)
4342
function scores = getSimilarityScores(obj, testData)
43+
testData = cvlNormalize(testData);
4444
inputSubspace = cvlBasisVector(testData,...
45-
obj.numDimInputSubspace);
45+
obj.numDimInputSubspace);
4646
similarities = cvlCanonicalAngles(obj.referenceSubspaces,...
47-
inputSubspace);
47+
inputSubspace);
4848
scores = similarities(:, :, end, end);
4949
end
5050
end

0 commit comments

Comments
 (0)