Skip to content
93 changes: 72 additions & 21 deletions Sources/NIOCore/AsyncChannel/AsyncChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@

import DequeModule

@usableFromInline
enum OutboundAction<OutboundOut>: Sendable where OutboundOut: Sendable {
/// Write value
case write(OutboundOut)
/// Write value and flush pipeline
case writeAndFlush(OutboundOut, EventLoopPromise<Void>)
/// flush writes to writer
case flush(EventLoopPromise<Void>)
}

/// A ``ChannelHandler`` that is used to transform the inbound portion of a NIO
/// ``Channel`` into an asynchronous sequence that supports back-pressure. It's also used
/// to write the outbound portion of a NIO ``Channel`` from Swift Concurrency with back-pressure
Expand Down Expand Up @@ -77,7 +87,7 @@ internal final class NIOAsyncChannelHandler<InboundIn: Sendable, ProducerElement

@usableFromInline
typealias Writer = NIOAsyncWriter<
OutboundOut,
OutboundAction<OutboundOut>,
NIOAsyncChannelHandlerWriterDelegate<OutboundOut>
>

Expand Down Expand Up @@ -372,7 +382,10 @@ struct NIOAsyncChannelHandlerProducerDelegate: @unchecked Sendable, NIOAsyncSequ

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
@usableFromInline
struct NIOAsyncChannelHandlerWriterDelegate<Element: Sendable>: NIOAsyncWriterSinkDelegate, @unchecked Sendable {
struct NIOAsyncChannelHandlerWriterDelegate<OutboundOut: Sendable>: NIOAsyncWriterSinkDelegate, @unchecked Sendable {
@usableFromInline
typealias Element = OutboundAction<OutboundOut>

@usableFromInline
let eventLoop: EventLoop

Expand All @@ -386,7 +399,7 @@ struct NIOAsyncChannelHandlerWriterDelegate<Element: Sendable>: NIOAsyncWriterSi
let _didTerminate: ((any Error)?) -> Void

@inlinable
init<InboundIn, ProducerElement>(handler: NIOAsyncChannelHandler<InboundIn, ProducerElement, Element>) {
init<InboundIn, ProducerElement>(handler: NIOAsyncChannelHandler<InboundIn, ProducerElement, OutboundOut>) {
self.eventLoop = handler.eventLoop
self._didYieldContentsOf = handler._didYield(sequence:)
self._didYield = handler._didYield(element:)
Expand Down Expand Up @@ -430,35 +443,27 @@ struct NIOAsyncChannelHandlerWriterDelegate<Element: Sendable>: NIOAsyncWriterSi
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension NIOAsyncChannelHandler {
@inlinable
func _didYield(sequence: Deque<OutboundOut>) {
func _didYield(sequence: Deque<OutboundAction<OutboundOut>>) {
// This is always called from an async context, so we must loop-hop.
// Because we always loop-hop, we're always at the top of a stack frame. As this
// is the only source of writes for us, and as this channel handler doesn't implement
// func write(), we cannot possibly re-entrantly write. That means we can skip many of the
// awkward re-entrancy protections NIO usually requires, and can safely just do an iterative
// write.
self.eventLoop.preconditionInEventLoop()
guard let context = self.context else {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test has been moved into _doOutboundWrites as we need to complete promises even if the channel handler is no longer there.

// Already removed from the channel by now, we can stop.
return
}

self._doOutboundWrites(context: context, writes: sequence)
}

@inlinable
func _didYield(element: OutboundOut) {
func _didYield(element: OutboundAction<OutboundOut>) {
// This is always called from an async context, so we must loop-hop.
// Because we always loop-hop, we're always at the top of a stack frame. As this
// is the only source of writes for us, and as this channel handler doesn't implement
// func write(), we cannot possibly re-entrantly write. That means we can skip many of the
// awkward re-entrancy protections NIO usually requires, and can safely just do an iterative
// write.
self.eventLoop.preconditionInEventLoop()
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test has been moved into _doOutboundWrites as we need to complete promises even if the channel handler is no longer there.

return
}

self._doOutboundWrite(context: context, write: element)
}
Expand All @@ -475,18 +480,64 @@ extension NIOAsyncChannelHandler {
}

@inlinable
func _doOutboundWrites(context: ChannelHandlerContext, writes: Deque<OutboundOut>) {
for write in writes {
context.write(Self.wrapOutboundOut(write), promise: nil)
func _doOutboundWrites(context: ChannelHandlerContext?, writes: Deque<OutboundAction<OutboundOut>>) {
// write everything but the last item
for write in writes.dropLast() {
switch write {
case .write(let value), .writeAndFlush(let value, _):
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
return
}
context.write(Self.wrapOutboundOut(value), promise: nil)
context.flush()
case .flush(let promise):
promise.succeed()
}
}
// write last item
switch writes.last {
case .write(let value):
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
return
}
context.write(Self.wrapOutboundOut(value), promise: nil)
context.flush()
case .flush(let promise):
promise.succeed()
case .writeAndFlush(let value, let promise):
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
promise.succeed()
return
}
context.writeAndFlush(Self.wrapOutboundOut(value), promise: promise)
case .none:
break
}

context.flush()
}

@inlinable
func _doOutboundWrite(context: ChannelHandlerContext, write: OutboundOut) {
context.write(Self.wrapOutboundOut(write), promise: nil)
context.flush()
func _doOutboundWrite(context: ChannelHandlerContext?, write: OutboundAction<OutboundOut>) {
switch write {
case .write(let value):
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
return
}
context.write(Self.wrapOutboundOut(value), promise: nil)
context.flush()
case .flush(let promise):
promise.succeed()
case .writeAndFlush(let value, let promise):
guard let context = self.context else {
// Already removed from the channel by now, we can stop.
promise.succeed()
return
}
context.writeAndFlush(Self.wrapOutboundOut(value), promise: promise)
}
}
}

Expand Down
80 changes: 76 additions & 4 deletions Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
@usableFromInline
typealias _Writer = NIOAsyncWriter<
OutboundOut,
OutboundAction<OutboundOut>,
NIOAsyncChannelHandlerWriterDelegate<OutboundOut>
>

Expand Down Expand Up @@ -72,6 +72,9 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
@usableFromInline
internal let _backing: Backing

@usableFromInline
internal let eventLoop: EventLoop?

/// Creates a new ``NIOAsyncChannelOutboundWriter`` backed by a ``NIOAsyncChannelOutboundWriter/TestSink``.
/// This is mostly useful for testing purposes where one wants to observe the written data.
@inlinable
Expand All @@ -93,7 +96,7 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
) throws {
eventLoop.preconditionInEventLoop()
let writer = _Writer.makeWriter(
elementType: OutboundOut.self,
elementType: OutboundAction<OutboundOut>.self,
isWritable: true,
finishOnDeinit: closeOnDeinit,
delegate: .init(handler: handler)
Expand All @@ -103,11 +106,13 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
handler.writer = writer.writer

self._backing = .writer(writer.writer)
self.eventLoop = eventLoop
}

@inlinable
init(continuation: AsyncStream<OutboundOut>.Continuation) {
self._backing = .asyncStream(continuation)
self.eventLoop = nil
}

/// Send a write into the ``ChannelPipeline`` and flush it right away.
Expand All @@ -119,7 +124,26 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
case .asyncStream(let continuation):
continuation.yield(data)
case .writer(let writer):
try await writer.yield(data)
try await writer.yield(.write(data))
}
}

/// Send a write into the ``ChannelPipeline`` and flush it right away.
///
/// This method suspends until the write has been written and flushed.
@inlinable
public func writeAndFlush(_ data: OutboundOut) async throws {
switch self._backing {
case .asyncStream(let continuation):
continuation.yield(data)
case .writer(let writer):
if let eventLoop {
try await self.withPromise(eventLoop: eventLoop) { promise in
try await writer.yield(.writeAndFlush(data, promise))
}
} else {
try await writer.yield(.write(data))
}
}
}

Expand All @@ -134,7 +158,29 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
continuation.yield(data)
}
case .writer(let writer):
try await writer.yield(contentsOf: sequence)
try await writer.yield(contentsOf: sequence.map { .write($0) })
}
}

