Skip to content

Commit 3b32ad1

Browse files
committed
Enable CCQ for inference commands.
1 parent 004efb7 commit 3b32ad1

File tree

5 files changed

+43
-26
lines changed

5 files changed

+43
-26
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,7 +1120,9 @@ public PlanFactory visitRerankCommand(EsqlBaseParser.RerankCommandContext ctx) {
11201120
}
11211121

11221122
return p -> {
1123-
checkForRemoteClusters(p, source, "RERANK");
1123+
if (EsqlCapabilities.Cap.INFERENCE_CCQ_SUPPORT.isEnabled() == false) {
1124+
checkForRemoteClusters(p, source, "RERANK");
1125+
}
11241126
return applyRerankOptions(new Rerank(source, p, queryText, rerankFields, scoreAttribute), ctx.commandNamedParameters());
11251127
};
11261128
}
@@ -1161,7 +1163,9 @@ public PlanFactory visitCompletionCommand(EsqlBaseParser.CompletionCommandContex
11611163
}
11621164

11631165
return p -> {
1164-
checkForRemoteClusters(p, source, "COMPLETION");
1166+
if (EsqlCapabilities.Cap.INFERENCE_CCQ_SUPPORT.isEnabled() == false) {
1167+
checkForRemoteClusters(p, source, "COMPLETION");
1168+
}
11651169
return applyCompletionOptions(new Completion(source, p, prompt, targetField), ctx.commandNamedParameters());
11661170
};
11671171
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
import org.elasticsearch.xpack.esql.core.tree.Source;
2424
import org.elasticsearch.xpack.esql.core.type.DataType;
2525
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
26+
import org.elasticsearch.xpack.esql.plan.logical.ExecutesOn;
2627
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
28+
import org.elasticsearch.xpack.esql.plan.logical.PipelineBreaker;
2729

2830
import java.io.IOException;
2931
import java.util.List;
@@ -33,7 +35,12 @@
3335
import static org.elasticsearch.xpack.esql.core.type.DataType.TEXT;
3436
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
3537

