Skip to content

Commit 251b1a0

Browse files
authored
fix: xgboost raw parser (#518)
LGTM
1 parent 5ccc619 commit 251b1a0

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ public class XGBoostRawJsonParser implements LtrRankerParser {
2525

2626
public static final String TYPE = "model/xgboost+json+raw";
2727

28-
private static final Integer MISSING_NODE_ID = Integer.MAX_VALUE;
29-
3028
@Override
3129
public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) {
3230
XGBoostRawJsonParser.XGBoostDefinition modelDefinition;
@@ -439,8 +437,16 @@ private NaiveAdditiveDecisionTree.Node asLibTree(Integer nodeId) {
439437
}
440438

441439
if (isSplit(nodeId)) {
442-
return new NaiveAdditiveDecisionTree.Split(asLibTree(leftChildren.get(nodeId)), asLibTree(rightChildren.get(nodeId)),
443-
splitIndices.get(nodeId), splitConditions.get(nodeId), splitIndices.get(nodeId), MISSING_NODE_ID);
440+
Integer missingNodeId =
441+
defaultLeft.get(nodeId) == 1 ? leftChildren.get(nodeId) : rightChildren.get(nodeId);
442+
return new NaiveAdditiveDecisionTree.Split(
443+
asLibTree(leftChildren.get(nodeId)),
444+
asLibTree(rightChildren.get(nodeId)),
445+
splitIndices.get(nodeId),
446+
splitConditions.get(nodeId),
447+
leftChildren.get(nodeId),
448+
missingNodeId
449+
);
444450
} else {
445451
return new NaiveAdditiveDecisionTree.Leaf(baseWeights.get(nodeId));
446452
}

0 commit comments

Comments
 (0)