Monday, January 25, 2016

Why you should (sometimes) NOT use tail recursion in Scala

There was a recent post on /r/scala (direct article link) about how great tail recursion is. I agree with everything said in that article; this isn't an attempt to refute his points. But tail recursion has a dark side: it can be a huge hassle.

Non-tail recursive code has a very useful property: as each invocation on the stack completes, the previous invocation picks up exactly where it left off. So you can do some work, recurse, and then do more work when the recursion finishes. You can even recurse in a loop, meaning that the amount of work left to be done is dynamic and only known at runtime. In order to use tail-recursion, you have to restructure your code so that there is no work to be done after the recursive call returns. While it's always possible to restructure your code in this way, it can be a nontrivial transformation. In complex cases, you may even need to use your own stack to keep track of remaining work. Sure, it's not the call stack, so you don't need to worry about blowing up when your collection gets to be too large. But with that property comes a bit of complexity.

For example, I recently wrote an iterative implementation of Tarjan's topological sort of a directed graph. The overall algorithm is described well enough on Wikipedia. As you can see, the recursion occurs within a loop and with additional work to be done after all the recursion is complete.

Here's my Scala implementation. I won't promise that it's the best code, but it seems to work. I have more comments in the actual code, but I've stripped them here for brevity.

def topologicalSort[A](edgeMap: Map[A, Set[A]]): Seq[A] = {
  def helper(unprocessed: Seq[Seq[A]], inProgress: Set[A], finished: Set[A], result: Seq[A]): Seq[A] = {
    unprocessed match {
      case (hd +: tl) +: rest => // [ [hd, ...], ... ]
        if (finished(hd)) {
          helper(tl +: rest, inProgress, finished, result)
        } else {
          if (inProgress(hd)) {
            throw new Exception("Graph contains a cycle")
          val referencedVertices = edgeMap(hd)
          helper(referencedVertices.toSeq +: (hd +: tl) +: rest, inProgress + hd, finished, result)
      case Nil +: (hd +: tl) +: rest => // [ [], [hd, ...], ... ]
        helper(tl +: rest, inProgress - hd, finished + hd, hd +: result)
      case Nil +: Nil => // [ [] ]

  helper(edgeMap.keys.toSeq +: Nil, Set.empty, Set.empty, Nil)

(Assume that edgeMap contains one key for every vertex in the graph, even if the corresponding value is the empty set. This is an invariant that is enforced elsewhere.)

That unprocessed parameter is, essentially, the call stack. The outer Seq is used as a stack, while the inner Seq is used as a queue. Whenever we decide to visit a node, we push a new "frame" onto the front of that stack (and move the node into inProgress). Whenever the frame at the front of the stack is empty, it means that we have finished recursing children and can move the "current node" (encoded as the first element in the next frame) from inProgress to finished. And when we reach a state where the stack contains just one frame, and that frame is empty, we are done.

While this implementation won't blow the stack (at least, I don't think it will... it successfully sorted a graph with a path 10k vertices long), I wouldn't necessarily describe it as easy to understand. In my actual source code, the comments are nearly as long as the implementation. That in and of itself isn't a problem, but it's a shame that the source isn't more readable. (Though I'll freely admit that perhaps the lack of readability is my own fault.)

In fact, an astute reader might notice that the match expression is missing a case. What happens if the sequence looks like this:

[ [], [], ... ]

That is, why don't we have a pattern match clause that looks like this:

case Nil +: Nil +: rest => ???

This particular case can never occur. An invariant of this implementation is that, apart from the first queue in the stack, no other queue can be empty. Again, this fact is pointed out in a comment... a comment that isn't needed in the non-tail recursive version.

The author points out that tail recursion causes the compiler to rewrite your apparently recursive function as a loop. He also demonstrates a case where tail recursion shorter (and, I'd agree, more readable) than an explicit loop. But there are cases where the opposite is true. One that I've come across a few times is in what I will call "partitioning by type". Suppose you have a union type:

sealed trait Shape
case class Circle(c: Point, r: Float) extends Shape
case class Rectangle(lowerLeft: Point, w: Float, h: Float) extends Shape
case class Triangle(v1: Point, v2: Point, v3: Point) extends Shape
case class Polygon(vs: Seq[Point]) extends Shape

And suppose you want have a Seq[Shape]. But you would like to split it into independent lists: a Seq[Circle], a Seq[Rectangle], a Seq[Triangle], and a Seq[Polygon]. A tail recursive implementation might look like this:

type GroupShapesByTypeResult = (Seq[Circle], Seq[Rectangle], Seq[Triangle], Seq[Polygon])

def groupShapesByType(shapes: Seq[Shape]): GroupShapesByTypeResult = {
  def helper(remaining: Seq[Shape], circles: Seq[Circle], rectangles: Seq[Rectangle], 
             triangles : Seq[Triangle], polygons : Seq[Polygon]): GroupShapesByTypeResult = {
    remaining match {
      case Nil =>
        (circles, rectangles, triangles, polygons)
      case (hd: Circle) +: rest =>
        helper(rest, circles :+ hd, rectangles, triangles, polygons)
      case (hd: Rectangle) +: rest =>
        helper(rest, circles, rectangles :+ hd, triangles, polygons)
      case (hd: Triangle) +: rest =>
        helper(rest, circles, rectangles, triangles :+ hd, polygons)
      case (hd: Polygon) +: rest =>
        helper(rest, circles, rectangles, triangles, polygons :+ hd)

  helper(shapes, Nil, Nil, Nil, Nil)

Not terribly readable. But wait. We don't need to manage recursion ourselves; we could just use a fold:

def groupShapesByType(shapes: Seq[Shape]): GroupShapesByTypeResult = {
  val init: GroupShapesByTypeResult = (Nil, Nil, Nil, Nil)

  shapes.foldLeft(init) {
    (acc, shape) =>
      val (circles, rectangles, triangles, polygons) = acc
      shape match {
        case c : Circle =>
          (circles :+ c, rectangles, triangles, polygons)
        case r : Rectangle =>
          (circles, rectangles :+ r, triangles, polygons)
        case t : Triangle =>
          (circles, rectangles, triangles :+ t, polygons)
        case p : Polygon =>
          (circles, rectangles, triangles, polygons :+ p)

We've traded explicit loop management for more complex destructuring. Arguably more readable, but still not great. OK, what if we abandoned this functional approach (Scala is multi-paradigm after all) and went with an explicit loop and mutability instead:

def groupShapesByType(shapes: Seq[Shape]): GroupShapesByTypeResult = {
  var circles : Seq[Circle] = Nil
  var rectangles : Seq[Rectangle] = Nil
  var triangles : Seq[Triangle] = Nil
  var polygons : Seq[Polygon] = Nil

  for (shape <- shapes) {
    shape match {
      case c : Circle => circles = circles :+ c
      case r : Rectangle => rectangles = rectangles :+ r
      case t : Triangle => triangles = triangles :+ t
      case p : Polygon => polygons = polygons :+ p

  (circles, rectangles, triangles, polygons)

Since the patterns and corresponding bodies are so much simpler, I took the liberty of combining them into single lines. I don't think it hurts readability. I don't know for sure, but I would even expect this implementation to run faster than either other implementation. The tail recursive version needs to successively chop the front off our list. And the version with foldLeft needs to constantly decompose and rebuild the loop state variable. This implementation just walks an iterable and updates the corresponding sequence. Persistent collections are awesome, folds are awesome, but walking an iterator and updating vars is hard to beat.

Again, I'm not trying to refute anything that the original post's author is saying. Tail recursion is great. But tail recursion comes with the cost of complexity. For situations with relatively shallow recursion trees (and with a clear upper bound to the recursion depth), I'm actually OK with non-tail recursion. For example, using non-tail recursion to traverse an XML document that is known to be fairly flat is perfectly fine. It might even be fine for traversing a parsed AST for a programming language. Sure, most programming languages allow expressions to be nested to arbitrary depths, but most code written by reasonable humans has a soft upper bound on how deeply those expressions are nested.

If you can naturally express your algorithm with tail recursion, go for it! But if it's unnatural, consider whether tail recursion is actually needed.