Skip to content

Commit 0b1008f

Browse files
authored
custom AppendDateFormat (#3322)
1 parent 4462fa0 commit 0b1008f

File tree

2 files changed

+72
-42
lines changed

2 files changed

+72
-42
lines changed

sql/types/datetime.go

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"fmt"
2020
"math"
2121
"reflect"
22+
"strconv"
2223
"time"
2324
"unicode"
2425

@@ -452,59 +453,85 @@ func (t datetimeType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltype
452453
return sqltypes.Value{}, err
453454
}
454455

455-
var typ query.Type
456-
var val []byte
457-
458456
switch t.baseType {
459457
case sqltypes.Date:
460-
typ = sqltypes.Date
461458
if vt.Equal(ZeroTime) {
462-
val = vt.AppendFormat(dest, ZeroDateStr)
459+
dest = append(dest, ZeroDateStr...)
463460
} else {
464-
val = vt.AppendFormat(dest, sql.DateLayout)
461+
dest = appendDateFormat(dest, vt)
465462
}
466-
case sqltypes.Datetime:
467-
typ = sqltypes.Datetime
468-
if vt.Equal(ZeroTime) {
469-
val = vt.AppendFormat(dest, ZeroTimestampDatetimeStr)
470-
} else {
471-
val = vt.AppendFormat(dest, sql.TimestampDatetimeLayout)
472-
}
473-
case sqltypes.Timestamp:
474-
typ = sqltypes.Timestamp
463+
case sqltypes.Datetime, sqltypes.Timestamp:
475464
if vt.Equal(ZeroTime) {
476-
val = vt.AppendFormat(dest, ZeroTimestampDatetimeStr)
465+
dest = append(dest, ZeroTimestampDatetimeStr...)
477466
} else {
478-
val = vt.AppendFormat(dest, sql.TimestampDatetimeLayout)
467+
dest = appendDatetimeFormat(dest, vt)
479468
}
480469
default:
481470
return sqltypes.Value{}, sql.ErrInvalidBaseType.New(t.baseType.String(), "datetime")
482471
}
483472

484-
valBytes := val
485-
486-
return sqltypes.MakeTrusted(typ, valBytes), nil
473+
return sqltypes.MakeTrusted(t.baseType, dest), nil
487474
}
488475

489476
// SQLValue implements the ValueType interface.
490477
func (t datetimeType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) {
491478
if v.IsNull() {
492479
return sqltypes.NULL, nil
493480
}
481+
494482
switch t.baseType {
495483
case sqltypes.Date:
496-
t := values.ReadDate(v.Val)
497-
dest = t.AppendFormat(dest, sql.DateLayout)
484+
vt := values.ReadDate(v.Val)
485+
if vt.Equal(ZeroTime) {
486+
dest = append(dest, ZeroDateStr...)
487+
} else {
488+
dest = appendDateFormat(dest, vt)
489+
}
498490
case sqltypes.Datetime, sqltypes.Timestamp:
499491
x := values.ReadInt64(v.Val)
500-
t := time.UnixMicro(x).UTC()
501-
dest = t.AppendFormat(dest, sql.TimestampDatetimeLayout)
492+
vt := time.UnixMicro(x).UTC()
493+
if vt.Equal(ZeroTime) {
494+
dest = append(dest, ZeroTimestampDatetimeStr...)
495+
} else {
496+
dest = appendDatetimeFormat(dest, vt)
497+
}
502498
default:
503499
return sqltypes.Value{}, sql.ErrInvalidBaseType.New(t.baseType.String(), "datetime")
504500
}
505501
return sqltypes.MakeTrusted(t.baseType, dest), nil
506502
}
507503

