Skip to content

Commit 944204b

Browse files
committed
Refactor driver
1 parent c8ba5b5 commit 944204b

File tree

1 file changed

+43
-25
lines changed

1 file changed

+43
-25
lines changed

code.go

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,39 @@ type SqlCodeLoader struct {
4545
Table string
4646
Config CodeConfig
4747
Build func(i int) string
48+
Map func(col string) string
4849
}
4950
type DynamicSqlCodeLoader struct {
5051
DB *sql.DB
5152
Query string
5253
ParameterCount int
53-
Driver string
54+
Build func(string) string
5455
}
5556

56-
func NewDefaultDynamicSqlCodeLoader(db *sql.DB, query string) *DynamicSqlCodeLoader {
57-
driver := getDriver(db)
58-
return &DynamicSqlCodeLoader{DB: db, Query: query, ParameterCount: 0, Driver: driver}
59-
return NewDynamicSqlCodeLoader(db, query, 0, true)
57+
func NewDefaultDynamicSqlCodeLoader(db *sql.DB, query string, options...int) *DynamicSqlCodeLoader {
58+
var parameterCount int
59+
if len(options) >= 1 && options[0] > 0 {
60+
parameterCount = options[0]
61+
} else {
62+
parameterCount = 0
63+
}
64+
return NewDynamicSqlCodeLoader(db, query, parameterCount, true)
6065
}
61-
func NewDynamicSqlCodeLoader(db *sql.DB, query string, parameterCount int, handleDriver bool) *DynamicSqlCodeLoader {
66+
func NewDynamicSqlCodeLoader(db *sql.DB, query string, parameterCount int, options...bool) *DynamicSqlCodeLoader {
6267
driver := getDriver(db)
68+
var build func(string) string
69+
if driver == DriverOracle {
70+
build = strings.ToUpper
71+
}
6372
if parameterCount <= 0 {
6473
parameterCount = 1
6574
}
75+
var handleDriver bool
76+
if len(options) >= 1 {
77+
handleDriver = options[0]
78+
} else {
79+
handleDriver = true
80+
}
6681
if handleDriver {
6782
if driver == DriverOracle || driver == DriverPostgres {
6883
var x string
@@ -77,7 +92,7 @@ func NewDynamicSqlCodeLoader(db *sql.DB, query string, parameterCount int, handl
7792
}
7893
}
7994
}
80-
return &DynamicSqlCodeLoader{DB: db, Query: query, ParameterCount: parameterCount, Driver: driver}
95+
return &DynamicSqlCodeLoader{DB: db, Query: query, ParameterCount: parameterCount, Build: build}
8196
}
8297
func (l DynamicSqlCodeLoader) Load(ctx context.Context, master string) ([]CodeModel, error) {
8398
models := make([]CodeModel, 0)
@@ -88,7 +103,6 @@ func (l DynamicSqlCodeLoader) Load(ctx context.Context, master string) ([]CodeMo
88103
params = append(params, master)
89104
}
90105
}
91-
driver := l.Driver
92106
rows, er1 := l.DB.Query(l.Query, params...)
93107
if er1 != nil {
94108
return models, er1
@@ -101,7 +115,7 @@ func (l DynamicSqlCodeLoader) Load(ctx context.Context, master string) ([]CodeMo
101115
// get list indexes column
102116
modelTypes := reflect.TypeOf(models).Elem()
103117
modelType := reflect.TypeOf(CodeModel{})
104-
indexes, er3 := getColumnIndexes(modelType, columns, driver)
118+
indexes, er3 := getColumnIndexes(modelType, columns, l.Build)
105119
if er3 != nil {
106120
return models, er3
107121
}
@@ -118,17 +132,12 @@ func (l DynamicSqlCodeLoader) Load(ctx context.Context, master string) ([]CodeMo
118132
}
119133
func NewSqlCodeLoader(db *sql.DB, table string, config CodeConfig) *SqlCodeLoader {
120134
build := getBuild(db)
121-
return &SqlCodeLoader{DB: db, Table: table, Config: config, Build: build}
122-
}
123-
124-
func buildParam(i int) string {
125-
return "?"
126-
}
127-
func buildOracleParam(i int) string {
128-
return ":val" + strconv.Itoa(i)
129-
}
130-
func buildDollarParam(i int) string {
131-
return "$" + strconv.Itoa(i)
135+
driver := getDriver(db)
136+
var mp func(string)string
137+
if driver == DriverOracle {
138+
mp = strings.ToUpper
139+
}
140+
return &SqlCodeLoader{DB: db, Table: table, Config: config, Build: build, Map: mp}
132141
}
133142
func (l SqlCodeLoader) Load(ctx context.Context, master string) ([]CodeModel, error) {
134143
models := make([]CodeModel, 0)
@@ -203,7 +212,7 @@ func (l SqlCodeLoader) Load(ctx context.Context, master string) ([]CodeModel, er
203212
// get list indexes column
204213
modelTypes := reflect.TypeOf(models).Elem()
205214
modelType := reflect.TypeOf(CodeModel{})
206-
indexes, er2 := getColumnIndexes(modelType, columns, getDriver(l.DB))
215+
indexes, er2 := getColumnIndexes(modelType, columns, l.Map)
207216
if er2 != nil {
208217
return nil, er2
209218
}
@@ -230,16 +239,16 @@ func structScan(s interface{}, indexColumns []int) (r []interface{}) {
230239
}
231240
return
232241
}
233-
func getColumnIndexes(modelType reflect.Type, columnsName []string, driver string) (indexes []int, err error) {
242+
func getColumnIndexes(modelType reflect.Type, columnsName []string, build func(string) string) (indexes []int, err error) {
234243
if modelType.Kind() != reflect.Struct {
235244
return nil, errors.New("bad type")
236245
}
237246
for i := 0; i < modelType.NumField(); i++ {
238247
field := modelType.Field(i)
239248
ormTag := field.Tag.Get("gorm")
240249
column, ok := findTag(ormTag, "column")
241-
if driver == DriverOracle {
242-
column = strings.ToUpper(column)
250+
if build != nil {
251+
column = build(column)
243252
}
244253
if ok {
245254
if contains(columnsName, column) {
@@ -249,7 +258,6 @@ func getColumnIndexes(modelType reflect.Type, columnsName []string, driver strin
249258
}
250259
return
251260
}
252-
253261
func findTag(tag string, key string) (string, bool) {
254262
if has := strings.Contains(tag, key); has {
255263
str1 := strings.Split(tag, ";")
@@ -284,6 +292,16 @@ func scanType(rows *sql.Rows, modelTypes reflect.Type, indexes []int) (t []inter
284292
}
285293
return
286294
}
295+
296+
func buildParam(i int) string {
297+
return "?"
298+
}
299+
func buildOracleParam(i int) string {
300+
return ":val" + strconv.Itoa(i)
301+
}
302+
func buildDollarParam(i int) string {
303+
return "$" + strconv.Itoa(i)
304+
}
287305
func getBuild(db *sql.DB) func(i int) string {
288306
driver := reflect.TypeOf(db.Driver()).String()
289307
switch driver {

0 commit comments

Comments
 (0)