using Cysharp.Threading.Tasks.Internal; using System; using System.Collections.Generic; using System.Threading; namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { public static IUniTaskAsyncEnumerable Except(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second) { Error.ThrowArgumentNullException(first, nameof(first)); Error.ThrowArgumentNullException(second, nameof(second)); return new Except(first, second, EqualityComparer.Default); } public static IUniTaskAsyncEnumerable Except(this IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, IEqualityComparer comparer) { Error.ThrowArgumentNullException(first, nameof(first)); Error.ThrowArgumentNullException(second, nameof(second)); Error.ThrowArgumentNullException(comparer, nameof(comparer)); return new Except(first, second, comparer); } } internal sealed class Except : IUniTaskAsyncEnumerable { readonly IUniTaskAsyncEnumerable first; readonly IUniTaskAsyncEnumerable second; readonly IEqualityComparer comparer; public Except(IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, IEqualityComparer comparer) { this.first = first; this.second = second; this.comparer = comparer; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _Except(first, second, comparer, cancellationToken); } class _Except : AsyncEnumeratorBase { static Action HashSetAsyncCoreDelegate = HashSetAsyncCore; readonly IEqualityComparer comparer; readonly IUniTaskAsyncEnumerable second; HashSet set; UniTask>.Awaiter awaiter; public _Except(IUniTaskAsyncEnumerable first, IUniTaskAsyncEnumerable second, IEqualityComparer comparer, CancellationToken cancellationToken) : base(first, cancellationToken) { this.second = second; this.comparer = comparer; } protected override bool OnFirstIteration() { if (set != null) return false; awaiter = second.ToHashSetAsync(cancellationToken).GetAwaiter(); if (awaiter.IsCompleted) { set = awaiter.GetResult(); SourceMoveNext(); } else { awaiter.SourceOnCompleted(HashSetAsyncCoreDelegate, this); } return true; } static void HashSetAsyncCore(object state) { var self = (_Except)state; if (self.TryGetResult(self.awaiter, out var result)) { self.set = result; self.SourceMoveNext(); } } protected override bool TryMoveNextCore(bool sourceHasCurrent, out bool result) { if (sourceHasCurrent) { var v = SourceCurrent; if (set.Add(v)) { Current = v; result = true; return true; } else { result = default; return false; } } result = false; return true; } } } }