using Cysharp.Threading.Tasks.Internal; using System; using System.Threading; namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { public static IUniTaskAsyncEnumerable Concat(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second) { Error.ThrowArgumentNullException(first, nameof(first)); Error.ThrowArgumentNullException(second, nameof(second)); return new Concat(first, second); } } internal sealed class Concat : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable first; readonly IUniTaskAsyncEnumerable second; public Concat(IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second) { this.first = first; this.second = second; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _Concat(first, second, cancellationToken); } sealed class _Concat : MoveNextSource, IUniTaskAsyncEnumerator { static readonly Action MoveNextCoreDelegate = MoveNextCore; enum IteratingState { IteratingFirst, IteratingSecond, Complete } readonly IUniTaskAsyncEnumerable first; readonly IUniTaskAsyncEnumerable second; CancellationToken cancellationToken; IteratingState iteratingState; IUniTaskAsyncEnumerator enumerator; UniTask.Awaiter awaiter; public _Concat(IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, CancellationToken cancellationToken) { this.first = first; this.second = second; this.cancellationToken = cancellationToken; this.iteratingState = IteratingState.IteratingFirst; TaskTracker.TrackActiveTask(this, 3); } public TSource Current { get; private set; } public UniTask MoveNextAsync() { cancellationToken.ThrowIfCancellationRequested(); if (iteratingState == IteratingState.Complete) return CompletedTasks.False; completionSource.Reset(); StartIterate(); return new UniTask(this, completionSource.Version); } void StartIterate() { if (enumerator == null) { if (iteratingState == IteratingState.IteratingFirst) { enumerator = first.GetAsyncEnumerator(cancellationToken); } else if (iteratingState == IteratingState.IteratingSecond) { enumerator = second.GetAsyncEnumerator(cancellationToken); } } try { awaiter = enumerator.MoveNextAsync().GetAwaiter(); } catch (Exception ex) { completionSource.TrySetException(ex); return; } if (awaiter.IsCompleted) { MoveNextCoreDelegate(this); } else { awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); } } static void MoveNextCore(object state) { var self = (_Concat)state; if (self.TryGetResult(self.awaiter, out var result)) { if (result) { self.Current = self.enumerator.Current; self.completionSource.TrySetResult(true); } else { if (self.iteratingState == IteratingState.IteratingFirst) { self.RunSecondAfterDisposeAsync().Forget(); return; } self.iteratingState = IteratingState.Complete; self.completionSource.TrySetResult(false); } } } async UniTaskVoid RunSecondAfterDisposeAsync() { try { await enumerator.DisposeAsync(); enumerator = null; awaiter = default; iteratingState = IteratingState.IteratingSecond; } catch (Exception ex) { completionSource.TrySetException(ex); } StartIterate(); } public UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); if (enumerator != null) { return enumerator.DisposeAsync(); } return default; } } } }