线程安全的缓存枚举器-使用yield锁定 [英] Thread-safe Cached Enumerator - lock with yield

查看:80
本文介绍了线程安全的缓存枚举器-使用yield锁定的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个自定义的"CachedEnumerable"类(受缓存IEnumerable 的启发),我需要确保该线程对我的线程安全asp.net核心网络应用.

I have a custom "CachedEnumerable" class (inspired by Caching IEnumerable) that I need to make thread safe for my asp.net core web app.

以下枚举器线程实现是否安全? (对IList _cache的所有其他读取/写入均已适当锁定)(可能与 C#产量是否释放锁有关?)

Is the following implementation of the Enumerator thread safe? (All other reads/writes to IList _cache are locked appropriately) (Possibly related to Does the C# Yield free a lock?)

更具体地说,如果有2个线程正在访问枚举数,那么我该如何防止一个线程递增索引",从而导致第二个枚举线程从_cache中获取错误的元素(即,索引+ 1处的元素而不是在索引)?这是种族问题吗?

And more specifically, if there are 2 threads accessing the enumerator, how do I protect against one thread incrementing "index" causing a second enumerating thread from getting the wrong element from the _cache (ie. element at index + 1 instead of at index)? Is this race condition a real concern?

public IEnumerator<T> GetEnumerator()
{
    var index = 0;

    while (true)
    {
        T current;
        lock (_enumeratorLock)
        {
            if (index >= _cache.Count && !MoveNext()) break;
            current = _cache[index];
            index++;
        }
        yield return current;
    }
}

我的CachedEnumerable版本的完整代码:

Full code of my version of CachedEnumerable:

 public class CachedEnumerable<T> : IDisposable, IEnumerable<T>
    {
        IEnumerator<T> _enumerator;
        private IList<T> _cache = new List<T>();
        public bool CachingComplete { get; private set; } = false;

        public CachedEnumerable(IEnumerable<T> enumerable)
        {
            switch (enumerable)
            {
                case CachedEnumerable<T> cachedEnumerable: //This case is actually dealt with by the extension method.
                    _cache = cachedEnumerable._cache;
                    CachingComplete = cachedEnumerable.CachingComplete;
                    _enumerator = cachedEnumerable.GetEnumerator();

                    break;
                case IList<T> list:
                    //_cache = list; //without clone...
                    //Clone:
                    _cache = new T[list.Count];
                    list.CopyTo((T[]) _cache, 0);
                    CachingComplete = true;
                    break;
                default:
                    _enumerator = enumerable.GetEnumerator();
                    break;
            }
        }

        public CachedEnumerable(IEnumerator<T> enumerator)
        {
            _enumerator = enumerator;
        }

        private int CurCacheCount
        {
            get
            {
                lock (_enumeratorLock)
                {
                    return _cache.Count;
                }
            }
        }

        public IEnumerator<T> GetEnumerator()
        {
            var index = 0;

            while (true)
            {
                T current;
                lock (_enumeratorLock)
                {
                    if (index >= _cache.Count && !MoveNext()) break;
                    current = _cache[index];
                    index++;
                }
                yield return current;
            }
        }

        //private readonly AsyncLock _enumeratorLock = new AsyncLock();
        private readonly object _enumeratorLock = new object();

        private bool MoveNext()
        {
            if (CachingComplete) return false;

            if (_enumerator != null && _enumerator.MoveNext()) //The null check should have been unnecessary b/c of the lock...
            {
                _cache.Add(_enumerator.Current);
                return true;
            }
            else
            {
                CachingComplete = true;
                DisposeWrappedEnumerator(); //Release the enumerator, as it is no longer needed.
            }

            return false;
        }

        public T ElementAt(int index)
        {
            lock (_enumeratorLock)
            {
                if (index < _cache.Count)
                {
                    return _cache[index];
                }
            }

            EnumerateUntil(index);

            lock (_enumeratorLock)
            {
                if (_cache.Count <= index) throw new ArgumentOutOfRangeException(nameof(index));
                return _cache[index];
            }
        }


        public bool TryGetElementAt(int index, out T value)
        {
            lock (_enumeratorLock)
            {
                value = default;
                if (index < CurCacheCount)
                {
                    value = _cache[index];
                    return true;
                }
            }

            EnumerateUntil(index);

            lock (_enumeratorLock)
            {
                if (_cache.Count <= index) return false;
                value = _cache[index];
            }

            return true;
        }

        private void EnumerateUntil(int index)
        {
            while (true)
            {
                lock (_enumeratorLock)
                {
                    if (_cache.Count > index || !MoveNext()) break;
                }
            }
        }


        public void Dispose()
        {
            DisposeWrappedEnumerator();
        }

        private void DisposeWrappedEnumerator()
        {
            if (_enumerator != null)
            {
                _enumerator.Dispose();
                _enumerator = null;
                if (_cache is List<T> list)
                {
                    list.Trim();
                }
            }
        }

        IEnumerator IEnumerable.GetEnumerator()
        {
            return GetEnumerator();
        }

        public int CachedCount
        {
            get
            {
                lock (_enumeratorLock)
                {
                    return _cache.Count;
                }
            }
        }

        public int Count()
        {
            if (CachingComplete)
            {
                return _cache.Count;
            }

            EnsureCachingComplete();

            return _cache.Count;
        }

        private void EnsureCachingComplete()
        {
            if (CachingComplete)
            {
                return;
            }

            //Enumerate the rest of the collection
            while (!CachingComplete)
            {
                lock (_enumeratorLock)
                {
                    if (!MoveNext()) break;
                }
            }
        }

        public T[] ToArray()
        {
            EnsureCachingComplete();
            //Once Caching is complete, we don't need to lock
            if (!(_cache is T[] array))
            {
                array = _cache.ToArray();
                _cache = array;
            }

            return array;
        }

        public T this[int index] => ElementAt(index);
    }

    public static CachedEnumerable<T> Cached<T>(this IEnumerable<T> source)
    {
        //no gain in caching a cache.
        if (source is CachedEnumerable<T> cached)
        {
            return cached;
        }

        return new CachedEnumerable<T>(source);
    }
}

基本用法:(尽管不是有意义的用例)

Basic Usage: (Although not a meaningful use case)

var cached = expensiveEnumerable.Cached();
foreach (var element in cached) {
   Console.WriteLine(element);
}


更新

我根据@Theodors回答 https://stackoverflow.com/a/58547863/5683904 测试了当前实现并确认(AFAICT)使用foreach枚举时是线程安全的,而不创建重复的值(

I tested the current implementation based on @Theodors answer https://stackoverflow.com/a/58547863/5683904 and confirmed (AFAICT) that it is thread-safe when enumerated with a foreach without creating duplicate values (Thread-safe Cached Enumerator - lock with yield):

class Program
{
    static async Task Main(string[] args)
    {
        var enumerable = Enumerable.Range(0, 1_000_000);
        var cachedEnumerable = new CachedEnumerable<int>(enumerable);
        var c = new ConcurrentDictionary<int, List<int>>();
        var tasks = Enumerable.Range(1, 100).Select(id => Test(id, cachedEnumerable, c));
        Task.WaitAll(tasks.ToArray());
        foreach (var keyValuePair in c)
        {
            var hasDuplicates = keyValuePair.Value.Distinct().Count() != keyValuePair.Value.Count;
            Console.WriteLine($"Task #{keyValuePair.Key} count: {keyValuePair.Value.Count}. Has duplicates? {hasDuplicates}");
        }
    }

    static async Task Test(int id, IEnumerable<int> cache, ConcurrentDictionary<int, List<int>> c)
    {
        foreach (var i in cache)
        {
            //await Task.Delay(10);
            c.AddOrUpdate(id, v => new List<int>() {i}, (k, v) =>
            {
                v.Add(i);
                return v;
            });
        }
    }
}

推荐答案

您的类不是线程安全的,因为共享状态在类内部不受保护的区域中发生了突变.未受保护的区域是:

Your class is not thread safe, because shared state is mutated in unprotected regions inside your class. The unprotected regions are:

  1. 构造函数
  2. Dispose方法

共享状态为:

  1. _enumerator私有字段
  2. _cache私有字段
  3. CachingComplete公共财产
  1. The _enumerator private field
  2. The _cache private field
  3. The CachingComplete public property

有关您班级的其他一些问题:

Some other issues regarding your class:

  1. 实施IDisposable给调用者带来了处置您的类的责任. IEnumerable不需要是一次性的.相反,IEnumerator是可抛弃的,但是有语言支持它们的自动处理(foreach语句的功能).
  2. 您的课程提供了IEnumerable(ElementAtCount等)所不期望的扩展功能.也许您打算实现CachedList呢?如果不实现IList<T>接口,则Count()ToArray()之类的LINQ方法将无法利用您的扩展功能,并且将像使用普通香草IEnumerable一样使用慢速路径.
  1. Implementing IDisposable creates the responsibility to the caller to dispose your class. There is no need for IEnumerables to be disposable. In the contrary IEnumerators are disposable, but there is language support for their automatic disposal (feature of foreach statement).
  2. Your class offers extended functionality not expected from an IEnumerable (ElementAt, Count etc). Maybe you intended to implement a CachedList instead? Without implementing the IList<T> interface, LINQ methods like Count() and ToArray() cannot take advantage of your extended functionality, and will use the slow path like they do with plain vanilla IEnumerables.


更新:我刚刚注意到另一个线程安全问题.这与public IEnumerator<T> GetEnumerator()方法有关.枚举数是编译器生成的,因为该方法是一个迭代器(使用yield return).编译器生成的枚举器不是线程安全的.请考虑以下代码,例如:


Update: I just noticed another thread-safety issue. This one is related to the public IEnumerator<T> GetEnumerator() method. The enumerator is compiler-generated, since the method is an iterator (utilizes yield return). Compiler-generated enumerators are not thread safe. Consider this code for example:

var enumerable = Enumerable.Range(0, 1_000_000);
var cachedEnumerable = new CachedEnumerable<int>(enumerable);
var enumerator = cachedEnumerable.GetEnumerator();
var tasks = Enumerable.Range(1, 4).Select(id => Task.Run(() =>
{
    int count = 0;
    while (enumerator.MoveNext())
    {
        count++;
    }
    Console.WriteLine($"Task #{id} count: {count}");
})).ToArray();
Task.WaitAll(tasks);

四个线程同时使用同一IEnumerator.该枚举数有1,000,000个项目.您可能希望每个线程都可以枚举约250,000个项目,但这不会发生.

Four threads are using concurrently the same IEnumerator. The enumerable has 1,000,000 items. You may expect that each thread would enumerate ~250,000 items, but that's not what happens.

输出:

任务#1计数:0
任务#4计数:0
任务#3计数:0
任务2:1000000

Task #1 count: 0
Task #4 count: 0
Task #3 count: 0
Task #2 count: 1000000

while (enumerator.MoveNext())行中的MoveNext不是您的安全MoveNext.它是编译器生成的不安全MoveNext.尽管不安全,但它包含了一种预期的机制可能是为了处理异常,它会在调用外部提供的代码之前临时将枚举器标记为已完成.因此,当多个线程同时调用MoveNext时,除第一个以外的所有线程都将获得返回值false,并且将立即终止枚举,并完成零循环.为了解决这个问题,您可能必须编写自己的IEnumerator类.

The MoveNext in the line while (enumerator.MoveNext()) is not your safe MoveNext. It is the compiler-generated unsafe MoveNext. Although unsafe, it includes a mechanism intended probably for dealing with exceptions, that marks temporarily the enumerator as finished before calling the externally provided code. So when multiple threads are calling the MoveNext concurrently, all but the first will get a return value of false, and will terminate instantly the enumeration, having completed zero loops. To solve this you must probably code your own IEnumerator class.

更新:实际上,我关于线程安全枚举的最后一点有点不公平,因为使用IEnumerator接口进行枚举本质上是不安全的操作,如果不进行操作,则无法解决.调用代码.这是因为获取下一个元素不是原子操作,因为它涉及两个步骤(调用MoveNext() +读取Current).因此,您对线程安全的关注仅限于保护类的内部状态(字段_enumerator_cacheCachingComplete).这些仅在构造函数和Dispose方法中处于不受保护的状态,但是我想您的类的正常使用可能不会遵循会导致竞争条件而导致内部状态损坏的代码路径.

Update: Actually my last point about thread-safe enumeration is a bit unfair, because enumerating with the IEnumerator interface is an inherently unsafe operation, which is impossible to fix without the cooperation of the calling code. This is because obtaining the next element is not an atomic operation, since it involves two steps (call MoveNext() + read Current). So your thread-safety concerns are limited to the protection of the internal state of your class (fields _enumerator, _cache and CachingComplete). These are left unprotected only in the constructor and in the Dispose method, but I suppose that the normal use of your class may not follow code paths that create the race conditions that would result to internal state corruption.

我个人也希望同时处理这些代码路径,并且我不会让它产生偶然性.

Personally I would prefer to take care of these code paths too, and I wouldn't let it to the whims of chance.

更新:我为IAsyncEnumerable编写了一个缓存,以演示另一种技术.源IAsyncEnumerable的枚举不是由调用方使用锁或信号量来获取独占访问权限的,而是由单独的辅助任务驱动的.第一个调用者启动工作程序任务.每个调用方首先产生已经缓存的所有项目,然后等待更多项目,或等待没有更多项目的通知.作为通知机制,我使用了 TaskCompletionSource<bool> .仍然使用lock来确保对共享资源的所有访问都已同步.

Update: I wrote a cache for IAsyncEnumerables, to demonstrate an alternative technique. The enumeration of the source IAsyncEnumerable is not driven by the callers, using locks or semaphores to obtain exclusive access, but by a separate worker-task. The first caller starts the worker-task. Each caller at first yields all items that are already cached, and then awaits for more items, or for a notification that there are no more items. As notification mechanism I used a TaskCompletionSource<bool>. A lock is still used to ensure that all access to shared resources is synchronized.

public class CachedAsyncEnumerable<T> : IAsyncEnumerable<T>
{
    private readonly object _locker = new object();
    private IAsyncEnumerable<T> _source;
    private Task _sourceEnumerationTask;
    private List<T> _buffer;
    private TaskCompletionSource<bool> _moveNextTCS;
    private Exception _sourceEnumerationException;
    private int _sourceEnumerationVersion; // Incremented on exception

    public CachedAsyncEnumerable(IAsyncEnumerable<T> source)
    {
        _source = source ?? throw new ArgumentNullException(nameof(source));
    }

    public async IAsyncEnumerator<T> GetAsyncEnumerator(
        CancellationToken cancellationToken = default)
    {
        lock (_locker)
        {
            if (_sourceEnumerationTask == null)
            {
                _buffer = new List<T>();
                _moveNextTCS = new TaskCompletionSource<bool>();
                _sourceEnumerationTask = Task.Run(
                    () => EnumerateSourceAsync(cancellationToken));
            }
        }
        int index = 0;
        int localVersion = -1;
        while (true)
        {
            T current = default;
            Task<bool> moveNextTask = null;
            lock (_locker)
            {
                if (localVersion == -1)
                {
                    localVersion = _sourceEnumerationVersion;
                }
                else if (_sourceEnumerationVersion != localVersion)
                {
                    ExceptionDispatchInfo
                        .Capture(_sourceEnumerationException).Throw();
                }
                if (index < _buffer.Count)
                {
                    current = _buffer[index];
                    index++;
                }
                else
                {
                    moveNextTask = _moveNextTCS.Task;
                }
            }
            if (moveNextTask == null)
            {
                yield return current;
                continue;
            }
            var moved = await moveNextTask;
            if (!moved) yield break;
            lock (_locker)
            {
                current = _buffer[index];
                index++;
            }
            yield return current;
        }
    }

    private async Task EnumerateSourceAsync(CancellationToken cancellationToken)
    {
        TaskCompletionSource<bool> localMoveNextTCS;
        try
        {
            await foreach (var item in _source.WithCancellation(cancellationToken))
            {
                lock (_locker)
                {
                    _buffer.Add(item);
                    localMoveNextTCS = _moveNextTCS;
                    _moveNextTCS = new TaskCompletionSource<bool>();
                }
                localMoveNextTCS.SetResult(true);
            }
            lock (_locker)
            {
                localMoveNextTCS = _moveNextTCS;
                _buffer.TrimExcess();
                _source = null;
            }
            localMoveNextTCS.SetResult(false);
        }
        catch (Exception ex)
        {
            lock (_locker)
            {
                localMoveNextTCS = _moveNextTCS;
                _sourceEnumerationException = ex;
                _sourceEnumerationVersion++;
                _sourceEnumerationTask = null;
            }
            localMoveNextTCS.SetException(ex);
        }
    }
}

此实现遵循处理异常的特定策略.如果枚举源IAsyncEnumerable时发生异常,则该异常将传播到所有当前调用者,当前使用的IAsyncEnumerator将被丢弃,不完整的缓存数据也将被丢弃.当收到下一个枚举请求时,新的工作任务可能会在以后再次开始.

This implementation follows a specific strategy for dealing with exceptions. If an exception occurs while enumerating the source IAsyncEnumerable, the exception will be propagated to all current callers, the currently used IAsyncEnumerator will be discarded, and the incomplete cached data will be discarded too. A new worker-task may start again later, when the next enumeration request is received.

这篇关于线程安全的缓存枚举器-使用yield锁定的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