diff --git a/pkg/database/mysql/dump.go b/pkg/database/mysql/dump.go index e31df4a..61d09f7 100644 --- a/pkg/database/mysql/dump.go +++ b/pkg/database/mysql/dump.go @@ -25,6 +25,8 @@ import ( "strings" "text/template" "time" + + dbutil "github.com/databacker/mysql-backup/pkg/util/database" ) /* @@ -177,16 +179,11 @@ func (data *Data) Dump() error { // Lock all tables before dumping if present if data.LockTables && (len(tables) > 0 || len(views) > 0) { - var b bytes.Buffer - b.WriteString("LOCK TABLES ") - for index, table := range append(tables, views...) { - if index != 0 { - b.WriteString(",") - } - b.WriteString("`" + table.Name() + "` READ /*!32311 LOCAL */") + lockCommand, err := data.getBackupLockCommand(tables, views) + if err != nil { + return fmt.Errorf("failed to get lock command: %w", err) } - - if _, err := data.Connection.Exec(b.String()); err != nil { + if _, err := data.Connection.Exec(lockCommand); err != nil { return err } @@ -537,6 +534,36 @@ func (data *Data) getProceduresOrFunctionsCreateQueries(t string) ([]string, err return toGet, nil } +// getBackupLockCommand returns the SQL command to lock the tables for backup +// It may vary depending on the database variant or version, so it is generated dynamically +func (data *Data) getBackupLockCommand(tables, views []Table) (string, error) { + dbVar, err := dbutil.DetectVariant(data.Connection) + if err != nil { + return "", fmt.Errorf("failed to determine database variant: %w", err) + } + var lockString string + switch dbVar { + case dbutil.VariantMariaDB: + lockString = "LOCK TABLES" + case dbutil.VariantMySQL: + lockString = "LOCK TABLES" + case dbutil.VariantPercona: + // Percona just use the simple LOCK TABLES FOR BACKUP command + return "LOCK TABLES FOR BACKUP", nil + default: + lockString = "LOCK TABLES" + } + var b bytes.Buffer + b.WriteString(lockString + " ") + for index, table := range append(tables, views...) { + if index != 0 { + b.WriteString(",") + } + b.WriteString("`" + table.Name() + "` READ /*!32311 LOCAL */") + } + return b.String(), nil +} + func (meta *metaData) updateMetadata(data *Data) (err error) { var serverVersion sql.NullString err = data.tx.QueryRow("SELECT version()").Scan(&serverVersion) diff --git a/pkg/util/database/const.go b/pkg/util/database/const.go new file mode 100644 index 0000000..b876686 --- /dev/null +++ b/pkg/util/database/const.go @@ -0,0 +1,12 @@ +package database + +type Variant string + +const ( + // VariantMariaDB is the MariaDB variant of MySQL. + VariantMariaDB Variant = "mariadb" + // VariantMySQL is the MySQL variant of MySQL. + VariantMySQL Variant = "mysql" + // VariantPercona is the Percona variant of MySQL. + VariantPercona Variant = "percona" +) diff --git a/pkg/util/database/detect.go b/pkg/util/database/detect.go new file mode 100644 index 0000000..dd967c5 --- /dev/null +++ b/pkg/util/database/detect.go @@ -0,0 +1,47 @@ +package database + +import ( + "database/sql" + "fmt" + "strings" +) + +// DetectVariant returns the variant of the database, which can affect some commands. +// It uses several heuristics to determine the variant based on the version and comment. +// None of this is 100% reliable, but it should work for most cases. +func DetectVariant(conn *sql.DB) (Variant, error) { + // Check @@version and @@version_comment + var version, comment string + err := conn.QueryRow("SELECT @@version, @@version_comment").Scan(&version, &comment) + if err != nil { + return "", fmt.Errorf("failed to query version: %w", err) + } + + versionLower := strings.ToLower(version) + commentLower := strings.ToLower(comment) + + // Heuristic 1: version string or comment + switch { + case strings.Contains(versionLower, "mariadb") || strings.Contains(commentLower, "mariadb"): + return VariantMariaDB, nil + case strings.Contains(commentLower, "percona"): + return VariantPercona, nil + case strings.Contains(commentLower, "mysql"): + return VariantMySQL, nil + } + + // Heuristic 2: Check for Aria engine (MariaDB) + var dummy string + err = conn.QueryRow("SELECT 1 FROM information_schema.engines WHERE engine = 'Aria' LIMIT 1").Scan(&dummy) + if err == nil { + return VariantMariaDB, nil + } + + // Heuristic 3: Percona plugins + err = conn.QueryRow("SELECT 1 FROM information_schema.plugins WHERE plugin_name LIKE '%percona%' LIMIT 1").Scan(&dummy) + if err == nil { + return VariantPercona, nil + } + + return VariantMySQL, nil +} diff --git a/test/backup_test.go b/test/backup_test.go index 190ab19..dcbeb0e 100644 --- a/test/backup_test.go +++ b/test/backup_test.go @@ -8,6 +8,7 @@ import ( "bytes" "compress/gzip" "context" + "database/sql" "errors" "fmt" "io" @@ -28,6 +29,8 @@ import ( "github.com/databacker/mysql-backup/pkg/database" "github.com/databacker/mysql-backup/pkg/storage" "github.com/databacker/mysql-backup/pkg/storage/credentials" + dbutil "github.com/databacker/mysql-backup/pkg/util/database" + "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" @@ -47,6 +50,8 @@ const ( mysqlRootPass = "root" smbImage = "mysqlbackup_smb_test:latest" mysqlImage = "mysql:8.2.0" + mariaImage = "mariadb:11.8.2-noble" + perconaImage = "percona:8.0.42-33" bucketName = "mybucket" ) @@ -880,8 +885,57 @@ func populatePrePost(base string, targets []backupTarget) (err error) { return nil } +func startDatabase(dc *dockerContext, baseDir, image, name string) (containerPort, error) { + resp, err := dc.cli.ImagePull(context.Background(), image, types.ImagePullOptions{}) + if err != nil { + return containerPort{}, fmt.Errorf("failed to pull mysql image: %v", err) + } + io.Copy(os.Stdout, resp) + resp.Close() + + // start the mysql container; configure it for lots of debug logging, in case we need it + mysqlConf := ` +[mysqld] +log_error =/var/log/mysql/mysql_error.log +general_log_file=/var/log/mysql/mysql.log +general_log =1 +slow_query_log =1 +slow_query_log_file=/var/log/mysql/mysql_slow.log +long_query_time =2 +log_queries_not_using_indexes = 1 +` + if err := os.Mkdir(baseDir, 0o755); err != nil { + return containerPort{}, fmt.Errorf("failed to create mysql base directory: %v", err) + } + confFile := filepath.Join(baseDir, "log.cnf") + if err := os.WriteFile(confFile, []byte(mysqlConf), 0644); err != nil { + return containerPort{}, fmt.Errorf("failed to write mysql config file: %v", err) + } + logDir := filepath.Join(baseDir, "mysql_logs") + if err := os.Mkdir(logDir, 0755); err != nil { + return containerPort{}, fmt.Errorf("failed to create mysql log directory: %v", err) + } + + // start mysql + cid, port, err := dc.startContainer( + image, name, "3306/tcp", []string{fmt.Sprintf("%s:/etc/mysql/conf.d/log.conf:ro", confFile), fmt.Sprintf("%s:/var/log/mysql", logDir)}, nil, []string{ + fmt.Sprintf("MYSQL_ROOT_PASSWORD=%s", mysqlRootPass), + "MYSQL_DATABASE=tester", + fmt.Sprintf("MYSQL_USER=%s", mysqlUser), + fmt.Sprintf("MYSQL_PASSWORD=%s", mysqlPass), + }) + if err != nil { + return containerPort{}, fmt.Errorf("failed to start mysql container: %v", err) + } + return containerPort{name: name, id: cid, port: port}, nil +} + func TestIntegration(t *testing.T) { syscall.Umask(0) + dc, err := getDockerContext() + if err != nil { + t.Fatalf("failed to get docker client: %v", err) + } t.Run("dump", func(t *testing.T) { var ( err error @@ -898,10 +952,6 @@ func TestIntegration(t *testing.T) { if err := os.Chmod(base, 0o777); err != nil { t.Fatalf("failed to chmod temp dir: %v", err) } - dc, err := getDockerContext() - if err != nil { - t.Fatalf("failed to get docker client: %v", err) - } backupFile := filepath.Join(base, "backup.sql") compactBackupFile := filepath.Join(base, "backup-compact.sql") if mysql, smb, s3, s3backend, err = setup(dc, base, backupFile, compactBackupFile); err != nil { @@ -1033,4 +1083,64 @@ func TestIntegration(t *testing.T) { }) }) }) + t.Run("dbutil", func(t *testing.T) { + t.Run("detect", func(t *testing.T) { + // start all database variants + // wait for them to be ready + // then run the detect command on each of them + // then tear them down + + // set up dirs + + base := t.TempDir() + tests := []struct { + name string + image string + containerName string + variant dbutil.Variant + }{ + {"mysql", mysqlImage, "mysql-detect", dbutil.VariantMySQL}, + {"maria", mariaImage, "maria-detect", dbutil.VariantMariaDB}, + {"percona", perconaImage, "percona-detect", dbutil.VariantPercona}, + } + // tear down at the end + var cids []string + defer func() { + if err := teardown(dc, cids...); err != nil { + log.Errorf("failed to teardown test: %v", err) + } + }() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + container, err := startDatabase(dc, filepath.Join(base, tt.name), tt.image, tt.containerName) + if err != nil { + t.Fatalf("failed to start mysql container: %v", err) + } + cids = append(cids, container.id) + if err = dc.waitForDBConnectionAndGrantPrivileges(container.id, mysqlRootUser, mysqlRootPass); err != nil { + return + } + dbconn := database.Connection{ + User: mysqlRootUser, + Pass: mysqlRootPass, + Host: "localhost", + Port: container.port, + } + + db, err := sql.Open("mysql", dbconn.MySQL()) + if err != nil { + t.Fatalf("failed to open connection to database: %v", err) + } + v, err := dbutil.DetectVariant(db) + if err != nil { + t.Errorf("error detecting database variant: %v", err) + } + if v != tt.variant { + t.Errorf("expected database variant to be %s, got %s", tt.variant, v) + } + }) + } + }) + }) }