using Cysharp.Threading.Tasks.Internal; using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { public static IUniTaskAsyncEnumerable ToUniTaskAsyncEnumerable(this IEnumerable source) { Error.ThrowArgumentNullException(source, nameof(source)); return new ToUniTaskAsyncEnumerable(source); } public static IUniTaskAsyncEnumerable ToUniTaskAsyncEnumerable(this Task source) { Error.ThrowArgumentNullException(source, nameof(source)); return new ToUniTaskAsyncEnumerableTask(source); } public static IUniTaskAsyncEnumerable ToUniTaskAsyncEnumerable(this UniTask source) { return new ToUniTaskAsyncEnumerableUniTask(source); } public static IUniTaskAsyncEnumerable ToUniTaskAsyncEnumerable(this IObservable source) { Error.ThrowArgumentNullException(source, nameof(source)); return new ToUniTaskAsyncEnumerableObservable(source); } } internal class ToUniTaskAsyncEnumerable : IUniTaskAsyncEnumerable { readonly IEnumerable source; public ToUniTaskAsyncEnumerable(IEnumerable source) { this.source = source; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _ToUniTaskAsyncEnumerable(source, cancellationToken); } class _ToUniTaskAsyncEnumerable : IUniTaskAsyncEnumerator { readonly IEnumerable source; CancellationToken cancellationToken; IEnumerator enumerator; public _ToUniTaskAsyncEnumerable(IEnumerable source, CancellationToken cancellationToken) { this.source = source; this.cancellationToken = cancellationToken; } public T Current => enumerator.Current; public UniTask MoveNextAsync() { cancellationToken.ThrowIfCancellationRequested(); if (enumerator == null) { enumerator = source.GetEnumerator(); } if (enumerator.MoveNext()) { return CompletedTasks.True; } return CompletedTasks.False; } public UniTask DisposeAsync() { enumerator.Dispose(); return default; } } } internal class ToUniTaskAsyncEnumerableTask : IUniTaskAsyncEnumerable { readonly Task source; public ToUniTaskAsyncEnumerableTask(Task source) { this.source = source; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _ToUniTaskAsyncEnumerableTask(source, cancellationToken); } class _ToUniTaskAsyncEnumerableTask : IUniTaskAsyncEnumerator { readonly Task source; CancellationToken cancellationToken; T current; bool called; public _ToUniTaskAsyncEnumerableTask(Task source, CancellationToken cancellationToken) { this.source = source; this.cancellationToken = cancellationToken; this.called = false; } public T Current => current; public async UniTask MoveNextAsync() { cancellationToken.ThrowIfCancellationRequested(); if (called) { return false; } called = true; current = await source; return true; } public UniTask DisposeAsync() { return default; } } } internal class ToUniTaskAsyncEnumerableUniTask : IUniTaskAsyncEnumerable { readonly UniTask source; public ToUniTaskAsyncEnumerableUniTask(UniTask source) { this.source = source; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _ToUniTaskAsyncEnumerableUniTask(source, cancellationToken); } class _ToUniTaskAsyncEnumerableUniTask : IUniTaskAsyncEnumerator { readonly UniTask source; CancellationToken cancellationToken; T current; bool called; public _ToUniTaskAsyncEnumerableUniTask(UniTask source, CancellationToken cancellationToken) { this.source = source; this.cancellationToken = cancellationToken; this.called = false; } public T Current => current; public async UniTask MoveNextAsync() { cancellationToken.ThrowIfCancellationRequested(); if (called) { return false; } called = true; current = await source; return true; } public UniTask DisposeAsync() { return default; } } } internal class ToUniTaskAsyncEnumerableObservable : IUniTaskAsyncEnumerable { readonly IObservable source; public ToUniTaskAsyncEnumerableObservable(IObservable source) { this.source = source; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _ToUniTaskAsyncEnumerableObservable(source, cancellationToken); } class _ToUniTaskAsyncEnumerableObservable : MoveNextSource, IUniTaskAsyncEnumerator, IObserver { static readonly Action OnCanceledDelegate = OnCanceled; readonly IObservable source; CancellationToken cancellationToken; bool useCachedCurrent; T current; bool subscribeCompleted; readonly Queue queuedResult; Exception error; IDisposable subscription; CancellationTokenRegistration cancellationTokenRegistration; public _ToUniTaskAsyncEnumerableObservable(IObservable source, CancellationToken cancellationToken) { this.source = source; this.cancellationToken = cancellationToken; this.queuedResult = new Queue(); if (cancellationToken.CanBeCanceled) { cancellationTokenRegistration = cancellationToken.RegisterWithoutCaptureExecutionContext(OnCanceledDelegate, this); } } public T Current { get { if (useCachedCurrent) { return current; } lock (queuedResult) { if (queuedResult.Count != 0) { current = queuedResult.Dequeue(); useCachedCurrent = true; return current; } else { return default; // undefined. } } } } public UniTask MoveNextAsync() { lock (queuedResult) { useCachedCurrent = false; if (cancellationToken.IsCancellationRequested) { return UniTask.FromCanceled(cancellationToken); } if (subscription == null) { subscription = source.Subscribe(this); } if (error != null) { return UniTask.FromException(error); } if (queuedResult.Count != 0) { return CompletedTasks.True; } if (subscribeCompleted) { return CompletedTasks.False; } completionSource.Reset(); return new UniTask(this, completionSource.Version); } } public UniTask DisposeAsync() { subscription.Dispose(); cancellationTokenRegistration.Dispose(); completionSource.Reset(); return default; } public void OnCompleted() { lock (queuedResult) { subscribeCompleted = true; completionSource.TrySetResult(false); } } public void OnError(Exception error) { lock (queuedResult) { this.error = error; completionSource.TrySetException(error); } } public void OnNext(T value) { lock (queuedResult) { queuedResult.Enqueue(value); completionSource.TrySetResult(true); // include callback execution, too long lock? } } static void OnCanceled(object state) { var self = (_ToUniTaskAsyncEnumerableObservable)state; lock (self.queuedResult) { self.completionSource.TrySetCanceled(self.cancellationToken); } } } } }