User guide

The stack safety problem

That a program is stack safe means that it won’t overflow the stack, even for infinitely large input. Stack safety is closely related to recursive algorithms, since deep recursion normally requires a large stack.

Consider the following (binary) Tree type:

public class Tree<A> {
    public abstract <B> B visit(
            Function<A, B> onLeaf,
            Function<Tree<A>, B> onUnaryBranch,
            BiFunction<Tree<A>, Tree<A>, B> onBinaryBranch);
}

Example 1: A very basic Tree type

Branches of this tree type can have either one or two children. None of the branches have any associated value. Each leaf, however, has a value of type A. You might notice that there is no way that this type can represent an empty tree. That doesn’t matter much, this example is for instructional purposes only.

As it happens, this simple interface is all we need to inspect and traverse the tree. The visit method realizes a kind of functional visitor pattern, but without the need for a separate TreeVisitor type. Instead, we must provide three callbacks, one for each node type.

We can make a very general purpose method for recursively traversing the leaves of a Tree:

public class TreeOps {
    public static <A, B> B foldLeft(Tree<A> tree, BiFunction<B, A, B> reduce, B init) {
        return tree.visit(
                leafValue -> reduce.apply(init, leafValue),
                child -> foldLeft(child, reduce, init),
                (leftChild, rightChild) -> {
                    B leftAcc = foldLeft(leftChild, reduce, init);
                    B resultAcc = foldLeft(rightChild, reduce, leftAcc);
                    return resultAcc;
                });
    }
}

Example 2: A stack unsafe recursive traversal algorithm

foldLeft does a prefix traversal (from left to right), applying reduce to each leaf value to accumulate a result. The accumulator is initialized with the value of init. Please spend a little moment to understand this algorithm before proceeding.

This algorithm won’t work if the depth of the tree is in the order of thousands. That case is not as contrived as you might think; this tree type could be very fit for algorithms that need a type that offers a fast append operation, in which case such trees could become very unbalanced, leaning either to the left or the right.

Now, a tail recursive algorithm can always be rewritten as a loop (see Wikiedia: Tail call), but tree traversal is not, and can’t be rewritten to be, tail recursive. To make it stack safe, we must either rewrite the algorithm to use an explicit stack that resides on the heap, or —

Use a trampoline

A trampoline is a data structure that represents either an unevaluated calculation or a single value. Alternatively, it can be viewed as a value that will be resolved later. Instead of running methods/functions immediately, we put them into a type that can be run.

When a trampoline is run, it resolves the calculations inside one by one in a loop. When there are no more unevaluated calculations left, the final calculated value is returned. This process is described in Anatomy of Trampoline.

To make foldLeft stack safe using a trampoline, we rewrite it like this:

public final class TreeOps {
    public static <A, B> B foldLeft(Tree<A> tree, BiFunction<B, A, B> reduce, B init) {
        return trampolinedFoldLeft(tree, reduce, init).run();
    }

    private static <A, B> Trampoline<B> trampolinedFoldLeft(
            Tree<A> tree, 
            BiFunction<B, A, B> reduce, B init) {
        return tree.visit(
                leafValue -> Trampoline.ret(reduce.apply(init, leafValue)),
                child -> Trampoline.suspend(() -> trampolinedFoldLeft(child, reduce, init)),
                (leftChild, rightChild) -> {
                    Trampoline<B> leftAccTrampoline = Trampoline.suspend(() ->
                            trampolinedFoldLeft(leftChild, reduce, init));
                    Trampoline<B> resultAccTrampoline = leftAccTrampoline.flatMap(leftAcc ->
                            trampolinedFoldLeft(rightChild, reduce, leftAcc));
                    return resultAccTrampoline;
                });
    }
}

Example 3: A stack safe recursive traversal algorithm

trampolinedFoldLeft is a recursive method that immediately returns a Trampoline instance; only one visit at the root is done at this point. There is a one-to-one correspondence between each part of trampolinedFoldLeft and Example 2’s foldLeft. Before proceeding, have a look at the Javadoc to learn about the Trampoline methods.

The stack safe foldLeft explained

leafValue -> Trampoline.ret(reduce.apply(init, leafValue))

Example 3a: Trampolined left fold of a leaf

This callback is used for leaves of the tree. When recursion hits the bottom, a value is returned. Since the foldLeft method returns a Trampoline, the value must be wrapped in a Trampoline.

child -> Trampoline.suspend(() -> trampolinedFoldLeft(child, reduce, init))

Example 3b: Trampolined left fold of an unary branch

This callback is used for unary (one-child) branches. If we left out the suspend here, trampolinedFoldLeft would have been called immediately, leaving us with the exact problem we tried to solve. Hence, we wrap the recursive call in a suspend, creating a trampoline that will do the recursive call later.

(leftChild, rightChild) -> {
    Trampoline<B> leftAccTrampoline = Trampoline.suspend(() ->
            trampolinedFoldLeft(leftChild, reduce, init));
    Trampoline<B> resultAccTrampoline = leftAccTrampoline.flatMap(leftAcc ->
            trampolinedFoldLeft(rightChild, reduce, leftAcc));
    return resultAccTrampoline;
}