504+
func appendDateFormat(dest []byte, t time.Time) []byte {
505+
year := t.Year()
506+
if year == 0 {
507+
dest = append(dest, '0', '0', '0', '0')
508+
} else {
509+
dest = strconv.AppendInt(dest, int64(year), 10)
510+
}
511+
dest = append(dest, '-')
512+
513+
month := int64(t.Month())
514+
if month < 10 {
515+
dest = append(dest, '0')
516+
}
517+
dest = strconv.AppendInt(dest, month, 10)
518+
dest = append(dest, '-')
519+
520+
day := int64(t.Day())
521+
if day < 10 {
522+
dest = append(dest, '0')
523+
}
524+
dest = strconv.AppendInt(dest, day, 10)
525+
return dest
526+
}
527+
528+
func appendDatetimeFormat(dest []byte, t time.Time) []byte {
529+
dest = appendDateFormat(dest, t)
530+
dest = append(dest, ' ')
531+
dest = appendTimeFormat(dest, int64(t.Hour()), int64(t.Minute()), int64(t.Second()), int64(t.Nanosecond()/1000))
532+
return dest
533+
}
534+
508535
func (t datetimeType) String() string {
509536
switch t.baseType {
510537
case sqltypes.Date:

sql/types/time.go

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -264,20 +264,22 @@ func (t TimespanType_) SQL(_ *sql.Context, dest []byte, v interface{}) (sqltypes
264264
if v == nil {
265265
return sqltypes.NULL, nil
266266
}
267+
267268
ti, err := t.ConvertToTimespan(v)
268269
if err != nil {
269270
return sqltypes.Value{}, err
270271
}
271272

272-
val := ti.Bytes()
273-
return sqltypes.MakeTrusted(sqltypes.Time, val), nil
273+
dest = ti.AppendBytes(dest)
274+
return sqltypes.MakeTrusted(sqltypes.Time, dest), nil
274275
}
275276

276277
// SQLValue implements ValueType interface.
277278
func (t TimespanType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) {
278279
if v.IsNull() {
279280
return sqltypes.NULL, nil
280281
}
282+
281283
x := values.ReadInt64(v.Val)
282284
dest = Timespan(x).AppendBytes(dest)
283285
return sqltypes.MakeTrusted(sqltypes.Time, dest), nil
@@ -502,39 +504,40 @@ func (t Timespan) Bytes() []byte {
502504
}
503505

504506
func (t Timespan) AppendBytes(dest []byte) []byte {
505-
isNegative, hours, minutes, seconds, microseconds := t.timespanToUnits()
506-
sz := 10
507-
if microseconds > 0 {
508-
sz += 7
509-
}
510-
if isNegative {
507+
isNeg, h, m, s, ms := t.timespanToUnits()
508+
if isNeg {
511509
dest = append(dest, '-')
512510
}
511+
dest = appendTimeFormat(dest, int64(h), int64(m), int64(s), int64(ms))
512+
return dest
513+
}
513514

514-
if hours < 10 {
515+
func appendTimeFormat(dest []byte, h, m, s, ms int64) []byte {
516+
if h < 10 {
515517
dest = append(dest, '0')
516518
}
517-
dest = strconv.AppendInt(dest, int64(hours), 10)
519+
dest = strconv.AppendInt(dest, h, 10)
518520
dest = append(dest, ':')
519521

520-
if minutes < 10 {
522+
if m < 10 {
521523
dest = append(dest, '0')
522524
}
523-
dest = strconv.AppendInt(dest, int64(minutes), 10)
525+
dest = strconv.AppendInt(dest, m, 10)
524526
dest = append(dest, ':')
525527

526-
if seconds < 10 {
528+
if s < 10 {
527529
dest = append(dest, '0')
528530
}
529-
dest = strconv.AppendInt(dest, int64(seconds), 10)
530-
if microseconds > 0 {
531+
dest = strconv.AppendInt(dest, s, 10)
532+
533+
if ms > 0 {
531534
dest = append(dest, '.')
532-
cmp := int32(100000)
533-
for cmp > 0 && microseconds < cmp {
535+
cmp := int64(100000)
536+
for cmp > 0 && ms < cmp {
534537
dest = append(dest, '0')
535538
cmp /= 10
536539
}
537-
dest = strconv.AppendInt(dest, int64(microseconds), 10)
540+
dest = strconv.AppendInt(dest, ms, 10)
538541
}
539542
return dest
540543
}

0 commit comments

Comments
 (0)