diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/DistributionPlanner.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/DistributionPlanner.java index 6b69c4d06462..2a2301cfdbae 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/DistributionPlanner.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/DistributionPlanner.java @@ -48,8 +48,8 @@ import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -115,7 +115,7 @@ private void adjustUpStream(PlanNode root, NodeGroupContext context) { analysis.getTreeStatement() instanceof QueryStatement && needShuffleSinkNode((QueryStatement) analysis.getTreeStatement(), context); - adjustUpStreamHelper(root, new HashMap<>(), needShuffleSinkNode, context); + adjustUpStreamHelper(root, new LinkedHashMap<>(), needShuffleSinkNode, context); } private void adjustUpStreamHelper(PlanNode root, NodeGroupContext context) { @@ -271,7 +271,7 @@ private class FragmentBuilder { public SubPlan splitToSubPlan(PlanNode root) { SubPlan rootSubPlan = createSubPlan(root); - Set visitedSinkNode = new HashSet<>(); + Set visitedSinkNode = new LinkedHashSet<>(); splitToSubPlan(root, rootSubPlan, visitedSinkNode); return rootSubPlan; } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/SourceRewriter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/SourceRewriter.java index 41ddd8591110..bb6ff3b3e29c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/SourceRewriter.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/SourceRewriter.java @@ -96,8 +96,8 @@ import java.util.Arrays; import java.util.Collections; import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -201,7 +201,7 @@ public List visitDeviceView(DeviceViewNode node, DistributionPlanConte } // Step 1: constructs DeviceViewSplits - Set relatedDataRegions = new HashSet<>(); + Set relatedDataRegions = new LinkedHashSet<>(); List deviceViewSplits = new ArrayList<>(); boolean existDeviceCrossRegion = false; @@ -249,7 +249,7 @@ public List visitDeviceView(DeviceViewNode node, DistributionPlanConte // 1. generate old and new measurement idx relationship // 2. generate new outputColumns for each subDeviceView if (existDeviceCrossRegion && analysis.isDeviceViewSpecialProcess()) { - Map> newMeasurementIdxMap = new HashMap<>(); + Map> newMeasurementIdxMap = new LinkedHashMap<>(); List newPartialOutputColumns = new ArrayList<>(); Set deviceViewOutputExpressions = analysis.getDeviceViewOutputExpressions(); // Used to rewrite child ProjectNode if it exists @@ -442,7 +442,7 @@ private void constructDeviceViewNodeListWithoutCrossRegion( DistributionPlanContext context, Analysis analysis) { - Map regionDeviceViewMap = new HashMap<>(); + Map regionDeviceViewMap = new LinkedHashMap<>(); for (DeviceViewSplit split : deviceViewSplits) { if (split.dataPartitions.size() != 1) { throw new IllegalStateException( @@ -613,8 +613,8 @@ public List visitSchemaQueryMerge( SchemaQueryMergeNode root = (SchemaQueryMergeNode) node.clone(); SchemaQueryScanNode seed = (SchemaQueryScanNode) node.getChildren().get(0); List pathPatternList = seed.getPathPatternList(); - Set regionsOfSystemDatabase = new HashSet<>(); - Set regionsOfAuditDatabase = new HashSet<>(); + Set regionsOfSystemDatabase = new LinkedHashSet<>(); + Set regionsOfAuditDatabase = new LinkedHashSet<>(); if (pathPatternList.size() == 1) { // the path pattern overlaps with all storageGroup or storageGroup.** TreeSet schemaRegions = @@ -668,7 +668,7 @@ public List visitSchemaQueryMerge( for (PartialPath pathPattern : pathPatternList) { patternTree.appendPathPattern(pathPattern); } - Map> storageGroupSchemaRegionMap = new HashMap<>(); + Map> storageGroupSchemaRegionMap = new LinkedHashMap<>(); analysis .getSchemaPartitionInfo() .getSchemaPartitionMap() @@ -686,7 +686,7 @@ public List visitSchemaQueryMerge( deviceGroup.forEach( (deviceGroupId, schemaRegionReplicaSet) -> storageGroupSchemaRegionMap - .computeIfAbsent(storageGroup, k -> new HashSet<>()) + .computeIfAbsent(storageGroup, k -> new LinkedHashSet<>()) .add(schemaRegionReplicaSet)); } }); @@ -740,7 +740,7 @@ public List visitSchemaQueryMerge( private List filterPathPattern(PathPatternTree patternTree, String database) { // extract the patterns overlap with current database - Set filteredPathPatternSet = new HashSet<>(); + Set filteredPathPatternSet = new LinkedHashSet<>(); try { PartialPath storageGroupPath = new PartialPath(database); filteredPathPatternSet.addAll(patternTree.getOverlappedPathPatterns(storageGroupPath)); @@ -771,7 +771,7 @@ public List visitCountMerge( CountSchemaMergeNode node, DistributionPlanContext context) { CountSchemaMergeNode root = (CountSchemaMergeNode) node.clone(); SchemaQueryScanNode seed = (SchemaQueryScanNode) node.getChildren().get(0); - Set schemaRegions = new HashSet<>(); + Set schemaRegions = new LinkedHashSet<>(); analysis .getSchemaPartitionInfo() .getSchemaPartitionMap() @@ -860,7 +860,7 @@ private List processRawSeriesScan( private List splitRegionScanNodeByRegion( RegionScanNode node, DistributionPlanContext context) { - Map regionScanNodeMap = new HashMap<>(); + Map regionScanNodeMap = new LinkedHashMap<>(); Set devicesList = node.getDevicePaths(); boolean isAllDeviceOnlyInOneRegion = true; @@ -980,13 +980,13 @@ private List processSeriesAggregationSource( public List visitSchemaFetchMerge( SchemaFetchMergeNode node, DistributionPlanContext context) { SchemaFetchMergeNode root = (SchemaFetchMergeNode) node.clone(); - Map> storageGroupSchemaRegionMap = new HashMap<>(); + Map> storageGroupSchemaRegionMap = new LinkedHashMap<>(); analysis .getSchemaPartitionInfo() .getSchemaPartitionMap() .forEach( (storageGroup, deviceGroup) -> { - storageGroupSchemaRegionMap.put(storageGroup, new HashSet<>()); + storageGroupSchemaRegionMap.put(storageGroup, new LinkedHashSet<>()); deviceGroup.forEach( (deviceGroupId, schemaRegionReplicaSet) -> storageGroupSchemaRegionMap.get(storageGroup).add(schemaRegionReplicaSet)); @@ -1205,7 +1205,7 @@ private List splitInnerTimeJoinNode( innerTimeJoinNode.setTimePartitions(timePartitionIds); // region group id -> parent InnerTimeJoinNode - Map map = new HashMap<>(); + Map map = new LinkedHashMap<>(); for (SeriesSourceNode sourceNode : seriesScanNodes) { TRegionReplicaSet dataRegion = analysis.getPartitionInfo(sourceNode.getPartitionPath(), oneRegion.get(0)); @@ -1338,7 +1338,7 @@ private Map> groupBySourceNodes( // Step 1: Get all source nodes. For the node which is not source, add it as the child of // current TimeJoinNode List sources = new ArrayList<>(); - Map>> cachedRegionReplicas = new HashMap<>(); + Map>> cachedRegionReplicas = new LinkedHashMap<>(); for (PlanNode child : node.getChildren()) { if (child instanceof SeriesSourceNode) { // If the child is SeriesScanNode, we need to check whether this node should be seperated @@ -1408,7 +1408,7 @@ private List getDeviceReplicaSets( } Map> slot2ReplicasMap = - cache.computeIfAbsent(db, k -> new HashMap<>()); + cache.computeIfAbsent(db, k -> new LinkedHashMap<>()); TSeriesPartitionSlot tSeriesPartitionSlot = dataPartition.calculateDeviceGroupId(deviceID); Map>> @@ -1431,7 +1431,7 @@ public List getDataRegionReplicaSetWithTimeFilter( return Collections.singletonList(NOT_ASSIGNED); } List replicaSets = new ArrayList<>(); - Set uniqueValues = new HashSet<>(); + Set uniqueValues = new LinkedHashSet<>(); for (Map.Entry> entry : regionReplicaSetMap.entrySet()) { if (!TimePartitionUtils.satisfyPartitionStartTime(timeFilter, entry.getKey().startTime)) { @@ -1479,7 +1479,7 @@ private List planAggregationWithTimeJoin( // upstream will give the final aggregate result, // the step of this series' aggregator will be `STATIC` List sources = new ArrayList<>(); - Map regionCountPerSeries = new HashMap<>(); + Map regionCountPerSeries = new LinkedHashMap<>(); boolean[] eachSeriesOneRegion = {true}; sourceGroup = splitAggregationSourceByPartition( @@ -1584,7 +1584,7 @@ public List visitGroupByLevel(GroupByLevelNode root, DistributionPlanC : groupSourcesForGroupByLevel(root, sourceGroup, context); // Then, we calculate the attributes for GroupByLevelNode in each level - Map> columnNameToExpression = new HashMap<>(); + Map> columnNameToExpression = new LinkedHashMap<>(); for (CrossSeriesAggregationDescriptor originalDescriptor : newRoot.getGroupByLevelDescriptors()) { columnNameToExpression.putAll(originalDescriptor.getGroupedInputStringToExpressionsMap()); @@ -1715,7 +1715,7 @@ private void calculateGroupByLevelNodeAttributes( .forEach(child -> calculateGroupByLevelNodeAttributes(child, level + 1, context)); // Construct all outputColumns from children. Using Set here to avoid duplication - Set childrenOutputColumns = new HashSet<>(); + Set childrenOutputColumns = new LinkedHashSet<>(); node.getChildren().forEach(child -> childrenOutputColumns.addAll(child.getOutputColumnNames())); if (node instanceof SlidingWindowAggregationNode) { @@ -1747,7 +1747,7 @@ private void calculateGroupByLevelNodeAttributes( // AggregationDescriptor List descriptorList = new ArrayList<>(); Map> columnNameToExpression = context.getColumnNameToExpression(); - Map> childrenExpressionMap = new HashMap<>(); + Map> childrenExpressionMap = new LinkedHashMap<>(); for (String childColumn : childrenOutputColumns) { String childInput = childColumn.substring(childColumn.indexOf("(") + 1, childColumn.lastIndexOf(")")); @@ -1756,7 +1756,7 @@ private void calculateGroupByLevelNodeAttributes( for (CrossSeriesAggregationDescriptor originalDescriptor : handle.getGroupByLevelDescriptors()) { - Set descriptorExpressions = new HashSet<>(); + Set descriptorExpressions = new LinkedHashSet<>(); if (childrenExpressionMap.containsKey(originalDescriptor.getParametersString())) { descriptorExpressions.addAll(originalDescriptor.getOutputExpressions()); @@ -1969,7 +1969,7 @@ protected DeviceViewSplit( IDeviceID device, PlanNode root, List dataPartitions) { this.device = device; this.root = root; - this.dataPartitions = new HashSet<>(); + this.dataPartitions = new LinkedHashSet<>(); this.dataPartitions.addAll(dataPartitions); } diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/AlignByDeviceOrderByLimitOffsetTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/AlignByDeviceOrderByLimitOffsetTest.java index cff70fc457ce..aa86329abf0f 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/AlignByDeviceOrderByLimitOffsetTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/AlignByDeviceOrderByLimitOffsetTest.java @@ -24,6 +24,7 @@ import org.apache.iotdb.db.queryengine.common.QueryId; import org.apache.iotdb.db.queryengine.plan.analyze.Analysis; import org.apache.iotdb.db.queryengine.plan.planner.plan.DistributedQueryPlan; +import org.apache.iotdb.db.queryengine.plan.planner.plan.FragmentInstance; import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.AggregationMergeSortNode; @@ -48,6 +49,7 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; public class AlignByDeviceOrderByLimitOffsetTest { @@ -733,6 +735,24 @@ public void orderByTimeTest6() { instanceof SeriesAggregationScanNode); } + private static TopKNode firstTopKChild(LimitNode lim) { + if (lim.getChild() instanceof TopKNode) { + return (TopKNode) lim.getChild(); + } + if (lim.getChild() instanceof ExchangeNode + && !lim.getChild().getChildren().isEmpty() + && lim.getChild().getChildren().get(0) instanceof TopKNode) { + return (TopKNode) lim.getChild().getChildren().get(0); + } + return null; + } + + private static boolean contains(PlanNode n, Class clazz) { + if (clazz.isInstance(n)) return true; + for (PlanNode c : n.getChildren()) if (contains(c, clazz)) return true; + return false; + } + @Test public void orderByTimeWithOffsetTest() { // order by time, offset + limit @@ -746,24 +766,36 @@ public void orderByTimeWithOffsetTest() { planner = new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode)); plan = planner.planFragments(); assertEquals(4, plan.getInstances().size()); - firstFiRoot = plan.getInstances().get(0).getFragment().getPlanNodeTree(); - PlanNode firstFIFirstNode = firstFiRoot.getChildren().get(0); - assertTrue(firstFIFirstNode instanceof LimitNode); - PlanNode firstFiTopNode = ((LimitNode) firstFIFirstNode).getChild().getChildren().get(0); - for (PlanNode node : firstFiTopNode.getChildren().get(0).getChildren()) { - assertTrue(node instanceof SingleDeviceViewNode); + LimitNode rootLimit = null; + FragmentInstance limitFI = null; + + for (FragmentInstance fi : plan.getInstances()) { + PlanNode root = fi.getFragment().getPlanNodeTree(); + if (!root.getChildren().isEmpty() && root.getChildren().get(0) instanceof LimitNode) { + rootLimit = (LimitNode) root.getChildren().get(0); + limitFI = fi; + break; + } + if (!root.getChildren().isEmpty() + && root.getChildren().get(0) instanceof ExchangeNode + && !root.getChildren().get(0).getChildren().isEmpty() + && root.getChildren().get(0).getChildren().get(0) instanceof LimitNode) { + rootLimit = (LimitNode) root.getChildren().get(0).getChildren().get(0); + limitFI = fi; + break; + } + } + assertNotNull("no root-level LimitNode found", rootLimit); + assertTrue( + "Limit subtree lacks SingleDeviceViewNode", + contains(rootLimit, SingleDeviceViewNode.class)); + long exchCnt = + rootLimit.getChild().getChildren().stream().filter(c -> c instanceof ExchangeNode).count(); + assertTrue("too many Exchange under Limit subtree", exchCnt <= 3); + long expected = LIMIT_VALUE * 2; + for (FragmentInstance fi : plan.getInstances()) { + assertScanNodeLimitValue(fi.getFragment().getPlanNodeTree(), expected); } - assertTrue(firstFiTopNode.getChildren().get(1) instanceof ExchangeNode); - assertTrue(firstFiTopNode.getChildren().get(2) instanceof ExchangeNode); - assertTrue(firstFiTopNode.getChildren().get(3) instanceof ExchangeNode); - assertScanNodeLimitValue( - plan.getInstances().get(0).getFragment().getPlanNodeTree(), LIMIT_VALUE * 2); - assertScanNodeLimitValue( - plan.getInstances().get(1).getFragment().getPlanNodeTree(), LIMIT_VALUE * 2); - assertScanNodeLimitValue( - plan.getInstances().get(2).getFragment().getPlanNodeTree(), LIMIT_VALUE * 2); - assertScanNodeLimitValue( - plan.getInstances().get(3).getFragment().getPlanNodeTree(), LIMIT_VALUE * 2); } /* diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/DistributionPlannerCycleTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/DistributionPlannerCycleTest.java index df2f3d16339d..a08e1bd1905f 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/DistributionPlannerCycleTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/planner/distribution/DistributionPlannerCycleTest.java @@ -51,6 +51,11 @@ public class DistributionPlannerCycleTest { // / \ \ // d2.s1[2] d2.s2[2] d2.s3[2] // ------------------------------------------------------------------------------------------ + + private static long countDirectChildrenOfType(PlanNode node, Class clazz) { + return node.getChildren().stream().filter(clazz::isInstance).count(); + } + @Test public void timeJoinNodeTest() { QueryId queryId = new QueryId("test"); @@ -66,16 +71,13 @@ public void timeJoinNodeTest() { assertEquals(2, plan.getInstances().size()); PlanNode firstNode = plan.getInstances().get(0).getFragment().getPlanNodeTree().getChildren().get(0); - assertEquals(3, firstNode.getChildren().size()); - assertTrue(firstNode.getChildren().get(0) instanceof SeriesScanNode); - assertTrue(firstNode.getChildren().get(1) instanceof SeriesScanNode); - assertTrue(firstNode.getChildren().get(2) instanceof ExchangeNode); + assertEquals(1, countDirectChildrenOfType(firstNode, ExchangeNode.class)); + assertTrue(countDirectChildrenOfType(firstNode, SeriesScanNode.class) >= 2); PlanNode secondNode = plan.getInstances().get(1).getFragment().getPlanNodeTree().getChildren().get(0); - assertEquals(3, secondNode.getChildren().size()); - assertTrue(secondNode.getChildren().get(0) instanceof SeriesScanNode); - assertTrue(secondNode.getChildren().get(1) instanceof SeriesScanNode); - assertTrue(secondNode.getChildren().get(2) instanceof SeriesScanNode); + assertEquals(0, countDirectChildrenOfType(secondNode, ExchangeNode.class)); + long scanCnt = countDirectChildrenOfType(secondNode, SeriesScanNode.class); + assertTrue(scanCnt >= 2 && scanCnt <= 3); } }