Skip to content

Commit 79396ff

Browse files
authored
Merge pull request #402 from abhinavcohesity/master
Adding support for schema management in snowflake
2 parents 7236e82 + b00a0cc commit 79396ff

File tree

5 files changed

+406
-3
lines changed

5 files changed

+406
-3
lines changed

database/snowflake/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Snowflake
2+
3+
`snowflake://user:password@accountname/schema/dbname?query`
4+
5+
| URL Query | WithInstance Config | Description |
6+
|------------|---------------------|-------------|
7+
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
8+
9+
Snowflake is PostgreSQL compatible but has some specific features (or lack thereof) that require slightly different behavior.
10+
Snowflake doesn't run locally hence there are no tests. The library works against hosted instances of snowflake.

database/snowflake/snowflake.go

Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
package snowflake
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"fmt"
7+
"io"
8+
"io/ioutil"
9+
nurl "net/url"
10+
"strconv"
11+
"strings"
12+
13+
"github.com/golang-migrate/migrate/v4/database"
14+
"github.com/hashicorp/go-multierror"
15+
"github.com/lib/pq"
16+
sf "github.com/snowflakedb/gosnowflake"
17+
)
18+
19+
func init() {
20+
db := Snowflake{}
21+
database.Register("snowflake", &db)
22+
}
23+
24+
var DefaultMigrationsTable = "schema_migrations"
25+
26+
var (
27+
ErrNilConfig = fmt.Errorf("no config")
28+
ErrNoDatabaseName = fmt.Errorf("no database name")
29+
ErrNoPassword = fmt.Errorf("no password")
30+
ErrNoSchema = fmt.Errorf("no schema")
31+
ErrNoSchemaOrDatabase = fmt.Errorf("no schema/database name")
32+
)
33+
34+
type Config struct {
35+
MigrationsTable string
36+
DatabaseName string
37+
}
38+
39+
type Snowflake struct {
40+
isLocked bool
41+
conn *sql.Conn
42+
db *sql.DB
43+
44+
// Open and WithInstance need to guarantee that config is never nil
45+
config *Config
46+
}
47+
48+
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
49+
if config == nil {
50+
return nil, ErrNilConfig
51+
}
52+
53+
if err := instance.Ping(); err != nil {
54+
return nil, err
55+
}
56+
57+
if config.DatabaseName == "" {
58+
query := `SELECT CURRENT_DATABASE()`
59+
var databaseName string
60+
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
61+
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
62+
}
63+
64+
if len(databaseName) == 0 {
65+
return nil, ErrNoDatabaseName
66+
}
67+
68+
config.DatabaseName = databaseName
69+
}
70+
71+
if len(config.MigrationsTable) == 0 {
72+
config.MigrationsTable = DefaultMigrationsTable
73+
}
74+
75+
conn, err := instance.Conn(context.Background())
76+
77+
if err != nil {
78+
return nil, err
79+
}
80+
81+
px := &Snowflake{
82+
conn: conn,
83+
db: instance,
84+
config: config,
85+
}
86+
87+
if err := px.ensureVersionTable(); err != nil {
88+
return nil, err
89+
}
90+
91+
return px, nil
92+
}
93+
94+
func (p *Snowflake) Open(url string) (database.Driver, error) {
95+
purl, err := nurl.Parse(url)
96+
if err != nil {
97+
return nil, err
98+
}
99+
100+
password, isPasswordSet := purl.User.Password()
101+
if !isPasswordSet {
102+
return nil, ErrNoPassword
103+
}
104+
105+
splitPath := strings.Split(purl.Path, "/")
106+
if len(splitPath) < 3 {
107+
return nil, ErrNoSchemaOrDatabase
108+
}
109+
110+
database := splitPath[2]
111+
if len(database) == 0 {
112+
return nil, ErrNoDatabaseName
113+
}
114+
115+
schema := splitPath[1]
116+
if len(schema) == 0 {
117+
return nil, ErrNoSchema
118+
}
119+
120+
cfg := &sf.Config{
121+
Account: purl.Host,
122+
User: purl.User.Username(),
123+
Password: password,
124+
Database: database,
125+
Schema: schema,
126+
}
127+
128+
dsn, err := sf.DSN(cfg)
129+
if err != nil {
130+
return nil, err
131+
}
132+
133+
db, err := sql.Open("snowflake", dsn)
134+
if err != nil {
135+
return nil, err
136+
}
137+
138+
migrationsTable := purl.Query().Get("x-migrations-table")
139+
140+
px, err := WithInstance(db, &Config{
141+
DatabaseName: database,
142+
MigrationsTable: migrationsTable,
143+
})
144+
if err != nil {
145+
return nil, err
146+
}
147+
148+
return px, nil
149+
}
150+
151+
func (p *Snowflake) Close() error {
152+
connErr := p.conn.Close()
153+
dbErr := p.db.Close()
154+
if connErr != nil || dbErr != nil {
155+
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
156+
}
157+
return nil
158+
}
159+
160+
func (p *Snowflake) Lock() error {
161+
if p.isLocked {
162+
return database.ErrLocked
163+
}
164+
p.isLocked = true
165+
return nil
166+
}
167+
168+
func (p *Snowflake) Unlock() error {
169+
p.isLocked = false
170+
return nil
171+
}
172+
173+
func (p *Snowflake) Run(migration io.Reader) error {
174+
migr, err := ioutil.ReadAll(migration)
175+
if err != nil {
176+
return err
177+
}
178+
179+
// run migration
180+
query := string(migr[:])
181+
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
182+
if pgErr, ok := err.(*pq.Error); ok {
183+
var line uint
184+
var col uint
185+
var lineColOK bool
186+
if pgErr.Position != "" {
187+
if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
188+
line, col, lineColOK = computeLineFromPos(query, int(pos))
189+
}
190+
}
191+
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
192+
if lineColOK {
193+
message = fmt.Sprintf("%s (column %d)", message, col)
194+
}
195+
if pgErr.Detail != "" {
196+
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
197+
}
198+
return database.Error{OrigErr: err, Err: message, Query: migr, Line: line}
199+
}
200+
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
201+
}
202+
203+
return nil
204+
}
205+
206+
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
207+
// replace crlf with lf
208+
s = strings.Replace(s, "\r\n", "\n", -1)
209+
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
210+
runes := []rune(s)
211+
if pos > len(runes) {
212+
return 0, 0, false
213+
}
214+
sel := runes[:pos]
215+
line = uint(runesCount(sel, newLine) + 1)
216+
col = uint(pos - 1 - runesLastIndex(sel, newLine))
217+
return line, col, true
218+
}
219+
220+
const newLine = '\n'
221+
222+
func runesCount(input []rune, target rune) int {
223+
var count int
224+
for _, r := range input {
225+
if r == target {
226+
count++
227+
}
228+
}
229+
return count
230+
}
231+
232+
func runesLastIndex(input []rune, target rune) int {
233+
for i := len(input) - 1; i >= 0; i-- {
234+
if input[i] == target {
235+
return i
236+
}
237+
}
238+
return -1
239+
}
240+
241+
func (p *Snowflake) SetVersion(version int, dirty bool) error {
242+
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
243+
if err != nil {
244+
return &database.Error{OrigErr: err, Err: "transaction start failed"}
245+
}
246+
247+
query := `DELETE FROM "` + p.config.MigrationsTable + `"`
248+
if _, err := tx.Exec(query); err != nil {
249+
if errRollback := tx.Rollback(); errRollback != nil {
250+
err = multierror.Append(err, errRollback)
251+
}
252+
return &database.Error{OrigErr: err, Query: []byte(query)}
253+
}
254+
255+
// Also re-write the schema version for nil dirty versions to prevent
256+
// empty schema version for failed down migration on the first migration
257+
// See: https://github.com/golang-migrate/migrate/issues/330
258+
if version >= 0 || (version == database.NilVersion && dirty) {
259+
query = `INSERT INTO "` + p.config.MigrationsTable + `" (version,
260+
dirty) VALUES (` + strconv.FormatInt(int64(version), 10) + `,
261+
` + strconv.FormatBool(dirty) + `)`
262+
if _, err := tx.Exec(query); err != nil {
263+
if errRollback := tx.Rollback(); errRollback != nil {
264+
err = multierror.Append(err, errRollback)
265+
}
266+
return &database.Error{OrigErr: err, Query: []byte(query)}
267+
}
268+
}
269+
270+
if err := tx.Commit(); err != nil {
271+
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
272+
}
273+
274+
return nil
275+
}
276+
277+
func (p *Snowflake) Version() (version int, dirty bool, err error) {
278+
query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
279+
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
280+
switch {
281+
case err == sql.ErrNoRows:
282+
return database.NilVersion, false, nil
283+
284+
case err != nil:
285+
if e, ok := err.(*pq.Error); ok {
286+
if e.Code.Name() == "undefined_table" {
287+
return database.NilVersion, false, nil
288+
}
289+
}
290+
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
291+
292+
default:
293+
return version, dirty, nil
294+
}
295+
}
296+
297+
func (p *Snowflake) Drop() (err error) {
298+
// select all tables in current schema
299+
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
300+
tables, err := p.conn.QueryContext(context.Background(), query)
301+
if err != nil {
302+
return &database.Error{OrigErr: err, Query: []byte(query)}
303+
}
304+
defer func() {
305+
if errClose := tables.Close(); errClose != nil {
306+
err = multierror.Append(err, errClose)
307+
}
308+
}()
309+
310+
// delete one table after another
311+
tableNames := make([]string, 0)
312+
for tables.Next() {
313+
var tableName string
314+
if err := tables.Scan(&tableName); err != nil {
315+
return err
316+
}
317+
if len(tableName) > 0 {
318+
tableNames = append(tableNames, tableName)
319+
}
320+
}
321+
322+
if len(tableNames) > 0 {
323+
// delete one by one ...
324+
for _, t := range tableNames {
325+
query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
326+
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
327+
return &database.Error{OrigErr: err, Query: []byte(query)}
328+
}
329+
}
330+
}
331+
332+
return nil
333+
}
334+
335+
// ensureVersionTable checks if versions table exists and, if not, creates it.
336+
// Note that this function locks the database, which deviates from the usual
337+
// convention of "caller locks" in the Snowflake type.
338+
func (p *Snowflake) ensureVersionTable() (err error) {
339+
if err = p.Lock(); err != nil {
340+
return err
341+
}
342+
343+
defer func() {
344+
if e := p.Unlock(); e != nil {
345+
if err == nil {
346+
err = e
347+
} else {
348+
err = multierror.Append(err, e)
349+
}
350+
}
351+
}()
352+
353+
// check if migration table exists
354+
var count int
355+
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
356+
if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil {
357+
return &database.Error{OrigErr: err, Query: []byte(query)}
358+
}
359+
if count == 1 {
360+
return nil
361+
}
362+
363+
// if not, create the empty migration table
364+
query = `CREATE TABLE if not exists "` + p.config.MigrationsTable + `" (
365+
version bigint not null primary key, dirty boolean not null)`
366+
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
367+
return &database.Error{OrigErr: err, Query: []byte(query)}
368+
}
369+
370+
return nil
371+
}

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@ require (
3333
github.com/mattn/go-sqlite3 v1.10.0
3434
github.com/nakagami/firebirdsql v0.0.0-20190310045651-3c02a58cfed8
3535
github.com/neo4j/neo4j-go-driver v1.8.0-beta02
36+
github.com/snowflakedb/gosnowflake v1.3.5
3637
github.com/stretchr/testify v1.5.1
3738
github.com/tidwall/pretty v0.0.0-20180105212114-65a9db5fad51 // indirect
3839
github.com/xanzy/go-gitlab v0.15.0
3940
github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c // indirect
4041
github.com/xdg/stringprep v1.0.0 // indirect
4142
gitlab.com/nyarla/go-crypt v0.0.0-20160106005555-d9a5dc2b789b // indirect
4243
go.mongodb.org/mongo-driver v1.1.0
43-
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073 // indirect
4444
golang.org/x/exp v0.0.0-20200213203834-85f925bdd4d0 // indirect
4545
golang.org/x/net v0.0.0-20200202094626-16171245cfb2
4646
golang.org/x/tools v0.0.0-20200213224642-88e652f7a869

0 commit comments

Comments
 (0)