kuniga.me > NP-Incompleteness > Coroutines in C++: A Minimal Library

Coroutines in C++: A Minimal Library

23 Aug 2024

C++ logo

In the previous post we started studying coroutines in C++. One of the main takeaways is that the coroutine API offered by STL is pretty low level and not very useful for end developers.

In this post I’d like to cover the implementation of a minimal library that supports co_await and co_return. It’s a dummy library in the sense that it doesn’t do anything with co_await and treats co_return as return, so it doesn’t add any value compared to non-coroutine code.

The goal however is to understand the necessary machinery that has to happen for this simplest example to work.

Multiple co_awaits

At the end of [1] we mentioned that while we had an example using co_await, it didn’t handle multiple co-awaits in the same function. Suppose we have the following code:

Task h() {
    co_return "hello";
}

Task g() {
    std::string v1 = co_await h(); // (1)
    std::string v2 = co_await h();
    co_return v1 + v2;
}

Task f() {
    std::string r = co_await g();
    co_return r; // (2)
}

int main() {
    auto coro = f();
    coro();
}

The problem is that when we await the first h() inside g() at (1), it will return control back to f(), which will resume, call co_return at (2) and return to main(), and we’ll never call the second h().

On the other hand, if we are to ignore co_await and treat co_return as regular return, the expected sequence of events is:

- f runs
- f awaits g
- g runs
- g awaits h
- h runs
- h returns
- g runs (resumes)
- g await h
- h runs
- h returns
- g runs (resumes)
- g returns
- f runs (resumes)
- f returns

Translating it into the coroutine API calls we learned in [1], this would look like (ignoring some APIs like initial_suspend()):

- f.resume()
- g.await_suspend(f) // (1)
- g.resume()
- h.await_suspend(g)
- h.resume()
- h.final_suspend()
- g.resume()
- h.await_suspend(g)
- h.resume()
- h.final_suspend()
- g.resume()
- g.final_suspend()
- f.resume()
- f.final_suspend()

The first thing to observe is (1), that await_suspend is called on the right hand side object of co_await, i.e. the awaitable, not on the function calling it. A handler to the caller is passed as parameter. In other words, if we do co_await h inside g(), the API called is h.await_suspend(g).

Further, await_suspend is the API that gives us the opportunity to connect the caller g and the “callee”, h. This is important because we need h to call g.resume() after it is finished, so when we call await_suspend, we must store a reference to g in h. In order to do so, we add an optional Task (defined in [1]) to the Promise class (also defined in [1]):

struct Promise {
    ...
    std::optional<Task> caller;
    ...
};

We also want to call await_suspend() before we start executing the awaitable, because if such awaitable only has a co_return (such as h()), it will finish executing and return before we have the opportunity to store who the caller was, so we have to change the Promise to always start suspended:

struct Promise {
    ...
    std::suspend_always initial_suspend() {
        return {};
    }
    ...
};

Then we modify our implementation of Task::await_suspend() to save a reference to the caller and only then start the execution of the awaitable via .resume():

void Task::await_suspend(Task handler) {
    _handle.promise().caller = handler;
    // start executing callee
    _handle.resume();
}

To summarize, let’s go over what’s happening in f():

Task f() {
    std::string r = co_await g();
    co_return r;
}

Recall that a co_await gets transpiled roughly to:

T co_await(P& promise, Awaitable& awaitable) {
    using handle_t = std::experimental::coroutine_handle<P>;

    if (!awaitable.await_ready()) {
        // (1) <suspend-coroutine>

        awaitable.await_suspend(
            handle_t::from_promise(promise)
        );
        // (2) <return-to-caller>

        // (3) <resume-point>
    }
    return awaitable.await_resume();
}

So when we run std::string r = co_await g();, first g() returns an instance of Task (an Awaitable). We then suspend f() at (1). We then call Task::await_suspend() which as we saw stores a reference of f in g()’s promise. Then we start executing g().

As we’ll show next, g() will run to completion before we hit (2). In fact we’ll completely bypass (2) because we’ll have g() resume f() “manually”, which will cause the execution to jump to (3).

Let’s now show how the callee can resume the caller once it is done. Recall that a co-routine is roughly transpiled to [1]:

{
    auto promise = new Promise();
    co_await promise.initial_suspend();
    try {
        <body-statements>
    } catch (...) {
        promise.unhandled_exception();
    }
    co_await promise.final_suspend();
}

As we see above, final_suspend returns an awaitable which will be co_await‘ed. We’ll actually have that awaitable resume the caller. For that end, we can create a new awaitable type to support this custom behavior, which we’ll call it FinalAwaiter:

// forward declaration
struct Promise;

struct FinalAwaiter {
    using Handle = std::coroutine_handle<Promise>;

    bool await_ready() noexcept { return false; }

    void await_suspend(Handle h) noexcept;

    void await_resume() noexcept {}
};

Which is what we return in final_suspend():

struct Promise {
    ...

    FinalAwaiter final_suspend() noexcept {
        return FinalAwaiter{};
    }

    ...
}

Note that we don’t need to pass any info about the Promise when creating a FinalAwaiter. The promise will be available via the Handle parameter to FinalAwaiter::await_suspend(). The implementation of await_suspend() then resumes the caller:

void FinalAwaiter::await_suspend(Handle h) noexcept {
    if (h.promise().caller) {
        h.promise().caller->resume();
    }
}

Let’s go over it via another example to see the whole picture. When a function g() does:

Task g() {
    std::string v1 = co_await h();
    std::string v2 = co_await h();
    co_return v1 + v2;
}

