Skip to content

Commit d51a238

Browse files
committed
Refactor SQL
1 parent 25a6658 commit d51a238

File tree

2 files changed

+93
-45
lines changed

2 files changed

+93
-45
lines changed

code.go

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ type DynamicSqlCodeLoader struct {
5454
}
5555

5656
func NewDefaultDynamicSqlCodeLoader(db *sql.DB, query string) *DynamicSqlCodeLoader {
57-
driver := GetDriver(db)
57+
driver := getDriver(db)
5858
return &DynamicSqlCodeLoader{DB: db, Query: query, ParameterCount: 0, Driver: driver}
5959
return NewDynamicSqlCodeLoader(db, query, 0, true)
6060
}
6161
func NewDynamicSqlCodeLoader(db *sql.DB, query string, parameterCount int, handleDriver bool) *DynamicSqlCodeLoader {
62-
driver := GetDriver(db)
62+
driver := getDriver(db)
6363
if parameterCount <= 0 {
6464
parameterCount = 1
6565
}
@@ -101,11 +101,11 @@ func (l DynamicSqlCodeLoader) Load(ctx context.Context, master string) ([]CodeMo
101101
// get list indexes column
102102
modelTypes := reflect.TypeOf(models).Elem()
103103
modelType := reflect.TypeOf(CodeModel{})
104-
indexes, er3 := GetColumnIndexes(modelType, columns, driver)
104+
indexes, er3 := getColumnIndexes(modelType, columns, driver)
105105
if er3 != nil {
106106
return models, er3
107107
}
108-
tb, er4 := ScanType(rows, modelTypes, indexes)
108+
tb, er4 := scanType(rows, modelTypes, indexes)
109109
if er4 != nil {
110110
return models, er4
111111
}
@@ -117,9 +117,19 @@ func (l DynamicSqlCodeLoader) Load(ctx context.Context, master string) ([]CodeMo
117117
return models, nil
118118
}
119119
func NewSqlCodeLoader(db *sql.DB, table string, config CodeConfig) *SqlCodeLoader {
120-
driver := GetDriver(db)
120+
driver := getDriver(db)
121121
return &SqlCodeLoader{DB: db, Table: table, Config: config, Driver: driver}
122122
}
123+
124+
func buildParam(i int) string {
125+
return "?"
126+
}
127+
func buildOracleParam(i int) string {
128+
return ":val" + strconv.Itoa(i)
129+
}
130+
func buildDollarParam(i int) string {
131+
return "$" + strconv.Itoa(i)
132+
}
123133
func (l SqlCodeLoader) Load(ctx context.Context, master string) ([]CodeModel, error) {
124134
models := make([]CodeModel, 0)
125135
s := make([]string, 0)
@@ -154,14 +164,14 @@ func (l SqlCodeLoader) Load(ctx context.Context, master string) ([]CodeModel, er
154164
p1 := ""
155165
i := 1
156166
if len(c.Master) > 0 {
157-
i = i + 1
158167
if l.Driver == DriverPostgres {
159-
p1 = fmt.Sprintf("%s = $1", c.Master)
168+
p1 = fmt.Sprintf("%s = $%d", c.Master, i)
160169
} else if l.Driver == DriverOracle {
161-
p1 = fmt.Sprintf("%s = :val1", c.Master)
170+
p1 = fmt.Sprintf("%s = :val%d", c.Master, i)
162171
} else {
163172
p1 = fmt.Sprintf("%s = ?", c.Master)
164173
}
174+
i = i + 1
165175
values = append(values, master)
166176
}
167177
cols := strings.Join(s, ",")
@@ -170,7 +180,7 @@ func (l SqlCodeLoader) Load(ctx context.Context, master string) ([]CodeModel, er
170180
if l.Driver == DriverPostgres {
171181
p2 = fmt.Sprintf("%s = $%d", c.Status, i)
172182
} else if l.Driver == DriverOracle {
173-
p1 = fmt.Sprintf("%s = :val%d", c.Status, i)
183+
p2 = fmt.Sprintf("%s = :val%d", c.Status, i)
174184
} else {
175185
p2 = fmt.Sprintf("%s = ?", c.Status)
176186
}
@@ -194,18 +204,6 @@ func (l SqlCodeLoader) Load(ctx context.Context, master string) ([]CodeModel, er
194204
}
195205
}
196206
if len(sql2) > 0 {
197-
if l.Driver == DriverOracle || l.Driver == DriverPostgres {
198-
var x string
199-
if l.Driver == DriverOracle {
200-
x = ":val"
201-
} else {
202-
x = "$"
203-
}
204-
for i := 0; i < len(values); i++ {
205-
count := i + 1
206-
sql2 = strings.Replace(sql2, "?", x+strconv.Itoa(count), 1)
207-
}
208-
}
209207
rows, err1 := l.DB.Query(sql2, values...)
210208
if err1 != nil {
211209
return nil, err1
@@ -218,11 +216,11 @@ func (l SqlCodeLoader) Load(ctx context.Context, master string) ([]CodeModel, er
218216
// get list indexes column
219217
modelTypes := reflect.TypeOf(models).Elem()
220218
modelType := reflect.TypeOf(CodeModel{})
221-
indexes, er2 := GetColumnIndexes(modelType, columns, GetDriver(l.DB))
219+
indexes, er2 := getColumnIndexes(modelType, columns, getDriver(l.DB))
222220
if er2 != nil {
223221
return nil, er2
224222
}
225-
tb, er3 := ScanType(rows, modelTypes, indexes)
223+
tb, er3 := scanType(rows, modelTypes, indexes)
226224
if er3 != nil {
227225
return nil, er3
228226
}
@@ -236,7 +234,7 @@ func (l SqlCodeLoader) Load(ctx context.Context, master string) ([]CodeModel, er
236234
}
237235

238236
// StructScan : transfer struct to slice for scan
239-
func StructScan(s interface{}, indexColumns []int) (r []interface{}) {
237+
func structScan(s interface{}, indexColumns []int) (r []interface{}) {
240238
if s != nil {
241239
maps := reflect.Indirect(reflect.ValueOf(s))
242240
for _, index := range indexColumns {
@@ -245,15 +243,14 @@ func StructScan(s interface{}, indexColumns []int) (r []interface{}) {
245243
}
246244
return
247245
}
248-
249-
func GetColumnIndexes(modelType reflect.Type, columnsName []string, driver string) (indexes []int, err error) {
246+
func getColumnIndexes(modelType reflect.Type, columnsName []string, driver string) (indexes []int, err error) {
250247
if modelType.Kind() != reflect.Struct {
251248
return nil, errors.New("bad type")
252249
}
253250
for i := 0; i < modelType.NumField(); i++ {
254251
field := modelType.Field(i)
255252
ormTag := field.Tag.Get("gorm")
256-
column, ok := FindTag(ormTag, "column")
253+
column, ok := findTag(ormTag, "column")
257254
if driver == DriverOracle {
258255
column = strings.ToUpper(column)
259256
}
@@ -266,7 +263,7 @@ func GetColumnIndexes(modelType reflect.Type, columnsName []string, driver strin
266263
return
267264
}
268265

269-
func FindTag(tag string, key string) (string, bool) {
266+
func findTag(tag string, key string) (string, bool) {
270267
if has := strings.Contains(tag, key); has {
271268
str1 := strings.Split(tag, ";")
272269
num := len(str1)
@@ -291,17 +288,17 @@ func contains(array []string, v string) bool {
291288
return false
292289
}
293290

294-
func ScanType(rows *sql.Rows, modelTypes reflect.Type, indexes []int) (t []interface{}, err error) {
291+
func scanType(rows *sql.Rows, modelTypes reflect.Type, indexes []int) (t []interface{}, err error) {
295292
for rows.Next() {
296293
initArray := reflect.New(modelTypes).Interface()
297-
if err = rows.Scan(StructScan(initArray, indexes)...); err == nil {
294+
if err = rows.Scan(structScan(initArray, indexes)...); err == nil {
298295
t = append(t, initArray)
299296
}
300297
}
301298
return
302299
}
303300

304-
func GetDriver(db *sql.DB) string {
301+
func getDriver(db *sql.DB) string {
305302
driver := reflect.TypeOf(db.Driver()).String()
306303
switch driver {
307304
case "*pq.Driver":

handler.go

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,67 @@ import (
1010

1111
const internalServerError = "Internal Server Error"
1212

13+
type CodeHandlerConfig struct {
14+
Master *bool `mapstructure:"master" json:"master,omitempty" gorm:"column:master" bson:"master,omitempty" dynamodbav:"master,omitempty" firestore:"master,omitempty"`
15+
Id string `mapstructure:"id" json:"id,omitempty" gorm:"column:id" bson:"id,omitempty" dynamodbav:"id,omitempty" firestore:"id,omitempty"`
16+
Name string `mapstructure:"name" json:"name,omitempty" gorm:"column:name" bson:"name,omitempty" dynamodbav:"name,omitempty" firestore:"name,omitempty"`
17+
Resource string `mapstructure:"resource" json:"resource,omitempty" gorm:"column:resource" bson:"resource,omitempty" dynamodbav:"resource,omitempty" firestore:"resource,omitempty"`
18+
Action string `mapstructure:"action" json:"action,omitempty" gorm:"column:action" bson:"action,omitempty" dynamodbav:"action,omitempty" firestore:"action,omitempty"`
19+
}
1320
type CodeHandler struct {
14-
Loader CodeLoader
21+
Codes func(ctx context.Context, master string) ([]CodeModel, error)
22+
RequiredMaster bool
23+
Error func(context.Context, string)
24+
Log func(ctx context.Context, resource string, action string, success bool, desc string) error
1525
Resource string
1626
Action string
17-
RequiredMaster bool
18-
LogError func(context.Context, string)
19-
WriteLog func(ctx context.Context, resource string, action string, success bool, desc string) error
27+
Id string
28+
Name string
2029
}
2130

22-
func NewDefaultCodeHandler(loader CodeLoader, resource string, action string, logError func(context.Context, string), writeLog func(context.Context, string, string, bool, string) error) *CodeHandler {
23-
return NewCodeHandler(loader, resource, action, true, logError, writeLog)
31+
func NewDefaultCodeHandler(load func(ctx context.Context, master string) ([]CodeModel, error), logError func(context.Context, string), options ...func(context.Context, string, string, bool, string) error) *CodeHandler {
32+
var writeLog func(context.Context, string, string, bool, string) error
33+
if len(options) >= 1 {
34+
writeLog = options[0]
35+
}
36+
return NewCodeHandlerWithLog(load, logError, true, writeLog, "", "")
37+
}
38+
func NewCodeHandlerByConfig(load func(ctx context.Context, master string) ([]CodeModel, error), c CodeHandlerConfig, logError func(context.Context, string), options ...func(context.Context, string, string, bool, string) error) *CodeHandler {
39+
var requireMaster bool
40+
if c.Master != nil {
41+
requireMaster = *c.Master
42+
} else {
43+
requireMaster = true
44+
}
45+
var writeLog func(context.Context, string, string, bool, string) error
46+
if len(options) >= 1 {
47+
writeLog = options[0]
48+
}
49+
h := NewCodeHandlerWithLog(load, logError, requireMaster, writeLog, c.Resource, c.Action)
50+
h.Id = c.Id
51+
h.Name = c.Name
52+
return h
2453
}
25-
func NewCodeHandler(loader CodeLoader, resource string, action string, requiredMaster bool, logError func(context.Context, string), writeLog func(context.Context, string, string, bool, string) error) *CodeHandler {
26-
if len(resource) == 0 {
54+
func NewCodeHandler(load func(ctx context.Context, master string) ([]CodeModel, error), logError func(context.Context, string), requiredMaster bool, options ...func(context.Context, string, string, bool, string) error) *CodeHandler {
55+
var writeLog func(context.Context, string, string, bool, string) error
56+
if len(options) >= 1 {
57+
writeLog = options[0]
58+
}
59+
return NewCodeHandlerWithLog(load, logError, requiredMaster, writeLog, "", "")
60+
}
61+
func NewCodeHandlerWithLog(load func(ctx context.Context, master string) ([]CodeModel, error), logError func(context.Context, string), requiredMaster bool, writeLog func(context.Context, string, string, bool, string) error, options ...string) *CodeHandler {
62+
var resource, action string
63+
if len(options) >= 1 && len(options[0]) > 0 {
64+
resource = options[0]
65+
} else {
2766
resource = "code"
2867
}
29-
if len(action) == 0 {
68+
if len(options) >= 2 && len(options[1]) > 0 {
69+
action = options[1]
70+
} else {
3071
action = "load"
3172
}
32-
h := CodeHandler{Loader: loader, Resource: resource, Action: action, RequiredMaster: requiredMaster, WriteLog: writeLog, LogError: logError}
73+
h := CodeHandler{Codes: load, Resource: resource, Action: action, RequiredMaster: requiredMaster, Log: writeLog, Error: logError}
3374
return &h
3475
}
3576
func (c *CodeHandler) Load(w http.ResponseWriter, r *http.Request) {
@@ -49,14 +90,24 @@ func (c *CodeHandler) Load(w http.ResponseWriter, r *http.Request) {
4990
code = strings.Trim(string(b), " ")
5091
}
5192
}
52-
result, er4 := c.Loader.Load(r.Context(), code)
93+
result, er4 := c.Codes(r.Context(), code)
5394
if er4 != nil {
54-
respondError(w, r, http.StatusInternalServerError, internalServerError, c.LogError, c.Resource, c.Action, er4, c.WriteLog)
95+
respondError(w, r, http.StatusInternalServerError, internalServerError, c.Error, c.Resource, c.Action, er4, c.Log)
5596
} else {
56-
succeed(w, r, http.StatusOK, result, c.WriteLog, c.Resource, c.Action)
97+
if len(c.Id) == 0 && len(c.Name) == 0 {
98+
succeed(w, r, http.StatusOK, result, c.Log, c.Resource, c.Action)
99+
} else {
100+
rs := make([]map[string]string, 0)
101+
for _, r := range result {
102+
m := make(map[string]string)
103+
m[c.Id] = r.Id
104+
m[c.Name] = r.Name
105+
rs = append(rs, m)
106+
}
107+
succeed(w, r, http.StatusOK, rs, c.Log, c.Resource, c.Action)
108+
}
57109
}
58110
}
59-
60111
func respondString(w http.ResponseWriter, r *http.Request, code int, result string) {
61112
w.WriteHeader(code)
62113
w.Write([]byte(result))

0 commit comments

Comments
 (0)