using Cysharp.Threading.Tasks.Internal; using System; using System.Threading; namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { public static IUniTaskAsyncEnumerable SelectMany(this IUniTaskAsyncEnumerable source, Func> selector) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(selector, nameof(selector)); return new SelectMany(source, selector, (x, y) => y); } public static IUniTaskAsyncEnumerable SelectMany(this IUniTaskAsyncEnumerable source, Func> selector) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(selector, nameof(selector)); return new SelectMany(source, selector, (x, y) => y); } public static IUniTaskAsyncEnumerable SelectMany(this IUniTaskAsyncEnumerable source, Func> collectionSelector, Func resultSelector) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(collectionSelector, nameof(collectionSelector)); return new SelectMany(source, collectionSelector, resultSelector); } public static IUniTaskAsyncEnumerable SelectMany(this IUniTaskAsyncEnumerable source, Func> collectionSelector, Func resultSelector) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(collectionSelector, nameof(collectionSelector)); return new SelectMany(source, collectionSelector, resultSelector); } public static IUniTaskAsyncEnumerable SelectManyAwait(this IUniTaskAsyncEnumerable source, Func>> selector) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(selector, nameof(selector)); return new SelectManyAwait(source, selector, (x, y) => UniTask.FromResult(y)); } public static IUniTaskAsyncEnumerable SelectManyAwait(this IUniTaskAsyncEnumerable source, Func>> selector) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(selector, nameof(selector)); return new SelectManyAwait(source, selector, (x, y) => UniTask.FromResult(y)); } public static IUniTaskAsyncEnumerable SelectManyAwait(this IUniTaskAsyncEnumerable source, Func>> collectionSelector, Func> resultSelector) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(collectionSelector, nameof(collectionSelector)); return new SelectManyAwait(source, collectionSelector, resultSelector); } public static IUniTaskAsyncEnumerable SelectManyAwait(this IUniTaskAsyncEnumerable source, Func>> collectionSelector, Func> resultSelector) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(collectionSelector, nameof(collectionSelector)); return new SelectManyAwait(source, collectionSelector, resultSelector); } public static IUniTaskAsyncEnumerable SelectManyAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func>> selector) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(selector, nameof(selector)); return new SelectManyAwaitWithCancellation(source, selector, (x, y, c) => UniTask.FromResult(y)); } public static IUniTaskAsyncEnumerable SelectManyAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func>> selector) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(selector, nameof(selector)); return new SelectManyAwaitWithCancellation(source, selector, (x, y, c) => UniTask.FromResult(y)); } public static IUniTaskAsyncEnumerable SelectManyAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func>> collectionSelector, Func> resultSelector) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(collectionSelector, nameof(collectionSelector)); return new SelectManyAwaitWithCancellation(source, collectionSelector, resultSelector); } public static IUniTaskAsyncEnumerable SelectManyAwaitWithCancellation(this IUniTaskAsyncEnumerable source, Func>> collectionSelector, Func> resultSelector) { Error.ThrowArgumentNullException(source, nameof(source)); Error.ThrowArgumentNullException(collectionSelector, nameof(collectionSelector)); return new SelectManyAwaitWithCancellation(source, collectionSelector, resultSelector); } } internal sealed class SelectMany : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable source; readonly Func> selector1; readonly Func> selector2; readonly Func resultSelector; public SelectMany(IUniTaskAsyncEnumerable source, Func> selector, Func resultSelector) { this.source = source; this.selector1 = selector; this.selector2 = null; this.resultSelector = resultSelector; } public SelectMany(IUniTaskAsyncEnumerable source, Func> selector, Func resultSelector) { this.source = source; this.selector1 = null; this.selector2 = selector; this.resultSelector = resultSelector; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _SelectMany(source, selector1, selector2, resultSelector, cancellationToken); } sealed class _SelectMany : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action sourceMoveNextCoreDelegate = SourceMoveNextCore; static readonly Action selectedSourceMoveNextCoreDelegate = SeletedSourceMoveNextCore; static readonly Action selectedEnumeratorDisposeAsyncCoreDelegate = SelectedEnumeratorDisposeAsyncCore; readonly IUniTaskAsyncEnumerable source; readonly Func> selector1; readonly Func> selector2; readonly Func resultSelector; CancellationToken cancellationToken; TSource sourceCurrent; int sourceIndex; IUniTaskAsyncEnumerator sourceEnumerator; IUniTaskAsyncEnumerator selectedEnumerator; UniTask.Awaiter sourceAwaiter; UniTask.Awaiter selectedAwaiter; UniTask.Awaiter selectedDisposeAsyncAwaiter; public _SelectMany(IUniTaskAsyncEnumerable source, Func> selector1, Func> selector2, Func resultSelector, CancellationToken cancellationToken) { this.source = source; this.selector1 = selector1; this.selector2 = selector2; this.resultSelector = resultSelector; this.cancellationToken = cancellationToken; TaskTracker.TrackActiveTask(this, 3); } public TResult Current { get; private set; } public UniTask MoveNextAsync() { completionSource.Reset(); // iterate selected field if (selectedEnumerator != null) { MoveNextSelected(); } else { // iterate source field if (sourceEnumerator == null) { sourceEnumerator = source.GetAsyncEnumerator(cancellationToken); } MoveNextSource(); } return new UniTask(this, completionSource.Version); } void MoveNextSource() { try { sourceAwaiter = sourceEnumerator.MoveNextAsync().GetAwaiter(); } catch (Exception ex) { completionSource.TrySetException(ex); return; } if (sourceAwaiter.IsCompleted) { SourceMoveNextCore(this); } else { sourceAwaiter.SourceOnCompleted(sourceMoveNextCoreDelegate, this); } } void MoveNextSelected() { try { selectedAwaiter = selectedEnumerator.MoveNextAsync().GetAwaiter(); } catch (Exception ex) { completionSource.TrySetException(ex); return; } if (selectedAwaiter.IsCompleted) { SeletedSourceMoveNextCore(this); } else { selectedAwaiter.SourceOnCompleted(selectedSourceMoveNextCoreDelegate, this); } } static void SourceMoveNextCore(object state) { var self = (_SelectMany)state; if (self.TryGetResult(self.sourceAwaiter, out var result)) { if (result) { try { self.sourceCurrent = self.sourceEnumerator.Current; if (self.selector1 != null) { self.selectedEnumerator = self.selector1(self.sourceCurrent).GetAsyncEnumerator(self.cancellationToken); } else { self.selectedEnumerator = self.selector2(self.sourceCurrent, checked(self.sourceIndex++)).GetAsyncEnumerator(self.cancellationToken); } } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } self.MoveNextSelected(); // iterated selected source. } else { self.completionSource.TrySetResult(false); } } } static void SeletedSourceMoveNextCore(object state) { var self = (_SelectMany)state; if (self.TryGetResult(self.selectedAwaiter, out var result)) { if (result) { try { self.Current = self.resultSelector(self.sourceCurrent, self.selectedEnumerator.Current); } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } self.completionSource.TrySetResult(true); } else { // dispose selected source and try iterate source. try { self.selectedDisposeAsyncAwaiter = self.selectedEnumerator.DisposeAsync().GetAwaiter(); } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } if (self.selectedDisposeAsyncAwaiter.IsCompleted) { SelectedEnumeratorDisposeAsyncCore(self); } else { self.selectedDisposeAsyncAwaiter.SourceOnCompleted(selectedEnumeratorDisposeAsyncCoreDelegate, self); } } } } static void SelectedEnumeratorDisposeAsyncCore(object state) { var self = (_SelectMany)state; if (self.TryGetResult(self.selectedDisposeAsyncAwaiter)) { self.selectedEnumerator = null; self.selectedAwaiter = default; self.MoveNextSource(); // iterate next source } } public async UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); if (selectedEnumerator != null) { await selectedEnumerator.DisposeAsync(); } if (sourceEnumerator != null) { await sourceEnumerator.DisposeAsync(); } } } } internal sealed class SelectManyAwait : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable source; readonly Func>> selector1; readonly Func>> selector2; readonly Func> resultSelector; public SelectManyAwait(IUniTaskAsyncEnumerable source, Func>> selector, Func> resultSelector) { this.source = source; this.selector1 = selector; this.selector2 = null; this.resultSelector = resultSelector; } public SelectManyAwait(IUniTaskAsyncEnumerable source, Func>> selector, Func> resultSelector) { this.source = source; this.selector1 = null; this.selector2 = selector; this.resultSelector = resultSelector; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _SelectManyAwait(source, selector1, selector2, resultSelector, cancellationToken); } sealed class _SelectManyAwait : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action sourceMoveNextCoreDelegate = SourceMoveNextCore; static readonly Action selectedSourceMoveNextCoreDelegate = SeletedSourceMoveNextCore; static readonly Action selectedEnumeratorDisposeAsyncCoreDelegate = SelectedEnumeratorDisposeAsyncCore; static readonly Action selectorAwaitCoreDelegate = SelectorAwaitCore; static readonly Action resultSelectorAwaitCoreDelegate = ResultSelectorAwaitCore; readonly IUniTaskAsyncEnumerable source; readonly Func>> selector1; readonly Func>> selector2; readonly Func> resultSelector; CancellationToken cancellationToken; TSource sourceCurrent; int sourceIndex; IUniTaskAsyncEnumerator sourceEnumerator; IUniTaskAsyncEnumerator selectedEnumerator; UniTask.Awaiter sourceAwaiter; UniTask.Awaiter selectedAwaiter; UniTask.Awaiter selectedDisposeAsyncAwaiter; // await additional UniTask>.Awaiter collectionSelectorAwaiter; UniTask.Awaiter resultSelectorAwaiter; public _SelectManyAwait(IUniTaskAsyncEnumerable source, Func>> selector1, Func>> selector2, Func> resultSelector, CancellationToken cancellationToken) { this.source = source; this.selector1 = selector1; this.selector2 = selector2; this.resultSelector = resultSelector; this.cancellationToken = cancellationToken; TaskTracker.TrackActiveTask(this, 3); } public TResult Current { get; private set; } public UniTask MoveNextAsync() { completionSource.Reset(); // iterate selected field if (selectedEnumerator != null) { MoveNextSelected(); } else { // iterate source field if (sourceEnumerator == null) { sourceEnumerator = source.GetAsyncEnumerator(cancellationToken); } MoveNextSource(); } return new UniTask(this, completionSource.Version); } void MoveNextSource() { try { sourceAwaiter = sourceEnumerator.MoveNextAsync().GetAwaiter(); } catch (Exception ex) { completionSource.TrySetException(ex); return; } if (sourceAwaiter.IsCompleted) { SourceMoveNextCore(this); } else { sourceAwaiter.SourceOnCompleted(sourceMoveNextCoreDelegate, this); } } void MoveNextSelected() { try { selectedAwaiter = selectedEnumerator.MoveNextAsync().GetAwaiter(); } catch (Exception ex) { completionSource.TrySetException(ex); return; } if (selectedAwaiter.IsCompleted) { SeletedSourceMoveNextCore(this); } else { selectedAwaiter.SourceOnCompleted(selectedSourceMoveNextCoreDelegate, this); } } static void SourceMoveNextCore(object state) { var self = (_SelectManyAwait)state; if (self.TryGetResult(self.sourceAwaiter, out var result)) { if (result) { try { self.sourceCurrent = self.sourceEnumerator.Current; if (self.selector1 != null) { self.collectionSelectorAwaiter = self.selector1(self.sourceCurrent).GetAwaiter(); } else { self.collectionSelectorAwaiter = self.selector2(self.sourceCurrent, checked(self.sourceIndex++)).GetAwaiter(); } if (self.collectionSelectorAwaiter.IsCompleted) { SelectorAwaitCore(self); } else { self.collectionSelectorAwaiter.SourceOnCompleted(selectorAwaitCoreDelegate, self); } } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } } else { self.completionSource.TrySetResult(false); } } } static void SeletedSourceMoveNextCore(object state) { var self = (_SelectManyAwait)state; if (self.TryGetResult(self.selectedAwaiter, out var result)) { if (result) { try { self.resultSelectorAwaiter = self.resultSelector(self.sourceCurrent, self.selectedEnumerator.Current).GetAwaiter(); if (self.resultSelectorAwaiter.IsCompleted) { ResultSelectorAwaitCore(self); } else { self.resultSelectorAwaiter.SourceOnCompleted(resultSelectorAwaitCoreDelegate, self); } } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } } else { // dispose selected source and try iterate source. try { self.selectedDisposeAsyncAwaiter = self.selectedEnumerator.DisposeAsync().GetAwaiter(); } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } if (self.selectedDisposeAsyncAwaiter.IsCompleted) { SelectedEnumeratorDisposeAsyncCore(self); } else { self.selectedDisposeAsyncAwaiter.SourceOnCompleted(selectedEnumeratorDisposeAsyncCoreDelegate, self); } } } } static void SelectedEnumeratorDisposeAsyncCore(object state) { var self = (_SelectManyAwait)state; if (self.TryGetResult(self.selectedDisposeAsyncAwaiter)) { self.selectedEnumerator = null; self.selectedAwaiter = default; self.MoveNextSource(); // iterate next source } } static void SelectorAwaitCore(object state) { var self = (_SelectManyAwait)state; if (self.TryGetResult(self.collectionSelectorAwaiter, out var result)) { self.selectedEnumerator = result.GetAsyncEnumerator(self.cancellationToken); self.MoveNextSelected(); // iterated selected source. } } static void ResultSelectorAwaitCore(object state) { var self = (_SelectManyAwait)state; if (self.TryGetResult(self.resultSelectorAwaiter, out var result)) { self.Current = result; self.completionSource.TrySetResult(true); } } public async UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); if (selectedEnumerator != null) { await selectedEnumerator.DisposeAsync(); } if (sourceEnumerator != null) { await sourceEnumerator.DisposeAsync(); } } } } internal sealed class SelectManyAwaitWithCancellation : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable source; readonly Func>> selector1; readonly Func>> selector2; readonly Func> resultSelector; public SelectManyAwaitWithCancellation(IUniTaskAsyncEnumerable source, Func>> selector, Func> resultSelector) { this.source = source; this.selector1 = selector; this.selector2 = null; this.resultSelector = resultSelector; } public SelectManyAwaitWithCancellation(IUniTaskAsyncEnumerable source, Func>> selector, Func> resultSelector) { this.source = source; this.selector1 = null; this.selector2 = selector; this.resultSelector = resultSelector; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _SelectManyAwaitWithCancellation(source, selector1, selector2, resultSelector, cancellationToken); } sealed class _SelectManyAwaitWithCancellation : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action sourceMoveNextCoreDelegate = SourceMoveNextCore; static readonly Action selectedSourceMoveNextCoreDelegate = SeletedSourceMoveNextCore; static readonly Action selectedEnumeratorDisposeAsyncCoreDelegate = SelectedEnumeratorDisposeAsyncCore; static readonly Action selectorAwaitCoreDelegate = SelectorAwaitCore; static readonly Action resultSelectorAwaitCoreDelegate = ResultSelectorAwaitCore; readonly IUniTaskAsyncEnumerable source; readonly Func>> selector1; readonly Func>> selector2; readonly Func> resultSelector; CancellationToken cancellationToken; TSource sourceCurrent; int sourceIndex; IUniTaskAsyncEnumerator sourceEnumerator; IUniTaskAsyncEnumerator selectedEnumerator; UniTask.Awaiter sourceAwaiter; UniTask.Awaiter selectedAwaiter; UniTask.Awaiter selectedDisposeAsyncAwaiter; // await additional UniTask>.Awaiter collectionSelectorAwaiter; UniTask.Awaiter resultSelectorAwaiter; public _SelectManyAwaitWithCancellation(IUniTaskAsyncEnumerable source, Func>> selector1, Func>> selector2, Func> resultSelector, CancellationToken cancellationToken) { this.source = source; this.selector1 = selector1; this.selector2 = selector2; this.resultSelector = resultSelector; this.cancellationToken = cancellationToken; TaskTracker.TrackActiveTask(this, 3); } public TResult Current { get; private set; } public UniTask MoveNextAsync() { completionSource.Reset(); // iterate selected field if (selectedEnumerator != null) { MoveNextSelected(); } else { // iterate source field if (sourceEnumerator == null) { sourceEnumerator = source.GetAsyncEnumerator(cancellationToken); } MoveNextSource(); } return new UniTask(this, completionSource.Version); } void MoveNextSource() { try { sourceAwaiter = sourceEnumerator.MoveNextAsync().GetAwaiter(); } catch (Exception ex) { completionSource.TrySetException(ex); return; } if (sourceAwaiter.IsCompleted) { SourceMoveNextCore(this); } else { sourceAwaiter.SourceOnCompleted(sourceMoveNextCoreDelegate, this); } } void MoveNextSelected() { try { selectedAwaiter = selectedEnumerator.MoveNextAsync().GetAwaiter(); } catch (Exception ex) { completionSource.TrySetException(ex); return; } if (selectedAwaiter.IsCompleted) { SeletedSourceMoveNextCore(this); } else { selectedAwaiter.SourceOnCompleted(selectedSourceMoveNextCoreDelegate, this); } } static void SourceMoveNextCore(object state) { var self = (_SelectManyAwaitWithCancellation)state; if (self.TryGetResult(self.sourceAwaiter, out var result)) { if (result) { try { self.sourceCurrent = self.sourceEnumerator.Current; if (self.selector1 != null) { self.collectionSelectorAwaiter = self.selector1(self.sourceCurrent, self.cancellationToken).GetAwaiter(); } else { self.collectionSelectorAwaiter = self.selector2(self.sourceCurrent, checked(self.sourceIndex++), self.cancellationToken).GetAwaiter(); } if (self.collectionSelectorAwaiter.IsCompleted) { SelectorAwaitCore(self); } else { self.collectionSelectorAwaiter.SourceOnCompleted(selectorAwaitCoreDelegate, self); } } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } } else { self.completionSource.TrySetResult(false); } } } static void SeletedSourceMoveNextCore(object state) { var self = (_SelectManyAwaitWithCancellation)state; if (self.TryGetResult(self.selectedAwaiter, out var result)) { if (result) { try { self.resultSelectorAwaiter = self.resultSelector(self.sourceCurrent, self.selectedEnumerator.Current, self.cancellationToken).GetAwaiter(); if (self.resultSelectorAwaiter.IsCompleted) { ResultSelectorAwaitCore(self); } else { self.resultSelectorAwaiter.SourceOnCompleted(resultSelectorAwaitCoreDelegate, self); } } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } } else { // dispose selected source and try iterate source. try { self.selectedDisposeAsyncAwaiter = self.selectedEnumerator.DisposeAsync().GetAwaiter(); } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } if (self.selectedDisposeAsyncAwaiter.IsCompleted) { SelectedEnumeratorDisposeAsyncCore(self); } else { self.selectedDisposeAsyncAwaiter.SourceOnCompleted(selectedEnumeratorDisposeAsyncCoreDelegate, self); } } } } static void SelectedEnumeratorDisposeAsyncCore(object state) { var self = (_SelectManyAwaitWithCancellation)state; if (self.TryGetResult(self.selectedDisposeAsyncAwaiter)) { self.selectedEnumerator = null; self.selectedAwaiter = default; self.MoveNextSource(); // iterate next source } } static void SelectorAwaitCore(object state) { var self = (_SelectManyAwaitWithCancellation)state; if (self.TryGetResult(self.collectionSelectorAwaiter, out var result)) { self.selectedEnumerator = result.GetAsyncEnumerator(self.cancellationToken); self.MoveNextSelected(); // iterated selected source. } } static void ResultSelectorAwaitCore(object state) { var self = (_SelectManyAwaitWithCancellation)state; if (self.TryGetResult(self.resultSelectorAwaiter, out var result)) { self.Current = result; self.completionSource.TrySetResult(true); } } public async UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); if (selectedEnumerator != null) { await selectedEnumerator.DisposeAsync(); } if (sourceEnumerator != null) { await sourceEnumerator.DisposeAsync(); } } } } }