using Cysharp.Threading.Tasks.Internal; using System; using System.Threading; namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { public static IUniTaskAsyncEnumerable<(TFirst First, TSecond Second)> Zip(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second) { Error.ThrowArgumentNullException(first, nameof(first)); Error.ThrowArgumentNullException(second, nameof(second)); return Zip(first, second, (x, y) => (x, y)); } public static IUniTaskAsyncEnumerable Zip(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, Func resultSelector) { Error.ThrowArgumentNullException(first, nameof(first)); Error.ThrowArgumentNullException(second, nameof(second)); Error.ThrowArgumentNullException(resultSelector, nameof(resultSelector)); return new Zip(first, second, resultSelector); } public static IUniTaskAsyncEnumerable ZipAwait(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, Func> selector) { Error.ThrowArgumentNullException(first, nameof(first)); Error.ThrowArgumentNullException(second, nameof(second)); Error.ThrowArgumentNullException(selector, nameof(selector)); return new ZipAwait(first, second, selector); } public static IUniTaskAsyncEnumerable ZipAwaitWithCancellation(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, Func> selector) { Error.ThrowArgumentNullException(first, nameof(first)); Error.ThrowArgumentNullException(second, nameof(second)); Error.ThrowArgumentNullException(selector, nameof(selector)); return new ZipAwaitWithCancellation(first, second, selector); } } internal sealed class Zip : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable first; readonly IUniTaskAsyncEnumerable second; readonly Func resultSelector; public Zip(IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, Func resultSelector) { this.first = first; this.second = second; this.resultSelector = resultSelector; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _Zip(first, second, resultSelector, cancellationToken); } sealed class _Zip : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action firstMoveNextCoreDelegate = FirstMoveNextCore; static readonly Action secondMoveNextCoreDelegate = SecondMoveNextCore; readonly IUniTaskAsyncEnumerable first; readonly IUniTaskAsyncEnumerable second; readonly Func resultSelector; CancellationToken cancellationToken; IUniTaskAsyncEnumerator firstEnumerator; IUniTaskAsyncEnumerator secondEnumerator; UniTask.Awaiter firstAwaiter; UniTask.Awaiter secondAwaiter; public _Zip(IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, Func resultSelector, CancellationToken cancellationToken) { this.first = first; this.second = second; this.resultSelector = resultSelector; this.cancellationToken = cancellationToken; TaskTracker.TrackActiveTask(this, 3); } public TResult Current { get; private set; } public UniTask MoveNextAsync() { completionSource.Reset(); if (firstEnumerator == null) { firstEnumerator = first.GetAsyncEnumerator(cancellationToken); secondEnumerator = second.GetAsyncEnumerator(cancellationToken); } firstAwaiter = firstEnumerator.MoveNextAsync().GetAwaiter(); if (firstAwaiter.IsCompleted) { FirstMoveNextCore(this); } else { firstAwaiter.SourceOnCompleted(firstMoveNextCoreDelegate, this); } return new UniTask(this, completionSource.Version); } static void FirstMoveNextCore(object state) { var self = (_Zip)state; if (self.TryGetResult(self.firstAwaiter, out var result)) { if (result) { try { self.secondAwaiter = self.secondEnumerator.MoveNextAsync().GetAwaiter(); } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } if (self.secondAwaiter.IsCompleted) { SecondMoveNextCore(self); } else { self.secondAwaiter.SourceOnCompleted(secondMoveNextCoreDelegate, self); } } else { self.completionSource.TrySetResult(false); } } } static void SecondMoveNextCore(object state) { var self = (_Zip)state; if (self.TryGetResult(self.secondAwaiter, out var result)) { if (result) { try { self.Current = self.resultSelector(self.firstEnumerator.Current, self.secondEnumerator.Current); } catch (Exception ex) { self.completionSource.TrySetException(ex); } if (self.cancellationToken.IsCancellationRequested) { self.completionSource.TrySetCanceled(self.cancellationToken); } else { self.completionSource.TrySetResult(true); } } else { self.completionSource.TrySetResult(false); } } } public async UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); if (firstEnumerator != null) { await firstEnumerator.DisposeAsync(); } if (secondEnumerator != null) { await secondEnumerator.DisposeAsync(); } } } } internal sealed class ZipAwait : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable first; readonly IUniTaskAsyncEnumerable second; readonly Func> resultSelector; public ZipAwait(IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, Func> resultSelector) { this.first = first; this.second = second; this.resultSelector = resultSelector; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _ZipAwait(first, second, resultSelector, cancellationToken); } sealed class _ZipAwait : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action firstMoveNextCoreDelegate = FirstMoveNextCore; static readonly Action secondMoveNextCoreDelegate = SecondMoveNextCore; static readonly Action resultAwaitCoreDelegate = ResultAwaitCore; readonly IUniTaskAsyncEnumerable first; readonly IUniTaskAsyncEnumerable second; readonly Func> resultSelector; CancellationToken cancellationToken; IUniTaskAsyncEnumerator firstEnumerator; IUniTaskAsyncEnumerator secondEnumerator; UniTask.Awaiter firstAwaiter; UniTask.Awaiter secondAwaiter; UniTask.Awaiter resultAwaiter; public _ZipAwait(IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, Func> resultSelector, CancellationToken cancellationToken) { this.first = first; this.second = second; this.resultSelector = resultSelector; this.cancellationToken = cancellationToken; TaskTracker.TrackActiveTask(this, 3); } public TResult Current { get; private set; } public UniTask MoveNextAsync() { completionSource.Reset(); if (firstEnumerator == null) { firstEnumerator = first.GetAsyncEnumerator(cancellationToken); secondEnumerator = second.GetAsyncEnumerator(cancellationToken); } firstAwaiter = firstEnumerator.MoveNextAsync().GetAwaiter(); if (firstAwaiter.IsCompleted) { FirstMoveNextCore(this); } else { firstAwaiter.SourceOnCompleted(firstMoveNextCoreDelegate, this); } return new UniTask(this, completionSource.Version); } static void FirstMoveNextCore(object state) { var self = (_ZipAwait)state; if (self.TryGetResult(self.firstAwaiter, out var result)) { if (result) { try { self.secondAwaiter = self.secondEnumerator.MoveNextAsync().GetAwaiter(); } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } if (self.secondAwaiter.IsCompleted) { SecondMoveNextCore(self); } else { self.secondAwaiter.SourceOnCompleted(secondMoveNextCoreDelegate, self); } } else { self.completionSource.TrySetResult(false); } } } static void SecondMoveNextCore(object state) { var self = (_ZipAwait)state; if (self.TryGetResult(self.secondAwaiter, out var result)) { if (result) { try { self.resultAwaiter = self.resultSelector(self.firstEnumerator.Current, self.secondEnumerator.Current).GetAwaiter(); if (self.resultAwaiter.IsCompleted) { ResultAwaitCore(self); } else { self.resultAwaiter.SourceOnCompleted(resultAwaitCoreDelegate, self); } } catch (Exception ex) { self.completionSource.TrySetException(ex); } } else { self.completionSource.TrySetResult(false); } } } static void ResultAwaitCore(object state) { var self = (_ZipAwait)state; if (self.TryGetResult(self.resultAwaiter, out var result)) { self.Current = result; if (self.cancellationToken.IsCancellationRequested) { self.completionSource.TrySetCanceled(self.cancellationToken); } else { self.completionSource.TrySetResult(true); } } } public async UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); if (firstEnumerator != null) { await firstEnumerator.DisposeAsync(); } if (secondEnumerator != null) { await secondEnumerator.DisposeAsync(); } } } } internal sealed class ZipAwaitWithCancellation : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable first; readonly IUniTaskAsyncEnumerable second; readonly Func> resultSelector; public ZipAwaitWithCancellation(IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, Func> resultSelector) { this.first = first; this.second = second; this.resultSelector = resultSelector; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _ZipAwaitWithCancellation(first, second, resultSelector, cancellationToken); } sealed class _ZipAwaitWithCancellation : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action firstMoveNextCoreDelegate = FirstMoveNextCore; static readonly Action secondMoveNextCoreDelegate = SecondMoveNextCore; static readonly Action resultAwaitCoreDelegate = ResultAwaitCore; readonly IUniTaskAsyncEnumerable first; readonly IUniTaskAsyncEnumerable second; readonly Func> resultSelector; CancellationToken cancellationToken; IUniTaskAsyncEnumerator firstEnumerator; IUniTaskAsyncEnumerator secondEnumerator; UniTask.Awaiter firstAwaiter; UniTask.Awaiter secondAwaiter; UniTask.Awaiter resultAwaiter; public _ZipAwaitWithCancellation(IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, Func> resultSelector, CancellationToken cancellationToken) { this.first = first; this.second = second; this.resultSelector = resultSelector; this.cancellationToken = cancellationToken; TaskTracker.TrackActiveTask(this, 3); } public TResult Current { get; private set; } public UniTask MoveNextAsync() { completionSource.Reset(); if (firstEnumerator == null) { firstEnumerator = first.GetAsyncEnumerator(cancellationToken); secondEnumerator = second.GetAsyncEnumerator(cancellationToken); } firstAwaiter = firstEnumerator.MoveNextAsync().GetAwaiter(); if (firstAwaiter.IsCompleted) { FirstMoveNextCore(this); } else { firstAwaiter.SourceOnCompleted(firstMoveNextCoreDelegate, this); } return new UniTask(this, completionSource.Version); } static void FirstMoveNextCore(object state) { var self = (_ZipAwaitWithCancellation)state; if (self.TryGetResult(self.firstAwaiter, out var result)) { if (result) { try { self.secondAwaiter = self.secondEnumerator.MoveNextAsync().GetAwaiter(); } catch (Exception ex) { self.completionSource.TrySetException(ex); return; } if (self.secondAwaiter.IsCompleted) { SecondMoveNextCore(self); } else { self.secondAwaiter.SourceOnCompleted(secondMoveNextCoreDelegate, self); } } else { self.completionSource.TrySetResult(false); } } } static void SecondMoveNextCore(object state) { var self = (_ZipAwaitWithCancellation)state; if (self.TryGetResult(self.secondAwaiter, out var result)) { if (result) { try { self.resultAwaiter = self.resultSelector(self.firstEnumerator.Current, self.secondEnumerator.Current, self.cancellationToken).GetAwaiter(); if (self.resultAwaiter.IsCompleted) { ResultAwaitCore(self); } else { self.resultAwaiter.SourceOnCompleted(resultAwaitCoreDelegate, self); } } catch (Exception ex) { self.completionSource.TrySetException(ex); } } else { self.completionSource.TrySetResult(false); } } } static void ResultAwaitCore(object state) { var self = (_ZipAwaitWithCancellation)state; if (self.TryGetResult(self.resultAwaiter, out var result)) { self.Current = result; if (self.cancellationToken.IsCancellationRequested) { self.completionSource.TrySetCanceled(self.cancellationToken); } else { self.completionSource.TrySetResult(true); } } } public async UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); if (firstEnumerator != null) { await firstEnumerator.DisposeAsync(); } if (secondEnumerator != null) { await secondEnumerator.DisposeAsync(); } } } } }