@@ -19,6 +19,7 @@ static const int MAX_EVENTS = 128;
1919typedef struct conn_data_s {
2020 int fd ;
2121 int epoll_fd ;
22+ uint32_t event_mask ;
2223 hyper_waker * read_waker ;
2324 hyper_waker * write_waker ;
2425} conn_data ;
@@ -62,6 +63,7 @@ static int listen_on(const char *host, const char *port) {
6263 if (bind (sock , resp -> ai_addr , resp -> ai_addrlen ) == 0 ) {
6364 break ;
6465 }
66+ perror ("bind" );
6567
6668 // Failed, tidy up
6769 close (sock );
@@ -108,6 +110,7 @@ static int register_signal_handler() {
108110 perror ("signalfd" );
109111 return 1 ;
110112 }
113+ sigaddset (& mask , SIGPIPE );
111114 if (sigprocmask (SIG_BLOCK , & mask , NULL ) < 0 ) {
112115 perror ("sigprocmask" );
113116 return 1 ;
@@ -116,6 +119,19 @@ static int register_signal_handler() {
116119 return signal_fd ;
117120}
118121
122+ // Register connection FD with epoll, associated with this `conn`
123+ static bool update_conn_data_registrations (conn_data * conn , bool create ) {
124+ struct epoll_event transport_event ;
125+ transport_event .events = conn -> event_mask ;
126+ transport_event .data .ptr = conn ;
127+ if (epoll_ctl (conn -> epoll_fd , create ? EPOLL_CTL_ADD : EPOLL_CTL_MOD , conn -> fd , & transport_event ) < 0 ) {
128+ perror ("epoll_ctl (transport)" );
129+ return false;
130+ } else {
131+ return true;
132+ }
133+ }
134+
119135static size_t read_cb (void * userdata , hyper_context * ctx , uint8_t * buf , size_t buf_len ) {
120136 conn_data * conn = (conn_data * )userdata ;
121137 ssize_t ret = read (conn -> fd , buf , buf_len );
@@ -134,6 +150,14 @@ static size_t read_cb(void *userdata, hyper_context *ctx, uint8_t *buf, size_t b
134150 if (conn -> read_waker != NULL ) {
135151 hyper_waker_free (conn -> read_waker );
136152 }
153+
154+ if (!(conn -> event_mask & EPOLLIN )) {
155+ conn -> event_mask |= EPOLLIN ;
156+ if (!update_conn_data_registrations (conn , false)) {
157+ return HYPER_IO_ERROR ;
158+ }
159+ }
160+
137161 conn -> read_waker = hyper_context_waker (ctx );
138162 return HYPER_IO_PENDING ;
139163}
@@ -156,28 +180,31 @@ static size_t write_cb(void *userdata, hyper_context *ctx, const uint8_t *buf, s
156180 if (conn -> write_waker != NULL ) {
157181 hyper_waker_free (conn -> write_waker );
158182 }
183+
184+ if (!(conn -> event_mask & EPOLLOUT )) {
185+ conn -> event_mask |= EPOLLOUT ;
186+ if (!update_conn_data_registrations (conn , false)) {
187+ return HYPER_IO_ERROR ;
188+ }
189+ }
190+
159191 conn -> write_waker = hyper_context_waker (ctx );
160192 return HYPER_IO_PENDING ;
161193}
162194
163195static conn_data * create_conn_data (int epoll , int fd ) {
164196 conn_data * conn = malloc (sizeof (conn_data ));
165-
166- // Add fd to epoll set, associated with this `conn`
167- struct epoll_event transport_event ;
168- transport_event .events = EPOLLIN ;
169- transport_event .data .ptr = conn ;
170- if (epoll_ctl (epoll , EPOLL_CTL_ADD , fd , & transport_event ) < 0 ) {
171- perror ("epoll_ctl (transport, add)" );
172- free (conn );
173- return NULL ;
174- }
175-
176197 conn -> fd = fd ;
177198 conn -> epoll_fd = epoll ;
199+ conn -> event_mask = 0 ;
178200 conn -> read_waker = NULL ;
179201 conn -> write_waker = NULL ;
180202
203+ if (!update_conn_data_registrations (conn , true)) {
204+ free (conn );
205+ return NULL ;
206+ }
207+
181208 return conn ;
182209}
183210
@@ -477,13 +504,27 @@ int main(int argc, char *argv[]) {
477504 } else {
478505 // Existing transport socket, poke the wakers or close the socket
479506 conn_data * conn = events [n ].data .ptr ;
480- if ((events [n ].events & EPOLLIN ) && conn -> read_waker ) {
507+ if (events [n ].events & EPOLLIN ) {
508+ if (conn -> read_waker ) {
481509 hyper_waker_wake (conn -> read_waker );
482510 conn -> read_waker = NULL ;
511+ } else {
512+ conn -> event_mask &= ~EPOLLIN ;
513+ if (!update_conn_data_registrations (conn , false)) {
514+ epoll_ctl (conn -> epoll_fd , EPOLL_CTL_DEL , conn -> fd , NULL );
515+ }
516+ }
483517 }
484- if ((events [n ].events & EPOLLOUT ) && conn -> write_waker ) {
485- hyper_waker_wake (conn -> write_waker );
486- conn -> write_waker = NULL ;
518+ if (events [n ].events & EPOLLOUT ) {
519+ if (conn -> read_waker ) {
520+ hyper_waker_wake (conn -> read_waker );
521+ conn -> read_waker = NULL ;
522+ } else {
523+ conn -> event_mask &= ~EPOLLOUT ;
524+ if (!update_conn_data_registrations (conn , false)) {
525+ epoll_ctl (conn -> epoll_fd , EPOLL_CTL_DEL , conn -> fd , NULL );
526+ }
527+ }
487528 }
488529 }
489530 }
0 commit comments