|
12 | 12 | num_positive_samples; % Number of positive samples |
13 | 13 | num_negative_samples; % Number of negative samples |
14 | 14 |
|
15 | | - true_labels % true labels |
16 | | - predicted_labels; % Predicted labels |
| 15 | + true_labels % true labels |
| 16 | + predicted_labels; % Predicted labels |
| 17 | + prediction_scores; % Prediction scores |
17 | 18 |
|
| 19 | + model_name; % Name of the model |
18 | 20 | is_similarity; % Flag indicating if values represent similarities (true) or distances (false) |
19 | 21 | end |
20 | | - |
| 22 | + |
21 | 23 | methods |
22 | | - function obj = ModelEvaluation(evaluation_values, labels, is_first_label_zero,optional_flag) |
| 24 | + function obj = ModelEvaluation(evaluation_values, labels, model_name,is_first_label_zero,optional_flag) |
23 | 25 | % Constructor for ModelEvaluation class. |
24 | | - % |
| 26 | + % |
25 | 27 | % Parameters: |
26 | 28 | % - evaluation_values: Matrix of evaluation scores (either similarities or distances). |
27 | 29 | % - labels: Actual labels corresponding to evaluation_values. |
|
31 | 33 |
|
32 | 34 | % Check if values represent similarities or distances |
33 | 35 | obj.is_similarity = true; |
| 36 | + obj.model_name = model_name; |
34 | 37 | % Set default value for is_first_label_zero if not provided |
35 | | - if nargin < 3 || isempty(is_first_label_zero) |
36 | | - is_first_label_zero = true; % Default value |
| 38 | + if nargin < 4 || isempty(is_first_label_zero) |
| 39 | + is_first_label_zero = false; % Default value |
37 | 40 | end |
38 | | - if nargin == 4 && optional_flag == 'D' |
| 41 | + if nargin == 5 && optional_flag == 'D' |
39 | 42 | obj.is_similarity = false; |
40 | 43 | end |
41 | | - |
| 44 | + |
42 | 45 | evaluation_values = evaluation_values(:, :); |
| 46 | + |
| 47 | + if obj.is_similarity |
| 48 | + obj.prediction_scores = evaluation_values; |
| 49 | + end |
| 50 | + |
43 | 51 | is_multiclass_scenario = size(evaluation_values, 1) > 1; |
44 | 52 | if is_multiclass_scenario |
45 | | - |
| 53 | + |
46 | 54 | binary_labels = zeros(size(evaluation_values)); |
47 | 55 | unique_labels = unique(labels); |
48 | 56 | for i = 1:size(unique_labels, 2) |
49 | 57 | binary_labels(i, labels == unique_labels(i)) = 1; |
50 | 58 | end |
51 | | - |
| 59 | + |
52 | 60 | if obj.is_similarity |
53 | 61 | [~, predicted_labels] = max(evaluation_values, [], 1); |
54 | 62 | else |
55 | 63 | [~, predicted_labels] = min(evaluation_values, [], 1); |
56 | 64 | end |
57 | | - |
| 65 | + |
58 | 66 | if is_first_label_zero |
59 | 67 | % do this to accomodate for MatLab indexing |
60 | 68 | predicted_labels = predicted_labels - 1; |
61 | 69 | labels = labels - 1; |
62 | 70 | end |
63 | | - obj.accuracy = mean(predicted_labels == labels); |
| 71 | + obj.accuracy = mean(predicted_labels == labels); |
64 | 72 | obj.error_rate = 1 - obj.accuracy; |
65 | | - |
66 | | - |
| 73 | + |
| 74 | + |
67 | 75 | else |
68 | 76 | binary_labels = zeros(size(labels)); |
69 | 77 | binary_labels(labels ~= 0) = 1; |
70 | 78 | predicted_labels = evaluation_values >= obj.classification_threshold; |
71 | 79 | obj.accuracy = mean(predicted_labels == binary_labels); |
72 | 80 | end |
73 | | - |
| 81 | + |
74 | 82 | evaluation_values = evaluation_values(:); |
75 | 83 | binary_labels = binary_labels(:); |
76 | | - |
| 84 | + |
77 | 85 | obj.num_positive_samples = sum(binary_labels == 1); |
78 | 86 | obj.num_negative_samples = sum(binary_labels == 0); |
79 | | - |
| 87 | + |
80 | 88 | % Sort evaluation values |
81 | 89 | if obj.is_similarity |
82 | 90 | [obj.sorted_values, sorted_indices] = sort(evaluation_values, 'ascend'); |
83 | 91 | else |
84 | 92 | [obj.sorted_values, sorted_indices] = sort(evaluation_values, 'descend'); |
85 | 93 | end |
86 | | - |
| 94 | + |
87 | 95 | sorted_binary_labels = binary_labels(sorted_indices); |
88 | | - |
| 96 | + |
89 | 97 | obj.false_accept_rate = 1 - cumsum(sorted_binary_labels == 0) / obj.num_negative_samples; |
90 | 98 | obj.false_reject_rate = cumsum(sorted_binary_labels == 1) / obj.num_positive_samples; |
91 | | - |
| 99 | + |
92 | 100 | % Compute equal error rate |
93 | 101 | [~, eer_index] = min(abs(obj.false_accept_rate - obj.false_reject_rate)); |
94 | 102 | obj.equal_error_rate = (obj.false_accept_rate(eer_index) + obj.false_reject_rate(eer_index)) / 2; |
|
97 | 105 | obj.true_labels = categorical(labels); |
98 | 106 | obj.predicted_labels = categorical(predicted_labels); |
99 | 107 | end |
100 | | - |
101 | | - function plotConfusionMatrix(obj) |
102 | | - % Plot Confusion Matrix |
103 | | - cm = confusionmat(obj.true_labels, obj.predicted_labels); |
104 | | - heatmap(cm, 'XLabel', 'Predicted Labels', 'YLabel', 'True Labels'); |
105 | | - title('Confusion Matrix'); |
106 | | - end |
107 | | - |
108 | | - function plotErrorRateCDF(obj) |
109 | | - figure; |
110 | | - plot(obj.sorted_values, obj.false_accept_rate, 'b-', obj.sorted_values, obj.false_reject_rate, 'r-'); |
111 | | - xlabel('Threshold'); |
112 | | - ylabel('Error Rate'); |
113 | | - legend('False Accept Rate', 'False Reject Rate'); |
114 | | - title('CDF of Error Rates'); |
115 | | - end |
116 | | - |
117 | 108 |
|
118 | | - |
119 | | - |
| 109 | + function printResults(obj) |
| 110 | + disp('---------- Model Evaluation Results ----------'); |
| 111 | + fprintf('Model: %s\n', obj.model_name'); |
| 112 | + fprintf('Accuracy: %.2f%%\n', obj.accuracy * 100); |
| 113 | + fprintf('Error Rate: %.2f%%\n', obj.error_rate * 100); |
| 114 | + fprintf('Equal Error Rate: %.2f%%\n', obj.equal_error_rate * 100); |
| 115 | + fprintf('Classification Threshold: %.2f\n', obj.classification_threshold); |
| 116 | + % predicted labels |
| 117 | + fprintf('Predicted Labels: \n'); |
| 118 | + disp(obj.predicted_labels); |
| 119 | + disp('----------------------------------------------'); |
| 120 | + end |
120 | 121 | end |
121 | 122 | end |
0 commit comments