@@ -1074,27 +1074,14 @@ with args@Arguments {
10741074 -> MutableConnState peerAddr handle handleError version m
10751075 -> DisconnectionException handleError
10761076 -> m (Connected peerAddr handle handleError )
1077- terminateInboundWithErrorOrQuery connId connStateId connVar connThread stateVar mutableConnState connectionError = do
1077+ terminateInboundWithErrorOrQuery connId connStateId connVar connThread stateVar mutableConnState connError = do
10781078 transitions <- atomically $ do
10791079 connState <- readTVar connVar
10801080
1081- let connState' =
1082- case classifyHandlerError <$> connectionError of
1083- ConnectionHandlerError HandshakeFailure ->
1084- TerminatingState connId connThread
1085- (case connectionError of
1086- ConnectionHandlerError err -> Just err
1087- _ -> Nothing )
1088- ConnectionHandlerError HandshakeProtocolViolation ->
1089- TerminatedState (case connectionError of
1090- ConnectionHandlerError err -> Just err
1091- _ -> Nothing )
1092- -- On inbound query, connection is terminating.
1093- _ ->
1094- TerminatingState connId connThread Nothing
1095- transition = mkTransition connState connState'
1081+ let connState' = computeConnStateOnError connId connThread connError
1082+ transition = mkTransition connState connState'
10961083 absConnState = State. abstractState (Known connState)
1097- shouldTrace = absConnState /= TerminatedSt
1084+ shouldTrace = absConnState /= TerminatedSt
10981085
10991086 updated <-
11001087 modifyTMVarSTM
@@ -1155,7 +1142,34 @@ with args@Arguments {
11551142 traverse_ (traceWith trTracer . TransitionTrace connStateId) transitions
11561143 traceCounters stateVar
11571144
1158- return (Disconnected connId connectionError)
1145+ return (Disconnected connId connError)
1146+
1147+ -- Compute connection state based on `DisconnectionException`. Shared
1148+ -- between:
1149+ --
1150+ -- * `terminateInboundWithErrorOrQuery` and
1151+ -- * `terminateOutboundWithErrorOrQuery`.
1152+ --
1153+ computeConnStateOnError
1154+ :: ConnectionId peerAddr
1155+ -> Async m ()
1156+ -> DisconnectionException handleError
1157+ -> ConnectionState peerAddr handle handleError version m
1158+ computeConnStateOnError connId connThread connError =
1159+ case classifyHandlerError <$> connError of
1160+ ConnectionHandlerError HandshakeFailure ->
1161+ TerminatingState connId connThread
1162+ (case connError of
1163+ ConnectionHandlerError err -> Just err
1164+ _ -> Nothing )
1165+ ConnectionHandlerError HandshakeProtocolViolation ->
1166+ TerminatedState (case connError of
1167+ ConnectionHandlerError err -> Just err
1168+ _ -> Nothing )
1169+ -- On outbound query, connection is terminated.
1170+ _ ->
1171+ TerminatedState Nothing
1172+
11591173
11601174 -- We need 'mask' in order to guarantee that the traces are logged if an
11611175 -- async exception lands between the successful STM action and the logging
@@ -1810,29 +1824,16 @@ with args@Arguments {
18101824 -> MutableConnState peerAddr handle handleError version m
18111825 -> DisconnectionException handleError
18121826 -> m (Connected peerAddr handle handleError )
1813- terminateOutboundWithErrorOrQuery connId connStateId connVar connThread stateVar mutableConnState connectionError = do
1827+ terminateOutboundWithErrorOrQuery connId connStateId connVar connThread stateVar mutableConnState connError = do
18141828 transitions <- atomically $ do
18151829 connState <- readTVar connVar
18161830
1817- let connState' =
1818- case classifyHandlerError <$> connectionError of
1819- ConnectionHandlerError HandshakeFailure ->
1820- TerminatingState connId connThread
1821- (case connectionError of
1822- ConnectionHandlerError err -> Just err
1823- _ -> Nothing )
1824- ConnectionHandlerError HandshakeProtocolViolation ->
1825- TerminatedState (case connectionError of
1826- ConnectionHandlerError err -> Just err
1827- _ -> Nothing )
1828- -- On outbound query, connection is terminated.
1829- _ ->
1830- TerminatedState Nothing
1831- transition = mkTransition connState connState'
1832- absConnState = State. abstractState (Known connState)
1831+ let connState' = computeConnStateOnError connId connThread connError
1832+ transition = mkTransition connState connState'
1833+ absConnState = State. abstractState (Known connState)
18331834 shouldTransition = absConnState /= TerminatedSt
18341835
1835- -- 'connectionError ' might be either a handshake negotiation
1836+ -- 'connError ' might be either a handshake negotiation
18361837 -- a protocol failure (an IO exception, a timeout or
18371838 -- codec failure). In the first case we should not reset
18381839 -- the connection as this is not a protocol error.
@@ -1883,7 +1884,7 @@ with args@Arguments {
18831884 traverse_ (traceWith trTracer . TransitionTrace connStateId) transitions
18841885 traceCounters stateVar
18851886
1886- return (Disconnected connId connectionError )
1887+ return (Disconnected connId connError )
18871888
18881889
18891890 releaseOutboundConnectionImpl
0 commit comments