using Cysharp.Threading.Tasks.Internal; using System; using System.Collections.Generic; using System.Linq; using System.Threading; namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { public static IUniTaskAsyncEnumerable Join(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector) { Error.ThrowArgumentNullException(outer, nameof(outer)); Error.ThrowArgumentNullException(inner, nameof(inner)); Error.ThrowArgumentNullException(outerKeySelector, nameof(outerKeySelector)); Error.ThrowArgumentNullException(innerKeySelector, nameof(innerKeySelector)); Error.ThrowArgumentNullException(resultSelector, nameof(resultSelector)); return new Join(outer, inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer.Default); } public static IUniTaskAsyncEnumerable Join(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) { Error.ThrowArgumentNullException(outer, nameof(outer)); Error.ThrowArgumentNullException(inner, nameof(inner)); Error.ThrowArgumentNullException(outerKeySelector, nameof(outerKeySelector)); Error.ThrowArgumentNullException(innerKeySelector, nameof(innerKeySelector)); Error.ThrowArgumentNullException(resultSelector, nameof(resultSelector)); Error.ThrowArgumentNullException(comparer, nameof(comparer)); return new Join(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); } public static IUniTaskAsyncEnumerable JoinAwait(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector) { Error.ThrowArgumentNullException(outer, nameof(outer)); Error.ThrowArgumentNullException(inner, nameof(inner)); Error.ThrowArgumentNullException(outerKeySelector, nameof(outerKeySelector)); Error.ThrowArgumentNullException(innerKeySelector, nameof(innerKeySelector)); Error.ThrowArgumentNullException(resultSelector, nameof(resultSelector)); return new JoinAwait(outer, inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer.Default); } public static IUniTaskAsyncEnumerable JoinAwait(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector, IEqualityComparer comparer) { Error.ThrowArgumentNullException(outer, nameof(outer)); Error.ThrowArgumentNullException(inner, nameof(inner)); Error.ThrowArgumentNullException(outerKeySelector, nameof(outerKeySelector)); Error.ThrowArgumentNullException(innerKeySelector, nameof(innerKeySelector)); Error.ThrowArgumentNullException(resultSelector, nameof(resultSelector)); Error.ThrowArgumentNullException(comparer, nameof(comparer)); return new JoinAwait(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); } public static IUniTaskAsyncEnumerable JoinAwaitWithCancellation(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector) { Error.ThrowArgumentNullException(outer, nameof(outer)); Error.ThrowArgumentNullException(inner, nameof(inner)); Error.ThrowArgumentNullException(outerKeySelector, nameof(outerKeySelector)); Error.ThrowArgumentNullException(innerKeySelector, nameof(innerKeySelector)); Error.ThrowArgumentNullException(resultSelector, nameof(resultSelector)); return new JoinAwaitWithCancellation(outer, inner, outerKeySelector, innerKeySelector, resultSelector, EqualityComparer.Default); } public static IUniTaskAsyncEnumerable JoinAwaitWithCancellation(this IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector, IEqualityComparer comparer) { Error.ThrowArgumentNullException(outer, nameof(outer)); Error.ThrowArgumentNullException(inner, nameof(inner)); Error.ThrowArgumentNullException(outerKeySelector, nameof(outerKeySelector)); Error.ThrowArgumentNullException(innerKeySelector, nameof(innerKeySelector)); Error.ThrowArgumentNullException(resultSelector, nameof(resultSelector)); Error.ThrowArgumentNullException(comparer, nameof(comparer)); return new JoinAwaitWithCancellation(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer); } } internal sealed class Join : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable outer; readonly IUniTaskAsyncEnumerable inner; readonly Func outerKeySelector; readonly Func innerKeySelector; readonly Func resultSelector; readonly IEqualityComparer comparer; public Join(IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector, IEqualityComparer comparer) { this.outer = outer; this.inner = inner; this.outerKeySelector = outerKeySelector; this.innerKeySelector = innerKeySelector; this.resultSelector = resultSelector; this.comparer = comparer; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _Join(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer, cancellationToken); } sealed class _Join : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action MoveNextCoreDelegate = MoveNextCore; readonly IUniTaskAsyncEnumerable outer; readonly IUniTaskAsyncEnumerable inner; readonly Func outerKeySelector; readonly Func innerKeySelector; readonly Func resultSelector; readonly IEqualityComparer comparer; CancellationToken cancellationToken; ILookup lookup; IUniTaskAsyncEnumerator enumerator; UniTask.Awaiter awaiter; TOuter currentOuterValue; IEnumerator valueEnumerator; bool continueNext; public _Join(IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func outerKeySelector, Func innerKeySelector, Func resultSelector, IEqualityComparer comparer, CancellationToken cancellationToken) { this.outer = outer; this.inner = inner; this.outerKeySelector = outerKeySelector; this.innerKeySelector = innerKeySelector; this.resultSelector = resultSelector; this.comparer = comparer; this.cancellationToken = cancellationToken; TaskTracker.TrackActiveTask(this, 3); } public TResult Current { get; private set; } public UniTask MoveNextAsync() { cancellationToken.ThrowIfCancellationRequested(); completionSource.Reset(); if (lookup == null) { CreateInnerHashSet().Forget(); } else { SourceMoveNext(); } return new UniTask(this, completionSource.Version); } async UniTaskVoid CreateInnerHashSet() { try { lookup = await inner.ToLookupAsync(innerKeySelector, comparer, cancellationToken); enumerator = outer.GetAsyncEnumerator(cancellationToken); } catch (Exception ex) { completionSource.TrySetException(ex); return; } SourceMoveNext(); } void SourceMoveNext() { try { LOOP: if (valueEnumerator != null) { if (valueEnumerator.MoveNext()) { Current = resultSelector(currentOuterValue, valueEnumerator.Current); goto TRY_SET_RESULT_TRUE; } else { valueEnumerator.Dispose(); valueEnumerator = null; } } awaiter = enumerator.MoveNextAsync().GetAwaiter(); if (awaiter.IsCompleted) { continueNext = true; MoveNextCore(this); if (continueNext) { continueNext = false; goto LOOP; // avoid recursive } } else { awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); } } catch (Exception ex) { completionSource.TrySetException(ex); } return; TRY_SET_RESULT_TRUE: completionSource.TrySetResult(true); } static void MoveNextCore(object state) { var self = (_Join)state; if (self.TryGetResult(self.awaiter, out var result)) { if (result) { self.currentOuterValue = self.enumerator.Current; var key = self.outerKeySelector(self.currentOuterValue); self.valueEnumerator = self.lookup[key].GetEnumerator(); if (self.continueNext) { return; } else { self.SourceMoveNext(); } } else { self.continueNext = false; self.completionSource.TrySetResult(false); } } else { self.continueNext = false; } } public UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); if (valueEnumerator != null) { valueEnumerator.Dispose(); } if (enumerator != null) { return enumerator.DisposeAsync(); } return default; } } } internal sealed class JoinAwait : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable outer; readonly IUniTaskAsyncEnumerable inner; readonly Func> outerKeySelector; readonly Func> innerKeySelector; readonly Func> resultSelector; readonly IEqualityComparer comparer; public JoinAwait(IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector, IEqualityComparer comparer) { this.outer = outer; this.inner = inner; this.outerKeySelector = outerKeySelector; this.innerKeySelector = innerKeySelector; this.resultSelector = resultSelector; this.comparer = comparer; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _JoinAwait(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer, cancellationToken); } sealed class _JoinAwait : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action MoveNextCoreDelegate = MoveNextCore; static readonly Action OuterSelectCoreDelegate = OuterSelectCore; static readonly Action ResultSelectCoreDelegate = ResultSelectCore; readonly IUniTaskAsyncEnumerable outer; readonly IUniTaskAsyncEnumerable inner; readonly Func> outerKeySelector; readonly Func> innerKeySelector; readonly Func> resultSelector; readonly IEqualityComparer comparer; CancellationToken cancellationToken; ILookup lookup; IUniTaskAsyncEnumerator enumerator; UniTask.Awaiter awaiter; TOuter currentOuterValue; IEnumerator valueEnumerator; UniTask.Awaiter resultAwaiter; UniTask.Awaiter outerKeyAwaiter; bool continueNext; public _JoinAwait(IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector, IEqualityComparer comparer, CancellationToken cancellationToken) { this.outer = outer; this.inner = inner; this.outerKeySelector = outerKeySelector; this.innerKeySelector = innerKeySelector; this.resultSelector = resultSelector; this.comparer = comparer; this.cancellationToken = cancellationToken; TaskTracker.TrackActiveTask(this, 3); } public TResult Current { get; private set; } public UniTask MoveNextAsync() { cancellationToken.ThrowIfCancellationRequested(); completionSource.Reset(); if (lookup == null) { CreateInnerHashSet().Forget(); } else { SourceMoveNext(); } return new UniTask(this, completionSource.Version); } async UniTaskVoid CreateInnerHashSet() { try { lookup = await inner.ToLookupAwaitAsync(innerKeySelector, comparer, cancellationToken); enumerator = outer.GetAsyncEnumerator(cancellationToken); } catch (Exception ex) { completionSource.TrySetException(ex); return; } SourceMoveNext(); } void SourceMoveNext() { try { LOOP: if (valueEnumerator != null) { if (valueEnumerator.MoveNext()) { resultAwaiter = resultSelector(currentOuterValue, valueEnumerator.Current).GetAwaiter(); if (resultAwaiter.IsCompleted) { ResultSelectCore(this); } else { resultAwaiter.SourceOnCompleted(ResultSelectCoreDelegate, this); } return; } else { valueEnumerator.Dispose(); valueEnumerator = null; } } awaiter = enumerator.MoveNextAsync().GetAwaiter(); if (awaiter.IsCompleted) { continueNext = true; MoveNextCore(this); if (continueNext) { continueNext = false; goto LOOP; // avoid recursive } } else { awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); } } catch (Exception ex) { completionSource.TrySetException(ex); } } static void MoveNextCore(object state) { var self = (_JoinAwait)state; if (self.TryGetResult(self.awaiter, out var result)) { if (result) { self.currentOuterValue = self.enumerator.Current; self.outerKeyAwaiter = self.outerKeySelector(self.currentOuterValue).GetAwaiter(); if (self.outerKeyAwaiter.IsCompleted) { OuterSelectCore(self); } else { self.continueNext = false; self.outerKeyAwaiter.SourceOnCompleted(OuterSelectCoreDelegate, self); } } else { self.continueNext = false; self.completionSource.TrySetResult(false); } } else { self.continueNext = false; } } static void OuterSelectCore(object state) { var self = (_JoinAwait)state; if (self.TryGetResult(self.outerKeyAwaiter, out var key)) { self.valueEnumerator = self.lookup[key].GetEnumerator(); if (self.continueNext) { return; } else { self.SourceMoveNext(); } } else { self.continueNext = false; } } static void ResultSelectCore(object state) { var self = (_JoinAwait)state; if (self.TryGetResult(self.resultAwaiter, out var result)) { self.Current = result; self.completionSource.TrySetResult(true); } } public UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); if (valueEnumerator != null) { valueEnumerator.Dispose(); } if (enumerator != null) { return enumerator.DisposeAsync(); } return default; } } } internal sealed class JoinAwaitWithCancellation : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable outer; readonly IUniTaskAsyncEnumerable inner; readonly Func> outerKeySelector; readonly Func> innerKeySelector; readonly Func> resultSelector; readonly IEqualityComparer comparer; public JoinAwaitWithCancellation(IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector, IEqualityComparer comparer) { this.outer = outer; this.inner = inner; this.outerKeySelector = outerKeySelector; this.innerKeySelector = innerKeySelector; this.resultSelector = resultSelector; this.comparer = comparer; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _JoinAwaitWithCancellation(outer, inner, outerKeySelector, innerKeySelector, resultSelector, comparer, cancellationToken); } sealed class _JoinAwaitWithCancellation : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action MoveNextCoreDelegate = MoveNextCore; static readonly Action OuterSelectCoreDelegate = OuterSelectCore; static readonly Action ResultSelectCoreDelegate = ResultSelectCore; readonly IUniTaskAsyncEnumerable outer; readonly IUniTaskAsyncEnumerable inner; readonly Func> outerKeySelector; readonly Func> innerKeySelector; readonly Func> resultSelector; readonly IEqualityComparer comparer; CancellationToken cancellationToken; ILookup lookup; IUniTaskAsyncEnumerator enumerator; UniTask.Awaiter awaiter; TOuter currentOuterValue; IEnumerator valueEnumerator; UniTask.Awaiter resultAwaiter; UniTask.Awaiter outerKeyAwaiter; bool continueNext; public _JoinAwaitWithCancellation(IUniTaskAsyncEnumerable outer, IUniTaskAsyncEnumerable inner, Func> outerKeySelector, Func> innerKeySelector, Func> resultSelector, IEqualityComparer comparer, CancellationToken cancellationToken) { this.outer = outer; this.inner = inner; this.outerKeySelector = outerKeySelector; this.innerKeySelector = innerKeySelector; this.resultSelector = resultSelector; this.comparer = comparer; this.cancellationToken = cancellationToken; TaskTracker.TrackActiveTask(this, 3); } public TResult Current { get; private set; } public UniTask MoveNextAsync() { cancellationToken.ThrowIfCancellationRequested(); completionSource.Reset(); if (lookup == null) { CreateInnerHashSet().Forget(); } else { SourceMoveNext(); } return new UniTask(this, completionSource.Version); } async UniTaskVoid CreateInnerHashSet() { try { lookup = await inner.ToLookupAwaitWithCancellationAsync(innerKeySelector, comparer, cancellationToken: cancellationToken); enumerator = outer.GetAsyncEnumerator(cancellationToken); } catch (Exception ex) { completionSource.TrySetException(ex); return; } SourceMoveNext(); } void SourceMoveNext() { try { LOOP: if (valueEnumerator != null) { if (valueEnumerator.MoveNext()) { resultAwaiter = resultSelector(currentOuterValue, valueEnumerator.Current, cancellationToken).GetAwaiter(); if (resultAwaiter.IsCompleted) { ResultSelectCore(this); } else { resultAwaiter.SourceOnCompleted(ResultSelectCoreDelegate, this); } return; } else { valueEnumerator.Dispose(); valueEnumerator = null; } } awaiter = enumerator.MoveNextAsync().GetAwaiter(); if (awaiter.IsCompleted) { continueNext = true; MoveNextCore(this); if (continueNext) { continueNext = false; goto LOOP; // avoid recursive } } else { awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); } } catch (Exception ex) { completionSource.TrySetException(ex); } } static void MoveNextCore(object state) { var self = (_JoinAwaitWithCancellation)state; if (self.TryGetResult(self.awaiter, out var result)) { if (result) { self.currentOuterValue = self.enumerator.Current; self.outerKeyAwaiter = self.outerKeySelector(self.currentOuterValue, self.cancellationToken).GetAwaiter(); if (self.outerKeyAwaiter.IsCompleted) { OuterSelectCore(self); } else { self.continueNext = false; self.outerKeyAwaiter.SourceOnCompleted(OuterSelectCoreDelegate, self); } } else { self.continueNext = false; self.completionSource.TrySetResult(false); } } else { self.continueNext = false; } } static void OuterSelectCore(object state) { var self = (_JoinAwaitWithCancellation)state; if (self.TryGetResult(self.outerKeyAwaiter, out var key)) { self.valueEnumerator = self.lookup[key].GetEnumerator(); if (self.continueNext) { return; } else { self.SourceMoveNext(); } } else { self.continueNext = false; } } static void ResultSelectCore(object state) { var self = (_JoinAwaitWithCancellation)state; if (self.TryGetResult(self.resultAwaiter, out var result)) { self.Current = result; self.completionSource.TrySetResult(true); } } public UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); if (valueEnumerator != null) { valueEnumerator.Dispose(); } if (enumerator != null) { return enumerator.DisposeAsync(); } return default; } } } }