Example 3c: Trampolined left fold of a binary branch

This callback is used for binary (two-children) branches. When this branch is traversed, the left child is traversed first, then the right child. Like the unary branch, we need to suspend the traversal of the left child, to avoid the immediate recursion. In foldLeft, the accumulated result from traversing the left child is used in the traversal of the right child. Hence, the second recursion is dependent of the first one, which is exactly what flatMap is for.

The last callback could also have been written as:

(leftChild, rightChild) -> Trampoline.suspend(() -> {
    Trampoline<B> leftAccTrampoline = trampolinedFoldLeft(leftChild, reduce, init);
    Trampoline<B> resultAccTrampoline = leftAccTrampoline.flatMap(leftAcc ->
            trampolinedFoldLeft(rightChild, reduce, leftAcc));
    return resultAccTrampoline;
})

Example 4: Alternative suspension of the traversal of a binary branch

The difference is that instead of suspending the traversal of the left child, we suspend the traversal of the whole branch. The two are equivalent with regards to functionality and stack safety. The only observable difference is that if there were some time consuming computation while processing the branch node itself, the latter would create the resulting trampoline faster, postponing the time consuming calculation until the resulting trampoline is run.

Don’t do this

A developer who’s unfamiliar with trampolines might try to write the two-children branch traversal like this:

(leftChild, rightChild) -> Trampoline.suspend(() -> {
    B leftAcc = trampolinedFoldLeft(leftChild, reduce, init).run();
    Trampoline<B> resultAccTrampoline = trampolinedFoldLeft(rightChild,reduce,leftAcc);
    return resultAccTrampoline;
})

Example 5: Don’t call run while creating a tramoline

This is not stack safe. The reason is that when run is called, the corresponding code for child nodes is “unsuspended”, forcing the “unsuspension” of their child nodes and so on until the bottom of the recursion, possibly causing a StackOverflowError. The golden rule is:

Don’t call run on a sub-trampoline of the trampoline you’re creating!

A sign that you’re doing it wrong is if you call run inside a method that returns a Trampoline. This is safe if the trampoline you’re running is unrelated (that is, not using any of the same recursive methods) to the one you’re creating, but even in that case it is better (with regards to maintainability) to use map/flatMap.

Performance considerations

Even though the JVM’s garbage collector is pretty amazing, doing things in a loop will always be faster than using a trampoline. So, you should always consider rewriting the tail call part as a loop. foldLeft could be implemented like this, where we traverse the left child of binary branches in a trampoline, and the tail recursion in a loop:

public final class TreeOps {
    public static <A, B> B foldLeft(Tree<A> tree, BiFunction<B, A, B> reduce, B init) {
        return trampolinedFoldLeft(tree, reduce, init).run();
    }

    public static <A, B> Trampoline<B> trampolinedFoldLeft(
            Tree<A> tree, 
            BiFunction<B, A, B> reduce, 
            B init) {
        Trampoline<B> accTrampoline = Trampoline.ret(init);
        Tree<A> currTree = tree;
        while (isBranch(currTree)) {
            accTrampoline = trampolinedFoldLeftOfLeftChild(currTree, reduce, accTrampoline);
            currTree = getRightmostChild(currTree);
        }
        Tree<A> rightmostLeaf = currTree; 
        return accTrampoline.map(acc -> reduce.apply(acc, getLeafValue(rightmostLeaf)));
    }

    private static <A, B> Trampoline<B> trampolinedFoldLeftOfLeftChild(
            Tree<A> tree,
            BiFunction<B, A, B> reduce,
            Trampoline<B> initTrampoline) {
        return tree.visit(
                leafValue -> {
                    throw new AssertionError("Didn't you just test that this is a branch?");
                },
                child -> initTrampoline,
                (leftChild, rightChild) -> initTrampoline.flatMap(init -> 
                        trampolinedFoldLeft(leftChild, reduce, init)));
    }
}

Example 6: Replacing tail recursion with a loop

Note that we need to implement isBranch, getRightmostChild and getLeafValue as well. They’re left out to help us focus on the traversal logic. In this case, “loopifying” the tail call leads to more code and, for left-heavy and balanced trees, the performance gain will be quite small. The lesson must be that it is not always best to rewrite tail recursion as a loop.

Why is it called —

“Trampoline”?

While running a trampoline, the stack returns to the same state between each step. If you look at a stack trace while stepping through a trampolined calculation, you will observe this. It’s like the stack pointer jumps up and down.

“map”?

“Mapping” is a mathematical term which basically means associating a value in one domain with a value in another domain. The transformValue function passed to map represents such a mapping.

“flatMap”?

Imagine that there existed a function static <A> Trampoline<A> flatten(Trampoline<Trampoline<A>> trampolinedTrampoline). Simply by studying the signature, you should be able to figure out what it does. Originally, the map and flatten terms were used for Lists, where mapping from a list element to a list would produce a List of Lists. In that context, a function that converts a List of Lists to a List is naturally named “flatten”.

The expected relations between ret, map, flatten and flatMap are:

The last relation should explain why flatMap: It is a composition of flatten and map.