Unit 40: Fork and Join
Learning Objectives
Students should
- understand the task deque and work stealing.
- understand the behaviour of
fork
andjoin
(andcompute
). - be able to order
fork
andjoin
efficiently. - be able to use
RecursiveTask
.
Thread Pool
We now look under the hood of parallel Stream
and CompletableFuture<T>
to explore how Java manages its threads. Recall that creating and destroying threads is not cheap, and as much as possible we should reuse existing threads to perform different tasks. This goal can be achieved by using a thread pool.
A thread pool consists of (i) a collection of threads, each waiting for a task to execute, and (ii) a collection of tasks to be executed. Typically the tasks are put in a shared queue, and an idle thread picks up a task from the shared queue to execute.
To illustrate this concept, here is a trivial thread pool with a single thread:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
|
We assume that Queue<T>
can be safely modified concurrently (i.e., it is thread-safe) in the sample code above. Otherwise, just like the example you have seen in parallel streams with List
, items might be lost.
Fork and Join
Java implements a thread pool called ForkJoinPool
that is fine-tuned for the fork-join model of recursive parallel execution.
The Fork-join model is essentially a parallel divide-and-conquer model of computation. The general idea for the fork-join model is to solve a problem by breaking up the problem into identical problems but with smaller size (fork), then solve the smaller version of the problem recursively, then combine the results (join). This repeats recursively until the problem size is small enough -- we have reached the base case and so we just solve the problem sequentially without further parallelization.
In Java, we can create a task that we can fork and join as an instance of abstract class RecursiveTask<T>
. RecursiveTask<T>
supports the methods fork()
, which submits a smaller version of the task for execution, and join()
(which waits for the smaller tasks to complete and return). RecursiveTask<T>
has an abstract method compute()
, which we, as the client, have to define to specify what computation we want to compute.
Here is a simple RecursiveTask<T>
that recursively sums up the content of an array:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
|
To run this task, we run:
1 2 |
|
The line task.compute()
above is just like another method invocation. It causes the method compute()
to be invoked, and if the array is big enough, two new Summer
instances, left
and right
, to be created. left
. We then call left.fork()
, which adds the tasks to a thread pool so that one of the threads can call its compute()
method. We subsequently call right.compute()
(which is a normal method call). Finally, we call left.join()
, which blocks until the computation of the recursive sum is completed and returned. We add the result from left
and right
together and return the sum.
There are other ways we can combine and order the execution of fork()
, compute()
, and join()
. Some are better than others. We will explore more in the exercises.
ForkJoinPool
Let's now explore the idea behind how Java manages the thread pool with fork-join tasks. The details are beyond the scope of this module, but it would be interesting to note a few key points, as follows:
- Each thread has a deque1 of tasks.
- When a thread is idle, it checks its deque of tasks. If the deque is not empty, it picks up a task at the head of the deque to execute (e.g., invoke its
compute()
method). Otherwise, if the deque is empty, it picks up a task from the tail of the deque of another thread to run. The latter is a mechanism called work stealing. - When
fork()
is called, the caller adds itself to the head of the deque of the executing thread. The method returns immediately and exeuction continues. This may cause anotherfork()
to be executed which adds another task into the head of the deque. This is done so that the most recently forked task gets executed next, similar to how normal recursive calls. - When
join()
is called, several cases might happen. If the subtask to be joined hasn't been executed, itscompute()
method is called and the subtask is executed. If the subtask to be joined has been completed (some other thread has stolen this and completed it), then the result is read, andjoin()
returns. If the subtask to be joined has been stolen and is being executed by another thread, then the current thread finds some other tasks to work on either in its local deque or steal another task from another deque.
The beauty of the mechanism here is that the threads always look for something to do and they cooperate to get as much work done as possible.
Note that task stealing is always done from the back. In other words, an idle worker thread is always stealing a task from the tail of the deque of another worker thread. This is because the order tasks are added is from the head of the deque. So, tasks at the back is expected to have more unfinished computation compared to the tasks at the front of the deque. This will then minimizes the number of task stealing needed2.
The mechanism here is similar to that implemented in .NET and Rust.
Order of fork()
and join()
One implication of how ForkJoinPool
adds and removes tasks from the deque is the order in which we call fork()
and join()
. Since the most recently forked task is likely to be executed next, we should join()
the most recent fork()
task first. In other words, the order of forking should be the reverse of the order of joining.
In the class Summer
above,
1 2 3 4 |
|
is more efficient than
1 2 3 4 |
|
In other words, your fork()
, compute()
, join()
order should form a palindrome and there should be no crossing. Additionally, there should only be at most a single compute
and it should be in the middle of the palindrome.
For example, the following is ok.
1 2 3 |
|
But the following is not.
1 2 |
|
There are a combination of reason for this efficiency. Firstly, the operation on the deque has to be atomic. In other words, when a thread \(T_1\) is operating on its deque \(D_1\), the thread \(T_1\) has to finish its operation before another thread \(T_2\) can operate on \(D_1\) (e.g., task stealing, etc). Atomic operations are expensive by default so the more operation is being performed, the slower the program will be.
This is coupled by the behavior of join()
that when called, find and executes the subtask if it is not yet computed. If this subtask is not at the front of the deque, then we require a search which is a combination of pop and push on a deque as opposed to just a single pop if the task is at the head of the deque.
-
A deque is a double-ended queue. It behaves like both stack and queue. ↩
-
Of course we are assuming a "typical" program. We can always create a program where the split is not equal 50%-50% of the workload but instead 90%-10%. If the 10% of the workload is going to the task at the tail of the deque, then we actually need more stealing. ↩