Skip to content

Commit c8ba5b5

Browse files
committed
Refactor driver
1 parent d51a238 commit c8ba5b5

File tree

1 file changed

+16
-19
lines changed

1 file changed

+16
-19
lines changed

code.go

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ type SqlCodeLoader struct {
4444
DB *sql.DB
4545
Table string
4646
Config CodeConfig
47-
Driver string
47+
Build func(i int) string
4848
}
4949
type DynamicSqlCodeLoader struct {
5050
DB *sql.DB
@@ -117,8 +117,8 @@ func (l DynamicSqlCodeLoader) Load(ctx context.Context, master string) ([]CodeMo
117117
return models, nil
118118
}
119119
func NewSqlCodeLoader(db *sql.DB, table string, config CodeConfig) *SqlCodeLoader {
120-
driver := getDriver(db)
121-
return &SqlCodeLoader{DB: db, Table: table, Config: config, Driver: driver}
120+
build := getBuild(db)
121+
return &SqlCodeLoader{DB: db, Table: table, Config: config, Build: build}
122122
}
123123

124124
func buildParam(i int) string {
@@ -164,26 +164,13 @@ func (l SqlCodeLoader) Load(ctx context.Context, master string) ([]CodeModel, er
164164
p1 := ""
165165
i := 1
166166
if len(c.Master) > 0 {
167-
if l.Driver == DriverPostgres {
168-
p1 = fmt.Sprintf("%s = $%d", c.Master, i)
169-
} else if l.Driver == DriverOracle {
170-
p1 = fmt.Sprintf("%s = :val%d", c.Master, i)
171-
} else {
172-
p1 = fmt.Sprintf("%s = ?", c.Master)
173-
}
167+
p1 = fmt.Sprintf("%s = %s", c.Master, l.Build(i))
174168
i = i + 1
175169
values = append(values, master)
176170
}
177171
cols := strings.Join(s, ",")
178172
if len(c.Status) > 0 && c.Active != nil {
179-
p2 := ""
180-
if l.Driver == DriverPostgres {
181-
p2 = fmt.Sprintf("%s = $%d", c.Status, i)
182-
} else if l.Driver == DriverOracle {
183-
p2 = fmt.Sprintf("%s = :val%d", c.Status, i)
184-
} else {
185-
p2 = fmt.Sprintf("%s = ?", c.Status)
186-
}
173+
p2 := fmt.Sprintf("%s = %s", c.Status, l.Build(i))
187174
values = append(values, c.Active)
188175
if cols == "" {
189176
cols = "*"
@@ -297,7 +284,17 @@ func scanType(rows *sql.Rows, modelTypes reflect.Type, indexes []int) (t []inter
297284
}
298285
return
299286
}
300-
287+
func getBuild(db *sql.DB) func(i int) string {
288+
driver := reflect.TypeOf(db.Driver()).String()
289+
switch driver {
290+
case "*pq.Driver":
291+
return buildDollarParam
292+
case "*godror.drv":
293+
return buildOracleParam
294+
default:
295+
return buildParam
296+
}
297+
}
301298
func getDriver(db *sql.DB) string {
302299
driver := reflect.TypeOf(db.Driver()).String()
303300
switch driver {

0 commit comments

Comments
 (0)