using Cysharp.Threading.Tasks; using Cysharp.Threading.Tasks.Internal; using Cysharp.Threading.Tasks.Linq; using System; using System.Threading; using System.Threading.Tasks; namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { public static IUniTaskAsyncEnumerable Do(this IUniTaskAsyncEnumerable source, Action onNext) { Error.ThrowArgumentNullException(source, nameof(source)); return source.Do(onNext, null, null); } public static IUniTaskAsyncEnumerable Do(this IUniTaskAsyncEnumerable source, Action onNext, Action onError) { Error.ThrowArgumentNullException(source, nameof(source)); return source.Do(onNext, onError, null); } public static IUniTaskAsyncEnumerable Do(this IUniTaskAsyncEnumerable source, Action onNext, Action onCompleted) { Error.ThrowArgumentNullException(source, nameof(source)); return source.Do(onNext, null, onCompleted); } public static IUniTaskAsyncEnumerable Do(this IUniTaskAsyncEnumerable source, Action onNext, Action onError, Action onCompleted) { Error.ThrowArgumentNullException(source, nameof(source)); return new Do(source, onNext, onError, onCompleted); } public static IUniTaskAsyncEnumerable Do(this IUniTaskAsyncEnumerable source, IObserver observer) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(observer, nameof(observer)); return source.Do(observer.OnNext, observer.OnError, observer.OnCompleted); // alloc delegate. } // not yet impl. //public static IUniTaskAsyncEnumerable DoAwait(this IUniTaskAsyncEnumerable source, Func onNext) //{ // throw new NotImplementedException(); //} //public static IUniTaskAsyncEnumerable DoAwait(this IUniTaskAsyncEnumerable source, Func onNext, Func onError) //{ // throw new NotImplementedException(); //} //public static IUniTaskAsyncEnumerable DoAwait(this IUniTaskAsyncEnumerable source, Func onNext, Func onCompleted) //{ // throw new NotImplementedException(); //} //public static IUniTaskAsyncEnumerable DoAwait(this IUniTaskAsyncEnumerable source, Func onNext, Func onError, Func onCompleted) //{ // throw new NotImplementedException(); //} //public static IUniTaskAsyncEnumerable DoAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func onNext) //{ // throw new NotImplementedException(); //} //public static IUniTaskAsyncEnumerable DoAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func onNext, Func onError) //{ // throw new NotImplementedException(); //} //public static IUniTaskAsyncEnumerable DoAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func onNext, Func onCompleted) //{ // throw new NotImplementedException(); //} //public static IUniTaskAsyncEnumerable DoAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func onNext, Func onError, Func onCompleted) //{ // throw new NotImplementedException(); //} } internal sealed class Do : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable source; readonly Action onNext; readonly Action onError; readonly Action onCompleted; public Do(IUniTaskAsyncEnumerable source, Action onNext, Action onError, Action onCompleted) { this.source = source; this.onNext = onNext; this.onError = onError; this.onCompleted = onCompleted; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _Do(source, onNext, onError, onCompleted, cancellationToken); } sealed class _Do : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action MoveNextCoreDelegate = MoveNextCore; readonly IUniTaskAsyncEnumerable source; readonly Action onNext; readonly Action onError; readonly Action onCompleted; CancellationToken cancellationToken; IUniTaskAsyncEnumerator enumerator; UniTask.Awaiter awaiter; public _Do(IUniTaskAsyncEnumerable source, Action onNext, Action onError, Action onCompleted, CancellationToken cancellationToken) { this.source = source; this.onNext = onNext; this.onError = onError; this.onCompleted = onCompleted; this.cancellationToken = cancellationToken; TaskTracker.TrackActiveTask(this, 3); } public TSource Current { get; private set; } public UniTask MoveNextAsync() { cancellationToken.ThrowIfCancellationRequested(); completionSource.Reset(); bool isCompleted = false; try { if (enumerator == null) { enumerator = source.GetAsyncEnumerator(cancellationToken); } awaiter = enumerator.MoveNextAsync().GetAwaiter(); isCompleted = awaiter.IsCompleted; } catch (Exception ex) { CallTrySetExceptionAfterNotification(ex); return new UniTask(this, completionSource.Version); } if (isCompleted) { MoveNextCore(this); } else { awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); } return new UniTask(this, completionSource.Version); } void CallTrySetExceptionAfterNotification(Exception ex) { if (onError != null) { try { onError(ex); } catch (Exception ex2) { completionSource.TrySetException(ex2); return; } } completionSource.TrySetException(ex); } bool TryGetResultWithNotification(UniTask.Awaiter awaiter, out T result) { try { result = awaiter.GetResult(); return true; } catch (Exception ex) { CallTrySetExceptionAfterNotification(ex); result = default; return false; } } static void MoveNextCore(object state) { var self = (_Do)state; if (self.TryGetResultWithNotification(self.awaiter, out var result)) { if (result) { var v = self.enumerator.Current; if (self.onNext != null) { try { self.onNext(v); } catch (Exception ex) { self.CallTrySetExceptionAfterNotification(ex); } } self.Current = v; self.completionSource.TrySetResult(true); } else { if (self.onCompleted != null) { try { self.onCompleted(); } catch (Exception ex) { self.CallTrySetExceptionAfterNotification(ex); return; } } self.completionSource.TrySetResult(false); } } } public UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); if (enumerator != null) { return enumerator.DisposeAsync(); } return default; } } } }