Skip to content

Commit 3e57e08

Browse files
committed
throw exception when loading model with incorrect version
1 parent 1ed533b commit 3e57e08

File tree

7 files changed

+156
-2
lines changed

7 files changed

+156
-2
lines changed

improveai/src/main/java/ai/improve/xgbpredictor/ImprovePredictor.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ public class ImprovePredictor implements Serializable {
2222
private ObjFunction obj;
2323
private GradBooster gbm;
2424
private ModelMetadata modelMetadata;
25-
// private IMPModelMetadata modelMetadata;
2625

2726
private float base_score;
2827

improveai/src/main/java/ai/improve/xgbpredictor/ModelMetadata.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package ai.improve.xgbpredictor;
22

3+
import ai.improve.constants.BuildProperties;
34
import biz.k11i.xgboost.util.ModelReader;
45
import com.google.gson.JsonArray;
56
import com.google.gson.JsonObject;
@@ -15,6 +16,9 @@ public class ModelMetadata {
1516
private static final String Tag = "ModelMetadata";
1617

1718
public static final String USER_DEFINED_METADATA = "user_defined_metadata";
19+
20+
public static final String IMPROVE_VERSION_KEY = "ai.improve.version";
21+
1822
private Map<String, String> storage = new HashMap<>();
1923

2024
private String modelName;
@@ -63,6 +67,13 @@ public String getUserDefinedMetadata() {
6367
private void parseMetadata(String value) throws IOException {
6468
try {
6569
JsonObject root = JsonParser.parseString(value).getAsJsonObject().getAsJsonObject("json");
70+
if(root.has(IMPROVE_VERSION_KEY)) {
71+
String modelVersion = root.get(IMPROVE_VERSION_KEY).getAsString();
72+
if(!canParseModel(modelVersion, BuildProperties.getSDKVersion())) {
73+
throw new IOException("Major version don't match. ImproveAI SDK version(" + BuildProperties.getSDKVersion()+") " +
74+
"can't load the model of version("+ modelVersion + ").");
75+
}
76+
}
6677
modelName = root.get("model_name").getAsString();
6778
modelSeed = root.get("model_seed").getAsLong();
6879

@@ -71,8 +82,23 @@ private void parseMetadata(String value) throws IOException {
7182
for (int i = 0; i < featuresArray.size(); ++i) {
7283
modelFeatureNames.add(featuresArray.get(i).getAsString());
7384
}
74-
} catch (Throwable t) {
85+
} catch (RuntimeException e) {
7586
throw new IOException("Failed to parse the model metadata. Looks like the model being loaded is invalid.");
7687
}
7788
}
89+
90+
/**
91+
* Check if the SDK can parse the model.
92+
* @return Returns true, if {@value IMPROVE_VERSION_KEY} property is null;
93+
* Returns true if the @{value IMPROVE_VERSION_KEY} property is not null and its major version
94+
* matches the major version of the SDK; otherwise, return false.
95+
*/
96+
public static boolean canParseModel(String modelVersion, String sdkVersion) {
97+
if(modelVersion == null) {
98+
return true;
99+
}
100+
String modelMajorVersion = modelVersion.split("\\.")[0];
101+
String sdkMajorVersion = sdkVersion.split("\\.")[0];
102+
return modelMajorVersion.equals(sdkMajorVersion);
103+
}
78104
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package ai.improve.xgbpredictor;
2+
3+
import org.junit.jupiter.api.Test;
4+
5+
import static org.junit.jupiter.api.Assertions.assertEquals;
6+
import static org.junit.jupiter.api.Assertions.assertFalse;
7+
import static org.junit.jupiter.api.Assertions.assertTrue;
8+
import static org.junit.jupiter.api.Assertions.fail;
9+
10+
import static ai.improve.DecisionModelTest.DefaultFailMessage;
11+
12+
import java.io.DataOutputStream;
13+
import java.io.File;
14+
import java.io.FileInputStream;
15+
import java.io.FileOutputStream;
16+
import java.io.IOException;
17+
import java.net.URISyntaxException;
18+
import java.net.URL;
19+
import java.nio.ByteBuffer;
20+
import java.nio.ByteOrder;
21+
22+
import ai.improve.log.IMPLog;
23+
import biz.k11i.xgboost.util.ModelReader;
24+
25+
public class ModelMetadataTest {
26+
public static final String Tag = "ModelMetadataTest";
27+
28+
@Test
29+
public void testCanParseModel() {
30+
assertTrue(ModelMetadata.canParseModel(null, "7.0.1"));
31+
assertTrue(ModelMetadata.canParseModel("7.0.1", "7.0.1"));
32+
assertTrue(ModelMetadata.canParseModel("7.0.1", "7.0"));
33+
assertTrue(ModelMetadata.canParseModel("7.0.1", "7"));
34+
assertTrue(ModelMetadata.canParseModel("7.0", "7.0.1"));
35+
assertTrue(ModelMetadata.canParseModel("7", "7.0.1"));
36+
assertTrue(ModelMetadata.canParseModel("7.1.1", "7.0.1"));
37+
assertFalse(ModelMetadata.canParseModel("6.1.1", "7.0.1"));
38+
assertFalse(ModelMetadata.canParseModel("V7.1.1", "7.0.1"));
39+
assertFalse(ModelMetadata.canParseModel("77.1.1", "7.0.1"));
40+
assertFalse(ModelMetadata.canParseModel("", "7.0.1"));
41+
}
42+
43+
@Test
44+
public void testParseMetadata_valid() throws IOException, URISyntaxException {
45+
URL resource = getClass().getClassLoader().getResource("metadata/metadata_valid");
46+
ModelReader modelReader = new ModelReader(new FileInputStream(new File(resource.toURI())));
47+
ModelMetadata metadata = new ModelMetadata(modelReader);
48+
IMPLog.d(Tag, metadata.getModelName());
49+
}
50+
51+
@Test
52+
public void testParseMetadata_invalid() throws URISyntaxException {
53+
try {
54+
URL resource = getClass().getClassLoader().getResource("metadata/metadata_invalid");
55+
ModelReader modelReader = new ModelReader(new FileInputStream(new File(resource.toURI())));
56+
ModelMetadata metadata = new ModelMetadata(modelReader);
57+
IMPLog.d(Tag, metadata.getModelName());
58+
} catch (IOException e) {
59+
e.printStackTrace();
60+
return ;
61+
}
62+
fail("An IOException should have been thrown");
63+
}
64+
65+
@Test
66+
public void testParseMetadata_no_version() throws IOException, URISyntaxException {
67+
URL resource = getClass().getClassLoader().getResource("metadata/metadata_no_version");
68+
ModelReader modelReader = new ModelReader(new FileInputStream(new File(resource.toURI())));
69+
ModelMetadata metadata = new ModelMetadata(modelReader);
70+
assertEquals("test", metadata.getModelName());
71+
}
72+
73+
@Test
74+
public void testParseMetadata_outdated_version() throws URISyntaxException {
75+
try {
76+
URL resource = getClass().getClassLoader().getResource("metadata/metadata_outdated_version");
77+
ModelReader modelReader = new ModelReader(new FileInputStream(new File(resource.toURI())));
78+
ModelMetadata metadata = new ModelMetadata(modelReader);
79+
} catch (IOException e) {
80+
e.printStackTrace();
81+
assertTrue(e.getMessage().startsWith("Major version don't match"));
82+
return ;
83+
}
84+
fail(DefaultFailMessage);
85+
}
86+
87+
// Generate metadata for testing
88+
// Please copy the generated files to directory 'resources/metadata/'
89+
@Test
90+
public void testGen() throws IOException {
91+
// Valid metadata
92+
String userDefined = "{\"json\":{\"model_name\":\"test\",\"model_seed\":100000000,\"ai.improve.version\":\"7.0.1\",\"feature_names\":[\"12345678\"]}}";
93+
String path = "./metadata_valid";
94+
genMetadata(userDefined, path);
95+
96+
// Invalid userDefined
97+
userDefined = "\"json\":{\"model_name\":\"test\",\"model_seed\":100000000,\"ai.improve.version\":\"7.0.1\",\"feature_names\":[\"12345678\"]}}";
98+
path = "./metadata_invalid";
99+
genMetadata(userDefined, path);
100+
101+
// No model version
102+
userDefined = "{\"json\":{\"model_name\":\"test\",\"model_seed\":100000000,\"feature_names\":[\"12345678\"]}}";
103+
path = "./metadata_no_version";
104+
genMetadata(userDefined, path);
105+
106+
// outdated model version
107+
userDefined = "{\"json\":{\"model_name\":\"test\",\"model_seed\":100000000,\"ai.improve.version\":\"1.0.1\",\"feature_names\":[\"12345678\"]}}";
108+
path = "./metadata_outdated_version";
109+
genMetadata(userDefined, path);
110+
}
111+
112+
private void genMetadata(String userDefined, String path) throws IOException {
113+
FileOutputStream fos = new FileOutputStream(new File(path));
114+
DataOutputStream dos = new DataOutputStream(fos);
115+
dos.write(longToLittleEndianBytes(1));
116+
dos.write(longToLittleEndianBytes("user_defined_metadata".getBytes().length));
117+
dos.write("user_defined_metadata".getBytes());
118+
dos.write(longToLittleEndianBytes(userDefined.getBytes().length));
119+
dos.write(userDefined.getBytes());
120+
dos.close();
121+
}
122+
123+
private byte[] longToLittleEndianBytes(long v) {
124+
ByteBuffer buffer = ByteBuffer.allocate(8);
125+
buffer.order(ByteOrder.LITTLE_ENDIAN);
126+
buffer.putLong(v);
127+
return buffer.array();
128+
}
129+
}
155 Bytes
Binary file not shown.
127 Bytes
Binary file not shown.
156 Bytes
Binary file not shown.
156 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)