@@ -19,6 +19,7 @@ package internal
1919import (
2020 "context"
2121 "database/sql"
22+ "errors"
2223 "fmt"
2324 "strconv"
2425 "strings"
@@ -84,7 +85,6 @@ func NewConfigFromClusterKey(c client.Client, clusterKey client.ObjectKey, userN
8485 Port : utils .MysqlPort ,
8586 }, nil
8687
87-
8888 case utils .RootUser :
8989 password , ok := secret .Data ["root-password" ]
9090 if ! ok {
@@ -323,3 +323,139 @@ func columnValue(scanArgs []interface{}, slaveCols []string, colName string) str
323323
324324 return string (* scanArgs [columnIndex ].(* sql.RawBytes ))
325325}
326+
327+ // CreateUserIfNotExists creates a user if it doesn't already exist and it gives it the specified permissions.
328+ func (s sqlRunner ) CreateUserIfNotExists (
329+ user , pass string , allowedHosts []string , permissions []apiv1alpha1.UserPermission ,
330+ ) error {
331+
332+ // Throw error if there are no allowed hosts.
333+ if len (allowedHosts ) == 0 {
334+ return errors .New ("no allowedHosts specified" )
335+ }
336+
337+ queries := []Query {
338+ getCreateUserQuery (user , pass , allowedHosts ),
339+ // todo: getAlterUserQuery
340+ }
341+
342+ if len (permissions ) > 0 {
343+ queries = append (queries , permissionsToQuery (permissions , user , allowedHosts ))
344+ }
345+
346+ query := BuildAtomicQuery (queries ... )
347+
348+ if err := s .QueryExec (query ); err != nil {
349+ return fmt .Errorf ("failed to configure user (user/pass/access), err: %s" , err )
350+ }
351+
352+ return nil
353+ }
354+
355+ func getCreateUserQuery (user , pwd string , allowedHosts []string ) Query {
356+ idsTmpl , idsArgs := getUsersIdentification (user , & pwd , allowedHosts )
357+
358+ return NewQuery (fmt .Sprintf ("CREATE USER IF NOT EXISTS%s" , idsTmpl ), idsArgs ... )
359+ }
360+
361+ func getUsersIdentification (user string , pwd * string , allowedHosts []string ) (ids string , args []interface {}) {
362+ for i , host := range allowedHosts {
363+ // Add comma if more than one allowed hosts are used.
364+ if i > 0 {
365+ ids += ","
366+ }
367+
368+ if pwd != nil {
369+ ids += " ?@? IDENTIFIED BY ?"
370+ args = append (args , user , host , * pwd )
371+ } else {
372+ ids += " ?@?"
373+ args = append (args , user , host )
374+ }
375+ }
376+
377+ return ids , args
378+ }
379+
380+ // DropUser removes a MySQL user if it exists, along with its privileges.
381+ func (s sqlRunner ) DropUser (user , host string ) error {
382+ query := NewQuery ("DROP USER IF EXISTS ?@?;" , user , host )
383+
384+ if err := s .QueryExec (query ); err != nil {
385+ return fmt .Errorf ("failed to delete user, err: %s" , err )
386+ }
387+
388+ return nil
389+ }
390+
391+ func permissionsToQuery (permissions []apiv1alpha1.UserPermission , user string , allowedHosts []string ) Query {
392+ permQueries := []Query {}
393+
394+ for _ , perm := range permissions {
395+ // If you wish to grant permissions on all tables, you should explicitly use "*".
396+ for _ , table := range perm .Tables {
397+ args := []interface {}{}
398+
399+ escPerms := []string {}
400+ for _ , perm := range perm .Privileges {
401+ escPerms = append (escPerms , Escape (perm ))
402+ }
403+
404+ schemaTable := fmt .Sprintf ("%s.%s" , escapeID (perm .Database ), escapeID (table ))
405+
406+ // Build GRANT query.
407+ idsTmpl , idsArgs := getUsersIdentification (user , nil , allowedHosts )
408+
409+ query := "GRANT " + strings .Join (escPerms , ", " ) + " ON " + schemaTable + " TO" + idsTmpl
410+ args = append (args , idsArgs ... )
411+
412+ permQueries = append (permQueries , NewQuery (query , args ... ))
413+ }
414+ }
415+
416+ return ConcatenateQueries (permQueries ... )
417+ }
418+
419+ func escapeID (id string ) string {
420+ if id == "*" {
421+ return id
422+ }
423+
424+ // don't allow using ` in id name
425+ id = strings .ReplaceAll (id , "`" , "" )
426+
427+ return fmt .Sprintf ("`%s`" , id )
428+ }
429+
430+ // Escape escapes a string.
431+ func Escape (sql string ) string {
432+ dest := make ([]byte , 0 , 2 * len (sql ))
433+ var escape byte
434+ for i := 0 ; i < len (sql ); i ++ {
435+ escape = 0
436+ switch sql [i ] {
437+ case 0 : /* Must be escaped for 'mysql' */
438+ escape = '0'
439+ case '\n' : /* Must be escaped for logs */
440+ escape = 'n'
441+ case '\r' :
442+ escape = 'r'
443+ case '\\' :
444+ escape = '\\'
445+ case '\'' :
446+ escape = '\''
447+ case '"' : /* Better safe than sorry */
448+ escape = '"'
449+ case '\032' : /* This gives problems on Win32 */
450+ escape = 'Z'
451+ }
452+
453+ if escape != 0 {
454+ dest = append (dest , '\\' , escape )
455+ } else {
456+ dest = append (dest , sql [i ])
457+ }
458+ }
459+
460+ return string (dest )
461+ }
0 commit comments