Skip to content

Commit 72f82d0

Browse files
Santos SafraoSantos Safrao
authored andcommitted
created the MSM class
1 parent 2398bc8 commit 72f82d0

File tree

10 files changed

+219
-722
lines changed

10 files changed

+219
-722
lines changed

examples/msm.m

Lines changed: 14 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,20 @@
1-
%% Load Data
2-
clear;
3-
clc;
4-
load('TsukubaHandDigitsDataset24x24.mat')
5-
% Check if the variable 'trainData' and 'testData' do not exist in the workspace
6-
% if ~(exist('trainData', 'var') == 1 && exist('testData', 'var') == 1)
7-
% % If they don't exist, load the data from the .mat file
8-
% load('TsukubaHandDigitsDataset24x24.mat');
9-
% end
10-
% you can use the following code to convert the test data format from 3d to 4d
11-
testData = subsetTestData(testData, 2);
12-
% specific_class = 5;
13-
%% to accomodate for MATLAB indexing
14-
% specific_class = specific_class + 1;
15-
training_data = trainData;
16-
testing_data = testData;
17-
% testing_data = testData(:, :, specific_class);
18-
size_of_test_data = size(testing_data);
1+
%% Load data
2+
testDataUsage = TestDataUsage.WholeData;
3+
[trainData, testData, testLabels] = prepareData(testDataUsage);
194

20-
% get number of elements of size_of_test_data
21-
array_size = numel(size_of_test_data);
5+
%% Train model
6+
numDimReferenceSubspace = 10;
7+
numDimInputSubspace = 4;
228

23-
if array_size == 4
24-
% do nothing
25-
num_sets = size_of_test_data(3);
26-
num_classes = size_of_test_data(4);
27-
elseif array_size == 3
28-
num_sets = 1;
29-
num_classes = size_of_test_data(3);
30-
else
31-
num_classes = 1;
32-
num_sets = 1;
33-
end
9+
model = MSM(trainData,...
10+
numDimReferenceSubspace,...
11+
numDimInputSubspace,...
12+
testLabels);
13+
14+
%% Evaluate model
15+
modelEvaluation = model.evaluate(testData);
16+
modelEvaluation.printResults();
3417

35-
%% Train Model
36-
num_dim_reference_subspaces = 1;
37-
num_dim_input_subpaces = 1;
3818

39-
reference_subspaces = cvlBasisVector(training_data, num_dim_reference_subspaces);
40-
input_subspaces = cvlBasisVector(testing_data, num_dim_input_subpaces);
41-
% save('reference_subspaces.mat', 'reference_subspaces');
42-
% reference_subspaces = reference_subspaces(:, :, 1);
43-
tic;
44-
%% Recognition Phase
45-
similarities = cvlCanonicalAngles(reference_subspaces, input_subspaces);
46-
similarities = similarities(:, :, end, end);
47-
% End timing and display the elapsed time
48-
elapsedTime = toc;
49-
fprintf('The code block executed in %.5f seconds.\n', elapsedTime);
50-
model_evaluation = ModelEvaluation(similarities, generateLabels(num_classes, num_sets));
5119

52-
displayModelResults('Mutual Subspace Methods', model_evaluation);
5320

54-
%% Print preditions
55-
% disp(model_evaluation.predicted_labels);
56-
% disp(model_evaluation.true_labels);
57-
% disp(similarities)
58-
% plotSimilarities(similarities)

examples/msm2.asv

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
%% Load Data
2+
clear;
3+
clc;
4+
load('TsukubaHandDigitsDataset24x24.mat')
5+
testDataUsage = TestDataUsage.WholeData;
6+
switch testDataUsage
7+
case TestDataUsage.SingleClass
8+
% Code for testing on a single entry
9+
10+
case TestDataUsage.WholeData
11+
% Code for testing on a subset of data
12+
case TestDataUsage.Subsets
13+
% Code for testing on the whole dataset
14+
end
15+
% specific_class = 0;
16+
% specific_class_index = specific_class + 1;
17+
% testing_data = testData(:, :, specific_class_index);
18+
% test_labels = specific_class_index;
19+
20+
training_data = trainData;
21+
testData = subsetTestData(testData, 4);
22+
testing_data = testData;
23+
testLabels = generateLabels(testing_data);
24+
numDimReferenceSubspace = 10;
25+
numDimInputSubspace = 4;
26+
27+
model = MSM(trainData,...
28+
numDimReferenceSubspace, numDimInputSubspace, testLabels);
29+
mode = model.train();
30+
modelEvaluation = model.evaluate(testData);
31+
32+

