Skip to content

Commit fddb9ae

Browse files
Santos SafraoSantos Safrao
authored andcommitted
added the CMSM class
1 parent 72f82d0 commit fddb9ae

File tree

9 files changed

+149
-174
lines changed

9 files changed

+149
-174
lines changed

examples/cmsm.m

Lines changed: 16 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,23 @@
1-
%% Load Data
2-
clear;
3-
clc;
4-
% load('TsukubaHandDigitsDataset.mat')
5-
% Check if the variable 'trainData' and 'testData' do not exist in the workspace
6-
if ~(exist('trainData', 'var') == 1 && exist('testData', 'var') == 1)
7-
% If they don't exist, load the data from the .mat file
8-
load('TsukubaHandDigitsDataset24x24.mat');
9-
end
10-
% you can use the following code to convert the test data format from 3d to 4d
11-
% testData = subsetTestData(testData, 4);
12-
specific_class = 1;
13-
class_num = 6;
14-
%% to accomodate for MATLAB indexing
15-
specific_class = specific_class + 1;
16-
training_data = trainData;
17-
% testing_data = testData;
18-
testing_data = testData(:, :, specific_class);
19-
size_of_test_data = size(testing_data);
20-
dim = size_of_test_data(1);
21-
del_subspace_dim = 3;
1+
%% Load data
2+
testDataUsage = TestDataUsage.SingleClass;
3+
n = 2;
4+
[trainData, testData, testLabels] = prepareData(testDataUsage, n);
225

23-
% get number of elements of size_of_test_data
24-
array_size = numel(size_of_test_data);
6+
%% Train model
7+
numDimReferenceSubspace = 10;
8+
numDimInputSubspace = 4;
9+
indexOfEigsToKeep = 3;
2510

26-
if array_size == 4
27-
% do nothing
28-
num_samples_per_set = size_of_test_data(2);
29-
num_sets = size_of_test_data(3);
30-
num_classes = size_of_test_data(4);
31-
elseif array_size == 3
32-
num_sets = 1;
33-
num_classes = size_of_test_data(3);
34-
else
35-
num_classes = 1;
36-
num_sets = 1;
37-
end
11+
model = CMSM(trainData,...
12+
numDimReferenceSubspace,...
13+
numDimInputSubspace,...
14+
indexOfEigsToKeep,...
15+
testLabels);
3816

39-
%% Train Model
40-
tic;
41-
num_dim_reference_subspaces = 10;
42-
num_dim_input_subpaces = 5;
17+
%% Evaluate model
18+
modelEvaluation = model.evaluate(testData);
19+
modelEvaluation.printResults();
4320

44-
reference_subspaces = cvlBasisVector(training_data, num_dim_reference_subspaces);
45-
input_subspaces = cvlBasisVector(testing_data, num_dim_input_subpaces);
4621

47-
save('reference_subspaces.mat', 'reference_subspaces');
48-
% reference_subspaces = reference_subspaces(:, :, 1);
49-
% Generalizated difference subspace(Constraint Subspace)
50-
P = zeros(dim, dim);
51-
for I=1:class_num
52-
P = P + reference_subspaces(:,:,I)*reference_subspaces(:,:,I)';
53-
end
54-
[B, C] = eig(P);
55-
C = diag(C);
56-
[~, index] = sort(C,'descend');
57-
B = B(:,index); C = C(index);
58-
difference = B(:,del_subspace_dim+1:rank(P))';
5922

60-
difference_subspace = zeros(size(difference,1), num_dim_reference_subspaces, size(reference_subspaces,3));
61-
for I=1:size(reference_subspaces,3)
62-
difference_subspace(:,:,I) = orth(difference*reference_subspaces(:,:,I));
63-
end
6423

