using System; using System.Threading; namespace Cysharp.Threading.Tasks.Linq { // note: refactor all inherit class and should remove this. // see Select and Where. internal abstract class AsyncEnumeratorBase : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action moveNextCallbackDelegate = MoveNextCallBack; readonly IUniTaskAsyncEnumerable source; protected CancellationToken cancellationToken; IUniTaskAsyncEnumerator enumerator; UniTask.Awaiter sourceMoveNext; public AsyncEnumeratorBase(IUniTaskAsyncEnumerable source, CancellationToken cancellationToken) { this.source = source; this.cancellationToken = cancellationToken; TaskTracker.TrackActiveTask(this, 4); } // abstract /// /// If return value is false, continue source.MoveNext. /// protected abstract bool TryMoveNextCore(bool sourceHasCurrent, out bool result); // Util protected TSource SourceCurrent => enumerator.Current; // IUniTaskAsyncEnumerator public TResult Current { get; protected set; } public UniTask MoveNextAsync() { if (enumerator == null) { enumerator = source.GetAsyncEnumerator(cancellationToken); } completionSource.Reset(); if (!OnFirstIteration()) { SourceMoveNext(); } return new UniTask(this, completionSource.Version); } protected virtual bool OnFirstIteration() { return false; } protected void SourceMoveNext() { CONTINUE: sourceMoveNext = enumerator.MoveNextAsync().GetAwaiter(); if (sourceMoveNext.IsCompleted) { bool result = false; try { if (!TryMoveNextCore(sourceMoveNext.GetResult(), out result)) { goto CONTINUE; } } catch (Exception ex) { completionSource.TrySetException(ex); return; } if (cancellationToken.IsCancellationRequested) { completionSource.TrySetCanceled(cancellationToken); } else { completionSource.TrySetResult(result); } } else { sourceMoveNext.SourceOnCompleted(moveNextCallbackDelegate, this); } } static void MoveNextCallBack(object state) { var self = (AsyncEnumeratorBase)state; bool result; try { if (!self.TryMoveNextCore(self.sourceMoveNext.GetResult(), out result)) { self.SourceMoveNext(); return; } } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } if (self.cancellationToken.IsCancellationRequested) { self.completionSource.TrySetCanceled(self.cancellationToken); } else { self.completionSource.TrySetResult(result); } } // if require additional resource to dispose, override and call base.DisposeAsync. public virtual UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); if (enumerator != null) { return enumerator.DisposeAsync(); } return default; } } internal abstract class AsyncEnumeratorAwaitSelectorBase : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action moveNextCallbackDelegate = MoveNextCallBack; static readonly Action setCurrentCallbackDelegate = SetCurrentCallBack; readonly IUniTaskAsyncEnumerable source; protected CancellationToken cancellationToken; IUniTaskAsyncEnumerator enumerator; UniTask.Awaiter sourceMoveNext; UniTask.Awaiter resultAwaiter; public AsyncEnumeratorAwaitSelectorBase(IUniTaskAsyncEnumerable source, CancellationToken cancellationToken) { this.source = source; this.cancellationToken = cancellationToken; TaskTracker.TrackActiveTask(this, 4); } // abstract protected abstract UniTask TransformAsync(TSource sourceCurrent); protected abstract bool TrySetCurrentCore(TAwait awaitResult, out bool terminateIteration); // Util protected TSource SourceCurrent { get; private set; } protected (bool waitCallback, bool requireNextIteration) ActionCompleted(bool trySetCurrentResult, out bool moveNextResult) { if (trySetCurrentResult) { moveNextResult = true; return (false, false); } else { moveNextResult = default; return (false, true); } } protected (bool waitCallback, bool requireNextIteration) WaitAwaitCallback(out bool moveNextResult) { moveNextResult = default; return (true, false); } protected (bool waitCallback, bool requireNextIteration) IterateFinished(out bool moveNextResult) { moveNextResult = false; return (false, false); } // IUniTaskAsyncEnumerator public TResult Current { get; protected set; } public UniTask MoveNextAsync() { if (enumerator == null) { enumerator = source.GetAsyncEnumerator(cancellationToken); } completionSource.Reset(); SourceMoveNext(); return new UniTask(this, completionSource.Version); } protected void SourceMoveNext() { CONTINUE: sourceMoveNext = enumerator.MoveNextAsync().GetAwaiter(); if (sourceMoveNext.IsCompleted) { bool result = false; try { (bool waitCallback, bool requireNextIteration) = TryMoveNextCore(sourceMoveNext.GetResult(), out result); if (waitCallback) { return; } if (requireNextIteration) { goto CONTINUE; } else { completionSource.TrySetResult(result); } } catch (Exception ex) { completionSource.TrySetException(ex); return; } } else { sourceMoveNext.SourceOnCompleted(moveNextCallbackDelegate, this); } } (bool waitCallback, bool requireNextIteration) TryMoveNextCore(bool sourceHasCurrent, out bool result) { if (sourceHasCurrent) { SourceCurrent = enumerator.Current; var task = TransformAsync(SourceCurrent); if (UnwarapTask(task, out var taskResult)) { var currentResult = TrySetCurrentCore(taskResult, out var terminateIteration); if (terminateIteration) { return IterateFinished(out result); } return ActionCompleted(currentResult, out result); } else { return WaitAwaitCallback(out result); } } return IterateFinished(out result); } protected bool UnwarapTask(UniTask taskResult, out TAwait result) { resultAwaiter = taskResult.GetAwaiter(); if (resultAwaiter.IsCompleted) { result = resultAwaiter.GetResult(); return true; } else { resultAwaiter.SourceOnCompleted(setCurrentCallbackDelegate, this); result = default; return false; } } static void MoveNextCallBack(object state) { var self = (AsyncEnumeratorAwaitSelectorBase)state; bool result = false; try { (bool waitCallback, bool requireNextIteration) = self.TryMoveNextCore(self.sourceMoveNext.GetResult(), out result); if (waitCallback) { return; } if (requireNextIteration) { self.SourceMoveNext(); return; } else { self.completionSource.TrySetResult(result); } } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } } static void SetCurrentCallBack(object state) { var self = (AsyncEnumeratorAwaitSelectorBase)state; bool doneSetCurrent; bool terminateIteration; try { var result = self.resultAwaiter.GetResult(); doneSetCurrent = self.TrySetCurrentCore(result, out terminateIteration); } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } if (self.cancellationToken.IsCancellationRequested) { self.completionSource.TrySetCanceled(self.cancellationToken); } else { if (doneSetCurrent) { self.completionSource.TrySetResult(true); } else { if (terminateIteration) { self.completionSource.TrySetResult(false); } else { self.SourceMoveNext(); } } } } // if require additional resource to dispose, override and call base.DisposeAsync. public virtual UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); if (enumerator != null) { return enumerator.DisposeAsync(); } return default; } } }