Skip to content

Commit ed01387

Browse files
authored
Merge pull request #156 from howsunjow/bug/categorical_metric_calculations
Bug/categorical metric calculations
2 parents 3cbb3e2 + 1f5b795 commit ed01387

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed

niaarm/rule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def __post_init__(self, transactions):
237237
contains_antecedent &= transactions[attribute.name] >= attribute.min_val
238238
else:
239239
contains_antecedent &= (
240-
transactions[attribute.name] == attribute.categories[0]
240+
np.isin(transactions[attribute.name], attribute.categories)
241241
)
242242

243243
self.antecedent_count = contains_antecedent.sum()
@@ -255,7 +255,7 @@ def __post_init__(self, transactions):
255255
contains_consequent &= transactions[attribute.name] >= attribute.min_val
256256
else:
257257
contains_consequent &= (
258-
transactions[attribute.name] == attribute.categories[0]
258+
np.isin(transactions[attribute.name], attribute.categories)
259259
)
260260
self.__amplitude = 1 - (1 / (len(self.antecedent) + len(self.consequent))) * acc
261261
self.consequent_count = contains_consequent.sum()

tests/test_metrics.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import pandas as pd
23
from unittest import TestCase
34
from niaarm import Dataset, Feature, Rule
45

@@ -74,3 +75,61 @@ def test_zhang(self):
7475
def test_leverage(self):
7576
self.assertAlmostEqual(self.rule_one.leverage, 0.102040816326)
7677
self.assertAlmostEqual(self.rule_two.leverage, 0.102040816326)
78+
79+
80+
class TestMetricsMultipleCategories(TestCase):
81+
def setUp(self):
82+
self.data = Dataset(
83+
pd.DataFrame({"col1": [1.5, 2.5, 1.0], "col2": ["Green", "Blue", "Red"]})
84+
)
85+
self.rule = Rule(
86+
[Feature("col1", dtype="float", min_val=1.0, max_val=1.5)],
87+
[Feature("col2", dtype="cat", categories=["Red", "Green"])],
88+
transactions=self.data.transactions,
89+
)
90+
91+
def test_support(self):
92+
self.assertEqual(self.rule.support, 2 / 3)
93+
94+
def test_confidence(self):
95+
self.assertEqual(self.rule.confidence, 1)
96+
97+
def test_lift(self):
98+
self.assertEqual(self.rule.lift, 1.5)
99+
100+
def test_coverage(self):
101+
self.assertEqual(self.rule.coverage, 2 / 3)
102+
103+
def test_rhs_support(self):
104+
self.assertEqual(self.rule.rhs_support, 2 / 3)
105+
106+
def test_conviction(self):
107+
self.assertAlmostEqual(
108+
self.rule.conviction,
109+
(1 - self.rule.rhs_support)
110+
/ (1 - self.rule.confidence + 2.220446049250313e-16),
111+
)
112+
113+
def test_amplitude(self):
114+
self.assertEqual(self.rule.amplitude, 5 / 6)
115+
116+
def test_inclusion(self):
117+
self.assertEqual(self.rule.inclusion, 1)
118+
119+
def test_interestingness(self):
120+
self.assertEqual(self.rule.interestingness, 1 * 1 * (1 - (2 / 3) / 3))
121+
122+
def test_comprehensibility(self):
123+
self.assertAlmostEqual(self.rule.comprehensibility, 0.630929753571)
124+
125+
def test_netconf(self):
126+
self.assertAlmostEqual(self.rule.netconf, ((2/3) - (2/3 * 2/3))/(2/3 * 1/3))
127+
128+
def test_yulesq(self):
129+
self.assertAlmostEqual(self.rule.yulesq,1)
130+
131+
def test_zhang(self):
132+
self.assertAlmostEqual(self.rule.zhang, 1)
133+
134+
def test_leverage(self):
135+
self.assertAlmostEqual(self.rule.leverage, 2/3 - (2/3 * 2/3))

0 commit comments

Comments
 (0)