Skip to content

Commit e4e22e2

Browse files
author
James Cor
committed
use generics and apply to unsigned
1 parent 2446563 commit e4e22e2

File tree

1 file changed

+78
-20
lines changed

1 file changed

+78
-20
lines changed

sql/analyzer/costed_index_scan.go

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package analyzer
1616

1717
import (
18+
"cmp"
1819
"fmt"
1920
"slices"
2021
"sort"
@@ -828,15 +829,24 @@ type indexScanRangeBuilder struct {
828829
leftover []sql.Expression
829830
}
830831

831-
func setToSignedIntRange(setVals []any, colExprTypes []sql.ColumnExpressionType) (sql.MySQLRangeCollection, bool) {
832-
if len(colExprTypes) != 1 {
833-
return nil, false
832+
func keysToRangeColl[N cmp.Ordered](keys []N, typ sql.Type) sql.MySQLRangeCollection {
833+
slices.Sort(keys)
834+
keys = slices.Compact(keys)
835+
// TODO: for integers, if len(keys) - 1 == keys[len(keys)-1] - keys[0],
836+
// then we can just have one continuous range. unsure if this is worth
837+
res := make(sql.MySQLRangeCollection, len(keys))
838+
for i, key := range keys {
839+
res[i] = sql.MySQLRange{
840+
sql.ClosedRangeColumnExpr(key, key, typ),
841+
}
834842
}
835-
typ := colExprTypes[0].Type
836-
if !types.IsSigned(typ) {
837-
return nil, false
843+
if len(res) == 0 {
844+
return nil
838845
}
846+
return res
847+
}
839848

849+
func setToIntRangeColl(setVals []any, typ sql.Type) (sql.MySQLRangeCollection, bool) {
840850
keys := make([]int64, 0, len(setVals))
841851
for _, val := range setVals {
842852
switch v := val.(type) {
@@ -882,19 +892,55 @@ func setToSignedIntRange(setVals []any, colExprTypes []sql.ColumnExpressionType)
882892
}
883893
}
884894

885-
slices.Sort(keys)
886-
keys = slices.Compact(keys)
887-
res := make(sql.MySQLRangeCollection, len(keys))
888-
for i, key := range keys {
889-
res[i] = sql.MySQLRange{
890-
sql.ClosedRangeColumnExpr(key, key, typ),
891-
}
892-
}
895+
return keysToRangeColl(keys, typ), true
896+
}
893897

894-
if len(res) == 0 {
895-
return nil, true
898+
func setToUintRangeColl(setVals []any, typ sql.Type) (sql.MySQLRangeCollection, bool) {
899+
keys := make([]uint64, 0, len(setVals))
900+
for _, val := range setVals {
901+
switch v := val.(type) {
902+
case int:
903+
keys = append(keys, uint64(v))
904+
case int8:
905+
keys = append(keys, uint64(v))
906+
case int16:
907+
keys = append(keys, uint64(v))
908+
case int32:
909+
keys = append(keys, uint64(v))
910+
case int64:
911+
keys = append(keys, uint64(v))
912+
case uint:
913+
keys = append(keys, uint64(v))
914+
case uint8:
915+
keys = append(keys, uint64(v))
916+
case uint16:
917+
keys = append(keys, uint64(v))
918+
case uint32:
919+
keys = append(keys, uint64(v))
920+
case uint64:
921+
keys = append(keys, v)
922+
// float32, float64, and decimal are ok as long as they don't round
923+
case float32:
924+
key := uint64(v)
925+
if float32(key) == v {
926+
keys = append(keys, key)
927+
}
928+
case float64:
929+
key := uint64(v)
930+
if float64(key) == v {
931+
keys = append(keys, key)
932+
}
933+
case decimal.Decimal:
934+
key := v.IntPart()
935+
if v.Equal(decimal.NewFromInt(key)) {
936+
keys = append(keys, uint64(key))
937+
}
938+
default:
939+
// resort to default behavior for types that require more conversion
940+
return nil, false
941+
}
896942
}
897-
return res, true
943+
return keysToRangeColl(keys, typ), true
898944
}
899945

900946
// buildRangeCollection converts our representation of the best index scan
@@ -910,11 +956,23 @@ func (b *indexScanRangeBuilder) buildRangeCollection(f indexFilter) (sql.MySQLRa
910956
case *iScanOr:
911957
ranges, err = b.rangeBuildOr(f, inScan)
912958
case *iScanLeaf:
913-
// TODO: special case for in set. can skip building range tree and overlapping range check since it's a series of equality checks
959+
// When the filter is a simple IN, we can skip costly checks like building the RangeTree.
914960
if f.Op() == sql.IndexScanOpInSet {
915961
cets := b.idx.ColumnExpressionTypes()
916-
if ranges, ok := setToSignedIntRange(f.setValues, cets); ok {
917-
return ranges, nil
962+
if len(cets) == 1 {
963+
typ := cets[0].Type
964+
var ok bool
965+
// TODO: it's possible to apply this optimization to other
966+
// numeric types (float32, float64, decimal).
967+
if types.IsSigned(typ) {
968+
if ranges, ok = setToIntRangeColl(f.setValues, typ); ok {
969+
return ranges, nil
970+
}
971+
} else if types.IsUnsigned(typ) {
972+
if ranges, ok = setToUintRangeColl(f.setValues, typ); ok {
973+
return ranges, nil
974+
}
975+
}
918976
}
919977
}
920978
ranges, err = b.rangeBuildLeaf(f, inScan)

0 commit comments

Comments
 (0)