@@ -2,6 +2,10 @@ package dialer
22
33import (
44 "context"
5+ "crypto/hmac"
6+ "crypto/rand"
7+ "crypto/sha256"
8+ "encoding/base64"
59 "errors"
610 "fmt"
711 "net"
@@ -11,6 +15,7 @@ import (
1115
1216 "github.com/dolphindb/api-go/v3/dialer/protocol"
1317 "github.com/dolphindb/api-go/v3/model"
18+ "golang.org/x/crypto/pbkdf2"
1419)
1520
1621const (
@@ -72,6 +77,7 @@ type Conn interface {
7277 // for inner use
7378 GetReader () protocol.Reader
7479 enableScram () bool
80+ ConnLogin (userID , password string ) error
7581}
7682
7783type conn struct {
@@ -147,11 +153,6 @@ func NewSimpleConn(ctx context.Context, address, userID, pwd string) (Conn, erro
147153 return nil , err
148154 }
149155
150- err = Login (conn , userID , pwd )
151- if err != nil {
152- return nil , err
153- }
154-
155156 return conn , err
156157}
157158
@@ -242,11 +243,11 @@ func (c *conn) Connect() error {
242243 for _ , v := range c .highAvailabilitySites {
243244 c .nodePool .add (& node {address : v })
244245 }
245- return c .switchDatanode (nil )
246+ return c .switchDataNode (nil )
246247 } else if c .reconnect {
247248 c .nodePool = & nodePool {}
248249 c .nodePool .add (& node {address : c .addr })
249- return c .switchDatanode (nil )
250+ return c .switchDataNode (nil )
250251 } else {
251252 ok , err := c .connectNode (& node {address : c .addr })
252253 if err != nil {
@@ -292,11 +293,24 @@ func (c *conn) connect(addr string) error {
292293 c .isClosed = false
293294 c .refreshHeaderForResponse (h )
294295
295- if c .userID != "" {
296- Login (c , c .userID , c .password )
296+ args := make ([]model.DataForm , 0 )
297+ ret , err := c .runFuncInternal ("isNodeInitialized" , args )
298+ if err != nil {
299+ fmt .Println ("Server does not support the initialization check. Please upgrade to a newer version." )
300+ } else {
301+ if ! (ret .(* model.Scalar )).Value ().(bool ) {
302+ c .isConnected = false
303+ fmt .Println ("connection established, but the node has not been initialized" )
304+ return fmt .Errorf ("<DataNodeNotReady>" ) // use a special error to indicate that the node is not initialized
305+ }
306+ }
307+ err = nil
308+
309+ if c .userID != "" || c .password != "" {
310+ err = Login (c , c .userID , c .password )
297311 }
298312
299- return nil
313+ return err
300314}
301315
302316func (c * conn ) Close () error {
@@ -365,6 +379,20 @@ func (c *conn) RunFunc(s string, args []model.DataForm) (model.DataForm, error)
365379 return di , err
366380}
367381
382+ func (c * conn ) runFuncInternal (s string , args []model.DataForm ) (model.DataForm , error ) {
383+ bo := defaultByteOrder
384+
385+ _ , di , err := c .runInternal (& requestParams {
386+ commandType : functionCmd ,
387+ Command : generateFunctionCommand (s , bo , args ),
388+ SessionID : []byte (c .GetSession ()),
389+ Args : args ,
390+ ByteOrder : bo ,
391+ })
392+
393+ return di , err
394+ }
395+
368396// Upload sends local data to dolphindb and the specified variable is generated on the dolphindb.
369397func (c * conn ) Upload (vars map [string ]model.DataForm ) (model.DataForm , error ) {
370398 bo := defaultByteOrder
@@ -392,32 +420,39 @@ func (c *conn) Upload(vars map[string]model.DataForm) (model.DataForm, error) {
392420}
393421
394422func (c * conn ) run (params * requestParams ) (* responseHeader , model.DataForm , error ) {
395- retryTimes := c .getRetryTimes ()
396- if c .nodePool != nil && c .nodePool .len > 0 {
397- for i := 0 ; i <= retryTimes ; i ++ { // at least try once
398- rh , df , err := c .runInternal (params )
399- if err != nil {
400- n := c .nodePool .nodes [c .nodePool .lastInd ]
401- if c .connected () {
402- et := c .nodePool .parseError (err .Error (), n )
403- if et == IGNORE || et == NO_INITIALIZED {
404- continue
405- }
406- }
407- time .Sleep (300 * time .Millisecond )
408- err := c .switchDatanode (nil )
409- if err != nil {
410- return nil , nil , err
411- }
412- continue
413- }
423+ if c .nodePool == nil || c .nodePool .len <= 0 {
424+ return c .runInternal (params )
425+ }
414426
427+ rh , df , err := c .runInternal (params )
428+ for i := 0 ; i <= c .getRetryTimes (); i ++ {
429+ if err == nil {
415430 return rh , df , nil
416431 }
417- return nil , nil , fmt .Errorf ("failed to connect to %s after %d times of reconnecting" , c .addr , retryTimes )
418- } else {
419- return c .runInternal (params )
432+
433+ n := c .nodePool .nodes [c .nodePool .lastInd ]
434+ et := c .nodePool .parseError (err .Error (), n )
435+
436+ if c .connected () && et == UNKNOWN {
437+ return rh , df , err
438+ }
439+ if et == LOGIN_REQUIRED {
440+ return rh , df , err
441+ }
442+
443+ if ! (et == NEW_LEADER || et == NO_INITIALIZED || et == NODE_NOT_AVAIL ) {
444+ // not use the current node
445+ n = nil
446+ }
447+ time .Sleep (300 * time .Millisecond )
448+
449+ // for loop would break when switchDataNode run out of retry times
450+ if err := c .switchDataNode (n ); err != nil {
451+ return nil , nil , err
452+ }
453+ rh , df , err = c .runInternal (params )
420454 }
455+ return rh , df , err
421456}
422457
423458func (c * conn ) runInternal (params * requestParams ) (* responseHeader , model.DataForm , error ) {
@@ -465,3 +500,158 @@ func (c *conn) refreshHeaderForResponse(h *responseHeader) {
465500func (c * conn ) enableScram () bool {
466501 return c .behaviorOpt .EnableScram
467502}
503+
504+ func (c * conn ) ConnLogin (userID , password string ) error {
505+ if c .enableScram () {
506+ return c .scramLogin (userID , password )
507+ } else {
508+ err := c .scramLogin (userID , password )
509+ if err == nil {
510+ return nil
511+ }
512+ }
513+ args := make ([]model.DataForm , 2 )
514+ user , err := model .NewDataType (model .DtString , userID )
515+ if err != nil {
516+ return err
517+ }
518+ pwd , err := model .NewDataType (model .DtString , password )
519+ if err != nil {
520+ return err
521+ }
522+
523+ args [0 ] = model .NewScalar (user )
524+ args [1 ] = model .NewScalar (pwd )
525+ _ , err = c .runFuncInternal ("login" , args )
526+ if err != nil {
527+ return err
528+ }
529+ return nil
530+ }
531+
532+ func generateNonce (length int ) (string , error ) {
533+ buffer := make ([]byte , length )
534+ _ , err := rand .Read (buffer )
535+ if err != nil {
536+ return "" , err
537+ }
538+ return base64 .StdEncoding .EncodeToString (buffer ), nil
539+ }
540+
541+ func xorBytes (a , b []byte ) []byte {
542+ result := make ([]byte , len (a ))
543+ for i := range a {
544+ result [i ] = a [i ] ^ b [i ]
545+ }
546+ return result
547+ }
548+
549+ func (c * conn ) scramLogin (userID , password string ) error {
550+ args := make ([]model.DataForm , 2 )
551+ user , err := model .NewDataType (model .DtString , userID )
552+ if err != nil {
553+ return fmt .Errorf ("SCRAM login failed, %w" , err )
554+ }
555+ clientNonce , err := generateNonce (16 )
556+ if err != nil {
557+ return fmt .Errorf ("SCRAM login failed, %w" , err )
558+ }
559+ nonce , err := model .NewDataType (model .DtString , clientNonce )
560+ if err != nil {
561+ return fmt .Errorf ("SCRAM login failed, %w" , err )
562+ }
563+ args [0 ] = model .NewScalar (user )
564+ args [1 ] = model .NewScalar (nonce )
565+
566+ result , err := c .runFuncInternal ("scramClientFirst" , args )
567+ if err != nil {
568+ if strings .Contains (err .Error (), "Can't recognize function name scramClientFirst" ) {
569+ return fmt .Errorf ("SCRAM login is unavailable on current server" )
570+ }
571+ if strings .Contains (err .Error (), "sha256 authMode doesn't support scram authMode" ) {
572+ return fmt .Errorf ("user '%s' doesn't support scram authMode" , userID )
573+ }
574+ return fmt .Errorf ("scramClientFirst failed: %w" , err )
575+ }
576+
577+ retVec := result .(* model.Vector )
578+
579+ if retVec .Rows () != 3 {
580+ return fmt .Errorf ("SCRAM login failed, server error: get server nonce failed" )
581+ }
582+ saltStr := retVec .Get (0 ).Value ().(* model.Scalar ).Value ().(string )
583+ iterCount := int (retVec .Get (1 ).Value ().(* model.Scalar ).Value ().(int32 ))
584+ combinedNonce := retVec .Get (2 ).Value ().(* model.Scalar ).Value ().(string )
585+
586+ salt , err := base64 .StdEncoding .DecodeString (saltStr )
587+ if err != nil {
588+ return fmt .Errorf ("SCRAM login failed, base64 decode failed: %w" , err )
589+ }
590+
591+ saltedPassword := pbkdf2 .Key ([]byte (password ), salt , iterCount , 32 , sha256 .New )
592+
593+ mac := hmac .New (sha256 .New , saltedPassword )
594+ _ , err = mac .Write ([]byte ("Client Key" ))
595+ if err != nil {
596+ return fmt .Errorf ("SCRAM login failed, HMAC calculation failed: %w" , err )
597+ }
598+ clientKey := mac .Sum (nil )
599+
600+ storedKey := sha256 .Sum256 (clientKey )
601+
602+ authMessage := fmt .Sprintf (`n=%s,r=%s,r=%s,s=%s,i=%d,c=biws,r=%s` ,
603+ userID , clientNonce , combinedNonce , saltStr , iterCount , combinedNonce )
604+
605+ mac = hmac .New (sha256 .New , storedKey [:])
606+ _ , err = mac .Write ([]byte (authMessage ))
607+ if err != nil {
608+ return fmt .Errorf ("SCRAM login failed, HMAC calculation failed: %w" , err )
609+ }
610+ clientSig := mac .Sum (nil )
611+
612+ proof := xorBytes (clientKey , clientSig )
613+
614+ finalArgs := make ([]model.DataForm , 3 )
615+ combinedNonceScalar , err := model .NewDataType (model .DtString , combinedNonce )
616+ if err != nil {
617+ return fmt .Errorf ("SCRAM login failed, %w" , err )
618+ }
619+ proofScalar , err := model .NewDataType (model .DtString , base64 .StdEncoding .EncodeToString (proof ))
620+ if err != nil {
621+ return fmt .Errorf ("SCRAM login failed, %w" , err )
622+ }
623+
624+ finalArgs [0 ] = model .NewScalar (user )
625+ finalArgs [1 ] = model .NewScalar (combinedNonceScalar )
626+ finalArgs [2 ] = model .NewScalar (proofScalar )
627+
628+ finalResult , err := c .runFuncInternal ("scramClientFinal" , finalArgs )
629+ if err != nil {
630+ return fmt .Errorf ("scramClientFinal failed: %w" , err )
631+ }
632+ serverSigBase64 := finalResult .(* model.Scalar ).Value ().(string )
633+
634+ mac = hmac .New (sha256 .New , saltedPassword )
635+ _ , err = mac .Write ([]byte ("Server Key" ))
636+ if err != nil {
637+ return fmt .Errorf ("SCRAM login failed, HMAC calculation failed: %w" , err )
638+ }
639+ serverKey := mac .Sum (nil )
640+
641+ mac = hmac .New (sha256 .New , serverKey )
642+ _ , err = mac .Write ([]byte (authMessage ))
643+ if err != nil {
644+ return fmt .Errorf ("SCRAM login failed, HMAC calculation failed: %w" , err )
645+ }
646+ serverSig := mac .Sum (nil )
647+
648+ expectedSig := base64 .StdEncoding .EncodeToString (serverSig )
649+
650+ if serverSigBase64 != "" && expectedSig != serverSigBase64 {
651+ c .Close ()
652+ return fmt .Errorf ("invalid SCRAM server signature" )
653+ }
654+
655+ fmt .Println ("SCRAM login succeeded" )
656+ return nil
657+ }
0 commit comments