Multi-word CAS를 구현해봤습니다.

https://paulcavallaro.com/blog/a-practical-multi-word-compare-and-swap-operation/

이 링크를 참고하면서 구현한 코드입니다:

rdcss.h

/*
 * Refered to this paper:
 * “A Practical Multi-Word Compare-and-Swap Operation” by Timothy et al.
 */

#pragma once

#include <cstddef>
#include <atomic>

namespace lockfree_op
{
    enum class value_status_t : size_t
    {
        NORMAL,
        RDCSS,
        DCAS,
        UNDECIDED,
    };

    struct alignas(sizeof(size_t)) atomic_ext
    {
        std::atomic_size_t value = 0;
        std::atomic<value_status_t> status = value_status_t::NORMAL;
    };

    struct atomic_captured
    {
        size_t value = 0;
        value_status_t status = value_status_t::NORMAL;
    };

    struct rdcss_descriptor
    {
        std::atomic_size_t& atom1;
        size_t expected1;
        atomic_ext& atom2;
        size_t expected2;
        size_t new2;
        value_status_t status2;
    };

    namespace detail
    {
        template<class T>
        size_t toSizeT(const T* ptr)
        {
            return reinterpret_cast<size_t>(ptr);
        }

        rdcss_descriptor* toRdcssDescriptor(const size_t x)
        {
            return reinterpret_cast<rdcss_descriptor*>(x);
        }
    }

    atomic_captured cas1(atomic_ext& atom,
        const size_t expectedValue,
        const value_status_t expectedStatus,
        const size_t newValue,
        const value_status_t newStatus)
    {
        while (true)
        {
            value_status_t oldStatus = atom.status.load();
            if (oldStatus == value_status_t::UNDECIDED)
                continue;

            size_t oldValue = atom.value.load();
            if (oldStatus != expectedStatus || oldValue != expectedValue)
                return atomic_captured{
                .value = oldValue, .status = oldStatus };

            if (!atom.status.compare_exchange_weak(
                oldStatus, value_status_t::UNDECIDED))
                continue;

            bool isSuccessed =
                atom.value.compare_exchange_weak(oldValue, newValue);

            do
            {
                oldStatus = value_status_t::UNDECIDED;
            } while(!atom.status.compare_exchange_weak(
                oldStatus, isSuccessed ? newStatus : expectedStatus));

            return atomic_captured{
                .value = expectedValue, .status = expectedStatus };
        }
    }

    bool isRdcssDescriptor(const atomic_captured& captured)
    {
        return captured.status == value_status_t::RDCSS;
    }

    void complete(rdcss_descriptor& descriptor)
    {
        size_t assignValue =
            (descriptor.atom1.load() == descriptor.expected1 ?
                descriptor.new2 : descriptor.expected2);

        cas1(descriptor.atom2,
            detail::toSizeT(&descriptor), value_status_t::RDCSS,
            assignValue,
            descriptor.status2);
    }

    atomic_captured rdcss(rdcss_descriptor& descriptor)
    {
        atomic_captured captured;
        do
        {
            captured = cas1(descriptor.atom2,
                descriptor.expected2, value_status_t::NORMAL,
                detail::toSizeT(&descriptor), value_status_t::RDCSS);
            if (isRdcssDescriptor(captured))
            {
                rdcss_descriptor* oldDescriptor =
                    detail::toRdcssDescriptor(captured.value);
                if (oldDescriptor == nullptr)
                    continue;

                complete(*oldDescriptor);
            }
        } while (isRdcssDescriptor(captured));

        if (captured.value == descriptor.expected2)
            complete(descriptor);

        return captured;
    }
}

dcas.h

/*
 * Refered to this paper:
 * “A Practical Multi-Word Compare-and-Swap Operation” by Timothy et al.
 */

#pragma once

#include <array>

#include "rdcss.h"

namespace lockfree_op
{
    enum class dcas_status_t : size_t
    {
        UNDECIDED,
        SUCCESSED,
        FAILED
    };

    struct cas_requirement
    {
        atomic_ext& atom;
        size_t expectedValue;
        size_t newValue;
    };

    struct dcas_descriptor
    {
        std::atomic<dcas_status_t> status;
        std::array<cas_requirement, 2>& requirements;
        std::array<size_t, 2> order;
        std::atomic<size_t> cntRef;

        static dcas_descriptor create(
            std::array<cas_requirement, 2>& requirements);

        void doDcas();
    };

    namespace detail
    {
        dcas_descriptor* toDcasDescriptor(const size_t x)
        {
            return reinterpret_cast<dcas_descriptor*>(x);
        }
    }

    bool isDcasDescriptor(const atomic_captured& captured)
    {
        return captured.status == value_status_t::DCAS;
    }