65-
% process input difference subspace
66-
if array_size == 4
67-
input_difference_subspace = zeros(size(difference,1), num_dim_input_subpaces, num_sets, num_classes);
68-
for I=1:num_classes
69-
for J=1:num_sets
70-
input_difference_subspace(:,:,J,I) = orth(difference*input_subspaces(:,:,J,I));
71-
end
72-
end
73-
elseif array_size == 3
74-
input_difference_subspace = zeros(size(difference,1), num_dim_input_subpaces, num_classes);
75-
for I=1:num_classes
76-
input_difference_subspace(:,:,I) = orth(difference*input_subspaces(:,:,I));
77-
end
78-
else
79-
input_difference_subspace = orth(difference*input_subspaces);
80-
end
81-
82-
reference_subspaces= difference_subspace;
83-
input_subspaces = input_difference_subspace;
84-
85-
reference_subspaces = reference_subspaces(:, :, 1:3);
86-
87-
tic;
88-
%% Recognition Phase
89-
% convert reference and input subspace to cells
90-
reference_subspaces = mat2cell(reference_subspaces, size(reference_subspaces, 1), size(reference_subspaces, 2), ones(1, size(reference_subspaces, 3)));
91-
input_subspaces = mat2cell(input_subspaces, size(input_subspaces, 1), size(input_subspaces, 2), ones(1, size(input_subspaces, 3)));
92-
similarities = cvlCanonicalAngles(reference_subspaces, input_subspaces);
93-
similarities = similarities(:, :, end, end);
94-
% End timing and display the elapsed time
95-
elapsedTime = toc;
96-
fprintf('The code block executed in %.5f seconds.\n', elapsedTime);
97-
model_evaluation = ModelEvaluation(similarities, generateLabels(num_classes, num_sets, specific_class));
98-
99-
displayModelResults('Contained Mutual Subspace Methods', model_evaluation);
100-
101-
%% Print preditions
102-
disp(model_evaluation.predicted_labels);
103-
disp(model_evaluation.true_labels);
104-
disp(similarities);
105-
plotSimilarityMatrix(similarities, 'CMSM')
106-
% disp(similarities)
107-
% plotSimilarities(similarities)

examples/msm2.asv

Lines changed: 0 additions & 32 deletions
This file was deleted.

examples/msm2.m

Lines changed: 0 additions & 19 deletions
This file was deleted.

reference_subspaces.mat

0 Bytes
Binary file not shown.

src/classes/@CMSM/CMSM.m

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
classdef CMSM
2+
properties
3+
name = 'Constrained Mutual Subspace Method';
4+
trainData;
5+
generalizedDifferenceSubspace; % GDS
6+
referenceSubspaces;
7+
numDimInputSubspace;
8+
numDimReferenceSubspace;
9+
trueTestLabels;
10+
end
11+
12+
methods
13+
function obj = CMSM(trainData, numDimReferenceSubspace, numDimInputSubspace, indexOfEigsToKeep, labels)
14+
numDim = size(trainData, 1);
15+
numClasses = size(trainData, 3);
16+
obj.trainData = trainData;
17+
obj.numDimReferenceSubspace = numDimReferenceSubspace;
18+
obj.numDimInputSubspace = numDimInputSubspace;
19+
obj.trueTestLabels = labels;
20+
subspaces = cvlBasisVector(obj.trainData,...
21+
obj.numDimReferenceSubspace);
22+
P = zeros(numDim, numDim);
23+
for I=1:numClasses
24+
P = P + subspaces(:,:,I)*subspaces(:,:,I)';
25+
end
26+
[B, C] = eig(P);
27+
C = diag(C);
28+
[~, index] = sort(C,'descend');
29+
B = B(:,index);
30+
GDS = B(:,indexOfEigsToKeep+1:rank(P))';
31+
32+
differenceSubspaces = zeros(size(GDS,1), numDimReferenceSubspace, numClasses);
33+
for I=1:numClasses
34+
differenceSubspaces(:,:,I) = orth(GDS*subspaces(:,:,I));
35+
end
36+
obj.generalizedDifferenceSubspace = GDS;
37+
obj.referenceSubspaces = differenceSubspaces;
38+
end
39+
40+
41+
% Returns the predicted labels for the test data
42+
function prediction = predict(obj, testData)
43+
similarityScores = obj.getSimilarityScores(testData);
44+
eval = ModelEvaluation(similarityScores, obj.trueTestLabels, obj.name);
45+
prediction = eval.predicted_labels;
46+
end
47+
48+
% Returns the similarity scores for the test data (same as probabilities)
49+
function probabilities = predictProb(obj, testData)
50+
probabilities = obj.getSimilarityScores(testData);
51+
end
52+
53+
% Returns the evaluation object for the test data
54+
function eval = evaluate(obj, testData)
55+
similarityScores = obj.getSimilarityScores(testData);
56+
eval = ModelEvaluation(similarityScores, obj.trueTestLabels, obj.name);
57+
end
58+
end
59+
60+
methods (Access = private)
61+
function scores = getSimilarityScores(obj, testData)
62+
testDatasize = size(testData);
63+
testDatasizeNumElements = numel(testDatasize);
64+
65+
if testDatasizeNumElements == 4
66+
numSets = testDatasize(3);
67+
numClasses = testDatasize(4);
68+
elseif testDatasizeNumElements == 3
69+
numSets = 1;
70+
numClasses = testDatasize(3);
71+
end
72+
73+
subspace = cvlBasisVector(testData,...
74+
obj.numDimInputSubspace);
75+
GDSnumDim = size(obj.generalizedDifferenceSubspace, 1);
76+
if testDatasizeNumElements == 4
77+
differenceSubspace = zeros(GDSnumDim,...
78+
obj.numDimInputSubspace,...
79+
numSets,...
80+
numClasses);
81+
82+
for I=1:numClasses
83+
for J=1:numSets
84+
differenceSubspace(:,:,J,I) = orth(obj.generalizedDifferenceSubspace*subspace(:,:,J,I));
85+
end
86+
end
87+
elseif testDatasizeNumElements == 3
88+
differenceSubspace = zeros(GDSnumDim, obj.numDimInputSubspace, numClasses);
89+
for I=1:numClasses
90+
differenceSubspace(:,:,I) = orth(obj.generalizedDifferenceSubspace*subspace(:,:,I));
91+
end
92+
else
93+
differenceSubspace = orth(obj.generalizedDifferenceSubspace*subspace);
94+
end
95+
inputSubspace = differenceSubspace;
96+
similarities = cvlCanonicalAngles(obj.referenceSubspaces,...
97+
inputSubspace);
98+
scores = similarities(:, :, end, end);
99+
end
100+
end
101+
end
File renamed without changes.

