Skip to content

Commit 85dc586

Browse files
committed
Fix bug parameter count
1 parent 2bf496d commit 85dc586

File tree

1 file changed

+51
-46
lines changed

1 file changed

+51
-46
lines changed

code.go

Lines changed: 51 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,15 @@ type DynamicSqlCodeLoader struct {
5353
Query string
5454
ParameterCount int
5555
Map func(col string) string
56+
driver string
5657
}
5758

5859
func NewDefaultDynamicSqlCodeLoader(db *sql.DB, query string, options ...int) *DynamicSqlCodeLoader {
5960
var parameterCount int
60-
if len(options) >= 1 && options[0] > 0 {
61+
if len(options) > 0 {
6162
parameterCount = options[0]
6263
} else {
63-
parameterCount = 0
64+
parameterCount = 1
6465
}
6566
return NewDynamicSqlCodeLoader(db, query, parameterCount, true)
6667
}
@@ -70,7 +71,7 @@ func NewDynamicSqlCodeLoader(db *sql.DB, query string, parameterCount int, optio
7071
if driver == DriverOracle {
7172
mp = strings.ToUpper
7273
}
73-
if parameterCount <= 0 {
74+
if parameterCount < 0 {
7475
parameterCount = 1
7576
}
7677
var handleDriver bool
@@ -99,14 +100,19 @@ func NewDynamicSqlCodeLoader(db *sql.DB, query string, parameterCount int, optio
99100
}
100101
func (l DynamicSqlCodeLoader) Load(ctx context.Context, master string) ([]CodeModel, error) {
101102
models := make([]CodeModel, 0)
102-
params := make([]interface{}, 0)
103-
params = append(params, master)
104-
if l.ParameterCount > 1 {
105-
for i := 2; i <= l.ParameterCount; i++ {
103+
104+
var rows *sql.Rows
105+
var er1 error
106+
if l.ParameterCount > 0 {
107+
params := make([]interface{}, 0)
108+
for i := 1; i <= l.ParameterCount; i++ {
106109
params = append(params, master)
107110
}
111+
rows, er1 = l.DB.QueryContext(ctx, l.Query, params...)
112+
} else {
113+
rows, er1 = l.DB.QueryContext(ctx, l.Query)
108114
}
109-
rows, er1 := l.DB.QueryContext(ctx, l.Query, params...)
115+
110116
if er1 != nil {
111117
return models, er1
112118
}
@@ -116,13 +122,19 @@ func (l DynamicSqlCodeLoader) Load(ctx context.Context, master string) ([]CodeMo
116122
return models, er2
117123
}
118124
// get list indexes column
119-
modelTypes := reflect.TypeOf(models).Elem()
120125
modelType := reflect.TypeOf(CodeModel{})
121-
indexes, er3 := getColumnIndexes(modelType, columns, l.Map)
126+
127+
fieldsIndexSelected := make([]int, 0)
128+
fieldsIndex, er3 := getColumnIndexes(modelType, l.Map)
122129
if er3 != nil {
123130
return models, er3
124131
}
125-
tb, er4 := scanType(rows, modelTypes, indexes)
132+
for _, columnsName := range columns {
133+
if index, ok := fieldsIndex[columnsName]; ok {
134+
fieldsIndexSelected = append(fieldsIndexSelected, index)
135+
}
136+
}
137+
tb, er4 := scanType(rows, modelType, fieldsIndexSelected)
126138
if er4 != nil {
127139
return models, er4
128140
}
@@ -133,7 +145,7 @@ func (l DynamicSqlCodeLoader) Load(ctx context.Context, master string) ([]CodeMo
133145
}
134146
return models, nil
135147
}
136-
func NewSqlCodeLoader(db *sql.DB, table string, config CodeConfig, options...func(i int) string) *SqlCodeLoader {
148+
func NewSqlCodeLoader(db *sql.DB, table string, config CodeConfig, options ...func(i int) string) *SqlCodeLoader {
137149
var build func(i int) string
138150
if len(options) > 0 && options[0] != nil {
139151
build = options[0]
@@ -217,14 +229,19 @@ func (l SqlCodeLoader) Load(ctx context.Context, master string) ([]CodeModel, er
217229
if er1 != nil {
218230
return nil, er1
219231
}
220-
// get list indexes column
221-
modelTypes := reflect.TypeOf(models).Elem()
232+
fieldsIndexSelected := make([]int, 0)
222233
modelType := reflect.TypeOf(CodeModel{})
223-
indexes, er2 := getColumnIndexes(modelType, columns, l.Map)
224-
if er2 != nil {
225-
return nil, er2
234+
// get list indexes column
235+
fieldsIndex, er3 := getColumnIndexes(modelType, l.Map)
236+
if er3 != nil {
237+
return models, er3
238+
}
239+
for _, columnsName := range columns {
240+
if index, ok := fieldsIndex[columnsName]; ok {
241+
fieldsIndexSelected = append(fieldsIndexSelected, index)
242+
}
226243
}
227-
tb, er3 := scanType(rows, modelTypes, indexes)
244+
tb, er3 := scanType(rows, modelType, fieldsIndexSelected)
228245
if er3 != nil {
229246
return nil, er3
230247
}
@@ -237,7 +254,15 @@ func (l SqlCodeLoader) Load(ctx context.Context, master string) ([]CodeModel, er
237254
return models, nil
238255
}
239256

240-
// StructScan : transfer struct to slice for scan
257+
func scanType(rows *sql.Rows, modelType reflect.Type, indexes []int) (t []interface{}, err error) {
258+
for rows.Next() {
259+
initModel := reflect.New(modelType).Interface()
260+
if err = rows.Scan(structScan(initModel, indexes)...); err == nil {
261+
t = append(t, initModel)
262+
}
263+
}
264+
return
265+
}
241266
func structScan(s interface{}, indexColumns []int) (r []interface{}) {
242267
if s != nil {
243268
maps := reflect.Indirect(reflect.ValueOf(s))
@@ -247,24 +272,23 @@ func structScan(s interface{}, indexColumns []int) (r []interface{}) {
247272
}
248273
return
249274
}
250-
func getColumnIndexes(modelType reflect.Type, columnsName []string, build func(string) string) (indexes []int, err error) {
275+
func getColumnIndexes(modelType reflect.Type, mp func(col string) string) (map[string]int, error) {
276+
mapp := make(map[string]int, 0)
251277
if modelType.Kind() != reflect.Struct {
252-
return nil, errors.New("bad type")
278+
return mapp, errors.New("bad type")
253279
}
254280
for i := 0; i < modelType.NumField(); i++ {
255281
field := modelType.Field(i)
256282
ormTag := field.Tag.Get("gorm")
257283
column, ok := findTag(ormTag, "column")
258-
if build != nil {
259-
column = build(column)
260-
}
261284
if ok {
262-
if contains(columnsName, column) {
263-
indexes = append(indexes, i)
285+
if mp != nil {
286+
column = mp(column)
264287
}
288+
mapp[column] = i
265289
}
266290
}
267-
return
291+
return mapp, nil
268292
}
269293
func findTag(tag string, key string) (string, bool) {
270294
if has := strings.Contains(tag, key); has {
@@ -282,25 +306,6 @@ func findTag(tag string, key string) (string, bool) {
282306
return "", false
283307
}
284308

285-
func contains(array []string, v string) bool {
286-
for _, s := range array {
287-
if s == v {
288-
return true
289-
}
290-
}
291-
return false
292-
}
293-
294-
func scanType(rows *sql.Rows, modelTypes reflect.Type, indexes []int) (t []interface{}, err error) {
295-
for rows.Next() {
296-
initArray := reflect.New(modelTypes).Interface()
297-
if err = rows.Scan(structScan(initArray, indexes)...); err == nil {
298-
t = append(t, initArray)
299-
}
300-
}
301-
return
302-
}
303-
304309
func buildParam(i int) string {
305310
return "?"
306311
}

0 commit comments

Comments
 (0)