Skip to content

Commit f3d9ec6

Browse files
committed
fix: peer handling of messages at the counterparty (receiver) side
1 parent 17d5fcc commit f3d9ec6

File tree

5 files changed

+250
-168
lines changed

5 files changed

+250
-168
lines changed

auth/peer.go

Lines changed: 193 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ package auth
77

88
import (
99
"context"
10-
"crypto/rand"
1110
"encoding/base64"
1211
"encoding/json"
1312
"fmt"
@@ -21,7 +20,7 @@ import (
2120
)
2221

2322
// AUTH_PROTOCOL_ID is the protocol ID for authentication messages as specified in BRC-31 (Authrite)
24-
const AUTH_PROTOCOL_ID = "authrite message signature"
23+
const AUTH_PROTOCOL_ID = "auth message signature"
2524

2625
// AUTH_VERSION is the version of the auth protocol
2726
const AUTH_VERSION = "0.1"
@@ -284,15 +283,11 @@ func (p *Peer) GetAuthenticatedSession(ctx context.Context, identityKey *ec.Publ
284283

285284
// initiateHandshake starts the mutual authentication handshake with a peer
286285
func (p *Peer) initiateHandshake(ctx context.Context, peerIdentityKey *ec.PublicKey, maxWaitTimeMs int) (*PeerSession, error) {
287-
// Create a session nonce
288-
nonceBytes := make([]byte, 32)
289-
_, err := rand.Read(nonceBytes)
286+
sessionNonce, err := utils.CreateNonce(ctx, p.wallet, wallet.Counterparty{Type: wallet.CounterpartyTypeSelf})
290287
if err != nil {
291-
return nil, NewAuthError("failed to generate nonce", err)
288+
return nil, NewAuthError("failed to create session nonce", err)
292289
}
293290

294-
sessionNonce := base64.StdEncoding.EncodeToString(nonceBytes)
295-
296291
// Add a preliminary session entry (not yet authenticated)
297292
session := &PeerSession{
298293
IsAuthenticated: false,
@@ -495,75 +490,191 @@ func (p *Peer) handleInitialRequest(ctx context.Context, message *AuthMessage, s
495490
IdentityKey: identityKeyResult.PublicKey,
496491
Nonce: ourNonce,
497492
YourNonce: message.InitialNonce,
498-
InitialNonce: message.InitialNonce,
493+
InitialNonce: session.SessionNonce,
499494
Certificates: certs,
500495
}
501496

497+
data := message.InitialNonce + session.SessionNonce
498+
sigData, err := base64.StdEncoding.DecodeString(data)
499+
if err != nil {
500+
return NewAuthError("failed to prepare data to sign", err)
501+
}
502+
503+
keyID := fmt.Sprintf("%s %s", message.InitialNonce, session.SessionNonce)
504+
505+
arg := wallet.CreateSignatureArgs{
506+
EncryptionArgs: wallet.EncryptionArgs{
507+
ProtocolID: wallet.Protocol{
508+
// SecurityLevel set to 2 (SecurityLevelEveryAppAndCounterparty) as specified in BRC-31 (Authrite)
509+
SecurityLevel: wallet.SecurityLevelEveryAppAndCounterparty,
510+
Protocol: AUTH_PROTOCOL_ID,
511+
},
512+
KeyID: keyID,
513+
Counterparty: wallet.Counterparty{
514+
Type: wallet.CounterpartyTypeOther,
515+
Counterparty: message.IdentityKey,
516+
},
517+
},
518+
// Sign the certificate request data, as in TypeScript
519+
Data: sigData,
520+
}
521+
522+
sigResult, err := p.wallet.CreateSignature(ctx, arg, "")
523+
if err != nil {
524+
return fmt.Errorf("failed to sign initial response: %w", err)
525+
}
526+
527+
response.Signature = sigResult.Signature.Serialize()
528+
502529
// Send the response
503530
return p.transport.Send(ctx, response)
504531
}
505532

506533
// handleInitialResponse processes the response to our initial authentication request
507534
func (p *Peer) handleInitialResponse(ctx context.Context, message *AuthMessage, senderPublicKey *ec.PublicKey) error {
508-
// Validate the response has required nonces
509-
if message.YourNonce == "" || message.InitialNonce == "" {
535+
valid, err := utils.VerifyNonce(ctx, message.YourNonce, p.wallet, wallet.Counterparty{Type: wallet.CounterpartyTypeSelf})
536+
if err != nil {
537+
return fmt.Errorf("failed to validate nonce: %w", err)
538+
}
539+
if !valid {
510540
return ErrInvalidNonce
511541
}
512542

513-
// Find corresponding initial request callback by the initial nonce
514-
for id, callback := range p.onInitialResponseReceivedCallbacks {
515-
if callback.SessionNonce == message.InitialNonce {
516-
// Process certificates if included
517-
if len(message.Certificates) > 0 {
518-
// Create utils.AuthMessage from our message
519-
utilsMessage := &AuthMessage{
520-
IdentityKey: message.IdentityKey,
521-
Certificates: message.Certificates,
522-
}
523-
524-
// Convert our RequestedCertificateSet to utils.RequestedCertificateSet
525-
utilsRequestedCerts := &utils.RequestedCertificateSet{
526-
Certifiers: p.CertificatesToRequest.Certifiers,
527-
}
528-
529-
// Convert map type
530-
certTypes := make(utils.RequestedCertificateTypeIDAndFieldList)
531-
for k, v := range p.CertificatesToRequest.CertificateTypes {
532-
certTypes[k] = v
533-
}
534-
utilsRequestedCerts.CertificateTypes = certTypes
535-
536-
// Call ValidateCertificates with proper types
537-
err := ValidateCertificates(
538-
ctx,
539-
p.wallet,
540-
utilsMessage,
541-
utilsRequestedCerts,
542-
)
543-
if err != nil {
544-
// Log the error but continue - certificate error shouldn't stop auth
545-
p.logger.Printf("Warning: Certificate validation failed: %v", err)
546-
}
547-
548-
// Notify certificate listeners
549-
for _, callback := range p.onCertificateReceivedCallbacks {
550-
err := callback(senderPublicKey, message.Certificates)
551-
if err != nil {
552-
// Log callback error but continue
553-
p.logger.Printf("Warning: Certificate callback error: %v", err)
554-
}
555-
}
543+
session, err := p.sessionManager.GetSession(senderPublicKey.ToDERHex())
544+
if err != nil || session == nil {
545+
return ErrSessionNotFound
546+
}
547+
548+
data := message.InitialNonce + session.SessionNonce
549+
550+
sigData, err := base64.StdEncoding.DecodeString(data)
551+
if err != nil {
552+
return NewAuthError("failed to prepare data to sign", err)
553+
}
554+
555+
signature, err := ec.ParseSignature(message.Signature)
556+
if err != nil {
557+
return fmt.Errorf("failed to parse signature: %w", err)
558+
}
559+
560+
verifyResult, err := p.wallet.VerifySignature(ctx, wallet.VerifySignatureArgs{
561+
Data: sigData,
562+
Signature: signature,
563+
EncryptionArgs: wallet.EncryptionArgs{
564+
ProtocolID: wallet.Protocol{
565+
// SecurityLevel set to 2 (SecurityLevelEveryAppAndCounterparty) as specified in BRC-31 (Authrite)
566+
SecurityLevel: wallet.SecurityLevelEveryAppAndCounterparty,
567+
Protocol: AUTH_PROTOCOL_ID,
568+
},
569+
KeyID: fmt.Sprintf("%s %s", message.InitialNonce, session.SessionNonce),
570+
Counterparty: wallet.Counterparty{
571+
Type: wallet.CounterpartyTypeOther,
572+
Counterparty: message.IdentityKey,
573+
},
574+
},
575+
}, "")
576+
if err != nil {
577+
return fmt.Errorf("unable to verify signature in initial response: %w", err)
578+
} else if !verifyResult.Valid {
579+
return ErrInvalidSignature
580+
}
581+
582+
session.PeerNonce = message.InitialNonce
583+
session.PeerIdentityKey = message.IdentityKey
584+
session.IsAuthenticated = true
585+
session.LastUpdate = time.Now().UnixMilli()
586+
p.sessionManager.UpdateSession(session)
587+
588+
if p.CertificatesToRequest != nil && len(p.CertificatesToRequest.Certifiers) > 0 && len(message.Certificates) > 0 {
589+
// Create utils.AuthMessage from our message
590+
utilsMessage := &AuthMessage{
591+
IdentityKey: message.IdentityKey,
592+
Certificates: message.Certificates,
593+
}
594+
595+
// Convert our RequestedCertificateSet to utils.RequestedCertificateSet
596+
utilsRequestedCerts := &utils.RequestedCertificateSet{
597+
Certifiers: p.CertificatesToRequest.Certifiers,
598+
}
599+
600+
// Convert map type
601+
certTypes := make(utils.RequestedCertificateTypeIDAndFieldList)
602+
for k, v := range p.CertificatesToRequest.CertificateTypes {
603+
certTypes[k] = v
604+
}
605+
utilsRequestedCerts.CertificateTypes = certTypes
606+
607+
// Call ValidateCertificates with proper types
608+
err := ValidateCertificates(
609+
ctx,
610+
p.wallet,
611+
utilsMessage,
612+
utilsRequestedCerts,
613+
)
614+
if err != nil {
615+
return fmt.Errorf("invalid certificates: %w", err)
616+
}
617+
618+
for _, callback := range p.onCertificateReceivedCallbacks {
619+
err := callback(senderPublicKey, message.Certificates)
620+
if err != nil {
621+
return fmt.Errorf("certificate received callback error: %w", err)
556622
}
623+
}
624+
}
557625

626+
p.lastInteractedWithPeer = message.IdentityKey
627+
628+
for id, callback := range p.onInitialResponseReceivedCallbacks {
629+
if callback.SessionNonce == session.SessionNonce {
558630
// Call the initial response callback with the peer's nonce
559-
err := callback.Callback(message.Nonce)
631+
err := callback.Callback(session.SessionNonce)
560632
delete(p.onInitialResponseReceivedCallbacks, id)
561633
return err
562634
}
563635
}
564636

565-
// No matching callback found
566-
return fmt.Errorf("no matching initial request found for response with nonce %s", message.InitialNonce)
637+
// The peer might also request certificates from us
638+
if len(message.RequestedCertificates.Certifiers) > 0 || len(message.RequestedCertificates.CertificateTypes) > 0 {
639+
err = p.sendCertificates(ctx, message)
640+
if err != nil {
641+
return fmt.Errorf("failed to send requested certificates: %w", err)
642+
}
643+
}
644+
645+
return nil
646+
}
647+
648+
func (p *Peer) sendCertificates(ctx context.Context, message *AuthMessage) error {
649+
if len(p.onCertificateRequestReceivedCallbacks) > 0 {
650+
for _, callback := range p.onCertificateRequestReceivedCallbacks {
651+
err := callback(message.IdentityKey, message.RequestedCertificates)
652+
if err != nil {
653+
// Log callback error but continue
654+
return fmt.Errorf("on certificate request callback failed: %w", err)
655+
}
656+
}
657+
return nil
658+
}
659+
660+
certs, err := utils.GetVerifiableCertificates(
661+
ctx,
662+
&utils.GetVerifiableCertificatesOptions{
663+
Wallet: p.wallet,
664+
RequestedCertificates: &message.RequestedCertificates,
665+
VerifierIdentityKey: message.IdentityKey,
666+
},
667+
)
668+
if err != nil {
669+
return fmt.Errorf("failed to get verifiable certificates: %w", err)
670+
}
671+
672+
err = p.SendCertificateResponse(ctx, message.IdentityKey, certs)
673+
if err != nil {
674+
return fmt.Errorf("failed to send certificate response: %w", err)
675+
}
676+
677+
return nil
567678
}
568679

569680
// handleCertificateRequest processes a certificate request message
@@ -574,8 +685,11 @@ func (p *Peer) handleCertificateRequest(ctx context.Context, message *AuthMessag
574685
return ErrSessionNotFound
575686
}
576687

577-
// Verify nonces match
578-
if message.YourNonce != session.SessionNonce {
688+
valid, err := utils.VerifyNonce(ctx, message.YourNonce, p.wallet, wallet.Counterparty{Type: wallet.CounterpartyTypeSelf})
689+
if err != nil {
690+
return fmt.Errorf("failed to validate nonce: %w", err)
691+
}
692+
if !valid {
579693
return ErrInvalidNonce
580694
}
581695

@@ -623,31 +737,10 @@ func (p *Peer) handleCertificateRequest(ctx context.Context, message *AuthMessag
623737
return fmt.Errorf("invalid signature in certificate request: %w", err)
624738
}
625739

626-
// Notify certificate request listeners
627-
for _, callback := range p.onCertificateRequestReceivedCallbacks {
628-
err := callback(senderPublicKey, message.RequestedCertificates)
629-
if err != nil {
630-
// Log callback error but continue
631-
p.logger.Printf("Warning: Certificate request callback error: %v", err)
632-
}
633-
}
634-
635-
// If we have auto-response enabled, automatically send certificates
636740
if len(message.RequestedCertificates.Certifiers) > 0 || len(message.RequestedCertificates.CertificateTypes) > 0 {
637-
certs, err := utils.GetVerifiableCertificates(
638-
ctx,
639-
&utils.GetVerifiableCertificatesOptions{
640-
Wallet: p.wallet,
641-
RequestedCertificates: &message.RequestedCertificates,
642-
VerifierIdentityKey: senderPublicKey,
643-
},
644-
)
645-
if err == nil && len(certs) > 0 {
646-
// Auto-respond with available certificates
647-
err = p.SendCertificateResponse(ctx, senderPublicKey, certs)
648-
if err != nil {
649-
p.logger.Printf("Warning: Failed to auto-respond with certificates: %v", err)
650-
}
741+
err = p.sendCertificates(ctx, message)
742+
if err != nil {
743+
return fmt.Errorf("failed to send requested certificates: %w", err)
651744
}
652745
}
653746

@@ -662,8 +755,11 @@ func (p *Peer) handleCertificateResponse(ctx context.Context, message *AuthMessa
662755
return ErrSessionNotFound
663756
}
664757

665-
// Verify nonces match
666-
if message.YourNonce != session.SessionNonce {
758+
valid, err := utils.VerifyNonce(ctx, message.YourNonce, p.wallet, wallet.Counterparty{Type: wallet.CounterpartyTypeSelf})
759+
if err != nil {
760+
return fmt.Errorf("failed to validate nonce: %w", err)
761+
}
762+
if !valid {
667763
return ErrInvalidNonce
668764
}
669765

@@ -752,27 +848,19 @@ func (p *Peer) handleCertificateResponse(ctx context.Context, message *AuthMessa
752848

753849
// handleGeneralMessage processes a general message
754850
func (p *Peer) handleGeneralMessage(ctx context.Context, message *AuthMessage, senderPublicKey *ec.PublicKey) error {
851+
valid, err := utils.VerifyNonce(ctx, message.YourNonce, p.wallet, wallet.Counterparty{Type: wallet.CounterpartyTypeSelf})
852+
if err != nil {
853+
return fmt.Errorf("failed to validate nonce: %w", err)
854+
}
855+
if !valid {
856+
return ErrInvalidNonce
857+
}
858+
755859
// Validate the session exists and is authenticated
756860
session, err := p.sessionManager.GetSession(senderPublicKey.ToDERHex())
757861
if err != nil || session == nil {
758862
return ErrSessionNotFound
759863
}
760-
if !session.IsAuthenticated {
761-
if p.CertificatesToRequest != nil && len(p.CertificatesToRequest.Certifiers) > 0 {
762-
return ErrMissingCertificate
763-
}
764-
765-
return ErrNotAuthenticated
766-
}
767-
768-
// Verify nonces match
769-
if message.YourNonce != session.SessionNonce {
770-
return ErrInvalidNonce
771-
}
772-
773-
// Update session timestamp
774-
session.LastUpdate = time.Now().UnixMilli()
775-
p.sessionManager.UpdateSession(session)
776864

777865
// Try to parse the signature
778866
signature, err := ec.ParseSignature(message.Signature)
@@ -802,6 +890,10 @@ func (p *Peer) handleGeneralMessage(ctx context.Context, message *AuthMessage, s
802890
return fmt.Errorf("invalid signature in general message: %w", err)
803891
}
804892

893+
// Update session timestamp
894+
session.LastUpdate = time.Now().UnixMilli()
895+
p.sessionManager.UpdateSession(session)
896+
805897
// Update last interacted peer
806898
if p.autoPersistLastSession {
807899
p.lastInteractedWithPeer = senderPublicKey

0 commit comments

Comments
 (0)