Skip to content

Commit c8efc06

Browse files
authored
+ TaskHelper.ForEachAsync method (#112)
* + TaskHelper.ForEachAsync
1 parent 821e49f commit c8efc06

3 files changed

Lines changed: 354 additions & 0 deletions

File tree

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
using System;
2+
using System.Collections.Concurrent;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Threading;
6+
using System.Threading.Tasks;
7+
8+
using NUnit.Framework;
9+
10+
#if NET45_OR_GREATER || TARGETS_NETCOREAPP
11+
using TaskEx = System.Threading.Tasks.Task;
12+
#elif NET40_OR_GREATER
13+
using TaskEx = System.Threading.Tasks.TaskEx;
14+
#else
15+
using TaskEx = System.Threading.Tasks.Task;
16+
#endif
17+
18+
namespace CodeJam.Threading
19+
{
20+
public partial class TaskHelperTests
21+
{
22+
private class StubScheduler : TaskScheduler
23+
{
24+
public StubScheduler(int maximumConcurrencyLevel)
25+
{
26+
MaximumConcurrencyLevel = maximumConcurrencyLevel;
27+
}
28+
29+
protected override void QueueTask(Task task) => throw new NotImplementedException();
30+
31+
protected override IEnumerable<Task> GetScheduledTasks() => throw new NotImplementedException();
32+
33+
protected override bool TryExecuteTaskInline(Task task, bool taskWasPreviouslyQueued) => throw new NotImplementedException();
34+
35+
public override int MaximumConcurrencyLevel { get; }
36+
}
37+
38+
[Test]
39+
public void TestGetMaxDegreeOfParallelism()
40+
{
41+
var scheduler = TaskScheduler.Current;
42+
var stubScheduler = new StubScheduler(17);
43+
44+
Assert.AreEqual(scheduler.GetMaxDegreeOfParallelism(1), 1);
45+
Assert.AreEqual(scheduler.GetMaxDegreeOfParallelism(0), Environment.ProcessorCount);
46+
Assert.AreEqual(scheduler.GetMaxDegreeOfParallelism(-1), Environment.ProcessorCount);
47+
48+
Assert.AreEqual(stubScheduler.GetMaxDegreeOfParallelism(1), 1);
49+
Assert.AreEqual(stubScheduler.GetMaxDegreeOfParallelism(20), 17);
50+
Assert.AreEqual(stubScheduler.GetMaxDegreeOfParallelism(0), 17);
51+
Assert.AreEqual(stubScheduler.GetMaxDegreeOfParallelism(-1), 17);
52+
}
53+
54+
[Test]
55+
public void TestWithAggregateExceptions()
56+
{
57+
var tcs = new TaskCompletionSource<int>();
58+
tcs.SetException(new InvalidOperationException());
59+
var errorTask = TaskEx.WhenAll(tcs.Task);
60+
61+
Assert.Throws<InvalidOperationException>(() => errorTask.GetAwaiter().GetResult());
62+
var ex = Assert.Throws<AggregateException>(() => errorTask.WithAggregateException().GetAwaiter().GetResult());
63+
Assert.That(ex.InnerExceptions[0], Is.TypeOf<InvalidOperationException>());
64+
}
65+
66+
[Test]
67+
public void TestForEachAsync()
68+
{
69+
TaskEx.Run(() => TestForEachAsyncCore()).Wait();
70+
}
71+
72+
public async Task TestForEachAsyncCore()
73+
{
74+
var tasks = Enumerable.Range(0, 20).ToArray();
75+
76+
var result = await tasks.ForEachAsync((i, ct) => TaskEx.FromResult(i.ToString()), 4);
77+
78+
CollectionAssert.AreEquivalent(result, tasks.Select(t => t.ToString()));
79+
}
80+
81+
[Test]
82+
public void TestForEachAsyncThrows()
83+
{
84+
var tasks = Enumerable.Range(0, 20).ToArray();
85+
86+
var forEachTask = tasks.ForEachAsync(
87+
(i, ct) => throw new ArgumentException("a"),
88+
4);
89+
90+
var ex = Assert.Throws<AggregateException>(() => forEachTask.GetAwaiter().GetResult());
91+
92+
Assert.That(ex.InnerExceptions.Count, Is.InRange(1, 4));
93+
}
94+
95+
[Test]
96+
public void TestForEachAsyncThrowsBreaks()
97+
{
98+
var tasks = Enumerable.Range(0, 20).ToArray();
99+
var results = new ConcurrentBag<int>();
100+
101+
var forEachTask = tasks.ForEachAsync(
102+
(i, ct) =>
103+
{
104+
results.Add(i);
105+
return i == 0 ? throw new ArgumentException("a") : TaskEx.Delay(-1, ct);
106+
},
107+
4);
108+
109+
var ex = Assert.Throws<AggregateException>(() => forEachTask.GetAwaiter().GetResult());
110+
111+
Assert.AreEqual(ex.InnerExceptions.Count, 1);
112+
Assert.That(results.Count, Is.InRange(1, 4));
113+
}
114+
115+
[Test]
116+
public void ForEachAsyncCancellation()
117+
{
118+
var tasks = Enumerable.Range(0, 20).ToArray();
119+
var results = new ConcurrentBag<int>();
120+
var cts = new CancellationTokenSource();
121+
122+
var forEachTask = tasks.ForEachAsync(
123+
(i, ct) =>
124+
{
125+
results.Add(i);
126+
return TaskEx.Delay(-1, ct);
127+
},
128+
4,
129+
cts.Token);
130+
131+
cts.CancelAfter(TimeSpan.FromSeconds(2));
132+
133+
Assert.Throws<TaskCanceledException>(() => forEachTask.GetAwaiter().GetResult());
134+
Assert.AreEqual(results.Count, 4);
135+
}
136+
}
137+
}
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
using System;
2+
using System.Collections.Concurrent;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Runtime.ExceptionServices;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
9+
using JetBrains.Annotations;
10+
11+
#if NET45_OR_GREATER || TARGETS_NETSTANDARD || TARGETS_NETCOREAPP
12+
using TaskEx = System.Threading.Tasks.Task;
13+
#else
14+
using TaskEx = System.Threading.Tasks.TaskEx;
15+
16+
#endif
17+
18+
namespace CodeJam.Threading
19+
{
20+
/// <summary>
21+
/// Helper methods for <see cref="Task"/> and <see cref="Task{TResult}"/>.
22+
/// </summary>
23+
public static partial class TaskHelper
24+
{
25+
private const int _maxProcessorCountRefreshTicks = 30000;
26+
private static volatile int _processorCount;
27+
private static volatile int _lastProcessorCountRefreshTicks;
28+
29+
// BASEDON PlatformHelper.ProcessorCount
30+
private static int ProcessorCount
31+
{
32+
get
33+
{
34+
var tickCount = Environment.TickCount;
35+
if (_processorCount == 0 || tickCount - _lastProcessorCountRefreshTicks >= _maxProcessorCountRefreshTicks)
36+
{
37+
_processorCount = Environment.ProcessorCount;
38+
_lastProcessorCountRefreshTicks = tickCount;
39+
}
40+
return _processorCount;
41+
}
42+
}
43+
44+
/// <summary>
45+
/// Gets the maximum degree of parallelism for the scheduler.
46+
/// Matches to the <see cref="Parallel.ForEach{TSource}(IEnumerable{TSource}, Action{TSource})"/> behavior.
47+
/// Limits <paramref name="value"/> by <see cref="TaskScheduler.MaximumConcurrencyLevel"/> value (if non-zero positiver value).
48+
/// Otherwise, uses <see cref="Environment.ProcessorCount"/> as fallback value.
49+
/// </summary>
50+
public static int GetMaxDegreeOfParallelism([NotNull] this TaskScheduler scheduler, int value)
51+
{
52+
Code.NotNull(scheduler, nameof(scheduler));
53+
54+
var concurrencyLimit = scheduler.MaximumConcurrencyLevel;
55+
56+
if (concurrencyLimit > 0 && concurrencyLimit != int.MaxValue)
57+
return value <= 0 ? concurrencyLimit : Math.Min(concurrencyLimit, value);
58+
59+
return value <= 0 ? ProcessorCount : value;
60+
}
61+
62+
/// <summary>
63+
/// Runs actions over source items concurrently and asynchronously.
64+
/// </summary>
65+
/// <typeparam name="T">Type of items to process</typeparam>
66+
/// <param name="source">The source.</param>
67+
/// <param name="callback">The callback.</param>
68+
/// <param name="maxDegreeOfParallelism">
69+
/// The maximum degree of parallelism. If zero or negative, default scheduler value is used.
70+
/// See <see cref="GetMaxDegreeOfParallelism"/> documentation for more details.
71+
/// </param>
72+
/// <param name="cancellation">The cancellation.</param>
73+
// BASEDON https://stackoverflow.com/a/25877042
74+
public static async Task ForEachAsync<T>(
75+
[NotNull] this IEnumerable<T> source,
76+
[NotNull] Func<T, CancellationToken, Task> callback,
77+
int maxDegreeOfParallelism = 0,
78+
CancellationToken cancellation = default)
79+
{
80+
Code.NotNull(source, nameof(source));
81+
Code.NotNull(callback, nameof(callback));
82+
83+
maxDegreeOfParallelism = TaskScheduler.Current.GetMaxDegreeOfParallelism(maxDegreeOfParallelism);
84+
85+
using (var customCancellation = CreateCancellation(cancellation))
86+
using (customCancellation.CancellationScope())
87+
{
88+
var combinedToken = customCancellation.Token;
89+
90+
await Partitioner.Create(source)
91+
.GetPartitions(maxDegreeOfParallelism)
92+
.Select(
93+
partition => TaskEx.Run(
94+
async () =>
95+
{
96+
try
97+
{
98+
using (partition)
99+
{
100+
while (partition.MoveNext() && !combinedToken.IsCancellationRequested)
101+
{
102+
cancellation.ThrowIfCancellationRequested();
103+
await callback(partition.Current, combinedToken).ConfigureAwait(false);
104+
}
105+
}
106+
}
107+
catch (Exception)
108+
{
109+
// ReSharper disable once AccessToDisposedClosure
110+
customCancellation.Cancel();
111+
throw;
112+
}
113+
},
114+
combinedToken))
115+
.WhenAll()
116+
.WithAggregateException()
117+
.ConfigureAwait(false);
118+
119+
customCancellation.Token.ThrowIfCancellationRequested();
120+
}
121+
}
122+
123+
/// <summary>
124+
/// Runs actions over source items concurrently and asynchronously.
125+
/// </summary>
126+
/// <typeparam name="T">Type of items to process</typeparam>
127+
/// <typeparam name="TResult">The type of the result.</typeparam>
128+
/// <param name="source">The source.</param>
129+
/// <param name="callback">The callback.</param>
130+
/// <param name="maxDegreeOfParallelism">The maximum degree of parallelism. If zero or negative, default scheduler value is used.
131+
/// See <see cref="GetMaxDegreeOfParallelism" /> documentation for more details.</param>
132+
/// <param name="cancellation">The cancellation.</param>
133+
[ItemNotNull]
134+
public static async Task<TResult[]> ForEachAsync<T, TResult>(
135+
[NotNull] this IEnumerable<T> source,
136+
[NotNull] Func<T, CancellationToken, Task<TResult>> callback,
137+
int maxDegreeOfParallelism,
138+
CancellationToken cancellation = default)
139+
{
140+
Code.NotNull(source, nameof(source));
141+
Code.NotNull(callback, nameof(callback));
142+
143+
var queue = new ConcurrentQueue<TResult>();
144+
145+
await source
146+
.ForEachAsync(
147+
async (t, ct) =>
148+
{
149+
var x = await callback(t, ct).ConfigureAwait(false);
150+
queue.Enqueue(x);
151+
},
152+
maxDegreeOfParallelism,
153+
cancellation)
154+
.ConfigureAwait(false);
155+
156+
return queue.ToArray();
157+
}
158+
159+
/// <summary>
160+
/// Simplifies the <see cref="AggregateException"/> handling on await.
161+
/// By default awaiter rethrows only first exception of the <see cref="Exception.InnerException"/>.
162+
/// This helper rethrows original <see cref="AggregateException"/> as is.
163+
/// </summary>
164+
/// <param name="source">The task that may throw <see cref="AggregateException"/>.</param>
165+
// BASEDON https://stackoverflow.com/a/18315625
166+
public static async Task WithAggregateException([NotNull] this Task source)
167+
{
168+
Code.NotNull(source, nameof(source));
169+
try
170+
{
171+
await source.ConfigureAwait(false);
172+
}
173+
catch
174+
{
175+
if (source.Exception != null)
176+
ExceptionDispatchInfo.Capture(source.Exception).Throw();
177+
throw;
178+
}
179+
}
180+
181+
/// <summary>
182+
/// Simplifies the <see cref="AggregateException"/> handling on await.
183+
/// By default awaiter rethrows only first exception of the <see cref="Exception.InnerException"/>.
184+
/// This helper rethrows original <see cref="AggregateException"/> as is.
185+
/// </summary>
186+
/// <param name="source">The task that may throw <see cref="AggregateException"/>.</param>
187+
// BASEDON https://stackoverflow.com/a/18315625
188+
public static async Task<T> WithAggregateException<T>([NotNull] this Task<T> source)
189+
{
190+
Code.NotNull(source, nameof(source));
191+
try
192+
{
193+
return await source.ConfigureAwait(false);
194+
}
195+
catch
196+
{
197+
if (source.Exception != null)
198+
ExceptionDispatchInfo.Capture(source.Exception).Throw();
199+
throw;
200+
}
201+
}
202+
}
203+
}

CodeJam.Main/Threading/TaskHelper.NonGenerated.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,5 +119,19 @@ public static async Task WaitForCancellationAsync(this CancellationToken cancell
119119
{
120120
}
121121
}
122+
123+
/// <summary>
124+
/// Creates safe for await <see cref="TaskCompletionSource{TResult}"/> with <see cref="TaskCreationOptions.RunContinuationsAsynchronously"/> mode.
125+
/// See https://devblogs.microsoft.com/premier-developer/the-danger-of-taskcompletionsourcet-class/ for explanation.
126+
/// </summary>
127+
public static TaskCompletionSource<T> CreateAsyncTaskSource<T>() =>
128+
new TaskCompletionSource<T>(TaskCreationOptions.RunContinuationsAsynchronously);
129+
130+
/// <summary>
131+
/// Creates safe for await <see cref="TaskCompletionSource{TResult}"/> with <see cref="TaskCreationOptions.RunContinuationsAsynchronously"/> mode.
132+
/// See https://devblogs.microsoft.com/premier-developer/the-danger-of-taskcompletionsourcet-class/ for explanation.
133+
/// </summary>
134+
public static TaskCompletionSource<T> CreateAsyncTaskSource<T>(TaskCreationOptions creationOptions) =>
135+
new TaskCompletionSource<T>(creationOptions | TaskCreationOptions.RunContinuationsAsynchronously);
122136
}
123137
}

0 commit comments

Comments
 (0)