examples/msm2.m

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
% testDataUsage = TestDataUsage.WholeData;
2+
testDataUsage = TestDataUsage.Subsets;
3+
numSets = 2;
4+
[trainData, testData, testLabels] = prepareData(testDataUsage, numSets);
5+
6+
numDimReferenceSubspace = 10;
7+
numDimInputSubspace = 4;
8+
9+
model = MSM(trainData,...
10+
numDimReferenceSubspace,...
11+
numDimInputSubspace,...
12+
testLabels);
13+
14+
modelEvaluation = model.evaluate(testData);
15+
modelEvaluation.printResults();
16+
17+
18+
19+

src/classes/@MSM/MSM.m

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
classdef MSM
2+
properties
3+
name = 'Mutual Subspace Method';
4+
trainData;
5+
referenceSubspaces;
6+
numDimInputSubspace;
7+
numDimReferenceSubspace;
8+
trueTestLabels;
9+
end
10+
11+
methods
12+
function obj = MSM(trainData, numDimReferenceSubspace, numDimInputSubspace, labels)
13+
obj.trainData = trainData;
14+
obj.numDimReferenceSubspace = numDimReferenceSubspace;
15+
obj.numDimInputSubspace = numDimInputSubspace;
16+
obj.trueTestLabels = labels;
17+
subspaces = cvlBasisVector(obj.trainData,...
18+
obj.numDimReferenceSubspace);
19+
obj.referenceSubspaces = subspaces;
20+
end
21+
22+
23+
% Returns the predicted labels for the test data
24+
function prediction = predict(obj, testData)
25+
similarityScores = obj.getSimilarityScores(testData);
26+
eval = ModelEvaluation(similarityScores, obj.trueTestLabels, obj.name);
27+
prediction = eval.predicted_labels;
28+
end
29+
30+
% Returns the similarity scores for the test data (same as probabilities)
31+
function probabilities = predictProb(obj, testData)
32+
probabilities = obj.getSimilarityScores(testData);
33+
end
34+
35+
% Returns the evaluation object for the test data
36+
function eval = evaluate(obj, testData)
37+
similarityScores = obj.getSimilarityScores(testData);
38+
eval = ModelEvaluation(similarityScores, obj.trueTestLabels, obj.name);
39+
end
40+
end
41+
42+
methods (Access = private)
43+
function scores = getSimilarityScores(obj, testData)
44+
inputSubspace = cvlBasisVector(testData,...
45+
obj.numDimInputSubspace);
46+
similarities = cvlCanonicalAngles(obj.referenceSubspaces,...
47+
inputSubspace);
48+
scores = similarities(:, :, end, end);
49+
end
50+
end
51+
end

src/classes/@ModelEvaluation/ModelEvaluation.m

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@
1212
num_positive_samples; % Number of positive samples
1313
num_negative_samples; % Number of negative samples
1414

15-
true_labels % true labels
16-
predicted_labels; % Predicted labels
15+
true_labels % true labels
16+
predicted_labels; % Predicted labels
17+
prediction_scores; % Prediction scores
1718

19+
model_name; % Name of the model
1820
is_similarity; % Flag indicating if values represent similarities (true) or distances (false)
1921
end
20-
22+
2123
methods
22-
function obj = ModelEvaluation(evaluation_values, labels, is_first_label_zero,optional_flag)
24+
function obj = ModelEvaluation(evaluation_values, labels, model_name,is_first_label_zero,optional_flag)
2325
% Constructor for ModelEvaluation class.
24-
%
26+
%
2527
% Parameters:
2628
% - evaluation_values: Matrix of evaluation scores (either similarities or distances).
2729
% - labels: Actual labels corresponding to evaluation_values.
@@ -31,64 +33,70 @@
3133

