#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member using System; using System.Collections.Generic; using System.Threading; using Cysharp.Threading.Tasks.Internal; namespace Cysharp.Threading.Tasks { public partial struct UniTask { public static UniTask<(bool hasResultLeft, T result)> WhenAny(UniTask leftTask, UniTask rightTask) { return new UniTask<(bool, T)>(new WhenAnyLRPromise(leftTask, rightTask), 0); } public static UniTask<(int winArgumentIndex, T result)> WhenAny(params UniTask[] tasks) { return new UniTask<(int, T)>(new WhenAnyPromise(tasks, tasks.Length), 0); } public static UniTask<(int winArgumentIndex, T result)> WhenAny(IEnumerable> tasks) { using (var span = ArrayPoolUtil.Materialize(tasks)) { return new UniTask<(int, T)>(new WhenAnyPromise(span.Array, span.Length), 0); } } /// Return value is winArgumentIndex public static UniTask WhenAny(params UniTask[] tasks) { return new UniTask(new WhenAnyPromise(tasks, tasks.Length), 0); } /// Return value is winArgumentIndex public static UniTask WhenAny(IEnumerable tasks) { using (var span = ArrayPoolUtil.Materialize(tasks)) { return new UniTask(new WhenAnyPromise(span.Array, span.Length), 0); } } sealed class WhenAnyLRPromise : IUniTaskSource<(bool, T)> { int completedCount; UniTaskCompletionSourceCore<(bool, T)> core; public WhenAnyLRPromise(UniTask leftTask, UniTask rightTask) { TaskTracker.TrackActiveTask(this, 3); { UniTask.Awaiter awaiter; try { awaiter = leftTask.GetAwaiter(); } catch (Exception ex) { core.TrySetException(ex); goto RIGHT; } if (awaiter.IsCompleted) { TryLeftInvokeContinuation(this, awaiter); } else { awaiter.SourceOnCompleted(state => { using (var t = (StateTuple, UniTask.Awaiter>)state) { TryLeftInvokeContinuation(t.Item1, t.Item2); } }, StateTuple.Create(this, awaiter)); } } RIGHT: { UniTask.Awaiter awaiter; try { awaiter = rightTask.GetAwaiter(); } catch (Exception ex) { core.TrySetException(ex); return; } if (awaiter.IsCompleted) { TryRightInvokeContinuation(this, awaiter); } else { awaiter.SourceOnCompleted(state => { using (var t = (StateTuple, UniTask.Awaiter>)state) { TryRightInvokeContinuation(t.Item1, t.Item2); } }, StateTuple.Create(this, awaiter)); } } } static void TryLeftInvokeContinuation(WhenAnyLRPromise self, in UniTask.Awaiter awaiter) { T result; try { result = awaiter.GetResult(); } catch (Exception ex) { self.core.TrySetException(ex); return; } if (Interlocked.Increment(ref self.completedCount) == 1) { self.core.TrySetResult((true, result)); } } static void TryRightInvokeContinuation(WhenAnyLRPromise self, in UniTask.Awaiter awaiter) { try { awaiter.GetResult(); } catch (Exception ex) { self.core.TrySetException(ex); return; } if (Interlocked.Increment(ref self.completedCount) == 1) { self.core.TrySetResult((false, default)); } } public (bool, T) GetResult(short token) { TaskTracker.RemoveTracking(this); GC.SuppressFinalize(this); return core.GetResult(token); } public UniTaskStatus GetStatus(short token) { return core.GetStatus(token); } public void OnCompleted(Action continuation, object state, short token) { core.OnCompleted(continuation, state, token); } public UniTaskStatus UnsafeGetStatus() { return core.UnsafeGetStatus(); } void IUniTaskSource.GetResult(short token) { GetResult(token); } } sealed class WhenAnyPromise : IUniTaskSource<(int, T)> { int completedCount; UniTaskCompletionSourceCore<(int, T)> core; public WhenAnyPromise(UniTask[] tasks, int tasksLength) { if (tasksLength == 0) { throw new ArgumentException("The tasks argument contains no tasks."); } TaskTracker.TrackActiveTask(this, 3); for (int i = 0; i < tasksLength; i++) { UniTask.Awaiter awaiter; try { awaiter = tasks[i].GetAwaiter(); } catch (Exception ex) { core.TrySetException(ex); continue; // consume others. } if (awaiter.IsCompleted) { TryInvokeContinuation(this, awaiter, i); } else { awaiter.SourceOnCompleted(state => { using (var t = (StateTuple, UniTask.Awaiter, int>)state) { TryInvokeContinuation(t.Item1, t.Item2, t.Item3); } }, StateTuple.Create(this, awaiter, i)); } } } static void TryInvokeContinuation(WhenAnyPromise self, in UniTask.Awaiter awaiter, int i) { T result; try { result = awaiter.GetResult(); } catch (Exception ex) { self.core.TrySetException(ex); return; } if (Interlocked.Increment(ref self.completedCount) == 1) { self.core.TrySetResult((i, result)); } } public (int, T) GetResult(short token) { TaskTracker.RemoveTracking(this); GC.SuppressFinalize(this); return core.GetResult(token); } public UniTaskStatus GetStatus(short token) { return core.GetStatus(token); } public void OnCompleted(Action continuation, object state, short token) { core.OnCompleted(continuation, state, token); } public UniTaskStatus UnsafeGetStatus() { return core.UnsafeGetStatus(); } void IUniTaskSource.GetResult(short token) { GetResult(token); } } sealed class WhenAnyPromise : IUniTaskSource { int completedCount; UniTaskCompletionSourceCore core; public WhenAnyPromise(UniTask[] tasks, int tasksLength) { if (tasksLength == 0) { throw new ArgumentException("The tasks argument contains no tasks."); } TaskTracker.TrackActiveTask(this, 3); for (int i = 0; i < tasksLength; i++) { UniTask.Awaiter awaiter; try { awaiter = tasks[i].GetAwaiter(); } catch (Exception ex) { core.TrySetException(ex); continue; // consume others. } if (awaiter.IsCompleted) { TryInvokeContinuation(this, awaiter, i); } else { awaiter.SourceOnCompleted(state => { using (var t = (StateTuple)state) { TryInvokeContinuation(t.Item1, t.Item2, t.Item3); } }, StateTuple.Create(this, awaiter, i)); } } } static void TryInvokeContinuation(WhenAnyPromise self, in UniTask.Awaiter awaiter, int i) { try { awaiter.GetResult(); } catch (Exception ex) { self.core.TrySetException(ex); return; } if (Interlocked.Increment(ref self.completedCount) == 1) { self.core.TrySetResult(i); } } public int GetResult(short token) { TaskTracker.RemoveTracking(this); GC.SuppressFinalize(this); return core.GetResult(token); } public UniTaskStatus GetStatus(short token) { return core.GetStatus(token); } public void OnCompleted(Action continuation, object state, short token) { core.OnCompleted(continuation, state, token); } public UniTaskStatus UnsafeGetStatus() { return core.UnsafeGetStatus(); } void IUniTaskSource.GetResult(short token) { GetResult(token); } } } }