using System; using System.Collections.Generic; using System.Threading; namespace Cysharp.Threading.Tasks { public static class Channel { public static Channel CreateSingleConsumerUnbounded() { return new SingleConsumerUnboundedChannel(); } } public abstract class Channel { public ChannelReader Reader { get; protected set; } public ChannelWriter Writer { get; protected set; } public static implicit operator ChannelReader(Channel channel) => channel.Reader; public static implicit operator ChannelWriter(Channel channel) => channel.Writer; } public abstract class Channel : Channel { } public abstract class ChannelReader { public abstract bool TryRead(out T item); public abstract UniTask WaitToReadAsync(CancellationToken cancellationToken = default(CancellationToken)); public abstract UniTask Completion { get; } public virtual UniTask ReadAsync(CancellationToken cancellationToken = default(CancellationToken)) { if (this.TryRead(out var item)) { return UniTask.FromResult(item); } return ReadAsyncCore(cancellationToken); } async UniTask ReadAsyncCore(CancellationToken cancellationToken = default(CancellationToken)) { if (await WaitToReadAsync(cancellationToken)) { if (TryRead(out var item)) { return item; } } throw new ChannelClosedException(); } public abstract IUniTaskAsyncEnumerable ReadAllAsync(CancellationToken cancellationToken = default(CancellationToken)); } public abstract class ChannelWriter { public abstract bool TryWrite(T item); public abstract bool TryComplete(Exception error = null); public void Complete(Exception error = null) { if (!TryComplete(error)) { throw new ChannelClosedException(); } } } public partial class ChannelClosedException : InvalidOperationException { public ChannelClosedException() : base("Channel is already closed.") { } public ChannelClosedException(string message) : base(message) { } public ChannelClosedException(Exception innerException) : base("Channel is already closed", innerException) { } public ChannelClosedException(string message, Exception innerException) : base(message, innerException) { } } internal class SingleConsumerUnboundedChannel : Channel { readonly Queue items; readonly SingleConsumerUnboundedChannelReader readerSource; UniTaskCompletionSource completedTaskSource; UniTask completedTask; Exception completionError; bool closed; public SingleConsumerUnboundedChannel() { items = new Queue(); Writer = new SingleConsumerUnboundedChannelWriter(this); readerSource = new SingleConsumerUnboundedChannelReader(this); Reader = readerSource; } sealed class SingleConsumerUnboundedChannelWriter : ChannelWriter { readonly SingleConsumerUnboundedChannel parent; public SingleConsumerUnboundedChannelWriter(SingleConsumerUnboundedChannel parent) { this.parent = parent; } public override bool TryWrite(T item) { bool waiting; lock (parent.items) { if (parent.closed) return false; parent.items.Enqueue(item); waiting = parent.readerSource.isWaiting; } if (waiting) { parent.readerSource.SingalContinuation(); } return true; } public override bool TryComplete(Exception error = null) { bool waiting; lock (parent.items) { if (parent.closed) return false; parent.closed = true; waiting = parent.readerSource.isWaiting; if (parent.items.Count == 0) { if (error == null) { if (parent.completedTaskSource != null) { parent.completedTaskSource.TrySetResult(); } else { parent.completedTask = UniTask.CompletedTask; } } else { if (parent.completedTaskSource != null) { parent.completedTaskSource.TrySetException(error); } else { parent.completedTask = UniTask.FromException(error); } } if (waiting) { parent.readerSource.SingalCompleted(error); } } parent.completionError = error; } return true; } } sealed class SingleConsumerUnboundedChannelReader : ChannelReader, IUniTaskSource { readonly Action CancellationCallbackDelegate = CancellationCallback; readonly SingleConsumerUnboundedChannel parent; CancellationToken cancellationToken; CancellationTokenRegistration cancellationTokenRegistration; UniTaskCompletionSourceCore core; internal bool isWaiting; public SingleConsumerUnboundedChannelReader(SingleConsumerUnboundedChannel parent) { this.parent = parent; TaskTracker.TrackActiveTask(this, 4); } public override UniTask Completion { get { if (parent.completedTaskSource != null) return parent.completedTaskSource.Task; if (parent.closed) { return parent.completedTask; } parent.completedTaskSource = new UniTaskCompletionSource(); return parent.completedTaskSource.Task; } } public override bool TryRead(out T item) { lock (parent.items) { if (parent.items.Count != 0) { item = parent.items.Dequeue(); // complete when all value was consumed. if (parent.closed && parent.items.Count == 0) { if (parent.completionError != null) { if (parent.completedTaskSource != null) { parent.completedTaskSource.TrySetException(parent.completionError); } else { parent.completedTask = UniTask.FromException(parent.completionError); } } else { if (parent.completedTaskSource != null) { parent.completedTaskSource.TrySetResult(); } else { parent.completedTask = UniTask.CompletedTask; } } } } else { item = default; return false; } } return true; } public override UniTask WaitToReadAsync(CancellationToken cancellationToken) { if (cancellationToken.IsCancellationRequested) { return UniTask.FromCanceled(cancellationToken); } lock (parent.items) { if (parent.items.Count != 0) { return CompletedTasks.True; } if (parent.closed) { if (parent.completionError == null) { return CompletedTasks.False; } else { return UniTask.FromException(parent.completionError); } } cancellationTokenRegistration.Dispose(); core.Reset(); isWaiting = true; this.cancellationToken = cancellationToken; if (this.cancellationToken.CanBeCanceled) { cancellationTokenRegistration = this.cancellationToken.RegisterWithoutCaptureExecutionContext(CancellationCallbackDelegate, this); } return new UniTask(this, core.Version); } } public void SingalContinuation() { core.TrySetResult(true); } public void SingalCancellation(CancellationToken cancellationToken) { TaskTracker.RemoveTracking(this); core.TrySetCanceled(cancellationToken); } public void SingalCompleted(Exception error) { if (error != null) { TaskTracker.RemoveTracking(this); core.TrySetException(error); } else { TaskTracker.RemoveTracking(this); core.TrySetResult(false); } } public override IUniTaskAsyncEnumerable ReadAllAsync(CancellationToken cancellationToken = default) { return new ReadAllAsyncEnumerable(this, cancellationToken); } bool IUniTaskSource.GetResult(short token) { return core.GetResult(token); } void IUniTaskSource.GetResult(short token) { core.GetResult(token); } UniTaskStatus IUniTaskSource.GetStatus(short token) { return core.GetStatus(token); } void IUniTaskSource.OnCompleted(Action continuation, object state, short token) { core.OnCompleted(continuation, state, token); } UniTaskStatus IUniTaskSource.UnsafeGetStatus() { return core.UnsafeGetStatus(); } static void CancellationCallback(object state) { var self = (SingleConsumerUnboundedChannelReader)state; self.SingalCancellation(self.cancellationToken); } sealed class ReadAllAsyncEnumerable : IUniTaskAsyncEnumerable, IUniTaskAsyncEnumerator { readonly Action CancellationCallback1Delegate = CancellationCallback1; readonly Action CancellationCallback2Delegate = CancellationCallback2; readonly SingleConsumerUnboundedChannelReader parent; CancellationToken cancellationToken1; CancellationToken cancellationToken2; CancellationTokenRegistration cancellationTokenRegistration1; CancellationTokenRegistration cancellationTokenRegistration2; T current; bool cacheValue; bool running; public ReadAllAsyncEnumerable(SingleConsumerUnboundedChannelReader parent, CancellationToken cancellationToken) { this.parent = parent; this.cancellationToken1 = cancellationToken; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { if (running) { throw new InvalidOperationException("Enumerator is already running, does not allow call GetAsyncEnumerator twice."); } if (this.cancellationToken1 != cancellationToken) { this.cancellationToken2 = cancellationToken; } if (this.cancellationToken1.CanBeCanceled) { this.cancellationTokenRegistration1 = this.cancellationToken1.RegisterWithoutCaptureExecutionContext(CancellationCallback1Delegate, this); } if (this.cancellationToken2.CanBeCanceled) { this.cancellationTokenRegistration2 = this.cancellationToken2.RegisterWithoutCaptureExecutionContext(CancellationCallback2Delegate, this); } running = true; return this; } public T Current { get { if (cacheValue) { return current; } parent.TryRead(out current); return current; } } public UniTask MoveNextAsync() { cacheValue = false; return parent.WaitToReadAsync(CancellationToken.None); // ok to use None, registered in ctor. } public UniTask DisposeAsync() { cancellationTokenRegistration1.Dispose(); cancellationTokenRegistration2.Dispose(); return default; } static void CancellationCallback1(object state) { var self = (ReadAllAsyncEnumerable)state; self.parent.SingalCancellation(self.cancellationToken1); } static void CancellationCallback2(object state) { var self = (ReadAllAsyncEnumerable)state; self.parent.SingalCancellation(self.cancellationToken2); } } } } }