@@ -49,11 +49,14 @@ type (
4949 RetryOptions exported.RetryOptions
5050
5151 clientMu sync.RWMutex
52- client * amqp. Client
52+ client amqpwrap. AMQPClient
5353 negotiateClaimMu sync.Mutex
5454 // indicates that the client was closed permanently, and not just
5555 // for recovery.
5656 closedPermanently bool
57+
58+ // newClientFn exists so we can stub out newClient for unit tests.
59+ newClientFn func (ctx context.Context ) (amqpwrap.AMQPClient , error )
5760 }
5861
5962 // NamespaceOption provides structure for configuring a new Service Bus namespace
@@ -143,6 +146,8 @@ func NamespaceWithRetryOptions(retryOptions exported.RetryOptions) NamespaceOpti
143146func NewNamespace (opts ... NamespaceOption ) (* Namespace , error ) {
144147 ns := & Namespace {}
145148
149+ ns .newClientFn = ns .newClientImpl
150+
146151 for _ , opt := range opts {
147152 err := opt (ns )
148153 if err != nil {
@@ -153,7 +158,7 @@ func NewNamespace(opts ...NamespaceOption) (*Namespace, error) {
153158 return ns , nil
154159}
155160
156- func (ns * Namespace ) newClient (ctx context.Context ) (* amqp. Client , error ) {
161+ func (ns * Namespace ) newClientImpl (ctx context.Context ) (amqpwrap. AMQPClient , error ) {
157162 defaultConnOptions := []amqp.ConnOption {
158163 amqp .ConnSASLAnonymous (),
159164 amqp .ConnMaxSessions (65535 ),
@@ -181,10 +186,12 @@ func (ns *Namespace) newClient(ctx context.Context) (*amqp.Client, error) {
181186 return nil , err
182187 }
183188
184- return amqp .New (nConn , append (defaultConnOptions , amqp .ConnServerHostname (ns .FQDN ))... )
189+ client , err := amqp .New (nConn , append (defaultConnOptions , amqp .ConnServerHostname (ns .FQDN ))... )
190+ return & amqpwrap.AMQPClientWrapper {Inner : client }, err
185191 }
186192
187- return amqp .Dial (ns .getAMQPHostURI (), defaultConnOptions ... )
193+ client , err := amqp .Dial (ns .getAMQPHostURI (), defaultConnOptions ... )
194+ return & amqpwrap.AMQPClientWrapper {Inner : client }, err
188195}
189196
190197// NewAMQPSession creates a new AMQP session with the internally cached *amqp.Client.
@@ -202,7 +209,7 @@ func (ns *Namespace) NewAMQPSession(ctx context.Context) (amqpwrap.AMQPSession,
202209 return nil , 0 , err
203210 }
204211
205- return & amqpwrap. AMQPSessionWrapper { Inner : session } , clientRevision , err
212+ return session , clientRevision , err
206213}
207214
208215// NewRPCLink creates a new amqp-common *rpc.Link with the internally cached *amqp.Client.
@@ -214,7 +221,7 @@ func (ns *Namespace) NewRPCLink(ctx context.Context, managementPath string) (RPC
214221 }
215222
216223 return NewRPCLink (RPCLinkArgs {
217- Client : & amqpwrap. AMQPClientWrapper { Inner : client } ,
224+ Client : client ,
218225 Address : managementPath ,
219226 LogEvent : exported .EventReceiver ,
220227 })
@@ -241,7 +248,9 @@ func (ns *Namespace) Close(ctx context.Context, permanently bool) error {
241248 }
242249
243250 if ns .client != nil {
244- return ns .client .Close ()
251+ err := ns .client .Close ()
252+ ns .client = nil
253+ return err
245254 }
246255
247256 return nil
@@ -291,16 +300,12 @@ func (ns *Namespace) Recover(ctx context.Context, theirConnID uint64) (bool, err
291300 _ = oldClient .Close ()
292301 }
293302
294- var err error
295303 log .Writef (exported .EventConn , "Creating a new client (rev:%d)" , ns .connID )
296- ns .client , err = ns .newClient (ctx )
297304
298- if err != nil {
305+ if _ , _ , err := ns . updateClientWithoutLock ( ctx ); err != nil {
299306 return false , err
300307 }
301308
302- ns .connID ++
303- log .Writef (exported .EventConn , "New client created, (rev: %d)" , ns .connID )
304309 return true , nil
305310}
306311
@@ -310,7 +315,6 @@ func (ns *Namespace) NegotiateClaim(ctx context.Context, entityPath string) (con
310315 return ns .startNegotiateClaimRenewer (ctx ,
311316 entityPath ,
312317 NegotiateClaim ,
313- ns .GetAMQPClientImpl ,
314318 nextClaimRefreshDuration )
315319}
316320
@@ -320,15 +324,14 @@ func (ns *Namespace) NegotiateClaim(ctx context.Context, entityPath string) (con
320324// when the background renewal stops or an error.
321325func (ns * Namespace ) startNegotiateClaimRenewer (ctx context.Context ,
322326 entityPath string ,
323- cbsNegotiateClaim func (ctx context.Context , audience string , conn * amqp.Client , provider auth.TokenProvider ) error ,
324- nsGetAMQPClientImpl func (ctx context.Context ) (* amqp.Client , uint64 , error ),
327+ cbsNegotiateClaim func (ctx context.Context , audience string , conn amqpwrap.AMQPClient , provider auth.TokenProvider ) error ,
325328 nextClaimRefreshDurationFn func (expirationTime time.Time , currentTime time.Time ) time.Duration ) (func (), <- chan struct {}, error ) {
326329 audience := ns .GetEntityAudience (entityPath )
327330
328331 refreshClaim := func (ctx context.Context ) (time.Time , error ) {
329332 log .Writef (exported .EventAuth , "(%s) refreshing claim" , entityPath )
330333
331- amqpClient , clientRevision , err := nsGetAMQPClientImpl (ctx )
334+ amqpClient , clientRevision , err := ns . GetAMQPClientImpl (ctx )
332335
333336 if err != nil {
334337 return time.Time {}, err
@@ -430,7 +433,7 @@ func (ns *Namespace) startNegotiateClaimRenewer(ctx context.Context,
430433 }, refreshStoppedCh , nil
431434}
432435
433- func (ns * Namespace ) GetAMQPClientImpl (ctx context.Context ) (* amqp. Client , uint64 , error ) {
436+ func (ns * Namespace ) GetAMQPClientImpl (ctx context.Context ) (amqpwrap. AMQPClient , uint64 , error ) {
434437 if err := ns .Check (); err != nil {
435438 return nil , 0 , err
436439 }
@@ -442,17 +445,27 @@ func (ns *Namespace) GetAMQPClientImpl(ctx context.Context) (*amqp.Client, uint6
442445 return nil , 0 , ErrClientClosed
443446 }
444447
448+ return ns .updateClientWithoutLock (ctx )
449+ }
450+
451+ // updateClientWithoutLock takes care of initializing a client (if needed)
452+ // and returns the initialized client and it's connection ID, or an error.
453+ func (ns * Namespace ) updateClientWithoutLock (ctx context.Context ) (amqpwrap.AMQPClient , uint64 , error ) {
445454 if ns .client != nil {
446455 return ns .client , ns .connID , nil
447456 }
448457
449- var err error
450- ns . client , err = ns .newClient (ctx )
458+ log . Writef ( exported . EventConn , "Creating new client, current rev: %d" , ns . connID )
459+ tempClient , err : = ns .newClientFn (ctx )
451460
452- if err = = nil {
453- ns . connID ++
461+ if err ! = nil {
462+ return nil , 0 , err
454463 }
455464
465+ ns .connID ++
466+ ns .client = tempClient
467+ log .Writef (exported .EventConn , "Client created, new rev: %d" , ns .connID )
468+
456469 return ns .client , ns .connID , err
457470}
458471
0 commit comments