|
27 | 27 | import org.apache.lucene.tests.util.LuceneTestCase; |
28 | 28 | import org.elasticsearch.common.ParsingException; |
29 | 29 | import org.elasticsearch.common.io.Streams; |
| 30 | +import org.elasticsearch.xcontent.XContentParseException; |
30 | 31 | import org.hamcrest.CoreMatchers; |
31 | 32 |
|
32 | 33 | import java.io.ByteArrayOutputStream; |
|
40 | 41 | import static com.o19s.es.ltr.LtrTestUtils.randomFeature; |
41 | 42 | import static com.o19s.es.ltr.LtrTestUtils.randomFeatureSet; |
42 | 43 | import static java.util.Collections.singletonList; |
| 44 | +import static org.hamcrest.Matchers.instanceOf; |
43 | 45 |
|
44 | 46 | public class XGBoostJsonParserTests extends LuceneTestCase { |
45 | 47 | private final XGBoostJsonParser parser = new XGBoostJsonParser(); |
@@ -251,6 +253,28 @@ public void testMissingFeat() throws IOException { |
251 | 253 | CoreMatchers.containsString("Unknown feature [feat2]")); |
252 | 254 | } |
253 | 255 |
|
| 256 | + public void testInvalidLeaf() throws IOException { |
| 257 | + // The leaf nodes are missing nodeid field. |
| 258 | + String model = "[{" + |
| 259 | + "\"nodeid\": 0," + |
| 260 | + "\"split\":\"feat1\"," + |
| 261 | + "\"depth\":0," + |
| 262 | + "\"split_condition\":0.123," + |
| 263 | + "\"yes\":1," + |
| 264 | + "\"no\": 2," + |
| 265 | + "\"missing\":1,"+ |
| 266 | + "\"children\": [" + |
| 267 | + " {\"depth\": 1, \"leaf\": 0.5}," + |
| 268 | + " {\"depth\": 1, \"leaf\": 0.2}" + |
| 269 | + "]}]"; |
| 270 | + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); |
| 271 | + // In this test case, the ParsingException is wrapped in an XContentParseException, because the |
| 272 | + // ParsingException that occurs while parsing the invalid leaf node happens within the ObjectParser. |
| 273 | + Throwable e = expectThrows(XContentParseException.class, () -> parser.parse(set, model)).getCause(); |
| 274 | + assertThat(e, instanceOf(ParsingException.class)); |
| 275 | + assertThat(e.getMessage(), CoreMatchers.containsString("This leaf does not have all the required fields")); |
| 276 | + } |
| 277 | + |
254 | 278 | public void testComplexModel() throws Exception { |
255 | 279 | String model = readModel("/models/xgboost-wmf.json"); |
256 | 280 | List<StoredFeature> features = new ArrayList<>(); |
|
0 commit comments