    void applyRequirement(
        dcas_descriptor& descriptor,
        cas_requirement& requirement,
        bool successed)
    {
        size_t assignValue =
            (successed ? requirement.newValue : requirement.expectedValue);
        atomic_captured old = cas1(requirement.atom,
            detail::toSizeT(&descriptor), value_status_t::DCAS,
            assignValue, value_status_t::NORMAL);

        while (isRdcssDescriptor(old) || isDcasDescriptor(old))
            old = cas1(requirement.atom,
                0, value_status_t::NORMAL,
                0, value_status_t::NORMAL);

        requirement.expectedValue = old.value;
    }

    void complete(dcas_descriptor& descriptor)
    {
        bool successed =
            (descriptor.status.load() == dcas_status_t::SUCCESSED);

        for (auto i : descriptor.order)
            applyRequirement(
                descriptor, descriptor.requirements[i], successed);
    }

    rdcss_descriptor createProxy(
        dcas_descriptor& descriptor,
        cas_requirement& requirement)
    {
        return rdcss_descriptor{
            .atom1 =
            reinterpret_cast<std::atomic_size_t&>(descriptor.status),
            .expected1 = static_cast<size_t>(dcas_status_t::UNDECIDED),
            .atom2 = requirement.atom,
            .expected2 = requirement.expectedValue,
            .new2 = detail::toSizeT(&descriptor),
            .status2 = value_status_t::DCAS
        };
    }

    bool wasInstalledSuccessfully(
        rdcss_descriptor& proxy,
        atomic_captured& captured)
    {
        return !isDcasDescriptor(captured) &&
            captured.value == proxy.expected2;
    }

    bool hasPredecessorBeenFound(
        rdcss_descriptor& proxy,
        atomic_captured& captured)
    {
        return isDcasDescriptor(captured) &&
            captured.value != proxy.new2;
    }

    bool isRdcssFailed(rdcss_descriptor& proxy, atomic_captured& captured)
    {
        return !isDcasDescriptor(captured) &&
            captured.value != proxy.expected2;
    }

    void progressDescriptor(
        dcas_descriptor*& descriptor,
        dcas_status_t& oldStatus)
    {
        dcas_status_t status = dcas_status_t::SUCCESSED;

        for (auto i : descriptor->order)
        {
            rdcss_descriptor proxy =
                createProxy(*descriptor, descriptor->requirements[i]);
            atomic_captured captured = rdcss(proxy);

            if (wasInstalledSuccessfully(proxy, captured))
                continue;

            if (hasPredecessorBeenFound(proxy, captured))
            {
                dcas_descriptor* oldDescriptor = detail::toDcasDescriptor(captured.value);
                if (oldDescriptor != nullptr)
                {
                    oldDescriptor->cntRef.fetch_add(1);
                    descriptor = oldDescriptor;

                    return;
                }
            }

            if (isRdcssFailed(proxy, captured))
            {
                status = dcas_status_t::FAILED;

                break;
            }
        }

        descriptor->status.compare_exchange_weak(oldStatus, status);
    }

    dcas_descriptor dcas_descriptor::create(
        std::array<cas_requirement, 2>& requirements)
    {
        std::array<size_t, 2> order =
            (detail::toSizeT(&requirements[0].atom) <
                detail::toSizeT(&requirements[1].atom)) ?
            std::array<size_t, 2>{0, 1} : std::array<size_t, 2>{1, 0};

        return dcas_descriptor{
            .status = dcas_status_t::UNDECIDED,
            .requirements = requirements,
            .order = order,
            .cntRef = 0
        };
    }

    void dcas_descriptor::doDcas()
    {
        dcas_descriptor* descriptor = this;
        while (true)
        {
            dcas_status_t oldStatus = descriptor->status.load();
            if (oldStatus != dcas_status_t::UNDECIDED)
            {
                complete(*descriptor);

                if (descriptor == this)
                    break;

                descriptor->cntRef.fetch_sub(1);
                descriptor = this;

                continue;
            }

            progressDescriptor(descriptor, oldStatus);
        }

        while (cntRef.load() != 0);
    }

    bool dcas(std::array<cas_requirement, 2>& requirements)
    {
        dcas_descriptor descriptor = dcas_descriptor::create(requirements);
        descriptor.doDcas();

        return (descriptor.status.load() == dcas_status_t::SUCCESSED);
    }
}

원문에서는 tagged pointer를 이용해 타입을 구분하고 있는데,
제가 생각하기에는 어차피 tagged pointer의 주소값과 포인터가 가리키는 실제값을 동시에 atomic하게 바꿔야 할 것 같아서 value_status_t를 사용했습니다.

혹시 correctness 관련 문제나 리팩토링이 필요한 부분이 보인다면
적극적인 코드 리뷰 환영합니다:grinning:

3 Likes