src/utils/displayModelResults.m

Lines changed: 0 additions & 22 deletions
This file was deleted.

src/utils/subsetTestData.asv

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
function data = subsetTestData(testData, numSets)
2+
numDim = size(testData, 1); % Image dimensions
3+
numClasses = size(testData, 3); % Number of classes
4+
5+
% Initialize the new test data structure
6+
numSamplesPerClass = size(testData, 2);
7+
numSamplesPerSet = numSamplesPerClass / numSets;
8+
9+
% Check if numSamplesPerSet is a positive integer
10+
if ~isreal(numSamplesPerSet) || ~ispositive(numSamplesPerSet) || floor(numSamplesPerSet) ~= numSamplesPerSet
11+
error(['numSamplesPerSet is not a positive integer. The data cannot be evenly divided into ',...
12+
num2str(numSets), ' sets. Please adjust numSets to a value that allows equal division.']);
13+
end
14+
15+
data = reshape(testData, [numDim, numSamplesPerSet, numSets, numClasses]);
16+
end
17+
18+
function bool = ispositive(x)
19+
bool = x > 0;
20+
end

src/utils/subsetTestData.m

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,16 @@
55
% Initialize the new test data structure
66
numSamplesPerClass = size(testData, 2);
77
numSamplesPerSet = numSamplesPerClass / numSets;
8+
9+
% Check if numSamplesPerSet is a positive integer
10+
if ~isreal(numSamplesPerSet) || ~ispositive(numSamplesPerSet) || floor(numSamplesPerSet) ~= numSamplesPerSet
11+
error(['numSamplesPerSet is not a positive integer. The data cannot be evenly divided into ',...
12+
num2str(numSets), ' sets. Please adjust numSets to a value that allows equal division.']);
13+
end
14+
815
data = reshape(testData, [numDim, numSamplesPerSet, numSets, numClasses]);
9-
end
16+
end
17+
18+
function bool = ispositive(x)
19+
bool = x > 0;
20+
end

0 commit comments

Comments
 (0)