1515import java .io .IOException ;
1616import java .util .Arrays ;
1717import java .util .ArrayList ;
18+ import java .util .HashMap ;
1819import java .util .List ;
1920import java .util .ListIterator ;
21+ import java .util .Map ;
2022import java .util .Optional ;
2123
2224public class XGBoostRawJsonParser implements LtrRankerParser {
@@ -37,9 +39,42 @@ public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) {
3739 }
3840
3941 NaiveAdditiveDecisionTree .Node [] trees = modelDefinition .getLearner ().getTrees (set );
42+ List <String > modelFeatures = modelDefinition .learner .featureNames ;
43+
44+ // remap features according to the order in the feature set
45+ Map <Integer , Integer > modelFeaturesReordering = new HashMap <>();
46+ for (int i = 0 ; i < modelFeatures .size (); i ++) {
47+ modelFeaturesReordering .put (i , set .featureOrdinal (modelFeatures .get (i )));
48+ }
49+
50+ // Reorder features in each tree
51+ NaiveAdditiveDecisionTree .Node [] adjustedTrees = new NaiveAdditiveDecisionTree .Node [trees .length ];
52+ for (int i = 0 ; i < trees .length ; i ++) {
53+ adjustedTrees [i ] = reorderTreeFeatures (trees [i ], modelFeaturesReordering );
54+ }
55+
4056 float [] weights = new float [trees .length ];
4157 Arrays .fill (weights , 1F );
42- return new NaiveAdditiveDecisionTree (trees , weights , set .size (), modelDefinition .getLearner ().getObjective ().getNormalizer ());
58+ return new NaiveAdditiveDecisionTree (
59+ adjustedTrees , weights , set .size (), modelDefinition .getLearner ().getObjective ().getNormalizer ()
60+ );
61+ }
62+
63+ private NaiveAdditiveDecisionTree .Node reorderTreeFeatures (NaiveAdditiveDecisionTree .Node node ,
64+ Map <Integer , Integer > modelFeaturesReordering ) {
65+ if (node instanceof NaiveAdditiveDecisionTree .Split splitNode ) {
66+ return new NaiveAdditiveDecisionTree .Split (
67+ reorderTreeFeatures (splitNode .getLeft (), modelFeaturesReordering ),
68+ reorderTreeFeatures (splitNode .getRight (), modelFeaturesReordering ),
69+ modelFeaturesReordering .get (splitNode .getFeature ()),
70+ splitNode .getThreshold (),
71+ splitNode .getLeftNodeId (),
72+ splitNode .getMissingNodeId ()
73+ );
74+ }
75+
76+ // if the node is Leaf we don't do anything
77+ return node ;
4378 }
4479
4580 private static class XGBoostDefinition {
@@ -95,6 +130,7 @@ public static XGBoostRawJsonParser.XGBoostDefinition parse(XContentParser parser
95130 } else {
96131 throw new ParsingException (parser .getTokenLocation (), "Expected [START_OBJECT] but got [" + startToken + "]" );
97132 }
133+
98134 return definition ;
99135 }
100136
0 commit comments