0

UPDATE 2024.03.16: Provided code that produces the correct output, but is still not tail-recursive.


How can I create a tail recursive merge method in Scala on a self-referential tree structure (or is it even possible)?

I have been working on this problem for several days now. I've read articles about approaching it, even in other languages. I've even submitted it to the various AIs (Bard, Copilot, AskCodi, etc.), and they return non-functioning code that STILL cannot be compiled with the @tailrec annotation.

I must be missing some simple mental leap on converting the merge method in the Node case class to be tail recursive. Any guidance would be appreciated.

Especially anything (links to books, videos, articles, etc.) that offers a meta-cognitive way to "think through" this self-referential constructed from the bottom-up style of solution.

And finally, I now know some problems cannot be solved with tail recursion. Is this the case with this one? And if so, why?


CORRECTED CODE:

This performs the desired effect, but isn't tail-recursive.

object Node {
  val Terminal: Node = Node(true, Map.empty)
}

final case class Node(
    isWord: Boolean
  , nodeByLetter: Map[Char, Node]
) {
  require(
    isWord || nodeByLetter.nonEmpty
    , s"either isWord [$isWord] or nodeByLetter.nonEmpty [${nodeByLetter.nonEmpty}] must be true")

  def merge(that: Node): Node = {
    //@tailrec
    def recursive(cursor: (Node, Node) = (this, that)): Node = {
      cursor match {
        case (Node.Terminal, Node.Terminal) =>
          Node.Terminal
        case (left, Node.Terminal) =>
          if (left.isWord)
            left
          else
            left.copy(isWord = true)
        case (Node.Terminal, right) =>
          if (right.isWord)
            right
          else
            right.copy(isWord = true)
        case (left, right) =>
          val lettersToMerge =
            left.nodeByLetter
              .keySet
              .filter(
                letter =>
                  right.nodeByLetter.keySet.contains(letter)
                    && (left.nodeByLetter(letter) != right.nodeByLetter(letter)))
          if (lettersToMerge.isEmpty)
            Node(
                left.isWord || right.isWord
              , right.nodeByLetter ++ left.nodeByLetter)
          else {
            val nodeKeysAll = (left.nodeByLetter.keySet ++ right.nodeByLetter.keySet)
              .toList
              .sorted
            val nodes = nodeKeysAll
              .map(
                letter =>
                  if (lettersToMerge.contains(letter)) {
                    //this call fails the @tailrec annotation
                    recursive(left.nodeByLetter(letter), right.nodeByLetter(letter))
                  } else
                    left.nodeByLetter.getOrElse(letter, right.nodeByLetter(letter))
              )
            val nodeByLetter = {
              nodes
                .zip(nodeKeysAll)
                .map(_.swap)
                .toMap
            }

            Node(
                left.isWord || right.isWord
              , nodeByLetter
            )
          }
      }
    }

    recursive()
  }
}

When the @tailrec line in method merge is uncommented, the line...

recursive(left.nodeByLetter(letter), right.nodeByLetter(letter))

... highlights with a red squiggly (in IntelliJ) and reports the error...

Recursive call not in tail position (in @tailrec annotated method).


Here's the sample data I am using to ensure that the resulting function works:

object Main {
  def main(args: Array[String]): Unit = {
    //cat
    val t = Node.Terminal
    val at = Node(false, Map('t' -> t))
    val cat = Node(false, Map('a' -> at))
    val catRoot = Node(false, Map('c' -> cat))
    //camp - intentionally not in alpha order
    val p = Node.Terminal
    val mp = Node(true, Map('p' -> p))
    val amp = Node(false, Map('m' -> mp))
    val camp = Node(false, Map('a' -> amp))
    val campRoot = Node(false, Map('c' -> camp))


    val root = catRoot.merge(campRoot)
    println("----------------")
    println("root: " + root)
  }
}

And the output should look like this:

----------------
root: Node(false,Map(c -> Node(false,Map(a -> Node(false,Map(t -> Node(true,Map()), m -> Node(true,Map(p -> Node(true,Map())))))))))

ORIGINAL POSTED CODE WAS INCORRECT.

It doesn't perform the desired effect, much less is not tail-recursive. I've left it per the StackOverflow rules regarding "updating a Question". The corrected code is above.

