Skip to content

Commit d00811f

Browse files
committed
v3.0.3
1 parent ddb39b9 commit d00811f

File tree

15 files changed

+359
-257
lines changed

15 files changed

+359
-257
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ performance/
2828
cover.sh
2929
gobco-counts.json
3030
branchCover.sh
31+
TradesSmall.csv

api/pool.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,17 +118,15 @@ func newConn(addr string, opt *PoolOption) (dialer.Conn, error) {
118118
return nil, err
119119
}
120120

121+
conn.SetUserID(opt.UserID)
122+
conn.SetPassword(opt.Password)
123+
121124
err = conn.Connect()
122125
if err != nil {
123126
fmt.Printf("Failed to connect to the server: %s\n", err.Error())
124127
return nil, err
125128
}
126129

127-
err = dialer.Login(conn, opt.UserID, opt.Password)
128-
if err != nil {
129-
return nil, err
130-
}
131-
132130
return conn, nil
133131
}
134132

api/utils.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77
uuid "github.com/satori/go.uuid"
88
)
99

10-
var API_VERSION = "3.0.2.1"
10+
var API_VERSION = "3.0.3"
1111

1212
func GetAPIVersion() string {
1313
return API_VERSION

dialer/behavior.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,8 @@ func (f *BehaviorOptions) GetTryReconnectNums() int {
9494
if f.TryReconnectNums == nil {
9595
return 0
9696
}
97+
if *f.TryReconnectNums < 0 {
98+
return 0
99+
}
97100
return *f.TryReconnectNums
98101
}

dialer/dialer.go

Lines changed: 222 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ package dialer
22

33
import (
44
"context"
5+
"crypto/hmac"
6+
"crypto/rand"
7+
"crypto/sha256"
8+
"encoding/base64"
59
"errors"
610
"fmt"
711
"net"
@@ -11,6 +15,7 @@ import (
1115

1216
"github.com/dolphindb/api-go/v3/dialer/protocol"
1317
"github.com/dolphindb/api-go/v3/model"
18+
"golang.org/x/crypto/pbkdf2"
1419
)
1520

1621
const (
@@ -72,6 +77,7 @@ type Conn interface {
7277
// for inner use
7378
GetReader() protocol.Reader
7479
enableScram() bool
80+
ConnLogin(userID, password string) error
7581
}
7682

7783
type conn struct {
@@ -147,11 +153,6 @@ func NewSimpleConn(ctx context.Context, address, userID, pwd string) (Conn, erro
147153
return nil, err
148154
}
149155

150-
err = Login(conn, userID, pwd)
151-
if err != nil {
152-
return nil, err
153-
}
154-
155156
return conn, err
156157
}
157158

@@ -242,11 +243,11 @@ func (c *conn) Connect() error {
242243
for _, v := range c.highAvailabilitySites {
243244
c.nodePool.add(&node{address: v})
244245
}
245-
return c.switchDatanode(nil)
246+
return c.switchDataNode(nil)
246247
} else if c.reconnect {
247248
c.nodePool = &nodePool{}
248249
c.nodePool.add(&node{address: c.addr})
249-
return c.switchDatanode(nil)
250+
return c.switchDataNode(nil)
250251
} else {
251252
ok, err := c.connectNode(&node{address: c.addr})
252253
if err != nil {
@@ -292,11 +293,24 @@ func (c *conn) connect(addr string) error {
292293
c.isClosed = false
293294
c.refreshHeaderForResponse(h)
294295

295-
if c.userID != "" {
296-
Login(c, c.userID, c.password)
296+
args := make([]model.DataForm, 0)
297+
ret, err := c.runFuncInternal("isNodeInitialized", args)
298+
if err != nil {
299+
fmt.Println("Server does not support the initialization check. Please upgrade to a newer version.")
300+
} else {
301+
if !(ret.(*model.Scalar)).Value().(bool) {
302+
c.isConnected = false
303+
fmt.Println("connection established, but the node has not been initialized")
304+
return fmt.Errorf("<DataNodeNotReady>") // use a special error to indicate that the node is not initialized
305+
}
306+
}
307+
err = nil
308+
309+
if c.userID != "" || c.password != "" {
310+
err = Login(c, c.userID, c.password)
297311
}
298312

299-
return nil
313+
return err
300314
}
301315

302316
func (c *conn) Close() error {
@@ -365,6 +379,20 @@ func (c *conn) RunFunc(s string, args []model.DataForm) (model.DataForm, error)
365379
return di, err
366380
}
367381

382+
func (c *conn) runFuncInternal(s string, args []model.DataForm) (model.DataForm, error) {
383+
bo := defaultByteOrder
384+
385+
_, di, err := c.runInternal(&requestParams{
386+
commandType: functionCmd,
387+
Command: generateFunctionCommand(s, bo, args),
388+
SessionID: []byte(c.GetSession()),
389+
Args: args,
390+
ByteOrder: bo,
391+
})
392+
393+
return di, err
394+
}
395+
368396
// Upload sends local data to dolphindb and the specified variable is generated on the dolphindb.
369397
func (c *conn) Upload(vars map[string]model.DataForm) (model.DataForm, error) {
370398
bo := defaultByteOrder
@@ -392,32 +420,39 @@ func (c *conn) Upload(vars map[string]model.DataForm) (model.DataForm, error) {
392420
}
393421

394422
func (c *conn) run(params *requestParams) (*responseHeader, model.DataForm, error) {
395-
retryTimes := c.getRetryTimes()
396-
if c.nodePool != nil && c.nodePool.len > 0 {
397-
for i := 0; i <= retryTimes; i++ { // at least try once
398-
rh, df, err := c.runInternal(params)
399-
if err != nil {
400-
n := c.nodePool.nodes[c.nodePool.lastInd]
401-
if c.connected() {
402-
et := c.nodePool.parseError(err.Error(), n)
403-
if et == IGNORE || et == NO_INITIALIZED {
404-
continue
405-
}
406-
}
407-
time.Sleep(300 * time.Millisecond)
408-
err := c.switchDatanode(nil)
409-
if err != nil {
410-
return nil, nil, err
411-
}
412-
continue
413-
}
423+
if c.nodePool == nil || c.nodePool.len <= 0 {
424+
return c.runInternal(params)
425+
}
414426

427+
rh, df, err := c.runInternal(params)
428+
for i := 0; i <= c.getRetryTimes(); i++ {
429+
if err == nil {
415430
return rh, df, nil
416431
}
417-
return nil, nil, fmt.Errorf("failed to connect to %s after %d times of reconnecting", c.addr, retryTimes)
418-
} else {
419-
return c.runInternal(params)
432+
433+
n := c.nodePool.nodes[c.nodePool.lastInd]
434+
et := c.nodePool.parseError(err.Error(), n)
435+
436+
if c.connected() && et == UNKNOWN {
437+
return rh, df, err
438+
}
439+
if et == LOGIN_REQUIRED {
440+
return rh, df, err
441+
}
442+
443+
if !(et == NEW_LEADER || et == NO_INITIALIZED || et == NODE_NOT_AVAIL) {
444+
// not use the current node
445+
n = nil
446+
}
447+
time.Sleep(300 * time.Millisecond)
448+
449+
// for loop would break when switchDataNode run out of retry times
450+
if err := c.switchDataNode(n); err != nil {
451+
return nil, nil, err
452+
}
453+
rh, df, err = c.runInternal(params)
420454
}
455+
return rh, df, err
421456
}
422457

423458
func (c *conn) runInternal(params *requestParams) (*responseHeader, model.DataForm, error) {
@@ -465,3 +500,158 @@ func (c *conn) refreshHeaderForResponse(h *responseHeader) {
465500
func (c *conn) enableScram() bool {
466501
return c.behaviorOpt.EnableScram
467502
}
503+
504+
func (c *conn) ConnLogin(userID, password string) error {
505+
if c.enableScram() {
506+
return c.scramLogin(userID, password)
507+
} else {
508+
err := c.scramLogin(userID, password)
509+
if err == nil {
510+
return nil
511+
}
512+
}
513+
args := make([]model.DataForm, 2)
514+
user, err := model.NewDataType(model.DtString, userID)
515+
if err != nil {
516+
return err
517+
}
518+
pwd, err := model.NewDataType(model.DtString, password)
519+
if err != nil {
520+
return err
521+
}
522+
523+
args[0] = model.NewScalar(user)
524+
args[1] = model.NewScalar(pwd)
525+
_, err = c.runFuncInternal("login", args)
526+
if err != nil {
527+
return err
528+
}
529+
return nil
530+
}
531+
532+
func generateNonce(length int) (string, error) {
533+
buffer := make([]byte, length)
534+
_, err := rand.Read(buffer)
535+
if err != nil {
536+
return "", err
537+
}
538+
return base64.StdEncoding.EncodeToString(buffer), nil
539+
}
540+
541+
func xorBytes(a, b []byte) []byte {
542+
result := make([]byte, len(a))
543+
for i := range a {
544+
result[i] = a[i] ^ b[i]
545+
}
546+
return result
547+
}
548+
549+
func (c *conn) scramLogin(userID, password string) error {
550+
args := make([]model.DataForm, 2)
551+
user, err := model.NewDataType(model.DtString, userID)
552+
if err != nil {
553+
return fmt.Errorf("SCRAM login failed, %w", err)
554+
}
555+
clientNonce, err := generateNonce(16)
556+
if err != nil {
557+
return fmt.Errorf("SCRAM login failed, %w", err)
558+
}
559+
nonce, err := model.NewDataType(model.DtString, clientNonce)
560+
if err != nil {
561+
return fmt.Errorf("SCRAM login failed, %w", err)
562+
}
563+
args[0] = model.NewScalar(user)
564+
args[1] = model.NewScalar(nonce)
565+
566+
result, err := c.runFuncInternal("scramClientFirst", args)
567+
if err != nil {
568+
if strings.Contains(err.Error(), "Can't recognize function name scramClientFirst") {
569+
return fmt.Errorf("SCRAM login is unavailable on current server")
570+
}
571+
if strings.Contains(err.Error(), "sha256 authMode doesn't support scram authMode") {
572+
return fmt.Errorf("user '%s' doesn't support scram authMode", userID)
573+
}
574+
return fmt.Errorf("scramClientFirst failed: %w", err)
575+
}
576+
577+
retVec := result.(*model.Vector)
578+
579+
if retVec.Rows() != 3 {
580+
return fmt.Errorf("SCRAM login failed, server error: get server nonce failed")
581+
}
582+
saltStr := retVec.Get(0).Value().(*model.Scalar).Value().(string)
583+
iterCount := int(retVec.Get(1).Value().(*model.Scalar).Value().(int32))
584+
combinedNonce := retVec.Get(2).Value().(*model.Scalar).Value().(string)
585+
586+
salt, err := base64.StdEncoding.DecodeString(saltStr)
587+
if err != nil {
588+
return fmt.Errorf("SCRAM login failed, base64 decode failed: %w", err)
589+
}
590+
591+
saltedPassword := pbkdf2.Key([]byte(password), salt, iterCount, 32, sha256.New)
592+
593+
mac := hmac.New(sha256.New, saltedPassword)
594+
_, err = mac.Write([]byte("Client Key"))
595+
if err != nil {
596+
return fmt.Errorf("SCRAM login failed, HMAC calculation failed: %w", err)
597+
}
598+
clientKey := mac.Sum(nil)
599+
600+
storedKey := sha256.Sum256(clientKey)
601+
602+
authMessage := fmt.Sprintf(`n=%s,r=%s,r=%s,s=%s,i=%d,c=biws,r=%s`,
603+
userID, clientNonce, combinedNonce, saltStr, iterCount, combinedNonce)
604+
605+
mac = hmac.New(sha256.New, storedKey[:])
606+
_, err = mac.Write([]byte(authMessage))
607+
if err != nil {
608+
return fmt.Errorf("SCRAM login failed, HMAC calculation failed: %w", err)
609+
}
610+
clientSig := mac.Sum(nil)
611+
612+
proof := xorBytes(clientKey, clientSig)
613+
614+
finalArgs := make([]model.DataForm, 3)
615+
combinedNonceScalar, err := model.NewDataType(model.DtString, combinedNonce)
616+
if err != nil {
617+
return fmt.Errorf("SCRAM login failed, %w", err)
618+
}
619+
proofScalar, err := model.NewDataType(model.DtString, base64.StdEncoding.EncodeToString(proof))
620+
if err != nil {
621+
return fmt.Errorf("SCRAM login failed, %w", err)
622+
}
623+
624+
finalArgs[0] = model.NewScalar(user)
625+
finalArgs[1] = model.NewScalar(combinedNonceScalar)
626+
finalArgs[2] = model.NewScalar(proofScalar)
627+
628+
finalResult, err := c.runFuncInternal("scramClientFinal", finalArgs)
629+
if err != nil {
630+
return fmt.Errorf("scramClientFinal failed: %w", err)
631+
}
632+
serverSigBase64 := finalResult.(*model.Scalar).Value().(string)
633+
634+
mac = hmac.New(sha256.New, saltedPassword)
635+
_, err = mac.Write([]byte("Server Key"))
636+
if err != nil {
637+
return fmt.Errorf("SCRAM login failed, HMAC calculation failed: %w", err)
638+
}
639+
serverKey := mac.Sum(nil)
640+
641+
mac = hmac.New(sha256.New, serverKey)
642+
_, err = mac.Write([]byte(authMessage))
643+
if err != nil {
644+
return fmt.Errorf("SCRAM login failed, HMAC calculation failed: %w", err)
645+
}
646+
serverSig := mac.Sum(nil)
647+
648+
expectedSig := base64.StdEncoding.EncodeToString(serverSig)
649+
650+
if serverSigBase64 != "" && expectedSig != serverSigBase64 {
651+
c.Close()
652+
return fmt.Errorf("invalid SCRAM server signature")
653+
}
654+
655+
fmt.Println("SCRAM login succeeded")
656+
return nil
657+
}

0 commit comments

Comments
 (0)