diff --git a/CHANGELOG.md b/CHANGELOG.md index 7eaf7c4..dab08a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +**v0.5.2 - Oxygen:** + +This version is a bug fix version. + +- Multicast: don't cancel the upstream sequence when a client is cancelled + **v0.5.1 - Nitrogen:** This version removes compilation unsafe flags diff --git a/Sources/AsyncSubjects/AsyncCurrentValueSubject.swift b/Sources/AsyncSubjects/AsyncCurrentValueSubject.swift index 6dfed67..5225105 100644 --- a/Sources/AsyncSubjects/AsyncCurrentValueSubject.swift +++ b/Sources/AsyncSubjects/AsyncCurrentValueSubject.swift @@ -67,28 +67,24 @@ public final class AsyncCurrentValueSubject: AsyncSubject where Element /// Sends a value to all consumers /// - Parameter element: the value to send public func send(_ element: Element) { - let channels = self.state.withCriticalRegion { state -> [AsyncBufferedChannel] in + self.state.withCriticalRegion { state in state.current = element - return Array(state.channels.values) - } - - for channel in channels { - channel.send(element) + for channel in state.channels.values { + channel.send(element) + } } } /// Finishes the async sequences with a normal ending. /// - Parameter termination: The termination to finish the subject. public func send(_ termination: Termination) { - let channels = self.state.withCriticalRegion { state -> [AsyncBufferedChannel] in + self.state.withCriticalRegion { state in state.terminalState = termination let channels = Array(state.channels.values) state.channels.removeAll() - return channels - } - - for channel in channels { - channel.finish() + for channel in channels { + channel.finish() + } } } @@ -138,10 +134,10 @@ public final class AsyncCurrentValueSubject: AsyncSubject where Element } public mutating func next() async -> Element? { - await withTaskCancellationHandler { [unregister] in - unregister() - } operation: { + await withTaskCancellationHandler { await self.iterator.next() + } onCancel: { [unregister] in + unregister() } } } diff --git a/Sources/AsyncSubjects/AsyncPassthroughSubject.swift b/Sources/AsyncSubjects/AsyncPassthroughSubject.swift index bb66f86..61690f1 100644 --- a/Sources/AsyncSubjects/AsyncPassthroughSubject.swift +++ b/Sources/AsyncSubjects/AsyncPassthroughSubject.swift @@ -52,27 +52,23 @@ public final class AsyncPassthroughSubject: AsyncSubject { /// Sends a value to all consumers /// - Parameter element: the value to send public func send(_ element: Element) { - let channels = self.state.withCriticalRegion { state in - state.channels.values - } - - for channel in channels { - channel.send(element) + self.state.withCriticalRegion { state in + for channel in state.channels.values { + channel.send(element) + } } } /// Finishes the subject with a normal ending. /// - Parameter termination: The termination to finish the subject public func send(_ termination: Termination) { - let channels = self.state.withCriticalRegion { state -> [AsyncBufferedChannel] in + self.state.withCriticalRegion { state in state.terminalState = termination let channels = Array(state.channels.values) state.channels.removeAll() - return channels - } - - for channel in channels { - channel.finish() + for channel in channels { + channel.finish() + } } } @@ -120,10 +116,10 @@ public final class AsyncPassthroughSubject: AsyncSubject { } public mutating func next() async -> Element? { - await withTaskCancellationHandler { [unregister] in - unregister() - } operation: { + await withTaskCancellationHandler { await self.iterator.next() + } onCancel: { [unregister] in + unregister() } } } diff --git a/Sources/AsyncSubjects/AsyncReplaySubject.swift b/Sources/AsyncSubjects/AsyncReplaySubject.swift index 2b7a9bc..f4e610e 100644 --- a/Sources/AsyncSubjects/AsyncReplaySubject.swift +++ b/Sources/AsyncSubjects/AsyncReplaySubject.swift @@ -46,33 +46,29 @@ public final class AsyncReplaySubject: AsyncSubject where Element: Send /// Sends a value to all consumers /// - Parameter element: the value to send public func send(_ element: Element) { - let channels = self.state.withCriticalRegion { state -> [AsyncBufferedChannel] in + self.state.withCriticalRegion { state in if state.buffer.count >= state.bufferSize && !state.buffer.isEmpty { state.buffer.removeFirst() } state.buffer.append(element) - return Array(state.channels.values) - } - - for channel in channels { - channel.send(element) + for channel in state.channels.values { + channel.send(element) + } } } /// Finishes the subject with a normal ending. /// - Parameter termination: The termination to finish the subject. public func send(_ termination: Termination) { - let channels = self.state.withCriticalRegion { state -> [AsyncBufferedChannel] in + self.state.withCriticalRegion { state in state.terminalState = termination let channels = Array(state.channels.values) state.channels.removeAll() state.buffer.removeAll() state.bufferSize = 0 - return channels - } - - for channel in channels { - channel.finish() + for channel in channels { + channel.finish() + } } } @@ -124,10 +120,10 @@ public final class AsyncReplaySubject: AsyncSubject where Element: Send } public mutating func next() async -> Element? { - await withTaskCancellationHandler { [unregister] in - unregister() - } operation: { + await withTaskCancellationHandler { await self.iterator.next() + } onCancel: { [unregister] in + unregister() } } } diff --git a/Sources/AsyncSubjects/AsyncThrowingCurrentValueSubject.swift b/Sources/AsyncSubjects/AsyncThrowingCurrentValueSubject.swift index 494123e..2294b09 100644 --- a/Sources/AsyncSubjects/AsyncThrowingCurrentValueSubject.swift +++ b/Sources/AsyncSubjects/AsyncThrowingCurrentValueSubject.swift @@ -67,32 +67,28 @@ public final class AsyncThrowingCurrentValueSubject: As /// Sends a value to all consumers /// - Parameter element: the value to send public func send(_ element: Element) { - let channels = self.state.withCriticalRegion { state -> [AsyncThrowingBufferedChannel] in + self.state.withCriticalRegion { state in state.current = element - return Array(state.channels.values) - } - - for channel in channels { - channel.send(element) + for channel in state.channels.values { + channel.send(element) + } } } /// Finishes the subject with either a normal ending or an error. /// - Parameter termination: The termination to finish the subject. public func send(_ termination: Termination) { - let channels = self.state.withCriticalRegion { state -> [AsyncThrowingBufferedChannel] in + self.state.withCriticalRegion { state in state.terminalState = termination let channels = Array(state.channels.values) state.channels.removeAll() - return channels - } - - for channel in channels { - switch termination { - case .finished: - channel.finish() - case .failure(let error): - channel.fail(error) + for channel in channels { + switch termination { + case .finished: + channel.finish() + case .failure(let error): + channel.fail(error) + } } } } @@ -149,10 +145,10 @@ public final class AsyncThrowingCurrentValueSubject: As } public mutating func next() async throws -> Element? { - try await withTaskCancellationHandler { [unregister] in - unregister() - } operation: { + try await withTaskCancellationHandler { try await self.iterator.next() + } onCancel: { [unregister] in + unregister() } } } diff --git a/Sources/AsyncSubjects/AsyncThrowingPassthroughSubject.swift b/Sources/AsyncSubjects/AsyncThrowingPassthroughSubject.swift index c2a1eb4..c1da4a5 100644 --- a/Sources/AsyncSubjects/AsyncThrowingPassthroughSubject.swift +++ b/Sources/AsyncSubjects/AsyncThrowingPassthroughSubject.swift @@ -53,31 +53,28 @@ public final class AsyncThrowingPassthroughSubject: Asy /// Sends a value to all consumers /// - Parameter element: the value to send public func send(_ element: Element) { - let channels = self.state.withCriticalRegion { state in - state.channels.values - } - - for channel in channels { - channel.send(element) + self.state.withCriticalRegion { state in + for channel in state.channels.values { + channel.send(element) + } } } /// Finishes the subject with either a normal ending or an error. /// - Parameter termination: The termination to finish the subject public func send(_ termination: Termination) { - let channels = self.state.withCriticalRegion { state -> [AsyncThrowingBufferedChannel] in + self.state.withCriticalRegion { state in state.terminalState = termination let channels = Array(state.channels.values) state.channels.removeAll() - return channels - } - for channel in channels { - switch termination { - case .finished: - channel.finish() - case .failure(let error): - channel.fail(error) + for channel in channels { + switch termination { + case .finished: + channel.finish() + case .failure(let error): + channel.fail(error) + } } } } @@ -132,10 +129,10 @@ public final class AsyncThrowingPassthroughSubject: Asy } public mutating func next() async throws -> Element? { - try await withTaskCancellationHandler { [unregister] in - unregister() - } operation: { + try await withTaskCancellationHandler { try await self.iterator.next() + } onCancel: { [unregister] in + unregister() } } } diff --git a/Sources/AsyncSubjects/AsyncThrowingReplaySubject.swift b/Sources/AsyncSubjects/AsyncThrowingReplaySubject.swift index 709249e..c736d49 100644 --- a/Sources/AsyncSubjects/AsyncThrowingReplaySubject.swift +++ b/Sources/AsyncSubjects/AsyncThrowingReplaySubject.swift @@ -45,37 +45,33 @@ public final class AsyncThrowingReplaySubject: AsyncSub /// Sends a value to all consumers /// - Parameter element: the value to send public func send(_ element: Element) { - let channels = self.state.withCriticalRegion { state -> [AsyncThrowingBufferedChannel] in + self.state.withCriticalRegion { state in if state.buffer.count >= state.bufferSize && !state.buffer.isEmpty { state.buffer.removeFirst() } state.buffer.append(element) - return Array(state.channels.values) - } - - for channel in channels { - channel.send(element) + for channel in state.channels.values { + channel.send(element) + } } } /// Finishes the subject with either a normal ending or an error. /// - Parameter termination: The termination to finish the subject public func send(_ termination: Termination) { - let channels = self.state.withCriticalRegion { state -> [AsyncThrowingBufferedChannel] in + self.state.withCriticalRegion { state in state.terminalState = termination let channels = Array(state.channels.values) state.channels.removeAll() state.buffer.removeAll() state.bufferSize = 0 - return channels - } - - for channel in channels { - switch termination { - case .finished: - channel.finish() - case .failure(let error): - channel.fail(error) + for channel in channels { + switch termination { + case .finished: + channel.finish() + case .failure(let error): + channel.fail(error) + } } } } @@ -134,10 +130,10 @@ public final class AsyncThrowingReplaySubject: AsyncSub } public mutating func next() async throws -> Element? { - try await withTaskCancellationHandler { [unregister] in - unregister() - } operation: { + try await withTaskCancellationHandler { try await self.iterator.next() + } onCancel: { [unregister] in + unregister() } } } diff --git a/Sources/Operators/AsyncMulticastSequence.swift b/Sources/Operators/AsyncMulticastSequence.swift index 16a5db9..9fd6a32 100644 --- a/Sources/Operators/AsyncMulticastSequence.swift +++ b/Sources/Operators/AsyncMulticastSequence.swift @@ -105,31 +105,37 @@ where Base.Element == Subject.Element, Subject.Failure == Error, Base.AsyncItera } func next() async { - let (canAccessBase, iterator) = self.state.withCriticalRegion { state -> (Bool, Base.AsyncIterator?) in - switch state { - case .available(let iterator): - state = .busy - return (true, iterator) - case .busy: - return (false, nil) + await Task { + let (canAccessBase, iterator) = self.state.withCriticalRegion { state -> (Bool, Base.AsyncIterator?) in + switch state { + case .available(let iterator): + state = .busy + return (true, iterator) + case .busy: + return (false, nil) + } } - } - guard canAccessBase, var iterator = iterator else { return } + guard canAccessBase, var iterator = iterator else { return } - do { - if let element = try await iterator.next() { - self.subject.send(element) - } else { - self.subject.send(.finished) + let toSend: Result + do { + let element = try await iterator.next() + toSend = .success(element) + } catch { + toSend = .failure(error) } - } catch { - self.subject.send(.failure(error)) - } - self.state.withCriticalRegion { state in - state = .available(iterator) - } + self.state.withCriticalRegion { state in + state = .available(iterator) + } + + switch toSend { + case .success(.some(let element)): self.subject.send(element) + case .success(.none): self.subject.send(.finished) + case .failure(let error): self.subject.send(.failure(error)) + } + }.value } public func makeAsyncIterator() -> AsyncIterator { @@ -149,6 +155,8 @@ where Base.Element == Subject.Element, Subject.Failure == Error, Base.AsyncItera let isConnected: ManagedCriticalState public mutating func next() async rethrows -> Element? { + guard !Task.isCancelled else { return nil } + let shouldWaitForGate = self.isConnected.withCriticalRegion { isConnected -> Bool in if !isConnected { isConnected = true @@ -161,10 +169,11 @@ where Base.Element == Subject.Element, Subject.Failure == Error, Base.AsyncItera } if !self.subjectIterator.hasBufferedElements { - await self.asyncMulticastSequence.next() + await self.asyncMulticastSequence.next() } - return try await self.subjectIterator.next() + let element = try await self.subjectIterator.next() + return element } } } diff --git a/Tests/Operators/AsyncMulticastSequenceTests.swift b/Tests/Operators/AsyncMulticastSequenceTests.swift index b048967..a77e4cf 100644 --- a/Tests/Operators/AsyncMulticastSequenceTests.swift +++ b/Tests/Operators/AsyncMulticastSequenceTests.swift @@ -8,27 +8,6 @@ import AsyncExtensions import XCTest -private struct SpyAsyncSequenceForOnNextCall: AsyncSequence { - typealias Element = Element - typealias AsyncIterator = Iterator - - let onNext: () -> Void - - func makeAsyncIterator() -> AsyncIterator { - Iterator(onNext: self.onNext) - } - - struct Iterator: AsyncIteratorProtocol { - let onNext: () -> Void - - func next() async throws -> Element? { - self.onNext() - try await Task.sleep(nanoseconds: 100_000_000_000) - return nil - } - } -} - private class SpyAsyncSequenceForNumberOfIterators: AsyncSequence { typealias Element = Element typealias AsyncIterator = Iterator @@ -177,45 +156,4 @@ final class AsyncMulticastSequenceTests: XCTestCase { XCTAssertEqual(error as? MockError, expectedError) } } - - func test_multicast_finishes_when_task_is_cancelled() { - let taskHasFinishedExpectation = expectation(description: "Task has finished") - - let stream = AsyncThrowingPassthroughSubject() - let sut = AsyncLazySequence<[Int]>([1, 2, 3, 4, 5]) - .multicast(stream) - .autoconnect() - - Task { - for try await _ in sut {} - taskHasFinishedExpectation.fulfill() - }.cancel() - - wait(for: [taskHasFinishedExpectation], timeout: 1) - } - - func test_multicast_finishes_when_task_is_cancelled_while_waiting_for_next() { - let canCancelExpectation = expectation(description: "the task can be cancelled") - let taskHasFinishedExpectation = expectation(description: "Task has finished") - - let spyAsyncSequence = SpyAsyncSequenceForOnNextCall { - canCancelExpectation.fulfill() - } - - let stream = AsyncThrowingPassthroughSubject() - let sut = spyAsyncSequence - .multicast(stream) - .autoconnect() - - let task = Task { - for try await _ in sut {} - taskHasFinishedExpectation.fulfill() - } - - wait(for: [canCancelExpectation], timeout: 1) - - task.cancel() - - wait(for: [taskHasFinishedExpectation], timeout: 1) - } } diff --git a/Tests/Operators/AsyncSequence+ShareTests.swift b/Tests/Operators/AsyncSequence+ShareTests.swift index c41e336..51b18d9 100644 --- a/Tests/Operators/AsyncSequence+ShareTests.swift +++ b/Tests/Operators/AsyncSequence+ShareTests.swift @@ -40,15 +40,15 @@ private struct LongAsyncSequence: AsyncSequence, AsyncIteratorProtocol } mutating func next() async throws -> Element? { - return try await withTaskCancellationHandler { [onCancel] in - onCancel() - } operation: { + return try await withTaskCancellationHandler { try await Task.sleep(nanoseconds: self.interval.nanoseconds) self.currentIndex += 1 if self.currentIndex == self.failAt { throw MockError(code: 0) } return self.elements.next() + } onCancel: {[onCancel] in + onCancel() } }