@@ -14,7 +14,7 @@ use ferrumc_net_codec::encode::NetEncode;
1414use ferrumc_net_codec:: encode:: NetEncodeOpts ;
1515use ferrumc_state:: ServerState ;
1616use std:: sync:: atomic:: { AtomicBool , Ordering } ;
17- use std:: sync:: Arc ;
17+ use std:: sync:: { Arc , Mutex } ;
1818use std:: time:: Duration ;
1919use tokio:: io:: AsyncWriteExt ;
2020use tokio:: net:: tcp:: OwnedWriteHalf ;
@@ -41,6 +41,8 @@ pub struct StreamWriter {
4141 sender : UnboundedSender < Vec < u8 > > ,
4242 pub running : Arc < AtomicBool > ,
4343 pub compress : Arc < AtomicBool > ,
44+ pub state : Arc < ServerState > ,
45+ pub entity : Arc < Mutex < Option < Entity > > > ,
4446}
4547
4648impl Drop for StreamWriter {
@@ -55,11 +57,18 @@ impl StreamWriter {
5557 ///
5658 /// Spawns a background task that continuously reads from the channel
5759 /// and writes bytes to the network socket.
58- pub async fn new ( mut writer : OwnedWriteHalf , running : Arc < AtomicBool > ) -> Self {
60+ pub async fn new (
61+ mut writer : OwnedWriteHalf ,
62+ running : Arc < AtomicBool > ,
63+ state : Arc < ServerState > ,
64+ entity : Arc < Mutex < Option < Entity > > > ,
65+ ) -> Self {
5966 let compress = Arc :: new ( AtomicBool :: new ( false ) ) ; // Default: no compression
6067 let ( sender, mut receiver) : ( UnboundedSender < Vec < u8 > > , UnboundedReceiver < Vec < u8 > > ) =
6168 tokio:: sync:: mpsc:: unbounded_channel ( ) ;
6269 let running_clone = running. clone ( ) ;
70+ let entity_clone = entity. clone ( ) ;
71+ let state_clone = state. clone ( ) ;
6372
6473 // Task: forward packets from channel to socket
6574 tokio:: spawn ( async move {
@@ -68,21 +77,42 @@ impl StreamWriter {
6877 break ;
6978 } ;
7079
80+ // This handles ONLY if there was a writing error to the client.
7181 if let Err ( e) = writer. write_all ( & bytes) . await {
7282 error ! ( "Failed to write to client: {:?}" , e) ;
7383 running_clone. store ( false , Ordering :: Relaxed ) ;
84+ if let Some ( entity_id) = * entity_clone. lock ( ) . unwrap ( ) {
85+ state_clone. players . disconnect ( entity_id, None ) ;
86+ }
7487 break ;
7588 }
7689 }
90+
91+ // This handles cases where the channel closes without a write error
92+ if let Some ( entity_id) = * entity_clone. lock ( ) . unwrap ( ) {
93+ trace ! (
94+ "Write task ending for entity {:?}, ensuring disconnect" ,
95+ entity_id
96+ ) ;
97+ state_clone. players . disconnect ( entity_id, None ) ;
98+ }
7799 } ) ;
78100
79101 Self {
80102 sender,
81103 running,
82104 compress,
105+ state,
106+ entity,
83107 }
84108 }
85109
110+ /// Sets the entity ID for this stream writer.
111+ /// This should be called after the entity is created in the ECS.
112+ pub fn set_entity ( & self , entity : Entity ) {
113+ * self . entity . lock ( ) . unwrap ( ) = Some ( entity) ;
114+ }
115+
86116 /// Sends a packet to the client using the default `WithLength` encoding.
87117 pub fn send_packet ( & self , packet : impl NetEncode + Send ) -> Result < ( ) , NetError > {
88118 self . send_packet_with_opts ( & packet, & NetEncodeOpts :: WithLength )
@@ -177,9 +207,16 @@ pub async fn handle_connection(
177207
178208 let running = Arc :: new ( AtomicBool :: new ( true ) ) ;
179209
180- let stream = StreamWriter :: new ( tcp_writer, running. clone ( ) ) . await ;
210+ let entity_holder: Arc < Mutex < Option < Entity > > > = Arc :: new ( Mutex :: new ( None ) ) ;
211+
212+ let stream = StreamWriter :: new (
213+ tcp_writer,
214+ running. clone ( ) ,
215+ state. clone ( ) ,
216+ entity_holder. clone ( ) ,
217+ )
218+ . await ;
181219
182- // Perform handshake with timeout guard
183220 let handshake_result = timeout (
184221 MAX_HANDSHAKE_TIMEOUT ,
185222 handle_handshake ( & mut tcp_reader, & stream, state. clone ( ) ) ,
@@ -256,6 +293,11 @@ pub async fn handle_connection(
256293 }
257294 } ;
258295
296+ // Sets the entity for the stream writer.
297+ * entity_holder. lock ( ) . unwrap ( ) = Some ( entity) ;
298+
299+ trace ! ( "Entity {:?} assigned to connection" , entity) ;
300+
259301 // ---- Packet receive loop ----
260302 ' recv: loop {
261303 if !running. load ( Ordering :: Relaxed ) {
@@ -279,17 +321,20 @@ pub async fn handle_connection(
279321 if let NetError :: ConnectionDropped = err {
280322 trace!( "Connection dropped for entity {:?}" , entity) ;
281323 running. store( false , Ordering :: Relaxed ) ;
282- break ' recv;
324+ state. players. disconnect( entity, None ) ;
325+ } else {
326+ error!( "Failed to read packet skeleton: {:?} for {:?}" , err, entity) ;
327+ running. store( false , Ordering :: Relaxed ) ;
328+ state. players. disconnect( entity, None ) ;
283329 }
284- error!( "Failed to read packet skeleton: {:?} for {:?}" , err, entity) ;
285- running. store( false , Ordering :: Relaxed ) ;
286330 break ' recv;
287331 }
288332 }
289333 }
290334
291335 _ = & mut disconnect_receiver => {
292- debug!( "Received disconnect signal" ) ;
336+ debug!( "Received disconnect signal for entity {:?}" , entity) ;
337+ running. store( false , Ordering :: Relaxed ) ;
293338 break ' recv;
294339 }
295340 }
@@ -322,6 +367,7 @@ pub async fn handle_connection(
322367 _ => {
323368 warn ! ( "Error handling packet for {:?}: {:?}" , entity, err) ;
324369 running. store ( false , Ordering :: Relaxed ) ;
370+ state. players . disconnect ( entity, None ) ;
325371 break ' recv;
326372 }
327373 } ,
0 commit comments