Skip to content

Commit 8b55b5e

Browse files
committed
*: add user management related operations.
1 parent fd644d7 commit 8b55b5e

File tree

3 files changed

+166
-2
lines changed

3 files changed

+166
-2
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ require (
77
github.com/blang/semver v3.5.1+incompatible
88
github.com/go-ini/ini v1.62.0
99
github.com/go-sql-driver/mysql v1.6.0
10+
github.com/go-test/deep v1.0.7 // indirect
1011
github.com/iancoleman/strcase v0.0.0-20190422225806-e506e3ef7365
1112
github.com/imdario/mergo v0.3.12
1213
github.com/onsi/ginkgo v1.16.4

internal/query.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ limitations under the License.
1717
package internal
1818

1919
import (
20-
"strings"
2120
"errors"
21+
"strings"
2222
)
2323

2424
// Query contains a escaped query string with variables marked with a question mark (?) and a slice
@@ -53,3 +53,30 @@ func NewQuery(q string, args ...interface{}) Query {
5353
args: args,
5454
}
5555
}
56+
57+
// ConcatenateQueries concatenates the provided queries into a single query.
58+
func ConcatenateQueries(queries ...Query) Query {
59+
args := []interface{}{}
60+
query := ""
61+
62+
for _, pq := range queries {
63+
if query != "" {
64+
if !strings.HasSuffix(query, "\n") {
65+
query += "\n"
66+
}
67+
}
68+
69+
query += pq.escapedQuery
70+
args = append(args, pq.args...)
71+
}
72+
73+
return NewQuery(query, args...)
74+
}
75+
76+
// BuildAtomicQuery concatenates the provided queries into a single query wrapped in a BEGIN COMMIT block.
77+
func BuildAtomicQuery(queries ...Query) Query {
78+
queries = append([]Query{NewQuery("BEGIN")}, queries...)
79+
queries = append(queries, NewQuery("COMMIT"))
80+
81+
return ConcatenateQueries(queries...)
82+
}

internal/sql_runner.go

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package internal
1919
import (
2020
"context"
2121
"database/sql"
22+
"errors"
2223
"fmt"
2324
"strconv"
2425
"strings"
@@ -84,7 +85,6 @@ func NewConfigFromClusterKey(c client.Client, clusterKey client.ObjectKey, userN
8485
Port: utils.MysqlPort,
8586
}, nil
8687

87-
8888
case utils.RootUser:
8989
password, ok := secret.Data["root-password"]
9090
if !ok {
@@ -323,3 +323,139 @@ func columnValue(scanArgs []interface{}, slaveCols []string, colName string) str
323323

324324
return string(*scanArgs[columnIndex].(*sql.RawBytes))
325325
}
326+
327+
// CreateUserIfNotExists creates a user if it doesn't already exist and it gives it the specified permissions.
328+
func (s sqlRunner) CreateUserIfNotExists(
329+
user, pass string, allowedHosts []string, permissions []apiv1alpha1.UserPermission,
330+
) error {
331+
332+
// Throw error if there are no allowed hosts.
333+
if len(allowedHosts) == 0 {
334+
return errors.New("no allowedHosts specified")
335+
}
336+
337+
queries := []Query{
338+
getCreateUserQuery(user, pass, allowedHosts),
339+
// todo: getAlterUserQuery
340+
}
341+
342+
if len(permissions) > 0 {
343+
queries = append(queries, permissionsToQuery(permissions, user, allowedHosts))
344+
}
345+
346+
query := BuildAtomicQuery(queries...)
347+
348+
if err := s.QueryExec(query); err != nil {
349+
return fmt.Errorf("failed to configure user (user/pass/access), err: %s", err)
350+
}
351+
352+
return nil
353+
}
354+
355+
func getCreateUserQuery(user, pwd string, allowedHosts []string) Query {
356+
idsTmpl, idsArgs := getUsersIdentification(user, &pwd, allowedHosts)
357+
358+
return NewQuery(fmt.Sprintf("CREATE USER IF NOT EXISTS%s", idsTmpl), idsArgs...)
359+
}
360+
361+
func getUsersIdentification(user string, pwd *string, allowedHosts []string) (ids string, args []interface{}) {
362+
for i, host := range allowedHosts {
363+
// Add comma if more than one allowed hosts are used.
364+
if i > 0 {
365+
ids += ","
366+
}
367+
368+
if pwd != nil {
369+
ids += " ?@? IDENTIFIED BY ?"
370+
args = append(args, user, host, *pwd)
371+
} else {
372+
ids += " ?@?"
373+
args = append(args, user, host)
374+
}
375+
}
376+
377+
return ids, args
378+
}
379+
380+
// DropUser removes a MySQL user if it exists, along with its privileges.
381+
func (s sqlRunner) DropUser(user, host string) error {
382+
query := NewQuery("DROP USER IF EXISTS ?@?;", user, host)
383+
384+
if err := s.QueryExec(query); err != nil {
385+
return fmt.Errorf("failed to delete user, err: %s", err)
386+
}
387+
388+
return nil
389+
}
390+
391+
func permissionsToQuery(permissions []apiv1alpha1.UserPermission, user string, allowedHosts []string) Query {
392+
permQueries := []Query{}
393+
394+
for _, perm := range permissions {
395+
// If you wish to grant permissions on all tables, you should explicitly use "*".
396+
for _, table := range perm.Tables {
397+
args := []interface{}{}
398+
399+
escPerms := []string{}
400+
for _, perm := range perm.Privileges {
401+
escPerms = append(escPerms, Escape(perm))
402+
}
403+
404+
schemaTable := fmt.Sprintf("%s.%s", escapeID(perm.Database), escapeID(table))
405+
406+
// Build GRANT query.
407+
idsTmpl, idsArgs := getUsersIdentification(user, nil, allowedHosts)
408+
409+
query := "GRANT " + strings.Join(escPerms, ", ") + " ON " + schemaTable + " TO" + idsTmpl
410+
args = append(args, idsArgs...)
411+
412+
permQueries = append(permQueries, NewQuery(query, args...))
413+
}
414+
}
415+
416+
return ConcatenateQueries(permQueries...)
417+
}
418+
419+
func escapeID(id string) string {
420+
if id == "*" {
421+
return id
422+
}
423+
424+
// don't allow using ` in id name
425+
id = strings.ReplaceAll(id, "`", "")
426+
427+
return fmt.Sprintf("`%s`", id)
428+
}
429+
430+
// Escape escapes a string.
431+
func Escape(sql string) string {
432+
dest := make([]byte, 0, 2*len(sql))
433+
var escape byte
434+
for i := 0; i < len(sql); i++ {
435+
escape = 0
436+
switch sql[i] {
437+
case 0: /* Must be escaped for 'mysql' */
438+
escape = '0'
439+
case '\n': /* Must be escaped for logs */
440+
escape = 'n'
441+
case '\r':
442+
escape = 'r'
443+
case '\\':
444+
escape = '\\'
445+
case '\'':
446+
escape = '\''
447+
case '"': /* Better safe than sorry */
448+
escape = '"'
449+
case '\032': /* This gives problems on Win32 */
450+
escape = 'Z'
451+
}
452+
453+
if escape != 0 {
454+
dest = append(dest, '\\', escape)
455+
} else {
456+
dest = append(dest, sql[i])
457+
}
458+
}
459+
460+
return string(dest)
461+
}

0 commit comments

Comments
 (0)