Skip to content

Commit 0648bad

Browse files
committed
support generic types to avoid typecasting what Decision.get() returns.
1 parent 7580b9e commit 0648bad

File tree

6 files changed

+118
-19
lines changed

6 files changed

+118
-19
lines changed

improveai/src/main/java/ai/improve/Decision.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import ai.improve.log.IMPLog;
77
import ai.improve.util.ModelUtils;
88

9-
public class Decision {
9+
public class Decision<T> {
1010
public static final String Tag = "Decision";
1111

1212
private DecisionModel model;
1313

14-
protected List<?> variants;
14+
protected List<T> variants;
1515

1616
protected Map<String, Object> givens;
1717

@@ -25,7 +25,7 @@ public class Decision {
2525
* */
2626
protected int tracked;
2727

28-
protected Object best;
28+
protected T best;
2929

3030
// The message_id of the tracked decision
3131
protected String id;
@@ -39,7 +39,7 @@ protected Decision(DecisionModel model) {
3939
* @return Returns the chosen variant memoized.
4040
* @throws IllegalStateException Thrown if called before chooseFrom()
4141
*/
42-
public Object peek() {
42+
public T peek() {
4343
return best;
4444
}
4545

@@ -49,7 +49,7 @@ public Object peek() {
4949
* might return null.
5050
* @throws IllegalStateException Thrown if variants is null or empty.
5151
* */
52-
public synchronized Object get() {
52+
public synchronized T get() {
5353
if(tracked == 0) {
5454
DecisionTracker tracker = model.getTracker();
5555
if (tracker != null) {

improveai/src/main/java/ai/improve/DecisionContext.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public DecisionContext(DecisionModel decisionModel, Map givens) {
2222
/**
2323
* @see ai.improve.DecisionModel#chooseFrom(List)
2424
*/
25-
public Decision chooseFrom(List variants) {
25+
public <T> Decision<T> chooseFrom(List<T> variants) {
2626
if(variants == null || variants.size() <= 0) {
2727
throw new IllegalArgumentException("variants to choose from can't be null or empty");
2828
}
@@ -43,7 +43,7 @@ public Decision chooseFrom(List variants) {
4343
/**
4444
* @see ai.improve.DecisionModel#chooseFrom(List, List)
4545
*/
46-
public Decision chooseFrom(List variants, List scores) {
46+
public <T> Decision<T> chooseFrom(List<T> variants, List<Double> scores) {
4747
if(variants == null || scores == null || variants.size() <= 0) {
4848
throw new IllegalArgumentException("variants and scores can't be null or empty");
4949
}
@@ -67,7 +67,7 @@ public Decision chooseFrom(List variants, List scores) {
6767
/**
6868
* @see ai.improve.DecisionModel#chooseFrom(List)
6969
*/
70-
public Decision chooseFirst(List variants) {
70+
public <T> Decision<T> chooseFirst(List<T> variants) {
7171
if(variants == null || variants.size() <= 0) {
7272
throw new IllegalArgumentException("variants can't be null or empty");
7373
}
@@ -98,7 +98,7 @@ public Object first(Object... variants) {
9898
/**
9999
* @see ai.improve.DecisionModel#chooseRandom(List)
100100
*/
101-
public Decision chooseRandom(List variants) {
101+
public <T> Decision<T> chooseRandom(List<T> variants) {
102102
if(variants == null || variants.size() <= 0) {
103103
throw new IllegalArgumentException("variants can't be null or empty");
104104
}

improveai/src/main/java/ai/improve/DecisionModel.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ protected FeatureEncoder getFeatureEncoder() {
278278
* booleans.
279279
* @return an IMPDecision object.
280280
* */
281-
public <T> Decision chooseFrom(List<T> variants) {
281+
public <T> Decision<T> chooseFrom(List<T> variants) {
282282
return given(null).chooseFrom(variants);
283283
}
284284

@@ -291,7 +291,7 @@ public <T> Decision chooseFrom(List<T> variants) {
291291
* @throws IllegalArgumentException Thrown if variants or scores is null or empty; Thrown if
292292
* variants.size() != scores.size().
293293
*/
294-
public Decision chooseFrom(List variants, List scores) {
294+
public <T> Decision<T> chooseFrom(List<T> variants, List<Double> scores) {
295295
return given(null).chooseFrom(variants, scores);
296296
}
297297

@@ -338,7 +338,7 @@ public Object which(Object... variants) {
338338
* @param variants See chooseFrom()
339339
* @return A Decision object which has the first variant as the best.
340340
*/
341-
public Decision chooseFirst(List variants) {
341+
public <T> Decision<T> chooseFirst(List<T> variants) {
342342
if(variants == null || variants.size() <= 0) {
343343
throw new IllegalArgumentException("variants can't be null or empty");
344344
}
@@ -362,7 +362,7 @@ public Object first(Object... variants) {
362362
* @return A Decision object containing a random variant as the decision.
363363
* @throws IllegalArgumentException Thrown if variants is null or empty.
364364
*/
365-
public Decision chooseRandom(List variants) {
365+
public <T> Decision<T> chooseRandom(List<T> variants) {
366366
return given(null).chooseRandom(variants);
367367
}
368368

@@ -399,7 +399,7 @@ protected Map<String, Object> combinedGivens(Map<String, Object> givens) {
399399
* @throws IllegalArgumentException Thrown if variants is null or empty.
400400
* @return scores of the variants
401401
*/
402-
public <T> List<Double> score(List<T> variants) {
402+
public List<Double> score(List variants) {
403403
return scoreInternal(variants, combinedGivens(null));
404404
}
405405

@@ -416,12 +416,12 @@ public <T> List<Double> score(List<T> variants) {
416416
* @throws IllegalArgumentException Thrown if variants is null or empty
417417
* @return scores of the variants
418418
*/
419-
protected <T> List<Double> scoreInternal(List<T> variants, Map<String, ?> givens) {
419+
protected List<Double> scoreInternal(List variants, Map<String, ?> givens) {
420420
if(variants == null || variants.size() <= 0) {
421421
throw new IllegalArgumentException("variants can't be null or empty");
422422
}
423423

424-
IMPLog.d(Tag, "givens: " + givens);
424+
// IMPLog.d(Tag, "givens: " + givens);
425425

426426
if(predictor == null) {
427427
// When tracking a decision like this:
@@ -516,7 +516,7 @@ public int compare(Integer obj1, Integer obj2) {
516516
}
517517
});
518518

519-
List<T> result = new ArrayList<T>(variants.size());
519+
List<T> result = new ArrayList<>(variants.size());
520520
for(int i = 0; i < indices.length; ++i) {
521521
result.add(variants.get(indices[i]));
522522
}

improveai/src/test/java/ai/improve/DecisionContextTest.java

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,14 @@ public static void setUp() {
2727
DecisionModel.setDefaultTrackURL(Track_URL);
2828
}
2929

30-
private List variants() {
30+
private List<String> variants() {
3131
return Arrays.asList("Hello", "Hi", "Hey");
3232
}
3333

34+
private List<Double> scores() {
35+
return Arrays.asList(0.1, 0.2, 0.3);
36+
}
37+
3438
private Map givens() {
3539
return Map.of("lang", "en");
3640
}
@@ -43,6 +47,13 @@ public void testChooseFrom() {
4347
assertNotNull(decision);
4448
}
4549

50+
@Test
51+
public void testChooseFrom_generic() {
52+
DecisionModel decisionModel = new DecisionModel("greetings");
53+
String greeting = decisionModel.given(null).chooseFrom(variants()).get();
54+
IMPLog.d(Tag, "greeting is " + greeting);
55+
}
56+
4657
@Test
4758
public void testChooseFrom_null_variants() {
4859
DecisionModel decisionModel = new DecisionModel("greetings");
@@ -81,6 +92,13 @@ public void testChooseFromVaiantsAndScores() {
8192
assertEquals("en", decision.givens.get("lang"));
8293
}
8394

95+
@Test
96+
public void testChooseFromVaiantsAndScores_generic() {
97+
DecisionModel decisionModel = new DecisionModel("greetings");
98+
String greeting = decisionModel.given(null).chooseFrom(variants(), scores()).get();
99+
IMPLog.d(Tag, "greeting is " + greeting);
100+
}
101+
84102
@Test
85103
public void testChooseFromVaiantsAndScores_null_variants() {
86104
Map givens = Map.of("lang", "en");
@@ -130,10 +148,18 @@ public void testChooseFirst() {
130148
List variants = Arrays.asList("hi", "hello", "hey");
131149
DecisionModel decisionModel = new DecisionModel("greetings");
132150
Decision decision = decisionModel.given(givens).chooseFirst(variants);
133-
assertEquals(givens, decision.givens);
151+
assertEquals(1, decision.givens.size());
152+
assertEquals("en", decision.givens.get("lang"));
134153
assertEquals("hi", decision.get());
135154
}
136155

156+
@Test
157+
public void testChooseFirst_generic() {
158+
DecisionModel decisionModel = new DecisionModel("greetings");
159+
String greeting = decisionModel.given(null).chooseFirst(variants()).get();
160+
IMPLog.d(Tag, "greeting is " + greeting);
161+
}
162+
137163
@Test
138164
public void testChooseFirst_null_variants() {
139165
Map givens = Map.of("lang", "en");
@@ -224,6 +250,15 @@ public void testChooseRandom() {
224250
assertEquals(loop/3, countMap.get("hey"), 100);
225251
}
226252

253+
@Test
254+
public void testChooseRandom_generic() {
255+
Map givens = Map.of("lang", "en");
256+
List<String> variants = Arrays.asList("hi", "hello", "hey");
257+
DecisionModel decisionModel = new DecisionModel("greetings");
258+
String greeting = decisionModel.given(givens).chooseRandom(variants).get();
259+
IMPLog.d(Tag, "greeting is " + greeting);
260+
}
261+
227262
@Test
228263
public void testChooseRandom_null_variants() {
229264
DecisionModel decisionModel = new DecisionModel("greetings");

improveai/src/test/java/ai/improve/DecisionModelTest.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ public class DecisionModelTest {
4343
DecisionModel.setDefaultTrackApiKey(Track_Api_Key);
4444
}
4545

46+
private List<String> variants() {
47+
return Arrays.asList("Hello", "Hi", "Hey");
48+
}
49+
50+
private List<Double> scores() {
51+
return Arrays.asList(0.1, 0.2, 0.3);
52+
}
53+
4654
@BeforeEach
4755
public void setUp() throws Exception {
4856
IMPLog.d(Tag, "setUp");
@@ -319,6 +327,16 @@ public void testRankInvalid_largerScores() {
319327
fail("An IndexOutOfBoundException should have been thrown, we should never reach here");
320328
}
321329

330+
@Test
331+
public void testRank_generic() {
332+
List<String> variants = Arrays.asList("hi", "hello", "hey");
333+
List<Double> scores = Arrays.asList(0.1, 0.2, 0.3);
334+
// test that we don't have to do type cast here like:
335+
// String greeting = (String) DecisionModel.rank(variants, scores).get(0);
336+
String greeting = DecisionModel.rank(variants, scores).get(0);
337+
assertEquals("hey", greeting);
338+
}
339+
322340
@Test
323341
public void testTopScoringVariant() {
324342
int count = 100;
@@ -469,6 +487,13 @@ public void testChooseFrom() {
469487
// product = IMPDecisionModel.lo(modelUrl).chooseFrom(["clutch", "dress", "jacket"]).get()
470488
}
471489

490+
@Test
491+
public void testChooseFrom_generic() {
492+
DecisionModel decisionModel = new DecisionModel("greetings");
493+
String greeting = decisionModel.chooseFrom(variants()).get();
494+
IMPLog.d(Tag, "greetings is " + greeting);
495+
}
496+
472497
@Test
473498
public void testChooseFromVariantsAndScores() {
474499
List variants = Arrays.asList("hi", "hello", "Hey");
@@ -481,6 +506,13 @@ public void testChooseFromVariantsAndScores() {
481506
assertEquals(variants, decision.variants);
482507
}
483508

509+
@Test
510+
public void testChooseFromVariantsAndScores_generic() {
511+
DecisionModel decisionModel = new DecisionModel("greetings");
512+
String greeting = decisionModel.chooseFrom(variants(), scores()).get();
513+
IMPLog.d(Tag, "greetings is " + greeting);
514+
}
515+
484516
@Test
485517
public void testChooseFromVariantsAndScores_empty_variants() {
486518
List variants = new ArrayList();
@@ -720,6 +752,13 @@ public void testChooseFirst() {
720752
assertEquals(variants.size(), decision.scores.size());
721753
}
722754

755+
@Test
756+
public void testChooseFirst_generic() {
757+
DecisionModel decisionModel = new DecisionModel("greetings");
758+
String greeting = decisionModel.chooseFirst(variants()).get();
759+
IMPLog.d(Tag, "greetings is " + greeting);
760+
}
761+
723762
@Test
724763
public void testChooseFirst_null_variants() {
725764
List variants = null;
@@ -831,6 +870,13 @@ public void testChooseRandom() {
831870
assertEquals(loop/3, countMap.get("hey"), 100);
832871
}
833872

873+
@Test
874+
public void testChooseRandom_generic() {
875+
DecisionModel decisionModel = new DecisionModel("greetings");
876+
String greeting = decisionModel.chooseRandom(variants()).get();
877+
IMPLog.d(Tag, "greetings is " + greeting);
878+
}
879+
834880
@Test
835881
public void testChooseRandom_empty_variants() {
836882
List variants = new ArrayList();

improveai/src/test/java/ai/improve/DecisionTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,22 @@ public void testAddReward_negative_infinity() {
162162
}
163163
fail(DefaultFailMessage);
164164
}
165+
166+
@Test
167+
public void testGet_generic() {
168+
List<String> variants = Arrays.asList("hi", "hello", "Hey");
169+
DecisionModel decisionModel = new DecisionModel("greetings");
170+
// Unit test that no type cast needed here
171+
String greeting = decisionModel.chooseFrom(variants).get();
172+
assertEquals("hi", greeting);
173+
}
174+
175+
@Test
176+
public void testPeek_generic() {
177+
List<String> variants = Arrays.asList("hi", "hello", "Hey");
178+
DecisionModel decisionModel = new DecisionModel("greetings");
179+
// Unit test that no type cast needed here
180+
String greeting = decisionModel.chooseFrom(variants).peek();
181+
assertEquals("hi", greeting);
182+
}
165183
}

0 commit comments

Comments
 (0)