g() suspends on the first co_await, h() runs to completion (i.e. without yielding control back to g()) and then resumes g() explicitly. It proceeds to the second co_await and it suspends again, the second h() runs to completion and then resumes g() once again. Finally g() itself hits a co_return v1 + v2, which gets translated into:

...
promise.return_value(v1 + v2); // (1)
FinalAwaiter fa = promise.final_suspend(); // (2)
co_await fa;

In (1) we store value inside the Promise. In (2), promise.final_suspend() will return a FinalAwaiter which has a pointer to the caller, so when we co_await it in (3) it will call await_suspend and resume the caller, in this case f().

So when the caller gets resumed, it will call

return awaitable.await_resume();

Which we can implement as:

std::string Task::await_resume() {
    return promise().get_value();
}

to retrieve the value set via promise.return_value() by the callee. The code for this part is available on Github.

Needless to say, this flow is super confusing because we seem to be hijacking the coroutine API by resuming coroutines via the callee. If we look at API calls indented to highlight the call stack, we get a chain of function calls:

- f.resume()
    // await g()
    - g.await_suspend(f)
        - g.resume()
            // await h()
            - h.await_suspend(g)
                - h.resume()
                    // co_return
                    - h.final_suspend()
                        - g.resume()
                            // await h()
                            - h.await_suspend(g)
                                - h.resume()
                                    // co_return
                                    - h.final_suspend()
                                        - g.resume()
                                            // co_return
                                            - g.final_suspend()
                                                - f.resume()
                                                     // co_return
                                                    - f.final_suspend()

Which has the additional downside that it can cause stack overflows.

Symmetric Transfer

There’s an alternative overload for await_suspend() which allows us to return a coroutine handler instead of void [3]. In this case co_await() gets transpiled to something like:

T co_await(P& promise, Awaitable& awaitable) {
    if (!awaitable.await_ready()) {
        // (1) <suspend-coroutine>

        auto h = awaitable.await_suspend(
            handle_t::from_promise(promise)
        );
        h.resume();

        // (2) <return-to-caller>

        // (3) <resume-point>
    }
    return awaitable.await_resume();
}

In this overload the h.resume() is called outside the await_suspend, which avoids some of the stack calls. We have to modify Task::await_suspend() to:

std::coroutine_handle<> Task::await_suspend(
    std::coroutine_handle<> handler
) {
    handle_.promise().caller = handler;
    return handle_;
}

and FinalAwaiter::await_suspend() to:

std::coroutine_handle<> FinalAwaiter::await_suspend(
    Handle h
) noexcept {
    return h.promise().caller;
}

With this API our example callstack looks like:

- f.resume()
    // await g()
    - g.await_suspend(f)
    - g.resume()
        // await h()
        - h.await_suspend(g)
        - h.resume()
            // co_return
            - h.final_suspend()
            - g.resume()
                // await h()
                - h.await_suspend(g)
                - h.resume()
                    // co_return
                    - h.final_suspend()
                    - g.resume()
                        // co_return
                        - g.final_suspend()
                        - f.resume()
                            // co_return
                            - f.final_suspend()

So it cut the maximum depth in half. The problem with this API is that upon a co_return, the snippet:

auto h = awaitable.await_suspend(
    handle_t::from_promise(promise)
);

calls FinalAwaiter::await_suspend() which expects a “caller” to be set. In our current setup, the first coroutine function called (in our case f()) doesn’t have it.

To handle this we can return a different awaitable (instead of FinalAwaiter) for the root task. One option is to define an “adapter” coroutine, that extracts the value of a Task but returns a different type of awaitable, RootTask.

RootTask adapter(Task&& t) {
    std::string value = co_await t;
    co_return value;
}

Suppose RootTask is associated with the promise RootTaskPromise (which we’ll define later). Recall that co_return value is transpiled to something like:

auto promise = RootTaskPromise();
...
promise.return_value(value);
co_await promise.final_suspend();

Here we call final_suspend() on RootTaskPromise, not on TaskPromise, so we don’t have to return the FinalAwaiter. Let’s comple implementation:

struct RootTask {
    using coroutine_handle_t = std::coroutine_handle<RootTaskPromise>;
    using promise_type = RootTaskPromise;

	RootTask(coroutine_handle_t coroutine)
		: m_coroutine(coroutine) {}

    void resume() {
        m_coroutine.resume();
    }

    std::string& result() {
        return m_coroutine.promise().m_result;
    }

    coroutine_handle_t m_coroutine;
};

and a corresponding promise RootTaskPromise:

struct RootTaskPromise {

    std::suspend_never initial_suspend() {
		return{};
    }

    std::coroutine_handle<RootTaskPromise> get_return_object() {
        return {
            std::coroutine_handle<RootTaskPromise>::from_promise(*this)
        };
    }

    void unhandled_exception() noexcept {}

    void return_value(std::string result) {
        m_result = result;
    }

    std::suspend_always final_suspend() noexcept {
        return {};
	}

    std::string m_result;
};

We can then define a utility function that calls the adapter function and then extract the value:

std::string sync_wait(Task&& t) {
    RootTask rt = adapter(std::move(t));
    return rt.result();
}

The code for this part is available on Github.

Conclusion

In this post we aimed to build a minimal (and useless) coroutine library that doesn’t add value on top of regular function. The goal was however to build something higher level, that has semantics closer to that in other programming languages since we’ve seen that coroutines in C++ are pretty low-level.

This exercise was heavily based on Lewiss Baker’s coroutine library. Reading library code is a frustrating but rewarding exercise. Even with a reference code that I was able to run, I had a hard time groking the logic and for sure I’m not getting the overall picture.

Coroutines in C++ is one of the most difficult code I’ve seen in a while. Even after spending hours playing with it and writing this post I don’t feel like a have a satisfactory understanding.

References