@@ -828,26 +828,6 @@ type indexScanRangeBuilder struct {
828828 leftover []sql.Expression
829829}
830830
831- func castToInt64 (v any ) (int64 , bool ) {
832- switch v := v .(type ) {
833- case int :
834- return int64 (v ), true
835- case int8 :
836- return int64 (v ), true
837- case int16 :
838- return int64 (v ), true
839- case int32 :
840- return int64 (v ), true
841- case int64 :
842- return v , true
843- case float32 , float64 , decimal.Decimal :
844- // TODO: return an empty range here
845- return 0 , false
846- default :
847- return 0 , false
848- }
849- }
850-
851831func setToSignedIntRange (setVals []any , colExprTypes []sql.ColumnExpressionType ) (sql.MySQLRangeCollection , bool ) {
852832 if len (colExprTypes ) != 1 {
853833 return nil , false
@@ -856,22 +836,64 @@ func setToSignedIntRange(setVals []any, colExprTypes []sql.ColumnExpressionType)
856836 if ! types .IsSigned (typ ) {
857837 return nil , false
858838 }
859- var ok bool
860- keys := make ([]int64 , len (setVals ))
861- for i , val := range setVals {
862- keys [i ], ok = castToInt64 (val )
863- if ! ok {
839+
840+ keys := make ([]int64 , 0 , len (setVals ))
841+ for _ , val := range setVals {
842+ switch v := val .(type ) {
843+ case int :
844+ keys = append (keys , int64 (v ))
845+ case int8 :
846+ keys = append (keys , int64 (v ))
847+ case int16 :
848+ keys = append (keys , int64 (v ))
849+ case int32 :
850+ keys = append (keys , int64 (v ))
851+ case int64 :
852+ keys = append (keys , v )
853+ case uint :
854+ keys = append (keys , int64 (v ))
855+ case uint8 :
856+ keys = append (keys , int64 (v ))
857+ case uint16 :
858+ keys = append (keys , int64 (v ))
859+ case uint32 :
860+ keys = append (keys , int64 (v ))
861+ case uint64 :
862+ keys = append (keys , int64 (v ))
863+ // float32, float64, and decimal are ok as long as they don't round
864+ case float32 :
865+ key := int64 (v )
866+ if float32 (key ) == v {
867+ keys = append (keys , key )
868+ }
869+ case float64 :
870+ key := int64 (v )
871+ if float64 (key ) == v {
872+ keys = append (keys , key )
873+ }
874+ case decimal.Decimal :
875+ key := v .IntPart ()
876+ if v .Equal (decimal .NewFromInt (key )) {
877+ keys = append (keys , key )
878+ }
879+ default :
880+ // resort to default behavior for types that require more conversion
864881 return nil , false
865882 }
866883 }
884+
867885 slices .Sort (keys )
868- slices .Compact (keys )
886+ keys = slices .Compact (keys )
869887 res := make (sql.MySQLRangeCollection , len (keys ))
870888 for i , key := range keys {
871889 res [i ] = sql.MySQLRange {
872890 sql .ClosedRangeColumnExpr (key , key , typ ),
873891 }
874892 }
893+
894+ if len (res ) == 0 {
895+ return nil , true
896+ }
875897 return res , true
876898}
877899
0 commit comments