Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions pkg/database/mysql/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"strings"
"text/template"
"time"

dbutil "github.com/databacker/mysql-backup/pkg/util/database"
)

/*
Expand Down Expand Up @@ -177,16 +179,11 @@ func (data *Data) Dump() error {

// Lock all tables before dumping if present
if data.LockTables && (len(tables) > 0 || len(views) > 0) {
var b bytes.Buffer
b.WriteString("LOCK TABLES ")
for index, table := range append(tables, views...) {
if index != 0 {
b.WriteString(",")
}
b.WriteString("`" + table.Name() + "` READ /*!32311 LOCAL */")
lockCommand, err := data.getBackupLockCommand(tables, views)
if err != nil {
return fmt.Errorf("failed to get lock command: %w", err)
}

if _, err := data.Connection.Exec(b.String()); err != nil {
if _, err := data.Connection.Exec(lockCommand); err != nil {
return err
}

Expand Down Expand Up @@ -537,6 +534,36 @@ func (data *Data) getProceduresOrFunctionsCreateQueries(t string) ([]string, err
return toGet, nil
}

// getBackupLockCommand returns the SQL command to lock the tables for backup
// It may vary depending on the database variant or version, so it is generated dynamically
func (data *Data) getBackupLockCommand(tables, views []Table) (string, error) {
dbVar, err := dbutil.DetectVariant(data.Connection)
if err != nil {
return "", fmt.Errorf("failed to determine database variant: %w", err)
}
var lockString string
switch dbVar {
case dbutil.VariantMariaDB:
lockString = "LOCK TABLES"
case dbutil.VariantMySQL:
lockString = "LOCK TABLES"
case dbutil.VariantPercona:
// Percona just use the simple LOCK TABLES FOR BACKUP command
return "LOCK TABLES FOR BACKUP", nil
default:
lockString = "LOCK TABLES"
}
var b bytes.Buffer
b.WriteString(lockString + " ")
for index, table := range append(tables, views...) {
if index != 0 {
b.WriteString(",")
}
b.WriteString("`" + table.Name() + "` READ /*!32311 LOCAL */")
}
return b.String(), nil
}

func (meta *metaData) updateMetadata(data *Data) (err error) {
var serverVersion sql.NullString
err = data.tx.QueryRow("SELECT version()").Scan(&serverVersion)
Expand Down
12 changes: 12 additions & 0 deletions pkg/util/database/const.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package database

type Variant string

const (
// VariantMariaDB is the MariaDB variant of MySQL.
VariantMariaDB Variant = "mariadb"
// VariantMySQL is the MySQL variant of MySQL.
VariantMySQL Variant = "mysql"
// VariantPercona is the Percona variant of MySQL.
VariantPercona Variant = "percona"
)
47 changes: 47 additions & 0 deletions pkg/util/database/detect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package database

import (
"database/sql"
"fmt"
"strings"
)

// DetectVariant returns the variant of the database, which can affect some commands.
// It uses several heuristics to determine the variant based on the version and comment.
// None of this is 100% reliable, but it should work for most cases.
func DetectVariant(conn *sql.DB) (Variant, error) {
// Check @@version and @@version_comment
var version, comment string
err := conn.QueryRow("SELECT @@version, @@version_comment").Scan(&version, &comment)
if err != nil {
return "", fmt.Errorf("failed to query version: %w", err)
}

versionLower := strings.ToLower(version)
commentLower := strings.ToLower(comment)

// Heuristic 1: version string or comment
switch {
case strings.Contains(versionLower, "mariadb") || strings.Contains(commentLower, "mariadb"):
return VariantMariaDB, nil
case strings.Contains(commentLower, "percona"):
return VariantPercona, nil
case strings.Contains(commentLower, "mysql"):
return VariantMySQL, nil
}

// Heuristic 2: Check for Aria engine (MariaDB)
var dummy string
err = conn.QueryRow("SELECT 1 FROM information_schema.engines WHERE engine = 'Aria' LIMIT 1").Scan(&dummy)
if err == nil {
return VariantMariaDB, nil
}

// Heuristic 3: Percona plugins
err = conn.QueryRow("SELECT 1 FROM information_schema.plugins WHERE plugin_name LIKE '%percona%' LIMIT 1").Scan(&dummy)
if err == nil {
return VariantPercona, nil
}

return VariantMySQL, nil
}
118 changes: 114 additions & 4 deletions test/backup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"compress/gzip"
"context"
"database/sql"
"errors"
"fmt"
"io"
Expand All @@ -28,6 +29,8 @@ import (
"github.com/databacker/mysql-backup/pkg/database"
"github.com/databacker/mysql-backup/pkg/storage"
"github.com/databacker/mysql-backup/pkg/storage/credentials"
dbutil "github.com/databacker/mysql-backup/pkg/util/database"

"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
Expand All @@ -47,6 +50,8 @@ const (
mysqlRootPass = "root"
smbImage = "mysqlbackup_smb_test:latest"
mysqlImage = "mysql:8.2.0"
mariaImage = "mariadb:11.8.2-noble"
perconaImage = "percona:8.0.42-33"
bucketName = "mybucket"
)

Expand Down Expand Up @@ -880,8 +885,57 @@ func populatePrePost(base string, targets []backupTarget) (err error) {
return nil
}

func startDatabase(dc *dockerContext, baseDir, image, name string) (containerPort, error) {
resp, err := dc.cli.ImagePull(context.Background(), image, types.ImagePullOptions{})
if err != nil {
return containerPort{}, fmt.Errorf("failed to pull mysql image: %v", err)
}
io.Copy(os.Stdout, resp)
resp.Close()

// start the mysql container; configure it for lots of debug logging, in case we need it
mysqlConf := `
[mysqld]
log_error =/var/log/mysql/mysql_error.log
general_log_file=/var/log/mysql/mysql.log
general_log =1
slow_query_log =1
slow_query_log_file=/var/log/mysql/mysql_slow.log
long_query_time =2
log_queries_not_using_indexes = 1
`
if err := os.Mkdir(baseDir, 0o755); err != nil {
return containerPort{}, fmt.Errorf("failed to create mysql base directory: %v", err)
}
confFile := filepath.Join(baseDir, "log.cnf")
if err := os.WriteFile(confFile, []byte(mysqlConf), 0644); err != nil {
return containerPort{}, fmt.Errorf("failed to write mysql config file: %v", err)
}
logDir := filepath.Join(baseDir, "mysql_logs")
if err := os.Mkdir(logDir, 0755); err != nil {
return containerPort{}, fmt.Errorf("failed to create mysql log directory: %v", err)
}

// start mysql
cid, port, err := dc.startContainer(
image, name, "3306/tcp", []string{fmt.Sprintf("%s:/etc/mysql/conf.d/log.conf:ro", confFile), fmt.Sprintf("%s:/var/log/mysql", logDir)}, nil, []string{
fmt.Sprintf("MYSQL_ROOT_PASSWORD=%s", mysqlRootPass),
"MYSQL_DATABASE=tester",
fmt.Sprintf("MYSQL_USER=%s", mysqlUser),
fmt.Sprintf("MYSQL_PASSWORD=%s", mysqlPass),
})
if err != nil {
return containerPort{}, fmt.Errorf("failed to start mysql container: %v", err)
}
return containerPort{name: name, id: cid, port: port}, nil
}

func TestIntegration(t *testing.T) {
syscall.Umask(0)
dc, err := getDockerContext()
if err != nil {
t.Fatalf("failed to get docker client: %v", err)
}
t.Run("dump", func(t *testing.T) {
var (
err error
Expand All @@ -898,10 +952,6 @@ func TestIntegration(t *testing.T) {
if err := os.Chmod(base, 0o777); err != nil {
t.Fatalf("failed to chmod temp dir: %v", err)
}
dc, err := getDockerContext()
if err != nil {
t.Fatalf("failed to get docker client: %v", err)
}
backupFile := filepath.Join(base, "backup.sql")
compactBackupFile := filepath.Join(base, "backup-compact.sql")
if mysql, smb, s3, s3backend, err = setup(dc, base, backupFile, compactBackupFile); err != nil {
Expand Down Expand Up @@ -1033,4 +1083,64 @@ func TestIntegration(t *testing.T) {
})
})
})
t.Run("dbutil", func(t *testing.T) {
t.Run("detect", func(t *testing.T) {
// start all database variants
// wait for them to be ready
// then run the detect command on each of them
// then tear them down

// set up dirs

base := t.TempDir()
tests := []struct {
name string
image string
containerName string
variant dbutil.Variant
}{
{"mysql", mysqlImage, "mysql-detect", dbutil.VariantMySQL},
{"maria", mariaImage, "maria-detect", dbutil.VariantMariaDB},
{"percona", perconaImage, "percona-detect", dbutil.VariantPercona},
}
// tear down at the end
var cids []string
defer func() {
if err := teardown(dc, cids...); err != nil {
log.Errorf("failed to teardown test: %v", err)
}
}()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
container, err := startDatabase(dc, filepath.Join(base, tt.name), tt.image, tt.containerName)
if err != nil {
t.Fatalf("failed to start mysql container: %v", err)
}
cids = append(cids, container.id)
if err = dc.waitForDBConnectionAndGrantPrivileges(container.id, mysqlRootUser, mysqlRootPass); err != nil {
return
}
dbconn := database.Connection{
User: mysqlRootUser,
Pass: mysqlRootPass,
Host: "localhost",
Port: container.port,
}

db, err := sql.Open("mysql", dbconn.MySQL())
if err != nil {
t.Fatalf("failed to open connection to database: %v", err)
}
v, err := dbutil.DetectVariant(db)
if err != nil {
t.Errorf("error detecting database variant: %v", err)
}
if v != tt.variant {
t.Errorf("expected database variant to be %s, got %s", tt.variant, v)
}
})
}
})
})
}