1515package analyzer
1616
1717import (
18+ "cmp"
1819 "fmt"
1920 "slices"
2021 "sort"
@@ -828,15 +829,24 @@ type indexScanRangeBuilder struct {
828829 leftover []sql.Expression
829830}
830831
831- func setToSignedIntRange (setVals []any , colExprTypes []sql.ColumnExpressionType ) (sql.MySQLRangeCollection , bool ) {
832- if len (colExprTypes ) != 1 {
833- return nil , false
832+ func keysToRangeColl [N cmp.Ordered ](keys []N , typ sql.Type ) sql.MySQLRangeCollection {
833+ slices .Sort (keys )
834+ keys = slices .Compact (keys )
835+ // TODO: for integers, if len(keys) - 1 == keys[len(keys)-1] - keys[0],
836+ // then we can just have one continuous range. unsure if this is worth
837+ res := make (sql.MySQLRangeCollection , len (keys ))
838+ for i , key := range keys {
839+ res [i ] = sql.MySQLRange {
840+ sql .ClosedRangeColumnExpr (key , key , typ ),
841+ }
834842 }
835- typ := colExprTypes [0 ].Type
836- if ! types .IsSigned (typ ) {
837- return nil , false
843+ if len (res ) == 0 {
844+ return nil
838845 }
846+ return res
847+ }
839848
849+ func setToIntRangeColl (setVals []any , typ sql.Type ) (sql.MySQLRangeCollection , bool ) {
840850 keys := make ([]int64 , 0 , len (setVals ))
841851 for _ , val := range setVals {
842852 switch v := val .(type ) {
@@ -882,19 +892,55 @@ func setToSignedIntRange(setVals []any, colExprTypes []sql.ColumnExpressionType)
882892 }
883893 }
884894
885- slices .Sort (keys )
886- keys = slices .Compact (keys )
887- res := make (sql.MySQLRangeCollection , len (keys ))
888- for i , key := range keys {
889- res [i ] = sql.MySQLRange {
890- sql .ClosedRangeColumnExpr (key , key , typ ),
891- }
892- }
895+ return keysToRangeColl (keys , typ ), true
896+ }
893897
894- if len (res ) == 0 {
895- return nil , true
898+ func setToUintRangeColl (setVals []any , typ sql.Type ) (sql.MySQLRangeCollection , bool ) {
899+ keys := make ([]uint64 , 0 , len (setVals ))
900+ for _ , val := range setVals {
901+ switch v := val .(type ) {
902+ case int :
903+ keys = append (keys , uint64 (v ))
904+ case int8 :
905+ keys = append (keys , uint64 (v ))
906+ case int16 :
907+ keys = append (keys , uint64 (v ))
908+ case int32 :
909+ keys = append (keys , uint64 (v ))
910+ case int64 :
911+ keys = append (keys , uint64 (v ))
912+ case uint :
913+ keys = append (keys , uint64 (v ))
914+ case uint8 :
915+ keys = append (keys , uint64 (v ))
916+ case uint16 :
917+ keys = append (keys , uint64 (v ))
918+ case uint32 :
919+ keys = append (keys , uint64 (v ))
920+ case uint64 :
921+ keys = append (keys , v )
922+ // float32, float64, and decimal are ok as long as they don't round
923+ case float32 :
924+ key := uint64 (v )
925+ if float32 (key ) == v {
926+ keys = append (keys , key )
927+ }
928+ case float64 :
929+ key := uint64 (v )
930+ if float64 (key ) == v {
931+ keys = append (keys , key )
932+ }
933+ case decimal.Decimal :
934+ key := v .IntPart ()
935+ if v .Equal (decimal .NewFromInt (key )) {
936+ keys = append (keys , uint64 (key ))
937+ }
938+ default :
939+ // resort to default behavior for types that require more conversion
940+ return nil , false
941+ }
896942 }
897- return res , true
943+ return keysToRangeColl ( keys , typ ) , true
898944}
899945
900946// buildRangeCollection converts our representation of the best index scan
@@ -910,11 +956,23 @@ func (b *indexScanRangeBuilder) buildRangeCollection(f indexFilter) (sql.MySQLRa
910956 case * iScanOr :
911957 ranges , err = b .rangeBuildOr (f , inScan )
912958 case * iScanLeaf :
913- // TODO: special case for in set. can skip building range tree and overlapping range check since it's a series of equality checks
959+ // When the filter is a simple IN, we can skip costly checks like building the RangeTree.
914960 if f .Op () == sql .IndexScanOpInSet {
915961 cets := b .idx .ColumnExpressionTypes ()
916- if ranges , ok := setToSignedIntRange (f .setValues , cets ); ok {
917- return ranges , nil
962+ if len (cets ) == 1 {
963+ typ := cets [0 ].Type
964+ var ok bool
965+ // TODO: it's possible to apply this optimization to other
966+ // numeric types (float32, float64, decimal).
967+ if types .IsSigned (typ ) {
968+ if ranges , ok = setToIntRangeColl (f .setValues , typ ); ok {
969+ return ranges , nil
970+ }
971+ } else if types .IsUnsigned (typ ) {
972+ if ranges , ok = setToUintRangeColl (f .setValues , typ ); ok {
973+ return ranges , nil
974+ }
975+ }
918976 }
919977 }
920978 ranges , err = b .rangeBuildLeaf (f , inScan )
0 commit comments