using Cysharp.Threading.Tasks.Internal; using System; using System.Threading; namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { public static IUniTaskAsyncEnumerable DefaultIfEmpty(this IUniTaskAsyncEnumerable source) { Error.ThrowArgumentNullException(source, nameof(source)); return new DefaultIfEmpty(source, default); } public static IUniTaskAsyncEnumerable DefaultIfEmpty(this IUniTaskAsyncEnumerable source, TSource defaultValue) { Error.ThrowArgumentNullException(source, nameof(source)); return new DefaultIfEmpty(source, defaultValue); } } internal sealed class DefaultIfEmpty : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable source; readonly TSource defaultValue; public DefaultIfEmpty(IUniTaskAsyncEnumerable source, TSource defaultValue) { this.source = source; this.defaultValue = defaultValue; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _DefaultIfEmpty(source, defaultValue, cancellationToken); } sealed class _DefaultIfEmpty : MoveNextSource, IUniTaskAsyncEnumerator { enum IteratingState : byte { Empty, Iterating, Completed } static readonly Action MoveNextCoreDelegate = MoveNextCore; readonly IUniTaskAsyncEnumerable source; readonly TSource defaultValue; CancellationToken cancellationToken; IteratingState iteratingState; IUniTaskAsyncEnumerator enumerator; UniTask.Awaiter awaiter; public _DefaultIfEmpty(IUniTaskAsyncEnumerable source, TSource defaultValue, CancellationToken cancellationToken) { this.source = source; this.defaultValue = defaultValue; this.cancellationToken = cancellationToken; this.iteratingState = IteratingState.Empty; TaskTracker.TrackActiveTask(this, 3); } public TSource Current { get; private set; } public UniTask MoveNextAsync() { cancellationToken.ThrowIfCancellationRequested(); completionSource.Reset(); if (iteratingState == IteratingState.Completed) { return CompletedTasks.False; } if (enumerator == null) { enumerator = source.GetAsyncEnumerator(cancellationToken); } awaiter = enumerator.MoveNextAsync().GetAwaiter(); if (awaiter.IsCompleted) { MoveNextCore(this); } else { awaiter.SourceOnCompleted(MoveNextCoreDelegate, this); } return new UniTask(this, completionSource.Version); } static void MoveNextCore(object state) { var self = (_DefaultIfEmpty)state; if (self.TryGetResult(self.awaiter, out var result)) { if (result) { self.iteratingState = IteratingState.Iterating; self.Current = self.enumerator.Current; self.completionSource.TrySetResult(true); } else { if (self.iteratingState == IteratingState.Empty) { self.iteratingState = IteratingState.Completed; self.Current = self.defaultValue; self.completionSource.TrySetResult(true); } else { self.completionSource.TrySetResult(false); } } } } public UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); if (enumerator != null) { return enumerator.DisposeAsync(); } return default; } } } }