From 5403fc8f9c2cf5dafd73d051ad17775e2a8becdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Osipiuk?= Date: Fri, 5 Dec 2025 19:41:05 +0100 Subject: [PATCH] Verify block types in ExchangeOperator --- .../io/trino/operator/ExchangeOperator.java | 12 ++ .../OutputValidatingSourceOperator.java | 128 ++++++++++++++++++ .../java/io/trino/split/PageValidations.java | 44 ++++++ .../sql/planner/LocalExecutionPlanner.java | 2 + .../trino/operator/TestExchangeOperator.java | 1 + 5 files changed, 187 insertions(+) create mode 100644 core/trino-main/src/main/java/io/trino/operator/OutputValidatingSourceOperator.java diff --git a/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java b/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java index 8e08cbfd5118..041f23827d84 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/ExchangeOperator.java @@ -13,6 +13,7 @@ */ package io.trino.operator; +import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import com.google.errorprone.annotations.ThreadSafe; import io.airlift.slice.Slice; @@ -29,14 +30,18 @@ import io.trino.spi.catalog.CatalogName; import io.trino.spi.connector.CatalogVersion; import io.trino.spi.exchange.ExchangeId; +import io.trino.spi.type.Type; import io.trino.split.RemoteSplit; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.util.Ciphers; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; +import java.util.List; + import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static io.trino.SystemSessionProperties.isSourcePagesValidationEnabled; import static io.trino.connector.CatalogHandle.createRootCatalogHandle; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -51,6 +56,7 @@ public static class ExchangeOperatorFactory { private final int operatorId; private final PlanNodeId sourceId; + private final List outputTypes; private final DirectExchangeClientSupplier directExchangeClientSupplier; private final PagesSerdeFactory serdeFactory; private final RetryPolicy retryPolicy; @@ -64,6 +70,7 @@ public static class ExchangeOperatorFactory public ExchangeOperatorFactory( int operatorId, PlanNodeId sourceId, + List outputTypes, DirectExchangeClientSupplier directExchangeClientSupplier, PagesSerdeFactory serdeFactory, RetryPolicy retryPolicy, @@ -71,6 +78,7 @@ public ExchangeOperatorFactory( { this.operatorId = operatorId; this.sourceId = requireNonNull(sourceId, "sourceId is null"); + this.outputTypes = ImmutableList.copyOf(requireNonNull(outputTypes, "outputTypes is null")); this.directExchangeClientSupplier = requireNonNull(directExchangeClientSupplier, "directExchangeClientSupplier is null"); this.serdeFactory = requireNonNull(serdeFactory, "serdeFactory is null"); this.retryPolicy = requireNonNull(retryPolicy, "retryPolicy is null"); @@ -115,6 +123,10 @@ public SourceOperator createOperator(DriverContext driverContext) noMoreSplitsTracker, operatorInstanceId); noMoreSplitsTracker.operatorAdded(operatorInstanceId); + + if (isSourcePagesValidationEnabled(taskContext.getSession())) { + return new OutputValidatingSourceOperator(exchangeOperator, outputTypes); + } return exchangeOperator; } diff --git a/core/trino-main/src/main/java/io/trino/operator/OutputValidatingSourceOperator.java b/core/trino-main/src/main/java/io/trino/operator/OutputValidatingSourceOperator.java new file mode 100644 index 000000000000..dcb2e55e7c11 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/OutputValidatingSourceOperator.java @@ -0,0 +1,128 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ListenableFuture; +import io.trino.metadata.Split; +import io.trino.spi.Page; +import io.trino.spi.type.Type; +import io.trino.split.PageValidations; +import io.trino.sql.planner.plan.PlanNodeId; + +import java.util.List; +import java.util.function.Supplier; + +import static java.util.Objects.requireNonNull; + +public class OutputValidatingSourceOperator + implements SourceOperator +{ + private final SourceOperator delegate; + private final List outputTypes; + private final Supplier debugContextSupplier; + + public OutputValidatingSourceOperator(SourceOperator delegate, List outputTypes) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + this.outputTypes = ImmutableList.copyOf(requireNonNull(outputTypes, "outputTypes is null")); + this.debugContextSupplier = () -> "operator=%s, taskId/operatorId=%s/%s,".formatted( + delegate.getClass().getSimpleName(), + delegate.getOperatorContext().getDriverContext().getTaskId(), + delegate.getOperatorContext().getOperatorId()); + } + + @Override + public PlanNodeId getSourceId() + { + return delegate.getSourceId(); + } + + @Override + public void addSplit(Split split) + { + delegate.addSplit(split); + } + + @Override + public void noMoreSplits() + { + delegate.noMoreSplits(); + } + + @Override + public OperatorContext getOperatorContext() + { + return delegate.getOperatorContext(); + } + + @Override + public ListenableFuture isBlocked() + { + return delegate.isBlocked(); + } + + @Override + public boolean needsInput() + { + return delegate.needsInput(); + } + + @Override + public void addInput(Page page) + { + delegate.addInput(page); + } + + @Override + public Page getOutput() + { + Page page = delegate.getOutput(); + if (page != null) { + PageValidations.validateOutputPageTypes(page, outputTypes, debugContextSupplier); + } + return page; + } + + @Override + public ListenableFuture startMemoryRevoke() + { + return delegate.startMemoryRevoke(); + } + + @Override + public void finishMemoryRevoke() + { + delegate.finishMemoryRevoke(); + } + + @Override + public void finish() + { + delegate.finish(); + } + + @Override + public boolean isFinished() + { + return delegate.isFinished(); + } + + @Override + public void close() + throws Exception + { + delegate.close(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/split/PageValidations.java b/core/trino-main/src/main/java/io/trino/split/PageValidations.java index b06c84fc6d09..40d51e9b758d 100644 --- a/core/trino-main/src/main/java/io/trino/split/PageValidations.java +++ b/core/trino-main/src/main/java/io/trino/split/PageValidations.java @@ -14,6 +14,7 @@ package io.trino.split; import com.google.common.base.Joiner; +import io.trino.spi.Page; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.ValueBlock; @@ -64,6 +65,40 @@ public static void validateOutputPageTypes(SourcePage page, List expectedT } } + public static void validateOutputPageTypes(Page page, List expectedTypes, Supplier debugContextSupplier) + { + if (page.getChannelCount() != expectedTypes.size()) { + throw new TrinoException( + GENERIC_INTERNAL_ERROR, + "Invalid number of channels; got %s expected %s; context=%s; blocks=%s; types=%s".formatted( + page.getChannelCount(), + expectedTypes.size(), + debugContextSupplier.get(), + blocksDebugInfo(page), + expectedTypes)); + } + + List mismatches = null; + for (int channel = 0; channel < expectedTypes.size(); channel++) { + Type type = expectedTypes.get(channel); + Block block = page.getBlock(channel); + if (!isBlockValidForType(block, type)) { + if (mismatches == null) { + mismatches = new ArrayList<>(); + } + mismatches.add("Bad block %s for channel %s of type %s".formatted(blockDebugInfo(block), channel, type)); + } + } + if (mismatches != null) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, + "Bad block types found for context %s; blocks=%s; types=%s; mismatches=%s".formatted( + debugContextSupplier.get(), + blocksDebugInfo(page), + expectedTypes, + mismatches)); + } + } + private static boolean isBlockValidForType(Block block, Type type) { ValueBlock underlyingValueBlock = block.getUnderlyingValueBlock(); @@ -79,6 +114,15 @@ private static String blocksDebugInfo(SourcePage page) return Joiner.on(",").join(debugInfos); } + private static String blocksDebugInfo(Page page) + { + ArrayList debugInfos = new ArrayList<>(); + for (int i = 0; i < page.getChannelCount(); i++) { + debugInfos.add(blockDebugInfo(page.getBlock(i))); + } + return Joiner.on(",").join(debugInfos); + } + private static String blockDebugInfo(Block block) { ValueBlock underlyingValueBlock = block.getUnderlyingValueBlock(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index d63907662e3a..31aa2620fd0b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -949,9 +949,11 @@ private PhysicalOperation createRemoteSource(RemoteSourceNode node, LocalExecuti context.setDriverInstanceCount(getTaskConcurrency(session)); } + List types = node.getOutputSymbols().stream().map(Symbol::type).collect(toImmutableList()); OperatorFactory operatorFactory = new ExchangeOperatorFactory( context.getNextOperatorId(), node.getId(), + types, directExchangeClientSupplier, createExchangePagesSerdeFactory(plannerContext.getBlockEncodingSerde(), session), node.getRetryPolicy(), diff --git a/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java index b875542800b7..ee4f73e7d27f 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestExchangeOperator.java @@ -268,6 +268,7 @@ private SourceOperator createExchangeOperator() ExchangeOperatorFactory operatorFactory = new ExchangeOperatorFactory( 0, new PlanNodeId("test"), + TYPES, directExchangeClientSupplier, SERDE_FACTORY, RetryPolicy.NONE,