/// Send a sequence of writes into the ``ChannelPipeline`` and flush them right away.
///
/// This method suspends if the underlying channel is not writable and will resume once the it becomes writable again.
@inlinable
public func writeAndFlush<Writes: Sequence>(contentsOf sequence: Writes) async throws
where Writes.Element == OutboundOut {
switch self._backing {
case .asyncStream(let continuation):
for data in sequence {
continuation.yield(data)
}
case .writer(let writer):
if let eventLoop {
try await withPromise(eventLoop: eventLoop) { promise in
try await writer.yield(contentsOf: sequence.map { .writeAndFlush($0, promise) })
}
} else {
try await writer.yield(contentsOf: sequence.map { .write($0) })
}
}
}

Expand All @@ -151,6 +197,18 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
}
}

/// Ensure all writes to the writer have been read
@inlinable
public func flush() async throws {
if case .writer(let writer) = self._backing,
let eventLoop
{
try await self.withPromise(eventLoop: eventLoop) { promise in
try await writer.yield(.flush(promise))
}
}
}

/// Finishes the writer.
///
/// This might trigger a half closure if the ``NIOAsyncChannel`` was configured to support it.
Expand All @@ -162,6 +220,20 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
writer.finish()
}
}

@usableFromInline
func withPromise(
eventLoop: EventLoop,
_ process: (EventLoopPromise<Void>) async throws -> Void
) async throws {
let promise = eventLoop.makePromise(of: Void.self)
do {
try await process(promise)
try await promise.futureResult.get()
} catch {
promise.fail(error)
}
}
}

@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
Expand Down
Loading