case class Node(isWord: Boolean, nodeByLetter: Map[Char, Node]) {
  //@tailrec
  final def merge(that: Node): Node = {
    val mergedIsWord = this.isWord || that.isWord
    val mergedNodes =
      (this.nodeByLetter.keySet ++ that.nodeByLetter.keySet)
        .map(letter =>
          (
              letter
            , (this.nodeByLetter.get(letter), that.nodeByLetter.get(letter)) match {
                case (Some(thisNode), Some(thatNode)) =>
                  thisNode.merge(thatNode)
                case (Some(thisNode), None) =>
                  thisNode
                case (None, Some(thatNode)) =>
                  thatNode
                case _ =>
                  throw new IllegalStateException("should never get here")
              }))
        .toMap

    Node(mergedIsWord, mergedNodes)
  }
}
12
  • 1
    If your function needs to recursively call itself two or more times during a single invocation, then converting the calls to tail-calls is going to be a problem. Yes any algorithm can be converted to CPS and thus made tail recursive, but that's not what people usually want. Commented Jan 30, 2024 at 16:54
  • 1
    Continuation passing style Commented Jan 30, 2024 at 18:46
  • 1
    I guess that is just a bug in the IntelliJ warning. Again, there is no real recursion there, and there is no point in trying to make that method @tailrec. About the cats Monoid thing, check this: scastie.scala-lang.org/BalmungSan/0F8OJaJJS4KDdB8ZHk3tAg/2 Commented Jan 30, 2024 at 19:19
  • 1
    While it's not recursive on the same method, you have recursion here nontheless, and with a large enough data, it will result in StackOverflowError. If you have such big data (check it!) I would use Cats Eval to split this large method into a series of smaller methods - Eval uses trampoline so it is stack-safe and would not SO on recursion. Commented Jan 31, 2024 at 14:24
  • 1
    @chaotic3quilibrium Again that is expected, your function is not tail-recursive (I would even say is not really recursive at all, since there is no direct recursion and I simply can't be bothered to think how OOP affects recursion). And making all the necessary changes to make it properly tail-recursive is really complex, you basically would need to create your own in-head stack of pending steps. - Thus, I ask again, why do you care about making it @tailrecursive that much? Trees are usually not very deep thus in practice this kind of algorithm will rarely break the stack. Commented Jan 31, 2024 at 19:19

1 Answer 1

0

Summary:

The answer to the parenthetical question in the OP title...

...(or is it even possible)?

...is "Yes".

The answer to the full question in the OP title...

How can I create a tail recursive merge method in Scala on a self-referential tree structure?

...is to "Move to a heap-based strategy anytime there is a requirement for anything to follow the recursive call, even making an additional recursive call, before returning."

Details:

The DAWG (Directed Acyclic Word Graph) problem is solvable, but not (easily) using the @tailrec annotation. Instead, it more simply requires using an FP recursion concept called a trampoline. The concept is also referred to as a CPS (Continuation Passing Style).

Because of this well-structured and presented article on what and how to use a trampoline (specifically excluding the hand-waving of the mind-bending sequence method which I explored understanding on Reddit), I was able to derive a fully working answer, which you can see detailed below. Here's another great (newbie-oriented) article I found.

import scala.annotation.tailrec
import scala.util.control.TailCalls._

object Node {
  val Terminal: Node = Node(true, Map.empty)

  def encode(word: String): Node =
    word
      .reverse
      .foldLeft(Node.Terminal) {
        (node, letter) =>
          Node(false, Map(letter -> node))
      }
}

final case class Node(
    isWord: Boolean
  , nodeByLetter: Map[Char, Node]
) {
  require(
    isWord || nodeByLetter.nonEmpty
    , s"either isWord [$isWord] or nodeByLetter.nonEmpty [${nodeByLetter.nonEmpty}] must be true")

  def find(chars: String): Boolean = {
    @tailrec
    def recursive(charsRemaining: String = chars, node: Node = this): Boolean =
      charsRemaining match {
        case "" => node.isWord
        case charsRemainder =>
          node.nodeByLetter.get(charsRemainder.head) match {
            case Some(nextNode) => recursive(charsRemainder.tail, nextNode)
            case None => false
          }
      }

    recursive()
  }

  def merge(that: Node): Node = {
    def sequence[A](listTailRecA: List[TailRec[A]]): TailRec[List[A]] =
      listTailRecA
        .reverse
        .foldLeft(done(Nil): TailRec[List[A]]) {
          (tailRecListA, tailRecA) =>
            tailRecA map ((_: A) :: (_: List[A])).curried flatMap tailRecListA.map
        }

    def recursive(cursor: (Node, Node) = (this, that)): TailRec[Node] = {
      cursor match {
        case (Node.Terminal, Node.Terminal) =>
          done(Node.Terminal)
        case (left, Node.Terminal) =>
          if (left.isWord)
            done(left)
          else
            done(left.copy(isWord = true))
        case (Node.Terminal, right) =>
          if (right.isWord)
            done(right)
          else
            done(right.copy(isWord = true))
        case (left, right) =>
          val lettersToMerge =
            left.nodeByLetter
              .keySet
              .filter(
                letter =>
                  right.nodeByLetter.keySet.contains(letter)
                    && (left.nodeByLetter(letter) != right.nodeByLetter(letter)))
          if (lettersToMerge.isEmpty)
            done(Node(
                left.isWord || right.isWord
              , right.nodeByLetter ++ left.nodeByLetter))
          else {
            val nodeKeysAll = (left.nodeByLetter.keySet ++ right.nodeByLetter.keySet)
              .toList
              .sorted
            val listTailRecNode = nodeKeysAll
              .map(
                letter =>
                  if (lettersToMerge.contains(letter))
                    tailcall(recursive(left.nodeByLetter(letter), right.nodeByLetter(letter)))
                  else
                    done(left.nodeByLetter.getOrElse(letter, right.nodeByLetter(letter)))
              )
            val tailRecListNode = sequence(listTailRecNode)

            tailRecListNode
              .map(
                ln => {
                  val nodeByLetter = ln.zip(nodeKeysAll).map(_.swap).toMap

                  Node(
                      left.isWord || right.isWord
                    , nodeByLetter
                  )
                })
          }
      }
    }

    recursive().result
  }
}
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.