3234
% Check if values represent similarities or distances
3335
obj.is_similarity = true;
36+
obj.model_name = model_name;
3437
% Set default value for is_first_label_zero if not provided
35-
if nargin < 3 || isempty(is_first_label_zero)
36-
is_first_label_zero = true; % Default value
38+
if nargin < 4 || isempty(is_first_label_zero)
39+
is_first_label_zero = false; % Default value
3740
end
38-
if nargin == 4 && optional_flag == 'D'
41+
if nargin == 5 && optional_flag == 'D'
3942
obj.is_similarity = false;
4043
end
41-
44+
4245
evaluation_values = evaluation_values(:, :);
46+
47+
if obj.is_similarity
48+
obj.prediction_scores = evaluation_values;
49+
end
50+
4351
is_multiclass_scenario = size(evaluation_values, 1) > 1;
4452
if is_multiclass_scenario
45-
53+
4654
binary_labels = zeros(size(evaluation_values));
4755
unique_labels = unique(labels);
4856
for i = 1:size(unique_labels, 2)
4957
binary_labels(i, labels == unique_labels(i)) = 1;
5058
end
51-
59+
5260
if obj.is_similarity
5361
[~, predicted_labels] = max(evaluation_values, [], 1);
5462
else
5563
[~, predicted_labels] = min(evaluation_values, [], 1);
5664
end
57-
65+
5866
if is_first_label_zero
5967
% do this to accomodate for MatLab indexing
6068
predicted_labels = predicted_labels - 1;
6169
labels = labels - 1;
6270
end
63-
obj.accuracy = mean(predicted_labels == labels);
71+
obj.accuracy = mean(predicted_labels == labels);
6472
obj.error_rate = 1 - obj.accuracy;
65-
66-
73+
74+
6775
else
6876
binary_labels = zeros(size(labels));
6977
binary_labels(labels ~= 0) = 1;
7078
predicted_labels = evaluation_values >= obj.classification_threshold;
7179
obj.accuracy = mean(predicted_labels == binary_labels);
7280
end
73-
81+
7482
evaluation_values = evaluation_values(:);
7583
binary_labels = binary_labels(:);
76-
84+
7785
obj.num_positive_samples = sum(binary_labels == 1);
7886
obj.num_negative_samples = sum(binary_labels == 0);
79-
87+
8088
% Sort evaluation values
8189
if obj.is_similarity
8290
[obj.sorted_values, sorted_indices] = sort(evaluation_values, 'ascend');
8391
else
8492
[obj.sorted_values, sorted_indices] = sort(evaluation_values, 'descend');
8593
end
86-
94+
8795
sorted_binary_labels = binary_labels(sorted_indices);
88-
96+
8997
obj.false_accept_rate = 1 - cumsum(sorted_binary_labels == 0) / obj.num_negative_samples;
9098
obj.false_reject_rate = cumsum(sorted_binary_labels == 1) / obj.num_positive_samples;
91-
99+
92100
% Compute equal error rate
93101
[~, eer_index] = min(abs(obj.false_accept_rate - obj.false_reject_rate));
94102
obj.equal_error_rate = (obj.false_accept_rate(eer_index) + obj.false_reject_rate(eer_index)) / 2;
@@ -97,25 +105,18 @@
97105
obj.true_labels = categorical(labels);
98106
obj.predicted_labels = categorical(predicted_labels);
99107
end
100-
101-
function plotConfusionMatrix(obj)
102-
% Plot Confusion Matrix
103-
cm = confusionmat(obj.true_labels, obj.predicted_labels);
104-
heatmap(cm, 'XLabel', 'Predicted Labels', 'YLabel', 'True Labels');
105-
title('Confusion Matrix');
106-
end
107-
108-
function plotErrorRateCDF(obj)
109-
figure;
110-
plot(obj.sorted_values, obj.false_accept_rate, 'b-', obj.sorted_values, obj.false_reject_rate, 'r-');
111-
xlabel('Threshold');
112-
ylabel('Error Rate');
113-
legend('False Accept Rate', 'False Reject Rate');
114-
title('CDF of Error Rates');
115-
end
116-
117108

118-
119-
109+
function printResults(obj)
110+
disp('---------- Model Evaluation Results ----------');
111+
fprintf('Model: %s\n', obj.model_name');
112+
fprintf('Accuracy: %.2f%%\n', obj.accuracy * 100);
113+
fprintf('Error Rate: %.2f%%\n', obj.error_rate * 100);
114+
fprintf('Equal Error Rate: %.2f%%\n', obj.equal_error_rate * 100);
115+
fprintf('Classification Threshold: %.2f\n', obj.classification_threshold);
116+
% predicted labels
117+
fprintf('Predicted Labels: \n');
118+
disp(obj.predicted_labels);
119+
disp('----------------------------------------------');
120+
end
120121
end
121122
end

0 commit comments

Comments
 (0)