using Cysharp.Threading.Tasks.Internal; using System; using System.Threading; namespace Cysharp.Threading.Tasks.Linq { public static partial class UniTaskAsyncEnumerable { public static IUniTaskAsyncEnumerable Create(Func, CancellationToken, UniTask> create) { Error.ThrowArgumentNullException(create, nameof(create)); return new Create(create); } } public interface IAsyncWriter { UniTask YieldAsync(T value); } internal sealed class Create : IUniTaskAsyncEnumerable { readonly Func, CancellationToken, UniTask> create; public Create(Func, CancellationToken, UniTask> create) { this.create = create; } public IUniTaskAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { return new _Create(create, cancellationToken); } sealed class _Create : MoveNextSource, IUniTaskAsyncEnumerator { readonly Func, CancellationToken, UniTask> create; readonly CancellationToken cancellationToken; int state = -1; AsyncWriter writer; public _Create(Func, CancellationToken, UniTask> create, CancellationToken cancellationToken) { this.create = create; this.cancellationToken = cancellationToken; TaskTracker.TrackActiveTask(this, 3); } public T Current { get; private set; } public UniTask DisposeAsync() { TaskTracker.RemoveTracking(this); return default; } public UniTask MoveNextAsync() { if (state == -2) return default; completionSource.Reset(); MoveNext(); return new UniTask(this, completionSource.Version); } void MoveNext() { try { switch (state) { case -1: // init { writer = new AsyncWriter(this); RunWriterTask(create(writer, cancellationToken)).Forget(); if (Volatile.Read(ref state) == -2) { return; // complete synchronously } state = 0; // wait YieldAsync, it set TrySetResult(true) return; } case 0: writer.SignalWriter(); return; default: goto DONE; } } catch (Exception ex) { state = -2; completionSource.TrySetException(ex); return; } DONE: state = -2; completionSource.TrySetResult(false); return; } async UniTaskVoid RunWriterTask(UniTask task) { try { await task; goto DONE; } catch (Exception ex) { Volatile.Write(ref state, -2); completionSource.TrySetException(ex); return; } DONE: Volatile.Write(ref state, -2); completionSource.TrySetResult(false); } public void SetResult(T value) { Current = value; completionSource.TrySetResult(true); } } sealed class AsyncWriter : IUniTaskSource, IAsyncWriter { readonly _Create enumerator; UniTaskCompletionSourceCore core; public AsyncWriter(_Create enumerator) { this.enumerator = enumerator; } public void GetResult(short token) { core.GetResult(token); } public UniTaskStatus GetStatus(short token) { return core.GetStatus(token); } public UniTaskStatus UnsafeGetStatus() { return core.UnsafeGetStatus(); } public void OnCompleted(Action continuation, object state, short token) { core.OnCompleted(continuation, state, token); } public UniTask YieldAsync(T value) { core.Reset(); enumerator.SetResult(value); return new UniTask(this, core.Version); } public void SignalWriter() { core.TrySetResult(AsyncUnit.Default); } } } }