Skip to content

Commit 53ef02d

Browse files
committed
Use multistmt package to parse multi-statement migrations for clickhouse, cassandra, and neo4j
Addresses: #406
1 parent 6204921 commit 53ef02d

File tree

3 files changed

+119
-57
lines changed

3 files changed

+119
-57
lines changed

database/cassandra/cassandra.go

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"github.com/gocql/gocql"
1414
"github.com/golang-migrate/migrate/v4/database"
15+
"github.com/golang-migrate/migrate/v4/database/multistmt"
1516
"github.com/hashicorp/go-multierror"
1617
)
1718

@@ -20,6 +21,12 @@ func init() {
2021
database.Register("cassandra", db)
2122
}
2223

24+
var (
25+
multiStmtDelimiter = []byte(";")
26+
27+
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
28+
)
29+
2330
var DefaultMigrationsTable = "schema_migrations"
2431

2532
var (
@@ -33,6 +40,7 @@ type Config struct {
3340
MigrationsTable string
3441
KeyspaceName string
3542
MultiStatementEnabled bool
43+
MultiStatementMaxSize int
3644
}
3745

3846
type Cassandra struct {
@@ -58,6 +66,10 @@ func WithInstance(session *gocql.Session, config *Config) (database.Driver, erro
5866
config.MigrationsTable = DefaultMigrationsTable
5967
}
6068

69+
if config.MultiStatementMaxSize <= 0 {
70+
config.MultiStatementMaxSize = DefaultMultiStatementMaxSize
71+
}
72+
6173
c := &Cassandra{
6274
session: session,
6375
config: config,
@@ -148,10 +160,19 @@ func (c *Cassandra) Open(url string) (database.Driver, error) {
148160
return nil, err
149161
}
150162

163+
multiStatementMaxSize := DefaultMultiStatementMaxSize
164+
if s := u.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
165+
multiStatementMaxSize, err = strconv.Atoi(s)
166+
if err != nil {
167+
return nil, err
168+
}
169+
}
170+
151171
return WithInstance(session, &Config{
152172
KeyspaceName: strings.TrimPrefix(u.Path, "/"),
153173
MigrationsTable: u.Query().Get("x-migrations-table"),
154174
MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true",
175+
MultiStatementMaxSize: multiStatementMaxSize,
155176
})
156177
}
157178

@@ -174,31 +195,30 @@ func (c *Cassandra) Unlock() error {
174195
}
175196

176197
func (c *Cassandra) Run(migration io.Reader) error {
177-
migr, err := ioutil.ReadAll(migration)
178-
if err != nil {
179-
return err
180-
}
181-
// run migration
182-
query := string(migr[:])
183-
184198
if c.config.MultiStatementEnabled {
185-
// split query by semi-colon
186-
queries := strings.Split(query, ";")
187-
188-
for _, q := range queries {
189-
tq := strings.TrimSpace(q)
199+
var err error
200+
if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool {
201+
tq := strings.TrimSpace(string(m))
190202
if tq == "" {
191-
continue
203+
return true
192204
}
193-
if err := c.session.Query(tq).Exec(); err != nil {
194-
// TODO: cast to Cassandra error and get line number
195-
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
205+
if e := c.session.Query(tq).Exec(); e != nil {
206+
err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
207+
return false
196208
}
209+
return true
210+
}); e != nil {
211+
return e
197212
}
198-
return nil
213+
return err
199214
}
200215

201-
if err := c.session.Query(query).Exec(); err != nil {
216+
migr, err := ioutil.ReadAll(migration)
217+
if err != nil {
218+
return err
219+
}
220+
// run migration
221+
if err := c.session.Query(string(migr)).Exec(); err != nil {
202222
// TODO: cast to Cassandra error and get line number
203223
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
204224
}

database/clickhouse/clickhouse.go

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,30 @@ import (
66
"io"
77
"io/ioutil"
88
"net/url"
9+
"strconv"
910
"strings"
1011
"time"
1112

1213
"github.com/golang-migrate/migrate/v4"
1314
"github.com/golang-migrate/migrate/v4/database"
15+
"github.com/golang-migrate/migrate/v4/database/multistmt"
1416
"github.com/hashicorp/go-multierror"
1517
)
1618

17-
var DefaultMigrationsTable = "schema_migrations"
19+
var (
20+
multiStmtDelimiter = []byte(";")
1821

19-
var ErrNilConfig = fmt.Errorf("no config")
22+
DefaultMigrationsTable = "schema_migrations"
23+
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
24+
25+
ErrNilConfig = fmt.Errorf("no config")
26+
)
2027

2128
type Config struct {
2229
DatabaseName string
2330
MigrationsTable string
2431
MultiStatementEnabled bool
32+
MultiStatementMaxSize int
2533
}
2634

2735
func init() {
@@ -66,12 +74,21 @@ func (ch *ClickHouse) Open(dsn string) (database.Driver, error) {
6674
return nil, err
6775
}
6876

77+
multiStatementMaxSize := DefaultMultiStatementMaxSize
78+
if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
79+
multiStatementMaxSize, err = strconv.Atoi(s)
80+
if err != nil {
81+
return nil, err
82+
}
83+
}
84+
6985
ch = &ClickHouse{
7086
conn: conn,
7187
config: &Config{
7288
MigrationsTable: purl.Query().Get("x-migrations-table"),
7389
DatabaseName: purl.Query().Get("database"),
7490
MultiStatementEnabled: purl.Query().Get("x-multi-statement") == "true",
91+
MultiStatementMaxSize: multiStatementMaxSize,
7592
},
7693
}
7794

@@ -93,28 +110,35 @@ func (ch *ClickHouse) init() error {
93110
ch.config.MigrationsTable = DefaultMigrationsTable
94111
}
95112

113+
if ch.config.MultiStatementMaxSize <= 0 {
114+
ch.config.MultiStatementMaxSize = DefaultMultiStatementMaxSize
115+
}
116+
96117
return ch.ensureVersionTable()
97118
}
98119

99120
func (ch *ClickHouse) Run(r io.Reader) error {
100-
migration, err := ioutil.ReadAll(r)
101-
if err != nil {
102-
return err
103-
}
104-
105121
if ch.config.MultiStatementEnabled {
106-
// split query by semi-colon
107-
queries := strings.Split(string(migration), ";")
108-
for _, q := range queries {
109-
tq := strings.TrimSpace(q)
122+
var err error
123+
if e := multistmt.Parse(r, multiStmtDelimiter, ch.config.MultiStatementMaxSize, func(m []byte) bool {
124+
tq := strings.TrimSpace(string(m))
110125
if tq == "" {
111-
continue
126+
return true
112127
}
113-
if _, err := ch.conn.Exec(q); err != nil {
114-
return database.Error{OrigErr: err, Err: "migration failed", Query: []byte(q)}
128+
if _, e := ch.conn.Exec(string(m)); e != nil {
129+
err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
130+
return false
115131
}
132+
return true
133+
}); e != nil {
134+
return e
116135
}
117-
return nil
136+
return err
137+
}
138+
139+
migration, err := ioutil.ReadAll(r)
140+
if err != nil {
141+
return err
118142
}
119143

120144
if _, err := ch.conn.Exec(string(migration)); err != nil {

database/neo4j/neo4j.go

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"sync/atomic"
1212

1313
"github.com/golang-migrate/migrate/v4/database"
14+
"github.com/golang-migrate/migrate/v4/database/multistmt"
1415
"github.com/hashicorp/go-multierror"
1516
"github.com/neo4j/neo4j-go-driver/neo4j"
1617
)
@@ -22,15 +23,19 @@ func init() {
2223

2324
const DefaultMigrationsLabel = "SchemaMigration"
2425

25-
var StatementSeparator = []byte(";")
26+
var (
27+
StatementSeparator = []byte(";")
28+
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
29+
)
2630

2731
var (
2832
ErrNilConfig = fmt.Errorf("no config")
2933
)
3034

3135
type Config struct {
32-
MigrationsLabel string
33-
MultiStatement bool
36+
MigrationsLabel string
37+
MultiStatement bool
38+
MultiStatementMaxSize int
3439
}
3540

3641
type Neo4j struct {
@@ -87,6 +92,14 @@ func (n *Neo4j) Open(url string) (database.Driver, error) {
8792
}
8893
}
8994

95+
multiStatementMaxSize := DefaultMultiStatementMaxSize
96+
if s := uri.Query().Get("x-multi-statement-max-size"); s != "" {
97+
multiStatementMaxSize, err = strconv.Atoi(s)
98+
if err != nil {
99+
return nil, err
100+
}
101+
}
102+
90103
uri.RawQuery = ""
91104

92105
driver, err := neo4j.NewDriver(uri.String(), authToken, func(config *neo4j.Config) {
@@ -97,8 +110,9 @@ func (n *Neo4j) Open(url string) (database.Driver, error) {
97110
}
98111

99112
return WithInstance(driver, &Config{
100-
MigrationsLabel: DefaultMigrationsLabel,
101-
MultiStatement: multi,
113+
MigrationsLabel: DefaultMigrationsLabel,
114+
MultiStatement: multi,
115+
MultiStatementMaxSize: multiStatementMaxSize,
102116
})
103117
}
104118

@@ -123,11 +137,6 @@ func (n *Neo4j) Unlock() error {
123137
}
124138

125139
func (n *Neo4j) Run(migration io.Reader) (err error) {
126-
body, err := ioutil.ReadAll(migration)
127-
if err != nil {
128-
return err
129-
}
130-
131140
session, err := n.driver.Session(neo4j.AccessModeWrite)
132141
if err != nil {
133142
return err
@@ -139,30 +148,39 @@ func (n *Neo4j) Run(migration io.Reader) (err error) {
139148
}()
140149

141150
if n.config.MultiStatement {
142-
statements := bytes.Split(body, StatementSeparator)
143151
_, err = session.WriteTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
144-
for _, stmt := range statements {
152+
var stmtRunErr error
153+
if err := multistmt.Parse(migration, StatementSeparator, n.config.MultiStatementMaxSize, func(stmt []byte) bool {
145154
trimStmt := bytes.TrimSpace(stmt)
146155
if len(trimStmt) == 0 {
147-
continue
156+
return true
148157
}
149-
result, err := transaction.Run(string(trimStmt[:]), nil)
158+
trimStmt = bytes.TrimSuffix(trimStmt, StatementSeparator)
159+
if len(trimStmt) == 0 {
160+
return true
161+
}
162+
163+
result, err := transaction.Run(string(trimStmt), nil)
150164
if _, err := neo4j.Collect(result, err); err != nil {
151-
return nil, err
165+
stmtRunErr = err
166+
return false
152167
}
168+
return true
169+
}); err != nil {
170+
return nil, err
153171
}
154-
return nil, nil
172+
return nil, stmtRunErr
155173
})
156-
if err != nil {
157-
return err
158-
}
159-
} else {
160-
if _, err := neo4j.Collect(session.Run(string(body[:]), nil)); err != nil {
161-
return err
162-
}
174+
return err
163175
}
164176

165-
return nil
177+
body, err := ioutil.ReadAll(migration)
178+
if err != nil {
179+
return err
180+
}
181+
182+
_, err = neo4j.Collect(session.Run(string(body[:]), nil))
183+
return err
166184
}
167185

168186
func (n *Neo4j) SetVersion(version int, dirty bool) (err error) {

0 commit comments

Comments
 (0)