Unit 40: Fork and Join
Learning Objectives
Students should
- understand the task deque and work stealing
- understand the behaviour of
fork
andjoin
(andcompute
) - be able to order
fork
andjoin
efficiently - be able to use
RecursiveTask
Thread Pool
We now look under the hood of parallel Stream
and CompletableFuture<T>
to explore how Java manages its threads. Recall that creating and destroying threads is not cheap, and as much as possible we should reuse existing threads to perform different tasks. This goal can be achieved by using a thread pool.
A thread pool consists of (i) a collection of threads, each waiting for a task to execute, and (ii) a collection of tasks to be executed. Typically the tasks are put in a shared queue, and an idle thread picks up a task from the shared queue to execute.
To illustrate this concept, here is a trivial thread pool with a single thread:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
|
We assume that Queue<T>
can be safely modified concurrently (i.e., it is thread-safe) in the sample code above. Otherwise, just like the example you have seen in parallel streams with List
, items might be lost.
Fork and Join
Java implements a thread pool called ForkJoinPool
that is fine-tuned for the fork-join model of recursive parallel execution.
The Fork-join model is essentially a parallel divide-and-conquer model of computation. The general idea for the fork-join model is to solve a problem by breaking up the problem into identical problems but with a smaller size (fork), then solve the smaller version of the problem recursively, and then combine the results (join). This repeats recursively until the problem size is small enough — we have reached the base case and so we just solve the problem sequentially without further parallelization.
In Java, we can create a task that we can fork and join as an instance of the abstract class RecursiveTask<T>
. RecursiveTask<T>
supports the methods fork()
, which submits a smaller version of the task for execution, and join()
(which waits for the smaller tasks to complete and return). RecursiveTask<T>
has an abstract method compute()
, which we, as the client, have to define to specify what computation we want to compute.
Here is a simple RecursiveTask<T>
that recursively sums up the content of an array:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
|
To run this task, we run:
1 2 |
|
The line task.compute()
above is just like another method invocation. It causes the method compute()
to be invoked, and if the array is big enough, two new Summer
instances, left
and right
, are created. left
. We then call left.fork()
, which adds the tasks to a thread pool so that one of the threads can call its compute()
method. We subsequently call right.compute()
(which is a normal method call). Finally, we call left.join()
, which blocks until the computation of the recursive sum is completed and returned. We add the result from left
and right
together and return the sum.
There are other ways we can combine and order the execution of fork()
, compute()
, and join()
. Some are better than others. We will explore more in the exercises.
ForkJoinPool
Let's now explore the idea behind how Java manages the thread pool with fork-join tasks. The details are beyond the scope of this module, but it would be interesting to note a few key points, as follows:
- Each thread has a deque1 of tasks.
- When a thread is idle, it checks its deque of tasks. If the deque is not empty, it picks up a task at the head of the deque to execute (e.g., invoke its
compute()
method). Otherwise, if the deque is empty, it picks up a task from the tail of the deque of another thread to run. The latter is a mechanism called work stealing. - When
fork()
is called, the caller adds itself to the head of the deque of the executing thread. This is done so that the most recently forked task gets executed next, similar to how normal recursive calls. - When
join()
is called, several cases might happen. If the subtask to be joined hasn't been executed, itscompute()
method is called and the subtask is executed. If the subtask to be joined has been completed (some other thread has stolen this and completed it), then the result is read, andjoin()
returns. If the subtask to be joined has been stolen and is being executed by another thread, then the current thread either finds some other tasks to work on from its local deque, or steals another task from another deque.
The beauty of the mechanism here is that the threads always look for something to do and they cooperate to get as much work done as possible.
The mechanism here is similar to that implemented in .NET and Rust.
Order of fork()
and join()
One implication of how ForkJoinPool
adds and removes tasks from the deque is the order in which we call fork()
and join()
. Since the most recently forked task is likely to be executed next, we should join()
the most recent fork()
task first. In other words, the order of forking should be the reverse of the order of joining.
In the class Summer
above,
1 2 3 4 |
|
is more efficient than
1 2 3 4 |
|
In other words, your fork()
, compute()
, join()
order should form a palindrome and there should be no crossing. Additionally, there should only be at most a single compute
and it should be in the middle of the palindrome.
For example, the following is ok.
1 2 3 |
|
But the following is not.
1 2 |
|
-
A deque is a double-ended queue. It behaves like both a stack and a queue. ↩