@@ -42,6 +42,33 @@ pub enum ThreadStatus {
4242 Error ,
4343}
4444
45+ /// Internal representation of a Lua thread status.
46+ ///
47+ /// The number in `New` and `Yielded` variants is the number of arguments pushed
48+ /// to the thread stack.
49+ #[ derive( Clone , Copy ) ]
50+ enum ThreadStatusInner {
51+ New ( c_int ) ,
52+ Running ,
53+ Yielded ( c_int ) ,
54+ Finished ,
55+ Error ,
56+ }
57+
58+ impl ThreadStatusInner {
59+ #[ cfg( feature = "async" ) ]
60+ #[ inline( always) ]
61+ fn is_resumable ( self ) -> bool {
62+ matches ! ( self , ThreadStatusInner :: New ( _) | ThreadStatusInner :: Yielded ( _) )
63+ }
64+
65+ #[ cfg( feature = "async" ) ]
66+ #[ inline( always) ]
67+ fn is_yielded ( self ) -> bool {
68+ matches ! ( self , ThreadStatusInner :: Yielded ( _) )
69+ }
70+ }
71+
4572/// Handle to an internal Lua thread (coroutine).
4673#[ derive( Clone ) ]
4774pub struct Thread ( pub ( crate ) ValueRef , pub ( crate ) * mut ffi:: lua_State ) ;
@@ -60,9 +87,8 @@ unsafe impl Sync for Thread {}
6087#[ cfg( feature = "async" ) ]
6188#[ cfg_attr( docsrs, doc( cfg( feature = "async" ) ) ) ]
6289#[ must_use = "futures do nothing unless you `.await` or poll them" ]
63- pub struct AsyncThread < A , R > {
90+ pub struct AsyncThread < R > {
6491 thread : Thread ,
65- init_args : Option < A > ,
6692 ret : PhantomData < R > ,
6793 recycle : bool ,
6894}
@@ -122,17 +148,25 @@ impl Thread {
122148 R : FromLuaMulti ,
123149 {
124150 let lua = self . 0 . lua . lock ( ) ;
125- if self . status_inner ( & lua) != ThreadStatus :: Resumable {
126- return Err ( Error :: CoroutineUnresumable ) ;
127- }
151+ let mut pushed_nargs = match self . status_inner ( & lua) {
152+ ThreadStatusInner :: New ( nargs) | ThreadStatusInner :: Yielded ( nargs) => nargs,
153+ _ => return Err ( Error :: CoroutineUnresumable ) ,
154+ } ;
128155
129156 let state = lua. state ( ) ;
130157 let thread_state = self . state ( ) ;
131158 unsafe {
132159 let _sg = StackGuard :: new ( state) ;
133160 let _thread_sg = StackGuard :: with_top ( thread_state, 0 ) ;
134161
135- let nresults = self . resume_inner ( & lua, args) ?;
162+ let nargs = args. push_into_stack_multi ( & lua) ?;
163+ if nargs > 0 {
164+ check_stack ( thread_state, nargs) ?;
165+ ffi:: lua_xmove ( state, thread_state, nargs) ;
166+ pushed_nargs += nargs;
167+ }
168+
169+ let ( _, nresults) = self . resume_inner ( & lua, pushed_nargs) ?;
136170 check_stack ( state, nresults + 1 ) ?;
137171 ffi:: lua_xmove ( thread_state, state, nresults) ;
138172
@@ -143,50 +177,50 @@ impl Thread {
143177 /// Resumes execution of this thread.
144178 ///
145179 /// It's similar to `resume()` but leaves `nresults` values on the thread stack.
146- unsafe fn resume_inner ( & self , lua : & RawLua , args : impl IntoLuaMulti ) -> Result < c_int > {
180+ unsafe fn resume_inner ( & self , lua : & RawLua , nargs : c_int ) -> Result < ( ThreadStatusInner , c_int ) > {
147181 let state = lua. state ( ) ;
148182 let thread_state = self . state ( ) ;
149-
150- let nargs = args. push_into_stack_multi ( lua) ?;
151- if nargs > 0 {
152- check_stack ( thread_state, nargs) ?;
153- ffi:: lua_xmove ( state, thread_state, nargs) ;
154- }
155-
156183 let mut nresults = 0 ;
157184 let ret = ffi:: lua_resume ( thread_state, state, nargs, & mut nresults as * mut c_int ) ;
158- if ret != ffi:: LUA_OK && ret != ffi:: LUA_YIELD {
159- if ret == ffi:: LUA_ERRMEM {
185+ match ret {
186+ ffi:: LUA_OK => Ok ( ( ThreadStatusInner :: Finished , nresults) ) ,
187+ ffi:: LUA_YIELD => Ok ( ( ThreadStatusInner :: Yielded ( 0 ) , nresults) ) ,
188+ ffi:: LUA_ERRMEM => {
160189 // Don't call error handler for memory errors
161- return Err ( pop_error ( thread_state, ret) ) ;
190+ Err ( pop_error ( thread_state, ret) )
191+ }
192+ _ => {
193+ check_stack ( state, 3 ) ?;
194+ protect_lua ! ( state, 0 , 1 , |state| error_traceback_thread( state, thread_state) ) ?;
195+ Err ( pop_error ( state, ret) )
162196 }
163- check_stack ( state, 3 ) ?;
164- protect_lua ! ( state, 0 , 1 , |state| error_traceback_thread( state, thread_state) ) ?;
165- return Err ( pop_error ( state, ret) ) ;
166197 }
167-
168- Ok ( nresults)
169198 }
170199
171200 /// Gets the status of the thread.
172201 pub fn status ( & self ) -> ThreadStatus {
173- self . status_inner ( & self . 0 . lua . lock ( ) )
202+ match self . status_inner ( & self . 0 . lua . lock ( ) ) {
203+ ThreadStatusInner :: New ( _) | ThreadStatusInner :: Yielded ( _) => ThreadStatus :: Resumable ,
204+ ThreadStatusInner :: Running => ThreadStatus :: Running ,
205+ ThreadStatusInner :: Finished => ThreadStatus :: Finished ,
206+ ThreadStatusInner :: Error => ThreadStatus :: Error ,
207+ }
174208 }
175209
176210 /// Gets the status of the thread (internal implementation).
177- pub ( crate ) fn status_inner ( & self , lua : & RawLua ) -> ThreadStatus {
211+ fn status_inner ( & self , lua : & RawLua ) -> ThreadStatusInner {
178212 let thread_state = self . state ( ) ;
179213 if thread_state == lua. state ( ) {
180214 // The thread is currently running
181- return ThreadStatus :: Running ;
215+ return ThreadStatusInner :: Running ;
182216 }
183217 let status = unsafe { ffi:: lua_status ( thread_state) } ;
184- if status != ffi :: LUA_OK && status != ffi:: LUA_YIELD {
185- ThreadStatus :: Error
186- } else if status == ffi:: LUA_YIELD || unsafe { ffi :: lua_gettop ( thread_state ) > 0 } {
187- ThreadStatus :: Resumable
188- } else {
189- ThreadStatus :: Finished
218+ let top = unsafe { ffi:: lua_gettop ( thread_state ) } ;
219+ match status {
220+ ffi:: LUA_YIELD => ThreadStatusInner :: Yielded ( top ) ,
221+ ffi :: LUA_OK if top > 0 => ThreadStatusInner :: New ( top - 1 ) ,
222+ ffi :: LUA_OK => ThreadStatusInner :: Finished ,
223+ _ => ThreadStatusInner :: Error ,
190224 }
191225 }
192226
@@ -224,7 +258,7 @@ impl Thread {
224258 #[ cfg_attr( docsrs, doc( cfg( any( feature = "lua54" , feature = "luau" ) ) ) ) ]
225259 pub fn reset ( & self , func : crate :: function:: Function ) -> Result < ( ) > {
226260 let lua = self . 0 . lua . lock ( ) ;
227- if self . status_inner ( & lua) == ThreadStatus :: Running {
261+ if matches ! ( self . status_inner( & lua) , ThreadStatusInner :: Running ) {
228262 return Err ( Error :: runtime ( "cannot reset a running thread" ) ) ;
229263 }
230264
@@ -257,7 +291,9 @@ impl Thread {
257291
258292 /// Converts [`Thread`] to an [`AsyncThread`] which implements [`Future`] and [`Stream`] traits.
259293 ///
260- /// `args` are passed as arguments to the thread function for first call.
294+ /// Only resumable threads can be converted to [`AsyncThread`].
295+ ///
296+ /// `args` are pushed to the thread stack and will be used when the thread is resumed.
261297 /// The object calls [`resume`] while polling and also allow to run Rust futures
262298 /// to completion using an executor.
263299 ///
@@ -290,7 +326,7 @@ impl Thread {
290326 /// end)
291327 /// "#).eval()?;
292328 ///
293- /// let mut stream = thread.into_async::<i64>(1);
329+ /// let mut stream = thread.into_async::<i64>(1)? ;
294330 /// let mut sum = 0;
295331 /// while let Some(n) = stream.try_next().await? {
296332 /// sum += n;
@@ -303,15 +339,31 @@ impl Thread {
303339 /// ```
304340 #[ cfg( feature = "async" ) ]
305341 #[ cfg_attr( docsrs, doc( cfg( feature = "async" ) ) ) ]
306- pub fn into_async < R > ( self , args : impl IntoLuaMulti ) -> AsyncThread < impl IntoLuaMulti , R >
342+ pub fn into_async < R > ( self , args : impl IntoLuaMulti ) -> Result < AsyncThread < R > >
307343 where
308344 R : FromLuaMulti ,
309345 {
310- AsyncThread {
311- thread : self ,
312- init_args : Some ( args) ,
313- ret : PhantomData ,
314- recycle : false ,
346+ let lua = self . 0 . lua . lock ( ) ;
347+ if !self . status_inner ( & lua) . is_resumable ( ) {
348+ return Err ( Error :: CoroutineUnresumable ) ;
349+ }
350+
351+ let state = lua. state ( ) ;
352+ let thread_state = self . state ( ) ;
353+ unsafe {
354+ let _sg = StackGuard :: new ( state) ;
355+
356+ let nargs = args. push_into_stack_multi ( & lua) ?;
357+ if nargs > 0 {
358+ check_stack ( thread_state, nargs) ?;
359+ ffi:: lua_xmove ( state, thread_state, nargs) ;
360+ }
361+
362+ Ok ( AsyncThread {
363+ thread : self ,
364+ ret : PhantomData ,
365+ recycle : false ,
366+ } )
315367 }
316368 }
317369
@@ -392,7 +444,7 @@ impl LuaType for Thread {
392444}
393445
394446#[ cfg( feature = "async" ) ]
395- impl < A , R > AsyncThread < A , R > {
447+ impl < R > AsyncThread < R > {
396448 #[ inline]
397449 pub ( crate ) fn set_recyclable ( & mut self , recyclable : bool ) {
398450 self . recycle = recyclable;
@@ -401,15 +453,15 @@ impl<A, R> AsyncThread<A, R> {
401453
402454#[ cfg( feature = "async" ) ]
403455#[ cfg( any( feature = "lua54" , feature = "luau" ) ) ]
404- impl < A , R > Drop for AsyncThread < A , R > {
456+ impl < R > Drop for AsyncThread < R > {
405457 fn drop ( & mut self ) {
406458 if self . recycle {
407459 if let Some ( lua) = self . thread . 0 . lua . try_lock ( ) {
408460 unsafe {
409461 // For Lua 5.4 this also closes all pending to-be-closed variables
410462 if !lua. recycle_thread ( & mut self . thread ) {
411463 #[ cfg( feature = "lua54" ) ]
412- if self . thread . status_inner ( & lua) == ThreadStatus :: Error {
464+ if matches ! ( self . thread. status_inner( & lua) , ThreadStatusInner :: Error ) {
413465 #[ cfg( not( feature = "vendored" ) ) ]
414466 ffi:: lua_resetthread ( self . thread . state ( ) ) ;
415467 #[ cfg( feature = "vendored" ) ]
@@ -423,14 +475,15 @@ impl<A, R> Drop for AsyncThread<A, R> {
423475}
424476
425477#[ cfg( feature = "async" ) ]
426- impl < A : IntoLuaMulti , R : FromLuaMulti > Stream for AsyncThread < A , R > {
478+ impl < R : FromLuaMulti > Stream for AsyncThread < R > {
427479 type Item = Result < R > ;
428480
429481 fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
430482 let lua = self . thread . 0 . lua . lock ( ) ;
431- if self . thread . status_inner ( & lua) != ThreadStatus :: Resumable {
432- return Poll :: Ready ( None ) ;
433- }
483+ let nargs = match self . thread . status_inner ( & lua) {
484+ ThreadStatusInner :: New ( nargs) | ThreadStatusInner :: Yielded ( nargs) => nargs,
485+ _ => return Poll :: Ready ( None ) ,
486+ } ;
434487
435488 let state = lua. state ( ) ;
436489 let thread_state = self . thread . state ( ) ;
@@ -439,36 +492,34 @@ impl<A: IntoLuaMulti, R: FromLuaMulti> Stream for AsyncThread<A, R> {
439492 let _thread_sg = StackGuard :: with_top ( thread_state, 0 ) ;
440493 let _wg = WakerGuard :: new ( & lua, cx. waker ( ) ) ;
441494
442- // This is safe as we are not moving the whole struct
443- let this = self . get_unchecked_mut ( ) ;
444- let nresults = if let Some ( args) = this. init_args . take ( ) {
445- this. thread . resume_inner ( & lua, args) ?
446- } else {
447- this. thread . resume_inner ( & lua, ( ) ) ?
448- } ;
495+ let ( status, nresults) = ( self . thread ) . resume_inner ( & lua, nargs) ?;
449496
450- if nresults == 1 && is_poll_pending ( thread_state) {
451- return Poll :: Pending ;
497+ if status. is_yielded ( ) {
498+ if nresults == 1 && is_poll_pending ( thread_state) {
499+ return Poll :: Pending ;
500+ }
501+ // Continue polling
502+ cx. waker ( ) . wake_by_ref ( ) ;
452503 }
453504
454505 check_stack ( state, nresults + 1 ) ?;
455506 ffi:: lua_xmove ( thread_state, state, nresults) ;
456507
457- cx. waker ( ) . wake_by_ref ( ) ;
458508 Poll :: Ready ( Some ( R :: from_stack_multi ( nresults, & lua) ) )
459509 }
460510 }
461511}
462512
463513#[ cfg( feature = "async" ) ]
464- impl < A : IntoLuaMulti , R : FromLuaMulti > Future for AsyncThread < A , R > {
514+ impl < R : FromLuaMulti > Future for AsyncThread < R > {
465515 type Output = Result < R > ;
466516
467517 fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
468518 let lua = self . thread . 0 . lua . lock ( ) ;
469- if self . thread . status_inner ( & lua) != ThreadStatus :: Resumable {
470- return Poll :: Ready ( Err ( Error :: CoroutineUnresumable ) ) ;
471- }
519+ let nargs = match self . thread . status_inner ( & lua) {
520+ ThreadStatusInner :: New ( nargs) | ThreadStatusInner :: Yielded ( nargs) => nargs,
521+ _ => return Poll :: Ready ( Err ( Error :: CoroutineUnresumable ) ) ,
522+ } ;
472523
473524 let state = lua. state ( ) ;
474525 let thread_state = self . thread . state ( ) ;
@@ -477,21 +528,13 @@ impl<A: IntoLuaMulti, R: FromLuaMulti> Future for AsyncThread<A, R> {
477528 let _thread_sg = StackGuard :: with_top ( thread_state, 0 ) ;
478529 let _wg = WakerGuard :: new ( & lua, cx. waker ( ) ) ;
479530
480- // This is safe as we are not moving the whole struct
481- let this = self . get_unchecked_mut ( ) ;
482- let nresults = if let Some ( args) = this. init_args . take ( ) {
483- this. thread . resume_inner ( & lua, args) ?
484- } else {
485- this. thread . resume_inner ( & lua, ( ) ) ?
486- } ;
487-
488- if nresults == 1 && is_poll_pending ( thread_state) {
489- return Poll :: Pending ;
490- }
531+ let ( status, nresults) = self . thread . resume_inner ( & lua, nargs) ?;
491532
492- if ffi:: lua_status ( thread_state) == ffi:: LUA_YIELD {
493- // Ignore value returned via yield()
494- cx. waker ( ) . wake_by_ref ( ) ;
533+ if status. is_yielded ( ) {
534+ if !( nresults == 1 && is_poll_pending ( thread_state) ) {
535+ // Ignore value returned via yield()
536+ cx. waker ( ) . wake_by_ref ( ) ;
537+ }
495538 return Poll :: Pending ;
496539 }
497540
@@ -545,7 +588,7 @@ mod assertions {
545588 #[ cfg( feature = "send" ) ]
546589 static_assertions:: assert_impl_all!( Thread : Send , Sync ) ;
547590 #[ cfg( all( feature = "async" , not( feature = "send" ) ) ) ]
548- static_assertions:: assert_not_impl_any!( AsyncThread <( ) , ( ) >: Send ) ;
591+ static_assertions:: assert_not_impl_any!( AsyncThread <( ) >: Send ) ;
549592 #[ cfg( all( feature = "async" , feature = "send" ) ) ]
550- static_assertions:: assert_impl_all!( AsyncThread <( ) , ( ) >: Send , Sync ) ;
593+ static_assertions:: assert_impl_all!( AsyncThread <( ) >: Send , Sync ) ;
551594}
0 commit comments