36-
public class Completion extends InferencePlan<Completion> implements TelemetryAware, PostAnalysisVerificationAware {
38+
public class Completion extends InferencePlan<Completion>
39+
implements
40+
TelemetryAware,
41+
PostAnalysisVerificationAware,
42+
PipelineBreaker,
43+
ExecutesOn.Coordinator {
3744

3845
public static final String DEFAULT_OUTPUT_FIELD_NAME = "completion";
3946

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Rerank.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
import org.elasticsearch.xpack.esql.core.type.DataType;
2727
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
2828
import org.elasticsearch.xpack.esql.plan.logical.Eval;
29+
import org.elasticsearch.xpack.esql.plan.logical.ExecutesOn;
2930
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
31+
import org.elasticsearch.xpack.esql.plan.logical.PipelineBreaker;
3032
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
3133

3234
import java.io.IOException;
@@ -37,7 +39,12 @@
3739
import static org.elasticsearch.xpack.esql.common.Failure.fail;
3840
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
3941

40-
public class Rerank extends InferencePlan<Rerank> implements PostAnalysisVerificationAware, TelemetryAware {
42+
public class Rerank extends InferencePlan<Rerank>
43+
implements
44+
PostAnalysisVerificationAware,
45+
TelemetryAware,
46+
PipelineBreaker,
47+
ExecutesOn.Coordinator {
4148

4249
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Rerank", Rerank::new);
4350
public static final String DEFAULT_INFERENCE_ID = ".rerank-v1-elasticsearch";

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import org.elasticsearch.xpack.esql.plan.logical.PipelineBreaker;
2525
import org.elasticsearch.xpack.esql.plan.logical.TopN;
2626
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
27+
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
28+
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
2729
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
2830
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
2931
import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig;
@@ -37,6 +39,7 @@
3739
import org.elasticsearch.xpack.esql.plan.physical.MergeExec;
3840
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
3941
import org.elasticsearch.xpack.esql.plan.physical.TopNExec;
42+
import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec;
4043
import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec;
4144
import org.elasticsearch.xpack.esql.session.Versioned;
4245

@@ -95,6 +98,7 @@ private PhysicalPlan mapUnary(UnaryPlan unary) {
9598
mappedChild = addExchangeForFragment(enrich.child(), mappedChild);
9699
return MapperUtils.mapUnary(unary, mappedChild);
97100
}
101+
98102
// in case of a fragment, push to it any current streaming operator
99103
if (unary instanceof PipelineBreaker == false
100104
|| (unary instanceof Limit limit && limit.local())
@@ -141,16 +145,23 @@ else if (aggregate.groupings()
141145
return new TopNExec(topN.source(), mappedChild, topN.order(), topN.limit(), null);
142146
}
143147

144-
if (unary instanceof Rerank rerank) {
145-
mappedChild = addExchangeForFragment(rerank, mappedChild);
146-
return new RerankExec(
147-
rerank.source(),
148-
mappedChild,
149-
rerank.inferenceId(),
150-
rerank.queryText(),
151-
rerank.rerankFields(),
152-
rerank.scoreAttribute()
153-
);
148+
// Inference operations must execute on coordinator, not on remote clusters
149+
if (unary instanceof InferencePlan) {
150+
mappedChild = addExchangeForFragment(unary.child(), mappedChild);
151+
return switch (unary) {
152+
case Rerank r -> new RerankExec(
153+
r.source(),
154+
mappedChild,
155+
r.inferenceId(),
156+
r.queryText(),
157+
r.rerankFields(),
158+
r.scoreAttribute()
159+
);
160+
case Completion c -> new CompletionExec(c.source(), mappedChild, c.inferenceId(), c.prompt(), c.targetField());
161+
default -> throw new EsqlIllegalArgumentException(
162+
"unsupported inference plan type [" + unary.getClass().getSimpleName() + "]"
163+
);
164+
};
154165
}
155166

156167
//

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4220,12 +4220,6 @@ public void testInvalidRerank() {
42204220
expectError("FROM foo* | RERANK ON title WITH inferenceId", "line 1:20: extraneous input 'ON' expecting {QUOTED_STRING");
42214221
expectError("FROM foo* | RERANK \"query text\" WITH inferenceId", "line 1:33: mismatched input 'WITH' expecting 'on'");
42224222

4223-
var fromPatterns = randomIndexPatterns(CROSS_CLUSTER);
4224-
expectError(
4225-
"FROM " + fromPatterns + " | RERANK \"query text\" ON title WITH { \"inference_id\" : \"inference_id\" }",
4226-
"invalid index pattern [" + unquoteIndexPattern(fromPatterns) + "], remote clusters are not supported with RERANK"
4227-
);
4228-
42294223
expectError(
42304224
"FROM foo* | RERANK \"query text\" ON title WITH { \"inference_id\": { \"a\": 123 } }",
42314225
"Option [inference_id] must be a valid string, found [{ \"a\": 123 }]"
@@ -4311,12 +4305,6 @@ public void testInvalidCompletion() {
43114305

43124306
expectError("FROM foo* | COMPLETION completion=prompt WITH", "ine 1:46: mismatched input '<EOF>' expecting '{'");
43134307

4314-
var fromPatterns = randomIndexPatterns(CROSS_CLUSTER);
4315-
expectError(
4316-
"FROM " + fromPatterns + " | COMPLETION prompt_field WITH { \"inference_id\" : \"inference_id\" }",
4317-
"invalid index pattern [" + unquoteIndexPattern(fromPatterns) + "], remote clusters are not supported with COMPLETION"
4318-
);
4319-
43204308
expectError(
43214309
"FROM foo* | COMPLETION prompt WITH { \"inference_id\": { \"a\": 123 } }",
43224310
"line 1:54: Option [inference_id] must be a valid string, found [{ \"a\": 123 }]"

0 commit comments

Comments
 (0)