Tail Recurse by Mark Seemann
Tips on refactoring recursive functions to tail-recursive functions.
In a recent article, I described how to refactor an imperative loop to a recursive function. If you're coming from C# or Java, however, you may have learned to avoid recursion, since it leads to stack overflows when the recursion is too deep.
In some Functional languages, like F# and Haskell, such stack overflows can be prevented with tail recursion. If the last function call being made is a recursive call to the function itself, the current stack frame can be eliminated because execution is effectively completed. This allows a tail-recursive function to keep recursing without ever overflowing.
When you have a recursive function that's not tail recursive, however, it can sometimes be difficult to figure out how to refactor it so that it becomes tail recursive. In this article, I'll outline some techniques I've found to be effective.
Introduce an accumulator #
It seems as though the universal trick related to recursion is to introduce an accumulator argument. This is how you refactor an imperative loop to a recursive function, but it can also be used to make a non-tail-recursive function tail recursive. This will also require you to introduce an 'implementation function' that does the actual work.
Example: discarding candidate points when finding a convex hull #
One of my functions examined the three last points in a sequence of points in order to determine if the next-to-last point is in the interior of the set, or if that point could potentially be on the hull. If the point is positively known to be in the interior of the set, it should be discarded. At one stage, the function was implemented this way:
let tryDiscard points = let rec tryDiscardImp = function | [p1; p2; p3] when turn p1 p2 p3 = Direction.Right -> [p1; p3] | [p1; p2; p3] -> [p1; p2; p3] | p :: ps -> p :: tryDiscardImp ps |  ->  let newPoints = tryDiscardImp points if newPoints.Length <> points.Length then Some newPoints else None
This function was earlier called
check, which is the name used in the article about refactoring to recursion. The tryDiscard function is actually an inner function in a more complex function that's defined with the
inline keyword, so the type of tryDiscard is somewhat complicated, but think of it as having the type
(int * int) list -> (int * int) list option. If a point was discarded, the new, reduced list of points is returned in a Some case; otherwise, None is returned.
The tryDiscard function already has an 'implementation function' called tryDiscardImp, but while tryDiscardImp is recursive, it isn't tail recursive. The problem is that in the
p :: ps case, the recursive call to tryDiscardImp isn't the tail call. Rather, the stack frame has to wait for the recursive call
tryDiscardImp ps to complete, because only then can it cons
p onto its return value.
Since an 'implementation function' already exists, you can make it tail recursive by adding an accumulator argument to tryDiscardImp:
let tryDiscard points = let rec tryDiscardImp acc = function | [p1; p2; p3] when turn p1 p2 p3 = Direction.Right -> acc @ [p1; p3] | [p1; p2; p3] -> acc @ [p1; p2; p3] | p :: ps -> tryDiscardImp (acc @ [p]) ps |  -> acc let newPoints = tryDiscardImp  points if newPoints.Length <> points.Length then Some newPoints else None
As you can see, I added the
acc argument to tryDiscardImp; it has the type
(int * int) list (again: not really, but close enough). Instead of returning from each case, the tryDiscardImp function now appends points to
acc until it reaches the end of the list, which is when it returns the accumulator. The
p :: ps case now first appends the point in consideration to the accumulator (
acc @ [p]), and only then recursively calls tryDiscardImp. This puts the call to tryDiscardImp in the tail position.
The repeated use of the append operator (
@) is terribly inefficient, though, but I'll return to this issue later in this article. For now, let's take a step back.
Example: implementing map with recursion #
A common exercise for people new to Functional Programming is to implement a map function (C# developers will know it as Select) using recursion. This function already exists in the List module, but it can be enlightening to do the exercise.
An easy, but naive implementation is only two lines of code, using pattern matching:
let rec mapNaive f = function |  ->  | h::t -> f h :: mapNaive f t
Is mapNaive tail recursive? No, it isn't. The last operation happening is that
f h is consed unto the return value of
mapNaive f t. While
mapNaive f t is a recursive call, it's not in the tail position. For long lists, this will create a stack overflow.
How can you create a tail-recursive map implementation?
Example: inefficient tail-recursive map implementation #
According to my introduction, adding an accumulator and an 'implementation' function should do the trick. Here's the straightforward application of that technique:
let mapTailRecursiveUsingAppend f xs = let rec mapImp f acc = function |  -> acc | h::t -> mapImp f (acc @ [f h]) t mapImp f  xs
The mapImp function does the actual work, and it's tail recursive. It appends the result of mapping the head of the list unto the accumulator:
acc @ [f h]. Only then does it recursively call itself with the new accumulator and the tail of the list.
While this version is tail recursive, it's horribly inefficient, because appending to the tail of a (linked) list is inefficient. In theory, this implementation would never result in a stack overflow, but the question is whether anyone has the patience to wait for that to happen?
> mapTailRecursiveUsingAppend ((*) 2) [1..100000];; Real: 00:02:46.104, CPU: 00:02:44.750, GC gen0: 13068, gen1: 6380, gen2: 1 val it : int list = [2; 4; 6; 8; 10; 12; 14; 16; 18; 20; 22; 24; 26; 28; 30; 32; 34; 36; ...]
Doubling 100,000 integers this way takes nearly 3 minutes on my (admittedly mediocre) laptop. A better approach is required.
Example: efficient tail-recursive map implementation #
The problem with mapTailRecursiveUsingAppend is that appending to a list is slow when the left-hand list is long. This is because lists are linked lists, so the append operation has to traverse the entire list and copy all the element to link to the right-hand list.
Consing a single item unto an existing list, on the other hand, is efficient:
let mapTailRecursiveUsingRev f xs = let rec mapImp f acc = function |  -> acc | h::t -> mapImp f (f h :: acc) t mapImp f  xs |> List.rev
This function conses unto the accumulator (
f h :: acc) instead of appending to the accumulator. The only problem with this is that
acc is in reverse order compared to the input, so the final step must be to reverse the output of mapImp. While there's a cost to reversing a list, you pay it only once. In practice, it turns out to be efficient:
> mapTailRecursiveUsingRev ((*) 2) [1..100000];; Real: 00:00:00.017, CPU: 00:00:00.015, GC gen0: 1, gen1: 0, gen2: 0 val it : int list = [2; 4; 6; 8; 10; 12; 14; 16; 18; 20; 22; 24; 26; 28; 30; 32; 34; 36; ...]
From nearly three minutes to 17 milliseconds! That's a nice performance improvement.
The only problem, from a point of view where learning is in focus, is that it feels a bit like cheating: we've delegated an important step in the algorithm to List.rev. If we think it's OK to use the library functions, we could have directly used List.map. The whole point of this exercise, however, is to learn something about how to write tail-recursive functions.
At this point, we have two options:
- Learn how to write an efficient, tail-recursive implementation of
- Consider alternatives.
Example: efficient tail-recursive map using a difference list #
The mapTailRecursiveUsingAppend function is attractive because of its simplicity. If only there was an efficient way to append a single item to the tail of a (long) list, like
acc appendSingle (f h)! (appendSingle is a hypothetical function that we wish existed.)
So far, we've treated data as data, and functions as functions, but in Functional Programming, functions are data as well!
What if we could partially apply a cons operation?
Unfortunately, the cons operator (
::) can't be used as a function, so you'll have to introduce a little helper function:
// 'a -> 'a list -> 'a list let cons x xs = x :: xs
This enables you to partially apply a cons operation:
> cons 1;; val it : (int list -> int list) = <fun:it@4-5>
cons 1 is a function that awaits an
int list argument. You can, for example, call it with the empty list, or another list:
> (cons 1) ;; val it : int list =  > (cons 1) ;; val it : int list = [1; 2]
That hardly seems useful, but what happens if you start composing such partially applied functions?
> cons 1 >> cons 2;; val it : (int list -> int list) = <fun:it@7-8>
Notice that the result is another function with the same signature as
cons 1. A way to read it is:
cons 1 is a function that takes a list as input, appends the list after
1, and returns that new list. The return value of
cons 1 is passed to
cons 2, which takes that input, appends that list after
2, and returns that list. Got it? Try it out:
> (cons 1 >> cons 2) ;; val it : int list = [2; 1]
Not what you expected? Try going through the data flow again. The input is the empty list (
), which, when applied to
cons 1 produces
. That value is then passed to
cons 2, which puts
2 at the head of
, yielding the final result of
This still doesn't seem to help, because it still reverses the list. True, but you can reverse the composition:
> (cons 2 >> cons 1) ;; val it : int list = [1; 2] > (cons 1 << cons 2) ;; val it : int list = [1; 2]
Notice that in the first line, I reversed the composition by changing the order of partially applied functions. This, however, is equivalent to keeping the order, but using the reverse composition operator (
You can repeat this composition:
> (cons 1 << cons 2 << cons 3 << cons 4) ;; val it : int list = [1; 2; 3; 4]
That's exactly what you need, enabling you to write
acc << (cons (f h)) in order to efficiently append a single element to the tail of a list!
let mapTailRecursiveUsingDifferenceList f xs = let cons x xs = x :: xs let rec mapImp f acc = function |  -> acc  | h::t -> mapImp f (acc << (cons (f h))) t mapImp f id xs
This mapImp function's accumulator is no longer a list, but a function. For every item, it composes a new accumulator function from the old one, effectively appending the mapped item to the tail of the accumulated list. Yet, because
acc isn't a list, but rather a function, the 'append' operation doesn't trigger a list traversal.
When the recursive function finally reaches the end of the list (the
 case), it invokes the
acc function with the empty list (
) as the initial input.
This implementation is also tail recursive, because the accumulator is being completely composed (
acc << (cons (f h))) before mapImp is recursively invoked.
Is it efficient, then?
> mapTailRecursiveUsingDifferenceList ((*) 2) [1..100000];; Real: 00:00:00.024, CPU: 00:00:00.031, GC gen0: 1, gen1: 0, gen2: 0 val it : int list = [2; 4; 6; 8; 10; 12; 14; 16; 18; 20; 22; 24; 26; 28; 30; 32; 34; 36; ...]
24 milliseconds is decent. It's not as good as mapTailRecursiveUsingRev (17 milliseconds), but it's close.
In practice, you'll probably find that mapTailRecursiveUsingRev is not only more efficient, but also easier to understand. The advantage of using the difference list technique, however, is that now mapImp has a shape that almost screams to be refactored to a fold.
Example: implementing map with fold #
The mapImp function in mapTailRecursiveUsingDifferenceList almost has the shape required by the accumulator function in List.fold. This enables you to rewrite mapImp using fold:
let mapUsingFold f xs = let cons x xs = x :: xs let mapImp = List.fold (fun acc h -> acc << (cons (f h))) id mapImp xs 
As usual in Functional Programming, the ultimate goal seems to be to avoid writing recursive functions after all!
The mapUsingFold function is as efficient as mapTailRecursiveUsingDifferenceList:
> mapUsingFold ((*) 2) [1..100000];; Real: 00:00:00.025, CPU: 00:00:00.031, GC gen0: 2, gen1: 1, gen2: 1 val it : int list = [2; 4; 6; 8; 10; 12; 14; 16; 18; 20; 22; 24; 26; 28; 30; 32; 34; 36; ...]
Not only does 25 milliseconds seem fast, but it's also comparable with the performance of the built-in map function:
> List.map ((*) 2) [1..100000];; Real: 00:00:00.011, CPU: 00:00:00.015, GC gen0: 0, gen1: 0, gen2: 0 val it : int list = [2; 4; 6; 8; 10; 12; 14; 16; 18; 20; 22; 24; 26; 28; 30; 32; 34; 36; ...]
Granted, List.map seems to be twice as fast, but it's also been the subject of more scrutiny than the above fun exercise.
In Functional Programming, recursive functions take the place of imperative loops. In order to be efficient, they must be tail recursive.
You can make a function tail recursive by introducing an accumulator argument to an 'implementation function'. This also tends to put you in a position where you can ultimately refactor the implementation to use a fold instead of explicit recursion.