Skip to content

Commit 4462fa0

Browse files
authored
rewrite last query info (#3321)
1 parent 4d67403 commit 4462fa0

File tree

8 files changed

+41
-73
lines changed

8 files changed

+41
-73
lines changed

sql/base_session.go

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ type BaseSession struct {
3333
logger *logrus.Entry
3434
locks map[string]bool
3535
storedProcParams map[string]*StoredProcParam
36-
lastQueryInfo map[string]*atomic.Value
3736
systemVars map[string]SystemVarValue
3837
statusVars map[string]StatusVarValue
38+
lastQueryInfo *LastQueryInfo
3939
idxReg *IndexRegistry
4040
viewReg *ViewRegistry
4141
transactionDb string
@@ -497,28 +497,8 @@ func (s *BaseSession) SetViewRegistry(reg *ViewRegistry) {
497497
s.viewReg = reg
498498
}
499499

500-
func (s *BaseSession) SetLastQueryInfoInt(key string, value int64) {
501-
s.lastQueryInfo[key].Store(value)
502-
}
503-
504-
func (s *BaseSession) GetLastQueryInfoInt(key string) int64 {
505-
value, ok := s.lastQueryInfo[key].Load().(int64)
506-
if !ok {
507-
panic(fmt.Sprintf("last query info value stored for %s is not an int64 value, but a %T", key, s.lastQueryInfo[key]))
508-
}
509-
return value
510-
}
511-
512-
func (s *BaseSession) SetLastQueryInfoString(key string, value string) {
513-
s.lastQueryInfo[key].Store(value)
514-
}
515-
516-
func (s *BaseSession) GetLastQueryInfoString(key string) string {
517-
value, ok := s.lastQueryInfo[key].Load().(string)
518-
if !ok {
519-
panic(fmt.Sprintf("last query info value stored for %s is not a string value, but a %T", key, s.lastQueryInfo[key]))
520-
}
521-
return value
500+
func (s *BaseSession) GetLastQueryInfo() *LastQueryInfo {
501+
return s.lastQueryInfo
522502
}
523503

524504
func (s *BaseSession) GetTransaction() Transaction {

sql/expression/auto_uuid.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func (au *AutoUuid) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
7777
// read this value too early. We should verify this isn't how MySQL behaves, and then could fix
7878
// by setting a PENDING_LAST_INSERT_UUID value in the session query info, then moving it to
7979
// LAST_INSERT_UUID in the session query info at the end of execution.
80-
ctx.Session.SetLastQueryInfoString(sql.LastInsertUuid, uuidValue)
80+
ctx.GetLastQueryInfo().LastInsertUUID.Store(uuidValue)
8181
au.foundUuid = true
8282
}
8383

sql/expression/function/queryinfo.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func (r *RowCount) IsNullable() bool {
6666

6767
// Eval implements sql.Expression
6868
func (r *RowCount) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
69-
return ctx.GetLastQueryInfoInt(sql.RowCount), nil
69+
return ctx.GetLastQueryInfo().RowCount.Load(), nil
7070
}
7171

7272
// Children implements sql.Expression
@@ -126,8 +126,8 @@ func (l *LastInsertUuid) IsNullable() bool {
126126
}
127127

128128
func (l *LastInsertUuid) Eval(ctx *sql.Context, _ sql.Row) (interface{}, error) {
129-
lastInsertUuid := ctx.GetLastQueryInfoString(sql.LastInsertUuid)
130-
result, _, err := l.Type().Convert(ctx, lastInsertUuid)
129+
lastInsertUUID := ctx.GetLastQueryInfo().LastInsertUUID.Load()
130+
result, _, err := l.Type().Convert(ctx, lastInsertUUID)
131131
if err != nil {
132132
return nil, err
133133
}
@@ -204,7 +204,7 @@ func (r *LastInsertId) IsNullable() bool {
204204
func (r *LastInsertId) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
205205
// With no arguments, just return the last insert id for this session
206206
if len(r.Children()) == 0 {
207-
lastInsertId := ctx.GetLastQueryInfoInt(sql.LastInsertId)
207+
lastInsertId := ctx.GetLastQueryInfo().LastInsertId.Load()
208208
unsigned, _, err := types.Uint64.Convert(ctx, lastInsertId)
209209
if err != nil {
210210
return nil, err
@@ -222,7 +222,7 @@ func (r *LastInsertId) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
222222
return nil, err
223223
}
224224

225-
ctx.SetLastQueryInfoInt(sql.LastInsertId, id.(int64))
225+
ctx.GetLastQueryInfo().LastInsertId.Store(id.(int64))
226226
return id, nil
227227
}
228228

@@ -296,7 +296,7 @@ func (r *FoundRows) IsNullable() bool {
296296

297297
// Eval implements sql.Expression
298298
func (r *FoundRows) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
299-
return ctx.GetLastQueryInfoInt(sql.FoundRows), nil
299+
return ctx.GetLastQueryInfo().FoundRows.Load(), nil
300300
}
301301

302302
// Children implements sql.Expression

sql/iters/rel_iters.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,9 @@ func (i *topRowsIter) Next(ctx *sql.Context) (sql.Row, error) {
6767

6868
func (i *topRowsIter) Close(ctx *sql.Context) error {
6969
i.topRows = nil
70-
7170
if i.calcFoundRows {
72-
ctx.SetLastQueryInfoInt(sql.FoundRows, i.numFoundRows)
71+
ctx.GetLastQueryInfo().FoundRows.Store(i.numFoundRows)
7372
}
74-
7573
return i.childIter.Close(ctx)
7674
}
7775

@@ -467,9 +465,8 @@ func (li *LimitIter) Close(ctx *sql.Context) error {
467465
if err != nil {
468466
return err
469467
}
470-
471468
if li.CalcFoundRows {
472-
ctx.SetLastQueryInfoInt(sql.FoundRows, li.currentPos)
469+
ctx.GetLastQueryInfo().FoundRows.Store(li.currentPos)
473470
}
474471
return nil
475472
}

sql/plan/process.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,19 +354,20 @@ func (i *TrackedRowIter) GetIter() sql.RowIter {
354354
}
355355

356356
func (i *TrackedRowIter) updateSessionVars(ctx *sql.Context) {
357+
// TODO: possible to just remove switch entirely?
357358
switch i.QueryType {
358359
case QueryTypeSelect:
359-
ctx.SetLastQueryInfoInt(sql.RowCount, -1)
360+
ctx.GetLastQueryInfo().RowCount.Store(-1)
360361
case QueryTypeDdl:
361-
ctx.SetLastQueryInfoInt(sql.RowCount, 0)
362+
ctx.GetLastQueryInfo().RowCount.Store(0)
362363
case QueryTypeUpdate:
363364
// This is handled by RowUpdateAccumulator
364365
default:
365366
panic(fmt.Sprintf("Unexpected query type %v", i.QueryType))
366367
}
367368

368369
if i.ShouldSetFoundRows {
369-
ctx.SetLastQueryInfoInt(sql.FoundRows, i.numRows)
370+
ctx.GetLastQueryInfo().FoundRows.Store(i.numRows)
370371
}
371372
}
372373

sql/rowexec/dml_iters.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -746,14 +746,15 @@ func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) {
746746
}
747747
if err == io.EOF {
748748
// TODO: The information flow here is pretty gnarly. We
749-
// set some session variables based on the result, and
750-
// we actually use a session variable to set
751-
// InsertID. This should be improved.
749+
// set some session variables based on the result, and
750+
// we actually use a session variable to set
751+
// InsertID. This should be improved.
752752

753753
// UPDATE statements also set FoundRows to the number of rows that
754754
// matched the WHERE clause, same as a SELECT.
755+
lastQueryInfo := ctx.GetLastQueryInfo()
755756
if ma, ok := a.updateRowHandler.(matchingAccumulator); ok {
756-
ctx.SetLastQueryInfoInt(sql.FoundRows, ma.RowsMatched())
757+
lastQueryInfo.FoundRows.Store(ma.RowsMatched())
757758
}
758759

759760
res := a.updateRowHandler.okResult() // TODO: Should add warnings here
@@ -763,14 +764,13 @@ func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) {
763764
// to be fixed. See comment in buildRowUpdateAccumulator in rowexec/dml.go
764765
switch rowHandler := a.updateRowHandler.(type) {
765766
case *onDuplicateUpdateHandler, *replaceRowHandler:
766-
lastInsertId := ctx.Session.GetLastQueryInfoInt(sql.LastInsertId)
767+
lastInsertId := lastQueryInfo.LastInsertId.Load()
767768
res.InsertID = uint64(lastInsertId)
768769
case *insertRowHandler:
769770
res.InsertID = rowHandler.lastInsertId
770771
}
771-
772772
// By definition, ROW_COUNT() is equal to RowsAffected.
773-
ctx.SetLastQueryInfoInt(sql.RowCount, int64(res.RowsAffected))
773+
lastQueryInfo.RowCount.Store(int64(res.RowsAffected))
774774

775775
return sql.NewRow(res), nil
776776
} else if isIg {

sql/rowexec/insert.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ func (i *insertIter) updateLastInsertId(ctx *sql.Context, row sql.Row) {
355355
}
356356
if i.firstGeneratedAutoIncRowIdx == 0 {
357357
autoIncVal := i.getAutoIncVal(row)
358-
ctx.SetLastQueryInfoInt(sql.LastInsertId, autoIncVal)
358+
ctx.GetLastQueryInfo().LastInsertId.Store(autoIncVal)
359359
}
360360
i.firstGeneratedAutoIncRowIdx--
361361
}

sql/session.go

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,8 @@ type Session interface {
131131
DelLock(lockName string) error
132132
// IterLocks iterates through all locks owned by this user
133133
IterLocks(cb func(name string) error) error
134-
// SetLastQueryInfoInt sets session-level query info for the key given, applying to the query just executed.
135-
SetLastQueryInfoInt(key string, value int64)
136-
// GetLastQueryInfoInt returns the session-level query info for the key given, for the query most recently executed.
137-
GetLastQueryInfoInt(key string) int64
138-
// SetLastQueryInfoString sets session-level query info as a string for the key given, applying to the query just executed.
139-
SetLastQueryInfoString(key string, value string)
140-
// GetLastQueryInfoString returns the session-level query info as a string for the key given, for the query most recently executed.
141-
GetLastQueryInfoString(key string) string
134+
// GetLastQueryInfo returns session-level info for the most recently executed query.
135+
GetLastQueryInfo() *LastQueryInfo
142136
// GetTransaction returns the active transaction, if any
143137
GetTransaction() Transaction
144138
// SetTransaction sets the session's transaction
@@ -253,12 +247,21 @@ type (
253247
}
254248
)
255249

256-
const (
257-
RowCount = "row_count"
258-
FoundRows = "found_rows"
259-
LastInsertId = "last_insert_id"
260-
LastInsertUuid = "last_insert_uuid"
261-
)
250+
type LastQueryInfo struct {
251+
RowCount atomic.Int64 // Session-level Row Count for the last executed query
252+
FoundRows atomic.Int64 // Session-level Found Rows for the last executed query
253+
LastInsertId atomic.Int64 // Session-level ID for the last executed insert query
254+
LastInsertUUID atomic.Value // Session-level UUID for the last executed insert query
255+
}
256+
257+
func defaultLastQueryInfo() *LastQueryInfo {
258+
ret := LastQueryInfo{}
259+
ret.RowCount.Store(0)
260+
ret.FoundRows.Store(1) // this is kind of a hack -- it handles the case of `select found_rows()` before any select statement is issued
261+
ret.LastInsertId.Store(0)
262+
ret.LastInsertUUID.Store("")
263+
return &ret
264+
}
262265

263266
// Session ID 0 used as invalid SessionID
264267
var autoSessionIDs uint32 = 1
@@ -703,19 +706,6 @@ func (i *spanIter) Close(ctx *Context) error {
703706
return i.iter.Close(ctx)
704707
}
705708

706-
func defaultLastQueryInfo() map[string]*atomic.Value {
707-
ret := make(map[string]*atomic.Value)
708-
ret[RowCount] = &atomic.Value{}
709-
ret[RowCount].Store(int64(0))
710-
ret[FoundRows] = &atomic.Value{}
711-
ret[FoundRows].Store(int64(1)) // this is kind of a hack -- it handles the case of `select found_rows()` before any select statement is issue)
712-
ret[LastInsertId] = &atomic.Value{}
713-
ret[LastInsertId].Store(int64(0))
714-
ret[LastInsertUuid] = &atomic.Value{}
715-
ret[LastInsertUuid].Store("")
716-
return ret
717-
}
718-
719709
// cc: https://dev.mysql.com/doc/refman/8.0/en/temporary-files.html
720710
func GetTmpdirSessionVar() string {
721711
ret := os.Getenv("TMPDIR")

0 commit comments

Comments
 (0)