Parallel Map in Kotlin

written in collections, coroutines, kotlin, parallel

Ever wonder how to run map in parallel using coroutines? This is how you do it.

import kotlinx.coroutines.async import kotlinx.coroutines.runBlocking import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.awaitAll //sampleStart suspend fun <A, B> Iterable<A>.pmap(f: suspend (A) -> B): List<B> = coroutineScope { map { async { f(it) } }.awaitAll() } //sampleEnd

Confused? Let’s unpack it.

First we have the function signature which is pretty similar to the actual map extension function signature on Iterable. The only thing we added were the suspend keywords. One for our function and another one on the parameter.

Then we have the coroutineScope that marks the scope in which the async calls are going to be executed. This way if something goes wrong and an exception is thrown, all coroutines that were launched in this scope are cancelled. That’s the main benefit of structured concurrency.

Finally we have the actual execution which is divided in 2 steps: The first step launches a new coroutine for each function application using async. This effectively wraps the type of each element with Deferred.

In the second step we wait for all function invocations to complete and unwrap the result using awaitAll(). This is similar to doing .map { it.await() } but better because awaitAll() will fail as soon as one of the invocations fails, instead of having to sequentially wait for the call to await() on the failing deferred. For example, let’s say we call pmap with 3 elements. f(0) will complete in 2 minutes, f(1) completes in 5 minutes and f(3) fails immediately. With .map { it.await() } we’d have to wait for f(1) and f(2) completion before seeing the f(3) failure, whereas awaitAll() will know that something failed right away.

How to use it

Easy! Just like you use map:

import kotlinx.coroutines.async import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.runBlocking import kotlinx.coroutines.awaitAll import kotlinx.coroutines.Dispatchers suspend fun <A, B> Iterable<A>.pmap(f: suspend (A) -> B): List<B> = coroutineScope { map { async { f(it) } }.awaitAll() } //sampleStart fun main(args: Array<String>) = runBlocking(Dispatchers.Default) { println((1..100).pmap { it * 2 }) } //sampleEnd

(Psst! I’m using Kotlin Playground so you can actually run this code!)

Prove that it’s running in parallel

Ok so let’s resort to the good old delay to prove that this is actually running in parallel. We are going to add a delay of 1 second on each multiplication and measure the time it takes to run.

Running over 100 elements the result should be: close to 1,000 milliseconds if it’s running in parallel and close to 100,000 milliseconds if it’s running sequentially.

import kotlinx.coroutines.async import kotlinx.coroutines.runBlocking import kotlin.system.measureTimeMillis import kotlinx.coroutines.delay import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.awaitAll import kotlinx.coroutines.Dispatchers suspend fun <A, B> Iterable<A>.pmap(f: suspend (A) -> B): List<B> = coroutineScope { map { async { f(it) } }.awaitAll() } //sampleStart fun main(args: Array<String>) = runBlocking(Dispatchers.Default) { val time = measureTimeMillis { val output = (1..100).pmap { delay(1000) it * 2 } println(output) } println("Total time: $time") } //sampleEnd

Beware of runBlocking

A previous iteration of this article proposed the following solution:

import kotlinx.coroutines.async import kotlinx.coroutines.runBlocking import kotlinx.coroutines.awaitAll //sampleStart // DON'T DO THIS fun <A, B> Iterable<A>.pmapOld(f: suspend (A) -> B): List<B> = runBlocking { map { async { f(it) } }.awaitAll() } //sampleEnd

As Gildor pointed out in the comments, this a very bad idea. By default runBlocking uses a dispatcher that is confined to the invoker thread. Which means we are forcefully blocking the thread until the execution of pmap finishes, instead of letting the caller decide how the execution should go.

Note that the same would happen if we simply wrap our pmap call with runBlocking.

import kotlinx.coroutines.async import kotlinx.coroutines.runBlocking import kotlinx.coroutines.awaitAll suspend fun <A, B> Iterable<A>.pmap(f: suspend (A) -> B): List<B> = coroutineScope { map { async { f(it) } }.awaitAll() } //sampleStart // DON'T DO THIS fun main() = runBlocking<Unit> { (1..100).pmap { fibonnaci(it.toBigInteger()) } } //sampleEnd

This is because coroutineScope basically inherits the caller’s context. So again we’d be running in the single thread confined Dispatcher runBlocking gets by default. Which may, or may not, be OK depending on your use case. Remember that you can always change the Dispatcher used by runBlocking by passing one: runBlocking(Dispatchers.IO).


Comments