Saturday, July 20, 2024

Binary Trees with In-order Iterators (Part 2)

This is the sixth blog post in a series about developing correct implementations of basic data structures and algorithms using the Deduce language and proof checker.

This post continues were we left off from the previous post in which we implemented binary trees and in-order tree iterators.

Our goal in this post is to prove that we correctly implemented the iterator operations:

ti2tree : < E > fn TreeIter<E> -> Tree<E>
ti_first : < E > fn Tree<E>,E,Tree<E> -> TreeIter<E>
ti_get : < E > fn TreeIter<E> -> E
ti_next : < E > fn TreeIter<E> -> TreeIter<E>
ti_index : < E > fn(TreeIter<E>) -> Nat

The first operation, ti2tree, requires us to first obtain a tree iterator, for example, with ti_first, so ti2tree does not have a correctness criteria all of its own, but instead the proof of its correctness will be part of the correctness of the other operations.

So we skip to the proof of correctness for ti_first.

Correctness of ti_first

Let us make explicit the specification of ti_first:

Specification: The ti_first(A, x, B) function returns an iterator pointing to the first node, with respect to in-order traversal, of the tree TreeNode(A, x, B).

Also, recall that we said the following about ti2tree and ti_first: creating an iterator from a tree using ti_first and then applying ti2tree produces the original tree.

So we have two properties to prove about ti_first. For the first property, we need a way to formalize "the first node with respect to in-order traversal". This is where the ti_index operation comes in. If ti_first returns the first node, then its index should be 0. (One might worry that if ti_index is incorrect, then this property would not force ti_first to be correct. Not to worry, we will prove that ti_index is correct!) So we have the following theorem:

theorem ti_first_index: all E:type, A:Tree<E>, x:E, B:Tree<E>.
  ti_index(ti_first(A, x, B)) = 0
proof
  arbitrary E:type, A:Tree<E>, x:E, B:Tree<E>
  definition ti_first
  ?
end

After expanding the definition of ti_first, we are left with the following goal. So we need to prove a lemma about the first_path auxiliary function.

    ti_index(first_path(A,x,B,empty)) = 0

Here is a first attempt to formulate the lemma.

lemma first_path_index: all E:type. all A:Tree<E>. all y:E, B:Tree<E>.
  ti_index(first_path(A,y,B, empty)) = 0

However, because first_path is recursive, we will need to prove this by recursion on A. But looking at the second clause of in the definition of first_path, the path argument grows, so our induction hypothesis, which requires the path argument to be empty, will not be applicable. As is often the case, we need to generalize the lemma. Let’s replace empty with an arbitrary path as follows.

lemma first_path_index: all E:type. all A:Tree<E>. all y:E, B:Tree<E>, path:List<Direction<E>>.
  ti_index(first_path(A,y,B, path)) = 0

But now this lemma is false. Consider the following situation in which the current node y is 5 and the path is L,R (going from node 5 up to node 3).

Diagram for lemma first path index

The index of node 5 is not 0, it is 5! Instead the index of node 5 is equal to the number of nodes that come before 5 according to in-order travesal. We can obtain that portion of the tree using functions that we have already defined, in particular take_path followed by plug_tree. So we can formulate the lemma as follows.

lemma first_path_index: all E:type. all A:Tree<E>. all y:E, B:Tree<E>, path:List<Direction<E>>.
  ti_index(first_path(A,y,B, path)) = num_nodes(plug_tree(take_path(path), EmptyTree))
proof
  arbitrary E:type
  induction Tree<E>
  case EmptyTree {
    arbitrary y:E, B:Tree<E>, path:List<Direction<E>>
    ?
  }
  case TreeNode(L, x, R) suppose IH {
    arbitrary y:E, B:Tree<E>, path:List<Direction<E>>
    ?
  }
end

For the case A = EmptyTree, the goal simply follows from the definitions of first_path, ti_index, and ti_take.

    conclude ti_index(first_path(EmptyTree,y,B,path))
           = num_nodes(plug_tree(take_path(path),EmptyTree))
                by definition {first_path, ti_index, ti_take}.

For the case A = TreeNode(L, x, R), after expanding the definition of first_path, we need to prove:

  ti_index(first_path(L,x,R,node(LeftD(y,B),path)))
= num_nodes(plug_tree(take_path(path),EmptyTree))

But that follows from the induction hypothesis and the definition of take_path.

    definition {first_path}
    equations
          ti_index(first_path(L,x,R,node(LeftD(y,B),path)))
        = num_nodes(plug_tree(take_path(node(LeftD(y,B),path)),EmptyTree))
                by IH[x, R, node(LeftD(y,B), path)]
    ... = num_nodes(plug_tree(take_path(path),EmptyTree))
                by definition take_path.

Here is the completed proof of the first_path_index lemma.

lemma first_path_index: all E:type. all A:Tree<E>. all y:E, B:Tree<E>, path:List<Direction<E>>.
  ti_index(first_path(A,y,B, path)) = num_nodes(plug_tree(take_path(path), EmptyTree))
proof
  arbitrary E:type
  induction Tree<E>
  case EmptyTree {
    arbitrary y:E, B:Tree<E>, path:List<Direction<E>>
    conclude ti_index(first_path(EmptyTree,y,B,path))
           = num_nodes(plug_tree(take_path(path),EmptyTree))
                by definition {first_path, ti_index, ti_take}.
  }
  case TreeNode(L, x, R) suppose IH {
    arbitrary y:E, B:Tree<E>, path:List<Direction<E>>
    definition {first_path}
    equations
          ti_index(first_path(L,x,R,node(LeftD(y,B),path)))
        = num_nodes(plug_tree(take_path(node(LeftD(y,B),path)),EmptyTree))
                by IH[x, R, node(LeftD(y,B), path)]
    ... = num_nodes(plug_tree(take_path(path),EmptyTree))
                by definition take_path.
  }
end

Returning to the proof of ti_first_index, we need to prove that ti_index(first_path(A,x,B,empty)) = 0. So we apply the first_path_index lemma and then the definitions of take_path, plug_tree, and num_nodes. Here is the completed proof of ti_first_index.

theorem ti_first_index: all E:type, A:Tree<E>, x:E, B:Tree<E>.
  ti_index(ti_first(A, x, B)) = 0
proof
  arbitrary E:type, A:Tree<E>, x:E, B:Tree<E>
  definition ti_first
  equations  ti_index(first_path(A,x,B,empty))
           = num_nodes(plug_tree(take_path(empty),EmptyTree))
                       by first_path_index[E][A][x,B,empty]
       ... = 0      by definition {take_path, plug_tree, num_nodes}.
end

Our next task is to prove that creating an iterator from a tree using ti_first and then applying ti2tree produces the original tree.

theorem ti_first_stable: all E:type, A:Tree<E>, x:E, B:Tree<E>.
  ti2tree(ti_first(A, x, B)) = TreeNode(A, x, B)
proof
  arbitrary E:type, A:Tree<E>, x:E, B:Tree<E>
  definition ti_first
  ?
end

After expanding the definition of ti_first, we are left to prove that

ti2tree(first_path(A,x,B,empty)) = TreeNode(A,x,B)

So we need to prove another lemma about first_path and again we need to generalize the empty path to an arbitrary path. Let us consider again the situation where the current node x is 5.

Diagram for lemma first path index

The result of first_path(A,x,B,path) will be the path to node 4, and the result of ti2tree will be the whole tree, not just TreeNode(A,x,B) as in the above equation. However, we can construct the whole tree from the path and TreeNode(A,x,B) using the plug_tree function. So we have the following lemma to prove.

lemma first_path_stable:
  all E:type. all A:Tree<E>. all y:E, B:Tree<E>, path:List<Direction<E>>.
  ti2tree(first_path(A, y, B, path)) = plug_tree(path, TreeNode(A, y, B))
proof
  arbitrary E:type
  induction Tree<E>
  case EmptyTree {
    arbitrary y:E, B:Tree<E>, path:List<Direction<E>>
    ?
  }
  case TreeNode(L, x, R) suppose IH_L, IH_R {
    arbitrary y:E, B:Tree<E>, path:List<Direction<E>>
    ?
  }
end

In the case A = EmptyTree, we prove the equation using the definitions of first_path and ti2tree.

    equations  ti2tree(first_path(EmptyTree,y,B,path))
             = ti2tree(TrItr(path,EmptyTree,y,B))       by definition first_path.
         ... = plug_tree(path,TreeNode(EmptyTree,y,B))  by definition ti2tree.

In the case A = TreeNode(L, x, R), we need to prove that

  ti2tree(first_path(TreeNode(L,x,R),y,B,path))
= plug_tree(path,TreeNode(TreeNode(L,x,R),y,B))

We probably need to expand the definition of first_path, but doing so in your head is hard. So we can instead ask Deduce to do it. We start by constructing an equation with a bogus right-hand side and apply the definition of first_path.

    equations
          ti2tree(first_path(TreeNode(L,x,R),y,B,path))
        = EmptyTree
             by definition first_path ?
    ... = plug_tree(path,TreeNode(TreeNode(L,x,R),y,B))
             by ?

Deduce responds with

incomplete proof
Goal:
    ti2tree(first_path(L,x,R,node(LeftD(y,B),path))) = EmptyTree

in which the left-hand side has expanded the definition of first_path. So we cut and paste that into our proof and move on to the next step.

    equations
          ti2tree(first_path(TreeNode(L,x,R),y,B,path))
        = ti2tree(first_path(L,x,R,node(LeftD(y,B),path)))
             by definition first_path.
    ... = plug_tree(path,TreeNode(TreeNode(L,x,R),y,B))
             by ?

We now have something that matches the induction hypothesis, so we instantiate it and ask Deduce to tell us the new right-hand side.

    equations
          ti2tree(first_path(TreeNode(L,x,R),y,B,path))
        = ti2tree(first_path(L,x,R,node(LeftD(y,B),path)))
             by definition first_path.
    ... = EmptyTree
             by IH_L[x,R,node(LeftD(y,B),path)]
    ... = plug_tree(path,TreeNode(TreeNode(L,x,R),y,B))
             by ?

Deduce responds with

expected
ti2tree(first_path(L,x,R,node(LeftD(y,B),path))) = EmptyTree
but only have
ti2tree(first_path(L,x,R,node(LeftD(y,B),path))) = plug_tree(node(LeftD(y,B),path),TreeNode(L,x,R))

So we cut and paste the right-hand side of the induction hypothesis to replace EmptyTree.

    equations
          ti2tree(first_path(TreeNode(L,x,R),y,B,path))
        = ti2tree(first_path(L,x,R,node(LeftD(y,B),path)))
             by definition first_path.
    ... = plug_tree(node(LeftD(y,B),path),TreeNode(L,x,R))
             by IH_L[x,R,node(LeftD(y,B),path)]
    ... = plug_tree(path,TreeNode(TreeNode(L,x,R),y,B))
             by ?

The final step of the proof is easy; we just apply the definition of plug_tree. Here is the completed proof of first_path_stable.

lemma first_path_stable:
  all E:type. all A:Tree<E>. all y:E, B:Tree<E>, path:List<Direction<E>>.
  ti2tree(first_path(A, y, B, path)) = plug_tree(path, TreeNode(A, y, B))
proof
  arbitrary E:type
  induction Tree<E>
  case EmptyTree {
    arbitrary y:E, B:Tree<E>, path:List<Direction<E>>
    equations  ti2tree(first_path(EmptyTree,y,B,path))
             = ti2tree(TrItr(path,EmptyTree,y,B))       by definition first_path.
         ... = plug_tree(path,TreeNode(EmptyTree,y,B))  by definition ti2tree.
  }
  case TreeNode(L, x, R) suppose IH_L, IH_R {
    arbitrary y:E, B:Tree<E>, path:List<Direction<E>>
    equations
          ti2tree(first_path(TreeNode(L,x,R),y,B,path))
        = ti2tree(first_path(L,x,R,node(LeftD(y,B),path)))
             by definition first_path.
    ... = plug_tree(node(LeftD(y,B),path),TreeNode(L,x,R))
             by IH_L[x,R,node(LeftD(y,B),path)]
    ... = plug_tree(path,TreeNode(TreeNode(L,x,R),y,B))
             by definition plug_tree.
  }
end

Returning to the ti_first_stable theorem, the equation follows from our first_path_stable lemma and the definition of plug_tree.

theorem ti_first_stable: all E:type, A:Tree<E>, x:E, B:Tree<E>.
  ti2tree(ti_first(A, x, B)) = TreeNode(A, x, B)
proof
  arbitrary E:type, A:Tree<E>, x:E, B:Tree<E>
  definition ti_first
  equations  ti2tree(first_path(A,x,B,empty))
           = plug_tree(empty,TreeNode(A,x,B))  by first_path_stable[E][A][x,B,empty]
       ... = TreeNode(A,x,B)                   by definition plug_tree.
end

Correctness of ti_next

We start by writing down a more careful specification of ti_next.

Specification: The ti_next(iter) operation returns an iterator whose position is one more than the position of iter with respect to in-order traversal, assuming the iter is not at the end of the in-order traversal.

To make this specification formal, we can again use ti_index to talk about the position of the iterator. So we begin to prove the following theorem ti_next_index, taking the usual initial steps in the proof as guided by the formula to be proved and the definition of ti_next, which performs a switch on the right child R of the current node.

theorem ti_next_index: all E:type, iter : TreeIter<E>.
  if suc(ti_index(iter)) < num_nodes(ti2tree(iter))
  then ti_index(ti_next(iter)) = suc(ti_index(iter))
proof
  arbitrary E:type, iter : TreeIter<E>
  suppose prem: suc(ti_index(iter)) < num_nodes(ti2tree(iter))
  switch iter {
    case TrItr(path, L, x, R) suppose iter_eq {
      definition ti_next
      switch R {
        case EmptyTree suppose R_eq {
          ?
        }
        case TreeNode(RL, y, RR) suppose R_eq {
          ?
        }
      }
    }
  }
end

In the case R = EmptyTree, ti_next calls the auxiliary function next_up and we need to prove.

ti_index(next_up(path,L,x,EmptyTree)) = suc(ti_index(TrItr(path,L,x,EmptyTree)))

As usual, we must create a lemma that generalizes this equation.

Proving the next_up_index lemma

Looking at the definition of next_up, we see that the recursive call grows the fourth argument, so we must replace the EmptyTree in the needed equation with an arbitrary tree R:

ti_index(next_up(path,L,x,R)) = suc(ti_index(TrItr(path,L,x,R)))

But this equation is not true in general. Consider the situation below where the current node x is node 1 in our example tree. The index of the next_up from node 1 is 3, but the index of node 1 is 1 and of course, adding one to that is 2, not 3!

Diagram for path to node 1

So we need to change this equation to account for the situation where R is not empty, but instead an arbitrary subtree. The solution is to add the number of nodes in R to the right-hand side:

ti_index(next_up(path,L,x,R)) = suc(ti_index(TrItr(path,L,x,R))) + num_nodes(R)

One more addition is necessary to formulate the lemma. The above equation is only meaningful when the index on the right-hand side is in bounds. That is, it must be smaller than the number of nodes in the tree. So we formula the lemma next_up_index as follows and take a few obvious steps into the proof.

lemma next_up_index: all E:type. all path:List<Direction<E>>. all A:Tree<E>, x:E, B:Tree<E>.
  if suc(ti_index(TrItr(path, A, x, B)) + num_nodes(B)) < num_nodes(ti2tree(TrItr(path, A, x, B)))
  then ti_index(next_up(path, A, x, B)) = suc(ti_index(TrItr(path, A,x,B)) + num_nodes(B))
proof
  arbitrary E:type
  induction List<Direction<E>>
  case empty {
    arbitrary A:Tree<E>, x:E, B:Tree<E>
    suppose prem: suc(ti_index(TrItr(empty,A,x,B)) + num_nodes(B)) 
                  < num_nodes(ti2tree(TrItr(empty,A,x,B)))
    ?
  }
  case node(f, path') suppose IH {
    arbitrary A:Tree<E>, x:E, B:Tree<E>
    suppose prem
    switch f {
      case LeftD(y, R) {
        ?
      }
      case RightD(L, y) suppose f_eq {
        ?
      }
    }
  }
end

In the case path = empty, the premise is false because there are no nodes that come afterwards in the in-order traversal. In particular, the premise implies the following contradictory inequality.

    have AB_l_AB: suc(num_nodes(A) + num_nodes(B)) < suc(num_nodes(A) + num_nodes(B))
      by definition {ti_index, ti_take, take_path, plug_tree, ti2tree, num_nodes} 
         in prem
    conclude false  by apply less_irreflexive to AB_l_AB

Next consider the case path = node(LeftD(y, R), path'). After expanding all the relevant definitions, we need to prove that

  num_nodes(plug_tree(take_path(path'), TreeNode(A,x,B))) 
= suc(num_nodes(plug_tree(take_path(path'), A)) + num_nodes(B))

We need a lemma that relates num_nodes and plug_tree. So we pause the current proof for the following exercise.

Exercise: prove the num_nodes_plug lemma

lemma num_nodes_plug: all E:type. all path:List<Direction<E>>. all t:Tree<E>.
  num_nodes(plug_tree(path, t)) = num_nodes(plug_tree(path, EmptyTree)) + num_nodes(t)

Back to the next_up_index lemma

We use num_nodes_plug on both the left and right-hand sides of the equation, and apply the definition of num_nodes.

    rewrite num_nodes_plug[E][take_path(path')][TreeNode(A,x,B)]
    rewrite num_nodes_plug[E][take_path(path')][A]
    definition num_nodes

After that it suffices to prove the following.

  num_nodes(plug_tree(take_path(path'),EmptyTree)) + suc(num_nodes(A) + num_nodes(B)) 
= suc((num_nodes(plug_tree(take_path(path'),EmptyTree)) + num_nodes(A)) + num_nodes(B))

This equation is rather big, so let’s squint at it by giving names to its parts. (This is a new version of define that I’m experimenting with.)

    define_ X = num_nodes(plug_tree(take_path(path'),EmptyTree))
    define_ Y = num_nodes(A)
    define_ Z = num_nodes(B)

Now it’s easy to see that our goal is true using some simple arithmetic.

    conclude X + suc(Y + Z) = suc((X + Y) + Z)
        by rewrite add_suc[X][Y+Z] | add_assoc[X][Y,Z].

Finally, consider the case path = node(RightD(L, y), path'). After expanding the definition of next_up, we need to prove

  ti_index(next_up(path',L,y,TreeNode(A,x,B))) 
= suc(ti_index(TrItr(node(RightD(L,y),path'),A,x,B)) + num_nodes(B))

The left-hand side matches the induction hypothesis, so we have

    equations
      ti_index(next_up(path',L,y,TreeNode(A,x,B))) 
        = suc(ti_index(TrItr(path',L,y,TreeNode(A,x,B))) + num_nodes(TreeNode(A,x,B)))
            by apply IH[L,y,TreeNode(A,x,B)] 
               to definition {ti_index, ti_take, num_nodes, ti2tree} ?
    ... = suc(ti_index(TrItr(node(RightD(L,y),path'),A,x,B)) + num_nodes(B))
            by ?

But we need to prove the premise of the induction hypothesis. We can do that as follows, with many uses of num_nodes_plug and some arithmetic that we package up into lemma XYZW_equal.

    have IH_prem: suc(num_nodes(plug_tree(take_path(path'),L)) 
                      + suc(num_nodes(A) + num_nodes(B))) 
                  < num_nodes(plug_tree(path',TreeNode(L,y,TreeNode(A,x,B))))
      by rewrite num_nodes_plug[E][take_path(path')][L]
          | num_nodes_plug[E][path'][TreeNode(L,y,TreeNode(A,x,B))]
         definition {num_nodes, num_nodes}
         define_ X = num_nodes(plug_tree(take_path(path'),EmptyTree))
         define_ Y = num_nodes(L) define_ Z = num_nodes(A) define_ W = num_nodes(B)
         define_ P = num_nodes(plug_tree(path',EmptyTree))
         suffices suc((X + Y) + suc(Z + W)) < P + suc(Y + suc(Z + W))
         have prem2: suc((X + suc(Y + Z)) + W) < P + suc(Y + suc(Z + W))
           by enable {X,Y,Z,W,P}
              definition {num_nodes, num_nodes} in
              rewrite num_nodes_plug[E][take_path(path')][TreeNode(L,y,A)]
                    | num_nodes_plug[E][path'][TreeNode(L,y,TreeNode(A,x,B))] in
              definition {ti_index, ti_take, take_path, ti2tree, plug_tree} in
              rewrite f_eq in prem
         rewrite XYZW_equal[X,Y,Z,W]
         prem2

Here is the proof of XYZW_equal.

lemma XYZW_equal: all X:Nat, Y:Nat, Z:Nat, W:Nat.
  suc((X + Y) + suc(Z + W)) = suc((X + suc(Y + Z)) + W)
proof
  arbitrary X:Nat, Y:Nat, Z:Nat, W:Nat
  enable {operator+}
  equations
        suc((X + Y) + suc(Z + W))
      = suc(suc(X + Y) + (Z + W))      by rewrite add_suc[X+Y][Z+W].
  ... = suc(suc(((X + Y) + Z) + W))    by rewrite add_assoc[X+Y][Z,W].
  ... = suc(suc((X + (Y + Z)) + W))    by rewrite add_assoc[X][Y,Z].
  ... = suc((X + suc(Y + Z)) + W)      by rewrite add_suc[X][Y+Z].
end

Getting back to the equational proof, it remains to prove that

  suc(ti_index(TrItr(path',L,y,TreeNode(A,x,B))) + num_nodes(TreeNode(A,x,B)))
= suc(ti_index(TrItr(node(RightD(L,y),path'),A,x,B)) + num_nodes(B))

which we can do with yet more uses of num_nodes_plug and XYZW_equal.

    ... = suc(num_nodes(plug_tree(take_path(path'),L)) + suc(num_nodes(A) + num_nodes(B)))
          by definition {ti_index, ti_take, num_nodes}.
    ... = suc((num_nodes(plug_tree(take_path(path'),EmptyTree)) + num_nodes(L))
              + suc(num_nodes(A) + num_nodes(B)))
          by rewrite num_nodes_plug[E][take_path(path')][L].
    ... = suc((num_nodes(plug_tree(take_path(path'),EmptyTree)) 
              + suc(num_nodes(L) + num_nodes(A))) + num_nodes(B))
          by define_ X = num_nodes(plug_tree(take_path(path'),EmptyTree))
             define_ Y = num_nodes(L) define_ Z = num_nodes(A) define_ W = num_nodes(B)
             define_ P = num_nodes(plug_tree(path',EmptyTree))
             conclude suc((X + Y) + suc(Z + W)) = suc((X + suc(Y + Z)) + W)
                 by XYZW_equal[X,Y,Z,W]
    ... = suc(num_nodes(plug_tree(take_path(path'),TreeNode(L,y,A))) + num_nodes(B))
          by rewrite num_nodes_plug[E][take_path(path')][TreeNode(L,y,A)]
             definition {num_nodes, num_nodes}.
    ... = suc(ti_index(TrItr(node(RightD(L,y),path'),A,x,B)) + num_nodes(B))
          by definition {ti_index, ti_take, take_path, plug_tree}.

That completes the last case of the proof of next_up_index. Here’s the completed proof.

lemma next_up_index: all E:type. all path:List<Direction<E>>. all A:Tree<E>, x:E, B:Tree<E>.
  if suc(ti_index(TrItr(path, A, x, B)) + num_nodes(B)) < num_nodes(ti2tree(TrItr(path, A, x, B)))
  then ti_index(next_up(path, A, x, B)) = suc(ti_index(TrItr(path, A,x,B)) + num_nodes(B))
proof
  arbitrary E:type
  induction List<Direction<E>>
  case empty {
    arbitrary A:Tree<E>, x:E, B:Tree<E>
    suppose prem: suc(ti_index(TrItr(empty,A,x,B)) + num_nodes(B)) 
                  < num_nodes(ti2tree(TrItr(empty,A,x,B)))
    have AB_l_AB: suc(num_nodes(A) + num_nodes(B)) < suc(num_nodes(A) + num_nodes(B))
      by definition {ti_index, ti_take, take_path, plug_tree, ti2tree, num_nodes} 
         in prem
    conclude false  by apply less_irreflexive to AB_l_AB
  }
  case node(f, path') suppose IH {
    arbitrary A:Tree<E>, x:E, B:Tree<E>
    suppose prem
    switch f {
      case LeftD(y, R) {
        definition {next_up, ti_index, ti_take, take_path}
        rewrite num_nodes_plug[E][take_path(path')][TreeNode(A,x,B)]
        rewrite num_nodes_plug[E][take_path(path')][A]
        definition num_nodes
        define_ X = num_nodes(plug_tree(take_path(path'),EmptyTree))
        define_ Y = num_nodes(A)
        define_ Z = num_nodes(B)
        conclude X + suc(Y + Z) = suc((X + Y) + Z)
            by rewrite add_suc[X][Y+Z] | add_assoc[X][Y,Z].
      }
      case RightD(L, y) suppose f_eq {
        definition {next_up}
        have IH_prem: suc(num_nodes(plug_tree(take_path(path'),L)) 
                          + suc(num_nodes(A) + num_nodes(B))) 
                      < num_nodes(plug_tree(path',TreeNode(L,y,TreeNode(A,x,B))))
          by rewrite num_nodes_plug[E][take_path(path')][L]
              | num_nodes_plug[E][path'][TreeNode(L,y,TreeNode(A,x,B))]
             definition {num_nodes, num_nodes}
             define_ X = num_nodes(plug_tree(take_path(path'),EmptyTree))
             define_ Y = num_nodes(L) define_ Z = num_nodes(A) define_ W = num_nodes(B)
             define_ P = num_nodes(plug_tree(path',EmptyTree))
             suffices suc((X + Y) + suc(Z + W)) < P + suc(Y + suc(Z + W))
             have prem2: suc((X + suc(Y + Z)) + W) < P + suc(Y + suc(Z + W))
               by enable {X,Y,Z,W,P}
                  definition {num_nodes, num_nodes} in
                  rewrite num_nodes_plug[E][take_path(path')][TreeNode(L,y,A)]
                        | num_nodes_plug[E][path'][TreeNode(L,y,TreeNode(A,x,B))] in
                  definition {ti_index, ti_take, take_path, ti2tree, plug_tree} in
                  rewrite f_eq in prem
             rewrite XYZW_equal[X,Y,Z,W]
             prem2
        equations
              ti_index(next_up(path',L,y,TreeNode(A,x,B))) 
            = suc(ti_index(TrItr(path',L,y,TreeNode(A,x,B))) + num_nodes(TreeNode(A,x,B)))
                by apply IH[L,y,TreeNode(A,x,B)] 
                   to definition {ti_index, ti_take, num_nodes, ti2tree} IH_prem
        ... = suc(num_nodes(plug_tree(take_path(path'),L)) + suc(num_nodes(A) + num_nodes(B)))
              by definition {ti_index, ti_take, num_nodes}.
        ... = suc((num_nodes(plug_tree(take_path(path'),EmptyTree)) + num_nodes(L))
                  + suc(num_nodes(A) + num_nodes(B)))
              by rewrite num_nodes_plug[E][take_path(path')][L].
        ... = suc((num_nodes(plug_tree(take_path(path'),EmptyTree)) 
                  + suc(num_nodes(L) + num_nodes(A))) + num_nodes(B))
              by define_ X = num_nodes(plug_tree(take_path(path'),EmptyTree))
                 define_ Y = num_nodes(L) define_ Z = num_nodes(A) define_ W = num_nodes(B)
                 define_ P = num_nodes(plug_tree(path',EmptyTree))
                 conclude suc((X + Y) + suc(Z + W)) = suc((X + suc(Y + Z)) + W)
                     by XYZW_equal[X,Y,Z,W]
        ... = suc(num_nodes(plug_tree(take_path(path'),TreeNode(L,y,A))) + num_nodes(B))
              by rewrite num_nodes_plug[E][take_path(path')][TreeNode(L,y,A)]
                 definition {num_nodes, num_nodes}.
        ... = suc(ti_index(TrItr(node(RightD(L,y),path'),A,x,B)) + num_nodes(B))
              by definition {ti_index, ti_take, take_path, plug_tree}.
      }
    }
  }
end

Back to the proof of ti_next_index

With the next_up_index lemma complete, we can get back to proving the ti_next_index theorem. Recall that we were in the case R = EmptyTree and needed to prove the following.

ti_index(next_up(path,L,x,EmptyTree)) = suc(ti_index(TrItr(path,L,x,EmptyTree)))

To use the next_up_index lemma, we need to prove its premise:

    have next_up_index_prem:
        suc(ti_index(TrItr(path,L,x,EmptyTree)) + num_nodes(EmptyTree))
        < num_nodes(ti2tree(TrItr(path,L,x,EmptyTree)))
      by enable num_nodes
         rewrite add_zero[ti_index(TrItr(path,L,x,EmptyTree))]
         rewrite iter_eq | R_eq in prem

We can finish the proof of the equation using the definition of num_nodes and the add_zero property.

    equations
          ti_index(next_up(path,L,x,EmptyTree))
        = suc(ti_index(TrItr(path,L,x,EmptyTree)) + num_nodes(EmptyTree))
          by apply next_up_index[E][path][L, x, EmptyTree] to next_up_index_prem
    ... = suc(ti_index(TrItr(path,L,x,EmptyTree)))
          by definition num_nodes
             rewrite add_zero[ti_index(TrItr(path,L,x,EmptyTree))].

The next case in the proof of ti_next_index is for R = TreeNode(RL, y, RR). We need to prove

  ti_index(first_path(RL,y,RR,node(RightD(L,x),path))) 
= suc(ti_index(TrItr(path,L,x,TreeNode(RL,y,RR))))

We can start by applying the first_path_index lemma, which gives us

equations
      ti_index(first_path(RL,y,RR,node(RightD(L,x),path))) 
    = num_nodes(plug_tree(take_path(node(RightD(L,x),path)),EmptyTree))

We have opportunities to expand take_path and then plug_tree.

... = num_nodes(plug_tree(take_path(path),TreeNode(L,x,EmptyTree)))
        by definition {take_path,plug_tree}.

We can separate out the TreeNode(L,x,EmptyTree) using num_nodes_plug.

... = num_nodes(plug_tree(take_path(path),EmptyTree)) + suc(num_nodes(L))
        by rewrite num_nodes_plug[E][take_path(path)][TreeNode(L,x,EmptyTree)]
           definition {num_nodes, num_nodes}
           rewrite add_zero[num_nodes(L)].

Then we can move the L back into the plug_tree with num_nodes_plug.

... = suc(num_nodes(plug_tree(take_path(path),L)))
       by rewrite add_suc[num_nodes(plug_tree(take_path(path),EmptyTree))][num_nodes(L)]
          rewrite num_nodes_plug[E][take_path(path)][L].

We conclude the equational reasoning with the definition of ti_index and ti_take.

... = suc(ti_index(TrItr(path,L,x,TreeNode(RL,y,RR))))
        by definition {ti_index, ti_take}.

Here is the complete proof of ti_next_index.

theorem ti_next_index: all E:type, iter : TreeIter<E>.
  if suc(ti_index(iter)) < num_nodes(ti2tree(iter))
  then ti_index(ti_next(iter)) = suc(ti_index(iter))
proof
  arbitrary E:type, iter : TreeIter<E>
  suppose prem: suc(ti_index(iter)) < num_nodes(ti2tree(iter))
  switch iter {
    case TrItr(path, L, x, R) suppose iter_eq {
      definition ti_next
      switch R {
        case EmptyTree suppose R_eq {
          have next_up_index_prem:
              suc(ti_index(TrItr(path,L,x,EmptyTree)) + num_nodes(EmptyTree))
              < num_nodes(ti2tree(TrItr(path,L,x,EmptyTree)))
            by enable num_nodes
               rewrite add_zero[ti_index(TrItr(path,L,x,EmptyTree))]
               rewrite iter_eq | R_eq in prem
          equations
                ti_index(next_up(path,L,x,EmptyTree))
              = suc(ti_index(TrItr(path,L,x,EmptyTree)) + num_nodes(EmptyTree))
                by apply next_up_index[E][path][L, x, EmptyTree] to next_up_index_prem
          ... = suc(ti_index(TrItr(path,L,x,EmptyTree)))
                by definition num_nodes
                   rewrite add_zero[ti_index(TrItr(path,L,x,EmptyTree))].
        }
        case TreeNode(RL, y, RR) suppose R_eq {
          equations
                ti_index(first_path(RL,y,RR,node(RightD(L,x),path))) 
              = num_nodes(plug_tree(take_path(node(RightD(L,x),path)),EmptyTree))
                  by first_path_index[E][RL][y,RR,node(RightD(L,x),path)]
          ... = num_nodes(plug_tree(take_path(path),TreeNode(L,x,EmptyTree)))
                  by definition {take_path,plug_tree}.
          ... = num_nodes(plug_tree(take_path(path),EmptyTree)) + suc(num_nodes(L))
                  by rewrite num_nodes_plug[E][take_path(path)][TreeNode(L,x,EmptyTree)]
                     definition {num_nodes, num_nodes}
                     rewrite add_zero[num_nodes(L)].
          ... = suc(num_nodes(plug_tree(take_path(path),L)))
                 by rewrite add_suc[num_nodes(plug_tree(take_path(path),EmptyTree))][num_nodes(L)]
                    rewrite num_nodes_plug[E][take_path(path)][L].
          ... = suc(ti_index(TrItr(path,L,x,TreeNode(RL,y,RR))))
                  by definition {ti_index, ti_take}.

        }
      }
   }
  }
end

Proof of ti_next_stable

The second correctness condition for ti_next(iter) is that it is stable with respect to ti2tree. Following the definition of ti_next, we switch on the iterator and then on the right child of the current node.

theorem ti_next_stable: all E:type, iter:TreeIter<E>.
  ti2tree(ti_next(iter)) = ti2tree(iter)
proof
  arbitrary E:type, iter:TreeIter<E>
  switch iter {
    case TrItr(path, L, x, R) {
      switch R {
        case EmptyTree {
          definition {ti2tree, ti_next}
          ?
        }
        case TreeNode(RL, y, RR) {
          definition {ti2tree, ti_next}
          ?
        }
      }
    }
  }
end

For the case R = EmptyTree, we need to prove the following, which amounts to proving that next_up is stable.

ti2tree(next_up(path,L,x,EmptyTree)) = plug_tree(path,TreeNode(L,x,EmptyTree))

We’ll pause the current proof to prove the next_up_stable lemma.

Exercise: next_up_stable lemma

lemma next_up_stable: all E:type. all path:List<Direction<E>>. all A:Tree<E>, y:E, B:Tree<E>.
  ti2tree(next_up(path, A, y, B)) = plug_tree(path, TreeNode(A,y,B))

Back to ti_next_stable

Now we conclude the R = EmptyTree case of the ti_next_stable theorem.

    conclude ti2tree(next_up(path,L,x,EmptyTree))
       = plug_tree(path,TreeNode(L,x,EmptyTree))
      by next_up_stable[E][path][L,x,EmptyTree]

In the case R = TreeNode(RL, y, RR), we need prove the following, which is to say that first_path is stable. Thankfully we already proved that lemma!

    conclude ti2tree(first_path(RL,y,RR,node(RightD(L,x),path))) 
           = plug_tree(path,TreeNode(L,x,TreeNode(RL,y,RR)))
      by rewrite first_path_stable[E][RL][y,RR,node(RightD(L,x),path)]
         definition {plug_tree}.

Here is the completed proof of ti_next_stable.

theorem ti_next_stable: all E:type, iter:TreeIter<E>.
  ti2tree(ti_next(iter)) = ti2tree(iter)
proof
  arbitrary E:type, iter:TreeIter<E>
  switch iter {
    case TrItr(path, L, x, R) {
      switch R {
        case EmptyTree {
          definition {ti2tree, ti_next}
          conclude ti2tree(next_up(path,L,x,EmptyTree))
             = plug_tree(path,TreeNode(L,x,EmptyTree))
            by next_up_stable[E][path][L,x,EmptyTree]
        }
        case TreeNode(RL, y, RR) {
          definition {ti2tree, ti_next}
          conclude ti2tree(first_path(RL,y,RR,node(RightD(L,x),path))) 
                 = plug_tree(path,TreeNode(L,x,TreeNode(RL,y,RR)))
            by rewrite first_path_stable[E][RL][y,RR,node(RightD(L,x),path)]
               definition {plug_tree}.
        }
      }
    }
  }
end

Correctness of ti_get and ti_index

Recall that ti_get(iter) should return the data in the current node of iter and ti_index should return the position of iter as a natural number with respect to in-order traversal. Thus, if we apply in_order to the tree, the element at position ti_index(iter) should be the same as ti_get(iter). So we have the following theorem to prove.

theorem ti_index_get_in_order: all E:type, iter:TreeIter<E>, a:E.
  ti_get(iter) = nth(in_order(ti2tree(iter)), a)(ti_index(iter))
proof
  arbitrary E:type, iter:TreeIter<E>, a:E
  switch iter {
    case TrItr(path, L, x, R) {
      definition {ti2tree, ti_get, ti_index, ti_take}
      ?
    }
  }
end

After expanding with some definitions, we are left to prove

x = nth(in_order(plug_tree(path,TreeNode(L,x,R))),a)
       (num_nodes(plug_tree(take_path(path),L)))

We see num_nodes applied to plug_tree, so we can use the num_nodes_plug lemma

      rewrite num_nodes_plug[E][take_path(path)][L]

The goal now is to prove

x = nth(in_order(plug_tree(path, TreeNode(L,x,R))),a)
       (num_nodes(plug_tree(take_path(path), EmptyTree)) + num_nodes(L))

The next step to take is not so obvious. Perhaps one hint is that we have the following theorem about nth from List.pf that also involves addition in the index argument of nth.

theorem nth_append_back: all T:type. all xs:List<T>. all ys:List<T>, i:Nat, d:T.
  nth(append(xs, ys), d)(length(xs) + i) = nth(ys, d)(i)

So we would need to prove a lemma that relates in_order and plug_tree to append. Now the take_path function returns the part of the tree before the path, so perhaps it can be used to create the xs in nth_append_back. But what about ys? It seems like we need a function that returns the part of the tree after the path. Let us call this function drop_path.

function drop_path<E>(List<Direction<E>>) -> List<Direction<E>> {
  drop_path(empty) = empty
  drop_path(node(f, path')) =
    switch f {
      case RightD(L, x) {
        drop_path(path')
      }
      case LeftD(x, R) {
        node(LeftD(x, R), drop_path(path'))
      }
    }
}

So using take_path and drop_path, we should be able to come up with an equation for in_order(plug_tree(path, TreeNode(A, x, B))). The part of tree before x should be take_path(path) followed by the subtree A. The part of the tree after x should be the subtree B followed by drop_path(path).

lemma in_order_plug_take_drop: all E:type. all path:List<Direction<E>>. all A:Tree<E>, x:E, B:Tree<E>.
  in_order(plug_tree(path, TreeNode(A, x, B)))
  = append(in_order(plug_tree(take_path(path), A)), 
           node(x, in_order(plug_tree(drop_path(path), B))))

It turns out that to prove this, we will also need a lemma about the combination of plug_tree and take_path:

lemma in_order_plug_take: all E:type. all path:List<Direction<E>>. all t:Tree<E>.
  in_order(plug_tree(take_path(path), t)) 
  = append( in_order(plug_tree(take_path(path),EmptyTree)), in_order(t))

and a lemma about the combination of plug_tree and drop_path:

lemma in_order_plug_drop: all E:type. all path:List<Direction<E>>. all t:Tree<E>.
  in_order(plug_tree(drop_path(path), t)) = append( in_order(t), in_order(plug_tree(drop_path(path),EmptyTree)))

Exercise: prove the in_order_plug... lemmas

Prove the three lemmas in_order_plug_take_drop, in_order_plug_take, and in_order_plug_drop.

Back to the proof of ti_index_get_in_order

Our goal was to prove

x = nth(in_order(plug_tree(path,TreeNode(L,x,R))), a)
       (num_nodes(plug_tree(take_path(path),EmptyTree)) + num_nodes(L))

So we use lemma in_order_plug_take_drop to get the following

  in_order(plug_tree(path,TreeNode(L,x,R)))
= append(in_order(plug_tree(take_path(path),L)), node(x, in_order(plug_tree(drop_path(path),R))))

and then lemma in_order_plug_take separates out the L.

  in_order(plug_tree(take_path(path), L))
= append(in_order(plug_tree(take_path(path),EmptyTree)), in_order(L))

So rewriting with the above equations

    rewrite in_order_plug_take_drop[E][path][L,x,R]
    rewrite in_order_plug_take[E][path][L]

transforms our goal to

x = nth(append(append(in_order(plug_tree(take_path(path),EmptyTree)), in_order(L)),
               node(x,in_order(plug_tree(drop_path(path),R)))),a)
       (num_nodes(plug_tree(take_path(path),EmptyTree)) + num_nodes(L))

Recall that our plan is to use the nth_append_back lemma, in which the index argument to nth is length(xs), but in the above we have the index expressed in terms of num_nodes. The following exercise proves a theorem that relates length and in_order to num_nodes.

Exercise: prove the length_in_order theorem

theorem length_in_order: all E:type. all t:Tree<E>.
  length(in_order(t)) = num_nodes(t)

Back to ti_index_get_in_order

Now we rewrite with the length_in_order lemma a couple times, give some short names to these big expressions, and apply length_append from List.pf.

      rewrite symmetric length_in_order[E][L]
            | symmetric length_in_order[E][plug_tree(take_path(path),EmptyTree)]
      define_ X = in_order(plug_tree(take_path(path),EmptyTree))
      define_ Y = in_order(L)
      define_ Z = in_order(plug_tree(drop_path(path),R))
      rewrite symmetric length_append[E][X][Y]

Now we’re in a position to use nth_append_back.

x = nth(append(append(X,Y), node(x, Z)), a)
       (length(append(X,Y)))

In particular, nth_append_back[E][append(X,Y)][node(x,Z), 0, a] gives us

  nth(append(append(X,Y), node(x,Z)),a)(length(append(X,Y)) + 0) 
= nth(node(x,Z),a)(0)

With that we prove the goal using add_zero and the definition of nth.

  conclude x = nth(append(append(X,Y), node(x,Z)), a)(length(append(X,Y)))
    by rewrite (rewrite add_zero[length(append(X,Y))] in
                nth_append_back[E][append(X,Y)][node(x,Z), 0, a])
       definition nth.

Here is the complete proof of ti_index_get_in_order.

theorem ti_index_get_in_order: all E:type, z:TreeIter<E>, a:E.
  ti_get(z) = nth(in_order(ti2tree(z)), a)(ti_index(z))
proof
  arbitrary E:type, z:TreeIter<E>, a:E
  switch z {
    case TrItr(path, L, x, R) {
      definition {ti2tree, ti_get, ti_index, ti_take}
      rewrite num_nodes_plug[E][take_path(path)][L]
      
      suffices x = nth(in_order(plug_tree(path,TreeNode(L,x,R))),a)
                      (num_nodes(plug_tree(take_path(path),EmptyTree)) + num_nodes(L))
      rewrite in_order_plug_take_drop[E][path][L,x,R]
      rewrite in_order_plug_take[E][path][L]
      
      suffices x = nth(append(append(in_order(plug_tree(take_path(path),EmptyTree)),
                                     in_order(L)),
                              node(x,in_order(plug_tree(drop_path(path),R)))),a)
                      (num_nodes(plug_tree(take_path(path),EmptyTree)) + num_nodes(L))
      rewrite symmetric length_in_order[E][L]
            | symmetric length_in_order[E][plug_tree(take_path(path),EmptyTree)]
      define_ X = in_order(plug_tree(take_path(path),EmptyTree))
      define_ Y = in_order(L)
      define_ Z = in_order(plug_tree(drop_path(path),R))
      rewrite symmetric length_append[E][X][Y]
      
      conclude x = nth(append(append(X,Y), node(x,Z)), a)(length(append(X,Y)))
        by rewrite (rewrite add_zero[length(append(X,Y))] in
                    nth_append_back[E][append(X,Y)][node(x,Z), 0, a])
           definition nth.
    }
  }
end

This concludes the proofs of correctness for in-order iterator and the five operations ti2tree, ti_first, ti_get, ti_next, and ti_index.

Exercise: Prove that ti_prev is correct

In the previous post there was an exercise to implement ti_prev, which moves the iterator backwards one position with respect to in-order traversal. This exercise is to prove that your implementation of ti_prev is correct. There are two theorems to prove. The first one makes sure that ti_prev reduces the index of the iterator by one.

theorem ti_prev_index: all E:type, iter : TreeIter<E>.
  if 0 < ti_index(iter)
  then ti_index(ti_prev(iter)) = pred(ti_index(iter))

The second theorem makes sure that the resulting iterator is still an iterator for the same tree.

theorem ti_prev_stable: all E:type, iter:TreeIter<E>.
  ti2tree(ti_prev(iter)) = ti2tree(iter)

Thursday, July 18, 2024

Binary Trees with In-order Iterators (Part 1)

This is the fifth blog post in a series about developing correct implementations of basic data structures and algorithms using the Deduce language and proof checker.

In this blog post we study binary trees, that is, trees in which each node has at most two children. We study the in-order tree traversal, as that will become important when we study binary search trees. Furthermore, we implement tree iterators that keep track of a location within the tree and can move forward with respect to the in-order traversal. We shall prove that our implementation of tree iterators is correct in Part 2 of this blog post.

Binary Trees

We begin by defining a union for binary trees:

union Tree<E> {
  EmptyTree
  TreeNode(Tree<E>, E, Tree<E>)
}

For example, we can represent the following binary tree

Diagram of a Binary Tree

with a bunch of tree nodes like so:

define T0 = TreeNode(EmptyTree, 0, EmptyTree)
define T2 = TreeNode(EmptyTree, 2, EmptyTree)
define T1 = TreeNode(T0, 1, T2)
define T4 = TreeNode(EmptyTree, 4, EmptyTree)
define T5 = TreeNode(T4, 5, EmptyTree)
define T7 = TreeNode(EmptyTree, 7, EmptyTree)
define T6 = TreeNode(T5, 6, T7)
define T3 = TreeNode(T1, 3, T6)

We define the height of a tree with the following recursive function.

function height<E>(Tree<E>) -> Nat {
  height(EmptyTree) = 0
  height(TreeNode(L, x, R)) = suc(max(height(L), height(R)))
}

The example tree has height 4.

assert height(T3) = 4

We count the number of nodes in a binary tree with the num_nodes function.

function num_nodes<E>(Tree<E>) -> Nat {
  num_nodes(EmptyTree) = 0
  num_nodes(TreeNode(L, x, R)) = suc(num_nodes(L) + num_nodes(R))
}

The example tree has 8 nodes.

assert num_nodes(T3) = 8

In-order Tree Traversal

Now for the main event of this blog post, the in-order tree traversal. The idea of this traversal is that for each node in the tree, we follow this recipe:

  1. process the left subtree
  2. process the current node
  3. process the right subtree

What it means to process a node can be different for different instantiations of the in-order traversal. But to make things concrete, we study an in-order traversal that produces a list. So here is our definition of the in_order function.

function in_order<E>(Tree<E>) -> List<E> {
  in_order(EmptyTree) = empty
  in_order(TreeNode(L, x, R)) = append(in_order(L), node(x, in_order(R)))
}

The result of in_order for T3 is the list 0,1,2,3,4,5,6,7. As you can see, we chose the data values in T3 to match their position within the in-order traversal.

assert in_order(T3) = interval(8, 0)

In-order Tree Iterators

A tree iterator keeps track of a position with a tree. Our goal is to create a data structure to represent a tree iterator and also to implement the following operations on iterators, which we describe in the following paragraph.

ti2tree : < E > fn TreeIter<E> -> Tree<E>
ti_first : < E > fn Tree<E>,E,Tree<E> -> TreeIter<E>
ti_get : < E > fn TreeIter<E> -> E
ti_next : < E > fn TreeIter<E> -> TreeIter<E>
ti_index : < E > fn(TreeIter<E>) -> Nat
  • The ti2tree operator returns the tree that the iterator is traversing.

  • The ti_first operator returns an iterator pointing to the first node (with respect to the in-order traversal) of a non-empty tree. We represent non-empty trees with three things: the left subtree, the data in the root node, and the right subtree.

  • The ti_get operator returns the data of the node at the current position.

  • The ti_next operator moves the iterator forward by one position.

  • The ti_index operator returns the position of the iterator as a natural number.

Here is an example of creating an iterator for T3 and moving it forward.

define iter0 = ti_first(T1, 3, T6)
assert ti_get(iter0) = 0
assert ti_index(iter0) = 0

define iter3 = ti_next(ti_next(ti_next(iter0)))
assert ti_get(iter3) = 3
assert ti_index(iter3) = 3

define iter7 = ti_next(ti_next(ti_next(ti_next(iter3))))
assert ti_get(iter7) = 7
assert ti_index(iter7) = 7

Iterator Representation

We represent a position in the tree by recording a path of left-or-right decisions. For example, to represent the position of node 4 of the example tree, we record the path R,L,L (R for right and L for left).

Diagram of the iterator at position 4

When we come to implement the ti_next operation, we will sometimes need to climb the tree. For example, to get from 4 to 5. To make that easier, we will store the path in reverse. So the path to node 4 will be stored as L,L,R.

It would seem natural to store an iterator’s path separately from the tree, but doing so would complicate many of the upcoming proofs because only certain paths make sense for certain trees. Instead, we combine the path and the tree into a single data structure called a zipper (Huet, The Zipper, Journal of Functional Programming, Vol 7. Issue 5, 1997). The idea is to attach extra data to the left and right decisions and to store the subtree at the current position. So we define a union named Direction with constructors for left and right, and we define a union named TreeIter that contains a path and the non-empty tree at the current position.

union Direction<E> {
  LeftD(E, Tree<E>)
  RightD(Tree<E>, E)
}

union TreeIter<E> {
  TrItr(List<Direction<E>>, Tree<E>, E, Tree<E>)
}

The ti2tree Operation

Of the tree iterator operations, we will first implement ti2tree because it will help to explain this zipper-style representation. We start by defining the auxiliary function plug_tree, which reconstructs a tree from a path and the subtree at the specified position. The plug_tree function is defined by recursion on the path, so it moves upward in the tree with each recursive call. Consider the case for LeftD(x, R) below. To plug tree t into the path node(LeftD(x, R), path'), we used the extra data stored in LeftD(x, R) to create TreeNode(t, x, R) which we then pass to the recursive call, to plug the new tree node into the rest of the path.

function plug_tree<E>(List<Direction<E>>, Tree<E>) -> Tree<E> {
  plug_tree(empty, t) = t
  plug_tree(node(f, path'), t) =
    switch f {
      case LeftD(x, R) {
        plug_tree(path', TreeNode(t, x, R))
      }
      case RightD(L, x) {
        plug_tree(path', TreeNode(L, x, t))
      }
    }
}

The ti2tree operator simply invokes plug_tree.

function ti2tree<E>(TreeIter<E>) -> Tree<E> {
  ti2tree(TrItr(path, L, x, R)) = plug_tree(path, TreeNode(L, x, R))
}

Creating an iterator from a tree using ti_first and then applying ti2tree produces the original tree. Furthermore, moving an iterator does not change the tree that it is traversing, so ti2tree returns T3 for iterators iter0, iter3, and iter7.

assert ti2tree(iter0) = T3
assert ti2tree(iter3) = T3
assert ti2tree(iter7) = T3

The ti_first Operation

Recall that the ti_first operation returns an iterator pointing to the first node (with respect to the in-order traversal) of a non-empty tree. For example, applying ti_first to T3 should give us node 0. The idea to implement ti_first is simple: we walk down the tree going left at each step, until we get to a leaf.

To implement ti_first we define the auxiliary function first_path that takes a non-empty tree and the path-so-far and proceeds going to the left down the tree. (The first_path function will also come in handy when implementing ti_next.)

function first_path<E>(Tree<E>, E, Tree<E>, List<Direction<E>>) -> TreeIter<E> {
  first_path(EmptyTree, x, R, path) = TrItr(path, EmptyTree, x, R)
  first_path(TreeNode(LL, y, LR), x, R, path) = first_path(LL, y, LR, node(LeftD(x, R), path))
}

We implement ti_first simply as a call to first_path where the path-so-far is empty.

define ti_first : < E > fn Tree<E>,E,Tree<E> -> TreeIter<E>
    = λ L,x,R { first_path(L, x, R, empty) }

As promised above, applying ti_first to T3 gives us node 0.

assert ti_get(ti_first(T1, 3, T6)) = 0

The ti_get Operation

Recall that the ti_get operator should return the data of the node at the current position. This is straightforward to implement because that data is stored directly in the tree iterator.

function ti_get<E>(TreeIter<E>) -> E {
  ti_get(TrItr(path, L, x, R)) = x
}

The ti_next Operation

Recall that the ti_next operator moves the iterator forward by one position with respect to the in-order traversal. This operation is non-trivial to implement. Consider again our example tree.

Diagram of a Binary Tree

Suppose the current node is 2. Then the next node is 3, which requires climbing a fair ways up the tree. On the other hand, if the current node is 3, then the next node is 4, way back down the tree. So there are two different scenarios that we need to handle.

  1. If the current node has a right child, then the next node is the first node of the right child’s subtree (with respect to in-order traversal). For example, node 3 has right child 6, and the first node of that subtree is 4.

  2. If the current node does not have a right child, then the next node is the ancestor after the first left branch. For example, node 2 does not have a right child, so we go up the tree. We go up to 1 via a right branch and then up to 3 via a left branch, so 3 is the next node of 2.

For (1) we already have first_path, so we just need an auxiliary function for (2), which we call next_up. This function takes a path and the current non-empty subtree and returns the iterator for the next position. If the direction is RightD, we keep going up the tree. If the direction is LeftD(x, R), we stop and return an iterator for the parent node x.

function next_up<E>(List<Direction<E>>, Tree<E>, E, Tree<E>) -> TreeIter<E> {
  next_up(empty, A, z, B) = TrItr(empty, A, z, B)
  next_up(node(f, path'), A, z, B) =
    switch f {
      case RightD(L, x) {
        next_up(path', L, x, TreeNode(A, z, B))
      }
      case LeftD(x, R) {
        TrItr(path', TreeNode(A, z, B), x, R)
      }
    }
}

Now that we have both next_up and first_path, we implement ti_next by checking whether the right child R is empty. If it is, we invoke next_up, and if not, we invoke first_path.

function ti_next<E>(TreeIter<E>) -> TreeIter<E> {
  ti_next(TrItr(path, L, x, R)) =
    switch R {
      case EmptyTree {
        next_up(path, L, x, R)
      }
      case TreeNode(RL, y, RR) {
        first_path(RL, y, RR, node(RightD(L, x), path))
      }
    }
}

To see ti_next in action, in the following we go from position 2 up to position 3 and then back down to position 4.

define iter2 = ti_next(ti_next(iter0))
assert ti_get(iter2) = 2

define iter3_ = ti_next(iter2)
assert ti_get(iter3_) = 3

define iter4 = ti_next(iter3_)
assert ti_get(iter4) = 4

The ti_index Operation

Recall that the ti_index operator returns the position of the iterator as a natural number. More specifically, ti_index returns the position of the current node with respect to the in the in-order traversal. The following demonstrates this invariant on iter0 and iter7.

define L0 = in_order(ti2tree(iter0))
define i0 = ti_index(iter0)
assert ti_get(iter0) = nth(L0, 42)(i0)

define L7 = in_order(ti2tree(iter7))
define i7 = ti_index(iter7)
assert ti_get(iter7) = nth(L7, 42)(i7)

The idea for implementing ti_index is that we’ll count how many nodes are in the portion of the tree that comes before the current position. We define an auxiliary function that constructs this portion of the tree, calling it ti_take because it is reminiscent of the take(n, ls) function in List.pf, which returns the prefix of list ls of length n. Furthermore, we use a second auxiliary function named take_path that applies this idea to the path of the iterator. So to implement the take_path function, we throw away the subtrees to the right of the path (by removing LeftD(x, R)) and we keep the subtrees to the left of the path (by keeping Right(L, x)).

function take_path<E>(List<Direction<E>>) -> List<Direction<E>> {
  take_path(empty) = empty
  take_path(node(f, path')) =
    switch f {
      case RightD(L, x) {
        node(RightD(L,x), take_path(path'))
      }
      case LeftD(x, R) {
        take_path(path')
      }
    }
}

We implement ti_take by applying take_path to the path of the iterator, and then plug the left subtree L into the result. (The node x and subtree R are not before node x with respect to in-order traversal.)

function ti_take<E>(TreeIter<E>) -> Tree<E> {
  ti_take(TrItr(path, L, x, R)) = plug_tree(take_path(path), L)
}

Finally, we implement ti_index by counting the number of nodes in the tree returned by ti_take.

define ti_index : < E > fn(TreeIter<E>) -> Nat = λ iter { num_nodes(ti_take(iter))}

Exercise: Implement and test the ti_prev Operation

The ti_prev operation (for previous) moves the iterator backward by one position with respect to in-order traversal.

ti_prev : < E > fn TreeIter<E> -> TreeIter<E>

Implement and test the ti_prev operation.

Conclusion

This completes the implementation of the 5 tree iterator operations. In Part 2 of this blog post, we will prove that these operations are correct.

Sunday, June 30, 2024

Merge Sort with Leftovers, Correctly

Merge Sort with Leftovers

This is the fourth blog post in a series about developing correct implementations of basic data structures and algorithms using the Deduce language and proof checker.

In this blog post we study a fast sorting algorithm, Merge Sort. This classic algorithm splits the input list in half, recursively sorts each half, and then merges the two results back into a single sorted list.

The specification of Merge Sort is the same as Insertion Sort.

Specification: The merge_sort(xs) function returns a list that contains the same elements as xs but the elements in the result are in sorted order.

We follow the write-test-prove approach to develop a correct implementation of merge_sort.

Write the merge_sort function

The classic implementation of merge_sort would be something like the following.

function merge_sort(List<Nat>) -> List<Nat> {
  merge_sort(empty) = empty
  merge_sort(node(x,xs')) =
    let p = split(node(x,xs'))
    merge(merge_sort(first(p)), merge_sort(second(p)))
}

Unfortunately, Deduce rejects the above function definition because Deduce uses a very simple restriction to ensure the termination of recursive function, which is that a recursive call may only be made on a part of the input. In this case, the recursive call may only be applied to the sublist xs', not first(p) or second(p).

How can we work around this restriction? There’s an old trick that goes by many names (gas, fuel, etc.), which is to add another parameter of type Nat and use that for termination. Let us use the name msort for the following, and then we define merge_sort in terms of msort.

function msort(Nat, List<Nat>) -> List<Nat> {
  msort(0, xs) = xs
  msort(suc(n'), xs) =
    let p = split(xs)
    merge(msort(n', first(p)), msort(n', second(p)))
}

define merge_sort : fn List<Nat> -> List<Nat>
  = λxs{ msort(log(length(xs)), xs) }

In the above definition of merge_sort, we need to suppply enough gas so that msort won’t prematurely run out. Here we use the logarithm (base 2, rounding up) defined in Log.pf.

This definition of merge_sort and msort is fine, it has O(n log(n)) time complexity, so it is efficient. However, the use of split rubs me the wrong way because it requires traversing half of the input list. The use of split is necessary if one wanted to use parallelism to speed up the code, performing the two recursive calls in parallel. However, we are currently only interested in a single-threaded implementation.

Suppose you just finished baking a pie and intend to eat half now and half tomorrow night. One approach would be to split it in half and then eat one of the halves. Another approach is to just start eating the pie and stop when half of it is gone. That’s the approach that we will take with the next version of msort.

Specification The msort(n,xs) function sorts the first min(2ⁿ,length(xs)) many elements of xs and returns a pair containing (1) the sorted list and (2) the leftovers that were not yet sorted.

function msort(Nat, List<Nat>) -> Pair< List<Nat>, List<Nat> > {
  msort(0, xs) =
    switch xs {
      case empty { pair(empty, empty) }
      case node(x, xs') { pair(node(x, empty), xs') }
    }
  msort(suc(n'), xs) =
    let p1 = msort(n', xs)
    let p2 = msort(n', second(p1))
    let ys = first(p1)
    let zs = first(p2)
    pair(merge(length(ys) + length(zs), ys, zs), second(p2))
}

In the above case for suc(n'), the first recursive call to msort produces the pair p1 that includes a sorted list and the leftovers. We sort the leftovers with the second recursive call to msort. We return (1) the merge of the two sorted sublists and (2) the leftovers from the second recursive call to msort.

With the code for msort complete, we can turn to merge_sort. Similar to the previous version, we involke msort with the input list xs and use the logarithm of list length for the gas. This msort returns a pair, with the sorted results in the first component. The second component of the pair is an empty list because we supplied enough gas.

define merge_sort : fn List<Nat> -> List<Nat>
    = λxs{ first(msort(log(length(xs)), xs)) }

So far, we have neglected the implementation of merge. Here’s its specification.

Specification: The merge(xs,ys) function takes two sorted lists and returns a sorted list that contains just the elements from the two input lists.

Here’s the classic implementation of merge. The idea is to compare the two elements at the front of each list and use the lower of the two as the first element of the output. Then do the recursive call with the two lists, minus the element that was chosen. Again, we use an extra gas parameter to ensure termination. To ensure that we have enough gas, we will choose the sum of the lengths of the two input lists.

function merge(Nat, List<Nat>, List<Nat>) -> List<Nat> {
  merge(0, xs, ys) = empty
  merge(suc(n), xs, ys) =
    switch xs {
      case empty { ys }
      case node(x, xs') {
        switch ys {
          case empty {
            node(x, xs')
          }
          case node(y, ys') {
            if x ≤ y then
              node(x, merge(n, xs', node(y, ys')))
            else
              node(y, merge(n, node(x, xs'), ys'))
          }
        }
     }
   }
}

Test

We have three functions to test, merge, msort and merge_sort.

Test merge

We test that the result of merge is sorted and that it contains all the elements from the two input lists, which we check using count.

define L_1337 = node(1, node(3, node(3, node(7, empty))))
define L_2348 = node(2, node(3, node(4, node(8, empty))))
define L_12333478 = merge(length(L_1337) + length(L_2348), L_1337, L_2348)
assert sorted(L_12333478)
assert all_elements(append(L_1337, L_2348),
  λx{count(L_1337)(x) + count(L_2348)(x) = count(L_12333478)(x) })

Test msort

In the following tests, we vary the gas from 0 to 3, varying how much of the input list L18 gets sorted in the call to msort. The take(n,xs) function returns the first n elements of xs and drop(n,xs) drops the first n elements of xs and returns the remaining portion of xs.

define L18 = append(L_1337, L_2348)

define p0 = msort(0, L18)
define t0 = take(pow2(0), L18)
define d0 = drop(pow2(0), L18)
assert sorted(first(p0))
assert all_elements(t0, λx{count(t0)(x) = count(first(p0))(x) })
assert all_elements(d0, λx{count(d0)(x) = count(second(p0))(x) })

define p1 = msort(1, L18)
define t1 = take(pow2(1), L18)
define d1 = drop(pow2(1), L18)
assert sorted(first(p1))
assert all_elements(t1, λx{count(t1)(x) = count(first(p1))(x) })
assert all_elements(d1, λx{count(d1)(x) = count(second(p1))(x) })

define p2 = msort(2, L18)
define t2 = take(pow2(2), L18)
define d2 = drop(pow2(2), L18)
assert sorted(first(p2))
assert all_elements(t2, λx{count(t2)(x) = count(first(p2))(x) })
assert all_elements(d2, λx{count(d2)(x) = count(second(p2))(x) })

define p3 = msort(3, L18)
define t3 = take(pow2(3), L18)
define d3 = drop(pow2(3), L18)
assert sorted(first(p3))
assert all_elements(t3, λx{count(t3)(x) = count(first(p3))(x) })
assert all_elements(d3, λx{count(d3)(x) = count(second(p3))(x) })

Test merge_sort

Next we test that merge_sort returns a sorted list that contains the same elements as the input list. For input, we reuse the list L18 from above.

define s_L18 = merge_sort(L18)
assert sorted(s_L18)
assert all_elements(t0, λx{count(L18)(x) = count(s_L18)(x) })

We can bundle several tests, with varying-length inputs, into one assert by using all_elements and interval.

assert all_elements(interval(3, 0),
    λn{ let xs = reverse(interval(n, 0))
        let ls = merge_sort(xs)
        sorted(ls) and
        all_elements(xs, λx{count(xs)(x) = count(ls)(x)})
    })

Prove

Compared to the proof of correctness for insertion_sort, we have considerably more work to do for merge_sort. Instead of two functions, we have three functions to consider: merge, msort, and merge_sort. Furthermore, these functions are more complex than insert and insertion_sort. Nevertheless, we are up to the challenge!

Prove correctness of merge

The specificaiton of merge has two parts, one part saying that the elements of the output must be the elements of the two input lists, and the another part saying that the output must be sorted, provided the two input lists are sorted.

Here is how we state the theorem for the first part.

theorem mset_of_merge: all n:Nat. all xs:List<Nat>, ys:List<Nat>.
  if length(xs) + length(ys) = n
  then mset_of(merge(n, xs, ys)) = mset_of(xs) ⨄ mset_of(ys)

Here is the theorem stating that the output of merge is sort.

theorem merge_sorted: all n:Nat. all xs:List<Nat>, ys:List<Nat>.
  if sorted(xs) and sorted(ys) and
     length(xs) + length(ys) = n
  then sorted(merge(n, xs, ys))

Prove the mset_of_merge theorem

We begin with the proof of mset_of_merge. Because merge(n, xs, ys) is recursive on the natural number n, we proceed by induction on Nat.

  induction Nat
  case 0 {
    arbitrary xs:List<Nat>, ys:List<Nat>
    suppose prem: length(xs) + length(ys) = 0
    ?
  }
  case suc(n') suppose IH {
    ?
  }

In the case for n = 0, we need to prove

  mset_of(merge(0,xs,ys)) = mset_of(xs) ⨄ mset_of(ys)

and merge(0,xs,ys) returns empty, so we need to show that mset_of(xs) ⨄ mset_of(ys) is the empty multiset. From the premise prem, both xs and ys must be empty.

  have lxs_lys_z: length(xs) = 0 and length(ys) = 0
    by apply add_to_zero[length(xs)][length(ys)] to prem
  have xs_mt: xs = empty
    by apply length_zero_empty[Nat,xs] to lxs_lys_z
  have ys_mt: ys = empty
    by apply length_zero_empty[Nat,ys] to lxs_lys_z

After rewriting with those equalities and applying the definition of merge and mset_of:

  rewrite xs_mt | ys_mt
  definition {merge, mset_of}

it remains to prove m_fun(λ{0}) = m_fun(λ{0}) ⨄ m_fun(λ{0}) (the sum of two empty multisets is the empty multiset), which we prove with the theorem m_sum_empty from MultiSet.pf.

  symmetric m_sum_empty[Nat, m_fun(λx{0}) :MultiSet<Nat>]

In the case for n = suc(n'), we need to prove

  mset_of(merge(suc(n'),xs,ys)) = mset_of(xs) ⨄ mset_of(ys)

Looking a the suc clause of merge, there is a switch on xs and then on ys. So our proof will be structured analogously.

  switch xs {
    case empty {
      ?
    }
    case node(x, xs') suppose xs_xxs {
      ?
    }
  }

In the case for xs = empty, we conclude simply by use of the definitions of merge and mset_of and the fact that combining mset_of(ys) with the empty multiset produces mset_of(ys).

  case empty {
    definition {merge, mset_of}
    conclude mset_of(ys) = m_fun(λx{0}) ⨄ mset_of(ys)
      by symmetric empty_m_sum[Nat, mset_of(ys)]
  }

In the case for xs = node(x, xs'), merge performs a switch on ys, so our proof does too.

  switch ys {
    case empty {
      ?
    }
    case node(y, ys') suppose ys_yys {
      ?
    }

The case for ys = empty, is similar to the case for xs = empty. We conclude by use of the definitions of merge and mset_of and the fact that combining mset_of(ys) with the empty multiset produces mset_of(ys).

  definition {merge, mset_of}
  conclude m_one(x) ⨄ mset_of(xs')
         = m_one(x) ⨄ mset_of(xs') ⨄ m_fun(λ{0})
    by rewrite m_sum_empty[Nat, m_one(x) ⨄ mset_of(xs')].

In the case for ys = node(y, ys'), we continue to follow the structure of merge and switch on x ≤ y.

  definition merge
  switch x ≤ y {
    case true suppose xy_true {
      ?
    }
    case false suppose xy_false {
      ?
    }
  }

In the case for (x ≤ y) = true, the goal becomes

m_one(x) ⨄ mset_of(merge(n',xs',node(y,ys'))) 
= m_one(x) ⨄ mset_of(xs') ⨄ m_one(y) ⨄ mset_of(ys')

Which follows from the induction hypothesis instantiated with xs' and node(y,ys').

  mset_of(merge(n',xs',node(y,ys')))
= mset_of(xs') ⨄ mset_of(node(y, ys'))

Filling in the details, we prove this case as follows.

  case true suppose xy_true {
    definition mset_of
    have sxs_sys_sn: suc(length(xs')) + suc(length(ys')) = suc(n')
      by enable length rewrite xs_xxs | ys_yys in prem
    have len_xs_yys: length(xs') + length(node(y,ys')) = n'
      by enable {operator +,length}
         injective suc sxs_sys_sn
    have IH': mset_of(merge(n',xs',node(y,ys')))
            = mset_of(xs') ⨄ mset_of(node(y, ys'))
      by apply IH[xs', node(y, ys')] to len_xs_yys
    rewrite IH'
    definition mset_of
    rewrite m_sum_assoc[Nat, m_one(x), mset_of(xs'),
                        (m_one(y) ⨄ mset_of(ys'))].
  }

In the case for (x ≤ y) = false, the goal becomes

  m_one(y) ⨄ mset_of(merge(n',node(x,xs'),ys')) 
= m_one(x) ⨄ mset_of(xs') ⨄ m_one(y) ⨄ mset_of(ys')

The induction hypothesis instantiated with node(x,xs') and ys' is

  mset_of(merge(n',node(x,xs'),ys'))
= mset_of(node(x,xs')) ⨄ mset_of(ys')

So the goal follows from the fact that multiset sum is associative and commutative.

theorem mset_of_merge: all n:Nat. all xs:List<Nat>, ys:List<Nat>.
  if length(xs) + length(ys) = n
  then mset_of(merge(n, xs, ys)) = mset_of(xs) ⨄ mset_of(ys)
proof
  induction Nat
  case 0 {
    arbitrary xs:List<Nat>, ys:List<Nat>
    suppose prem: length(xs) + length(ys) = 0
    have lxs_lys_z: length(xs) = 0 and length(ys) = 0
      by apply add_to_zero[length(xs)][length(ys)] to prem
    have xs_mt: xs = empty
      by apply length_zero_empty[Nat,xs] to lxs_lys_z
    have ys_mt: ys = empty
      by apply length_zero_empty[Nat,ys] to lxs_lys_z
    rewrite xs_mt | ys_mt
    definition {merge, mset_of}
    symmetric m_sum_empty[Nat, m_fun(λx{0}) :MultiSet<Nat>]
  }
  case suc(n') suppose IH {
    arbitrary xs:List<Nat>, ys:List<Nat>
    suppose prem: length(xs) + length(ys) = suc(n')
    switch xs {
      case empty {
        definition {merge, mset_of}
        conclude mset_of(ys) = m_fun(λx{0}) ⨄ mset_of(ys)
          by symmetric empty_m_sum[Nat, mset_of(ys)]
      }
      case node(x, xs') suppose xs_xxs {
        switch ys {
          case empty {
            definition {merge, mset_of}
            conclude m_one(x) ⨄ mset_of(xs')
                   = m_one(x) ⨄ mset_of(xs') ⨄ m_fun(λ{0})
              by rewrite m_sum_empty[Nat, m_one(x) ⨄ mset_of(xs')].
          }
          case node(y, ys') suppose ys_yys {
            definition merge
            switch x ≤ y {
              case true suppose xy_true {
                definition mset_of
                have sxs_sys_sn: suc(length(xs')) + suc(length(ys')) = suc(n')
                  by enable length rewrite xs_xxs | ys_yys in prem
                have len_xs_yys: length(xs') + length(node(y,ys')) = n'
                  by enable {operator +,length}
                     injective suc sxs_sys_sn
                have IH': mset_of(merge(n',xs',node(y,ys')))
                        = mset_of(xs') ⨄ mset_of(node(y, ys'))
                  by apply IH[xs', node(y, ys')] to len_xs_yys
                rewrite IH'
                definition mset_of
                rewrite m_sum_assoc[Nat, m_one(x), mset_of(xs'),
                                    (m_one(y) ⨄ mset_of(ys'))].
              }
              case false suppose xy_false {
                definition mset_of
                have sxs_sys_sn: suc(length(xs')) + suc(length(ys')) = suc(n')
                  by enable length rewrite xs_xxs | ys_yys in prem
                have len_xxs_ys: length(node(x,xs')) + length(ys') = n'
                  by enable {operator +,length}
                     injective suc
                     rewrite add_suc[length(xs')][length(ys')] in
                     sxs_sys_sn
                have IH': mset_of(merge(n',node(x,xs'),ys'))
                        = mset_of(node(x,xs')) ⨄ mset_of(ys')
                  by apply IH[node(x,xs'), ys'] to len_xxs_ys
                equations
                        m_one(y) ⨄ mset_of(merge(n',node(x,xs'),ys'))
                      = m_one(y) ⨄ ((m_one(x) ⨄ mset_of(xs')) ⨄ mset_of(ys'))
                      by rewrite IH' definition mset_of.
                  ... = m_one(y) ⨄ (m_one(x) ⨄ (mset_of(xs') ⨄ mset_of(ys')))
                      by rewrite m_sum_assoc[Nat, m_one(x), mset_of(xs'),
                                             mset_of(ys')].
                  ... = (m_one(y) ⨄ m_one(x)) ⨄ (mset_of(xs') ⨄ mset_of(ys'))
                      by rewrite m_sum_assoc[Nat, m_one(y), m_one(x),
                               (mset_of(xs') ⨄ mset_of(ys'))].
                  ... = (m_one(x) ⨄ m_one(y)) ⨄ (mset_of(xs') ⨄ mset_of(ys'))
                      by rewrite m_sum_commutes[Nat, m_one(x), m_one(y)].
                  ... = m_one(x) ⨄ (m_one(y) ⨄ (mset_of(xs') ⨄ mset_of(ys')))
                      by rewrite m_sum_assoc[Nat, m_one(x), m_one(y),
                          (mset_of(xs') ⨄ mset_of(ys'))].
                  ... = m_one(x) ⨄ ((m_one(y) ⨄ mset_of(xs')) ⨄ mset_of(ys'))
                      by rewrite m_sum_assoc[Nat, m_one(y), mset_of(xs'),
                          mset_of(ys')].
                  ... = m_one(x) ⨄ ((mset_of(xs') ⨄ m_one(y)) ⨄ mset_of(ys'))
                      by rewrite m_sum_commutes[Nat, m_one(y), mset_of(xs')].
                  ... = m_one(x) ⨄ (mset_of(xs') ⨄ (m_one(y) ⨄ mset_of(ys')))
                      by rewrite m_sum_assoc[Nat, mset_of(xs'), m_one(y),
                         mset_of(ys')].
                  ... = (m_one(x) ⨄ mset_of(xs')) ⨄ (m_one(y) ⨄ mset_of(ys'))
                      by rewrite m_sum_assoc[Nat, m_one(x), mset_of(xs'),
                          (m_one(y) ⨄ mset_of(ys'))].
              }
            }
          }
        }
      }
    }
  }
end

The mset_of_merge theorem also holds for sets, using the set_of function. We prove the following set_of_merge theorem as a corollary of mset_of_merge.

theorem set_of_merge: all xs:List<Nat>, ys:List<Nat>.
  set_of(merge(length(xs) + length(ys), xs, ys)) = set_of(xs) ∪ set_of(ys)
proof
  arbitrary xs:List<Nat>, ys:List<Nat>
  have mset_of_merge: mset_of(merge(length(xs) + length(ys), xs, ys))
                    = mset_of(xs) ⨄ mset_of(ys)
    by apply mset_of_merge[length(xs) + length(ys)][xs, ys] to .
  equations
    set_of(merge(length(xs) + length(ys), xs, ys))
        = set_of_mset(mset_of(merge(length(xs) + length(ys), xs, ys)))
          by symmetric som_mset_eq_set[Nat]
                             [merge(length(xs) + length(ys), xs, ys)]
    ... = set_of_mset(mset_of(xs)) ∪ set_of_mset(mset_of(ys))
          by rewrite mset_of_merge  som_union[Nat,mset_of(xs),mset_of(ys)]
    ... = set_of(xs) ∪ set_of(ys)
          by rewrite som_mset_eq_set[Nat][xs] | som_mset_eq_set[Nat][ys].
end

Prove the merge_sorted theorem

Next up is the merge_sorted theorem. The structure of the proof will be similar to the one for mset_of_merge, because they both follow the structure of merge. So begin with induction on Nat.

theorem merge_sorted: all n:Nat. all xs:List<Nat>, ys:List<Nat>.
  if sorted(xs) and sorted(ys) and length(xs) + length(ys) = n
  then sorted(merge(n, xs, ys))
proof
  induction Nat
  case 0 {
    ?
  }
  case suc(n') suppose IH {
    arbitrary xs:List<Nat>, ys:List<Nat>
    suppose prem
    definition merge
    switch xs {
      case empty {
        ?
      }
      case node(x, xs') suppose xs_xxs {
        switch ys {
          case empty {
            ?
          }
          case node(y, ys') suppose ys_yys {
            switch x ≤ y {
              case true suppose xy_true {
                ?
              }
              case false suppose xy_false {
                ?
              }
            }
          }
        }
      }
    }
  }
end

In the case for n = 0, we need to prove sorted(merge(0, xs, ys)). But merge(0, xs, ys) = empty, and sorted(empty) is trivially true. So we conclude the case for n = 0 as follows.

  case 0 {
    arbitrary xs:List<Nat>, ys:List<Nat>
    suppose _
    definition merge
    conclude sorted(empty) by definition sorted.
  }

We move on to the case for n = suc(n') and xs = empty. Here merge returns ys, and we already know that ys is sorted from the premise.

    case empty {
      conclude sorted(ys) by prem
    }

In the case for xs = node(x, xs') and ys = empty, the merge function returns node(x, xs') (aka. xs), and we already know that xs is sorted from the premise.

  case empty {
    conclude sorted(node(x,xs'))  by rewrite xs_xxs in prem
  }

In the case for ys = node(y, ys') and (x ≤ y) = true, the merge function returns node(x, merge(n',xs',node(y,ys'))). So we need to prove the following.

  sorted(merge(n',xs',node(y,ys'))) and
  all_elements(merge(n',xs',node(y,ys')),λb{x ≤ b})

To prove the first, we invoke the induction hypothesis intantiated to xs' and node(y,ys') as follows.

  have s_xs: sorted(xs')
    by enable sorted rewrite xs_xxs in prem
  have s_yys: sorted(node(y,ys'))
    by rewrite ys_yys in prem
  have len_xs_yys: length(xs') + length(node(y,ys')) = n'
    by enable {operator +,length}
       have sxs: suc(length(xs')) + suc(length(ys')) = suc(n')
          by rewrite xs_xxs | ys_yys in prem
       injective suc sxs
  have IH_xs_yys: sorted(merge(n',xs',node(y,ys')))
    by apply IH[xs',node(y,ys')]
       to s_xs, s_yys, len_xs_yys

It remains to prove that x is less-or-equal to to all the elements in the rest of the output list:

  all_elements(merge(n',xs',node(y,ys')),λb{x ≤ b})

The theorem all_elements_eq_member in List.pf says

  all_elements(xs,P) = (all x:T. if x ∈ set_of(xs) then P(x))

which combined with the set_of_merge corollary above, simplifies our goal to

  all z:Nat. (if z ∈ set_of(xs') ∪ set_of(node(y,ys')) then x ≤ z)

So we have a few cases to consider and need to prove x ≤ z in each one. Consider the case where z ∈ set_of(xs'). Then we can deduce x ≤ z from the fact that node(x, xs') is sorted.

  have x_le_xs: all_elements(xs', λb{x ≤ b})
    by definition sorted in rewrite xs_xxs in prem
  conclude x ≤ z by
    apply all_elements_member[Nat][xs'][z, λb{x ≤ b}]
    to x_le_xs, z_in_xs

Next, consider the case where y = z. Then we can immediately conclude because x ≤ y.

Finally, consider when z ∈ set_of(ys'). Because node(y,ys') is sorted, we know y ≤ z. Then combined with x ≤ y, we conclude that x ≤ z by transitivity.

  have y_le_ys: all_elements(ys', λb{y ≤ b})
    by definition sorted in rewrite ys_yys in prem
  have y_z: y ≤ z
    by apply all_elements_member[Nat][ys'][z,λb{y ≤ b}]
       to y_le_ys, z_in_ys
  have x_y: x ≤ y by rewrite xy_true.
  conclude x ≤ z
    by apply less_equal_trans[x][y,z] to x_y, y_z

The last case to consider is for ys = node(y, ys') and (x ≤ y) = false. The reasoning is similar to the case for (x ≤ y) = true, so we’ll skip the detailed explanation.

Here’s the completed proof of merge_sorted.

theorem merge_sorted: all n:Nat. all xs:List<Nat>, ys:List<Nat>.
  if sorted(xs) and sorted(ys) and length(xs) + length(ys) = n
  then sorted(merge(n, xs, ys))
proof
  induction Nat
  case 0 {
    arbitrary xs:List<Nat>, ys:List<Nat>
    suppose _
    definition merge
    conclude sorted(empty) by definition sorted.
  }
  case suc(n') suppose IH {
    arbitrary xs:List<Nat>, ys:List<Nat>
    suppose prem
    definition merge
    switch xs {
      case empty {
        conclude sorted(ys) by prem
      }
      case node(x, xs') suppose xs_xxs {
        switch ys {
          case empty {
            conclude sorted(node(x,xs'))  by rewrite xs_xxs in prem
          }
          case node(y, ys') suppose ys_yys {
            /* Apply the induction hypothesis
             * to prove sorted(merge(n',xs',node(y,ys')))
             */
            have s_xs: sorted(xs')
              by enable sorted rewrite xs_xxs in prem
            have s_yys: sorted(node(y,ys'))
              by rewrite ys_yys in prem
            have len_xs_yys: length(xs') + length(node(y,ys')) = n'
              by enable {operator +,length}
                 have sxs: suc(length(xs')) + suc(length(ys')) = suc(n')
                    by rewrite xs_xxs | ys_yys in prem
                 injective suc sxs
            have IH_xs_yys: sorted(merge(n',xs',node(y,ys')))
              by apply IH[xs',node(y,ys')]
                 to s_xs, s_yys, len_xs_yys

            /* Apply the induction hypothesis
             * to prove sorted(merge(n',node(x,xs'),ys'))
             */
            have len_xxs_ys: length(node(x,xs')) + length(ys') = n'
              by definition {operator +,length}
                 rewrite symmetric len_xs_yys
                 definition length
                 rewrite add_suc[length(xs')][length(ys')].
            have s_xxs: sorted(node(x, xs'))
              by enable sorted rewrite xs_xxs in prem
            have s_ys: sorted(ys')
              by definition sorted in rewrite ys_yys in prem
            have IH_xxs_ys: sorted(merge(n',node(x,xs'),ys'))
              by apply IH[node(x,xs'),ys']
                 to s_xxs, s_ys, len_xxs_ys

            have x_le_xs: all_elements(xs', λb{x ≤ b})
              by definition sorted in rewrite xs_xxs in prem
            have y_le_ys: all_elements(ys', λb{y ≤ b})
              by definition sorted in rewrite ys_yys in prem
            
            switch x ≤ y {
              case true suppose xy_true {
                definition sorted
                suffices sorted(merge(n',xs',node(y,ys'))) and
                         all_elements(merge(n',xs',node(y,ys')), λb{x ≤ b})
                IH_xs_yys, 
                conclude all_elements(merge(n',xs',node(y,ys')),λb{x ≤ b})  by
                  rewrite all_elements_eq_member
                     [Nat,merge(n',xs',node(y,ys')),λb{x ≤ b}]
                  rewrite symmetric len_xs_yys
                  rewrite set_of_merge[xs',node(y,ys')]
                  arbitrary z:Nat
                  suppose z_in_xs_yys: z ∈ set_of(xs') ∪ set_of(node(y,ys'))
                  suffices x ≤ z
                  cases apply member_union
                               [Nat,z,set_of(xs'),set_of(node(y,ys'))]
                        to z_in_xs_yys
                  case z_in_xs: z ∈ set_of(xs') {
                    conclude x ≤ z by
                      apply all_elements_member[Nat][xs'][z, λb{x ≤ b}]
                      to x_le_xs, z_in_xs
                  }
                  case z_in_ys: z ∈ set_of(node(y,ys')) {
                    cases apply member_union[Nat,z,single(y),set_of(ys')]
                          to definition set_of in z_in_ys
                    case z_sy: z ∈ single(y) {
                      have y_z: y = z
                          by definition {operator ∈, single, rep} in z_sy
                      conclude x ≤ z by rewrite symmetric y_z | xy_true.
                    }
                    case z_in_ys: z ∈ set_of(ys') {
                      have y_z: y ≤ z
                        by apply all_elements_member[Nat][ys'][z,λb{y ≤ b}]
                           to y_le_ys, z_in_ys
                      have x_y: x ≤ y by rewrite xy_true.
                      conclude x ≤ z
                          by apply less_equal_trans[x][y,z] to x_y, y_z
                    }
                  }
              }
              case false suppose xy_false {
                have not_x_y: not (x ≤ y)
                  by suppose xs rewrite xy_false in xs
                have y_x: y ≤ x
                  by apply less_implies_less_equal[y][x] to
                     (apply not_less_equal_greater[x,y] to not_x_y)
                definition sorted
                suffices sorted(merge(n',node(x,xs'),ys')) and
                         all_elements(merge(n',node(x,xs'),ys'),λb{y ≤ b})
                IH_xxs_ys, 
                conclude all_elements(merge(n',node(x,xs'),ys'),λb{y ≤ b}) by
                  rewrite all_elements_eq_member
                     [Nat,merge(n',node(x,xs'),ys'),λb{y ≤ b}]
                  rewrite symmetric len_xxs_ys
                  rewrite set_of_merge[node(x,xs'),ys']
                  arbitrary z:Nat
                  suppose z_in_xxs_ys: z ∈ set_of(node(x,xs')) ∪ set_of(ys')
                  suffices y ≤ z
                  cases apply member_union
                               [Nat,z,set_of(node(x,xs')),set_of(ys')]
                        to z_in_xxs_ys
                  case z_in_xxs: z ∈ set_of(node(x,xs')) {
                    have z_in_sx_or_xs: z ∈ single(x) or z ∈ set_of(xs')
                      by apply member_union[Nat,z,single(x),set_of(xs')]
                         to definition set_of in z_in_xxs
                    cases z_in_sx_or_xs
                    case z_in_sx: z ∈ single(x) {
                      have x_z: x = z
                          by definition {operator ∈, single, rep} in z_in_sx
                      conclude y ≤ z  by rewrite symmetric x_z  y_x
                    }
                    case z_in_xs: z ∈ set_of(xs') {
                      have x_z: x ≤ z
                        by apply all_elements_member[Nat][xs'][z,λb{x ≤ b}]
                           to x_le_xs, z_in_xs
                      conclude y ≤ z 
                         by apply less_equal_trans[y][x,z] to y_x, x_z
                    }
                  }
                  case z_in_ys: z ∈ set_of(ys') {
                    conclude y ≤ z by
                      apply all_elements_member[Nat][ys'][z,λb{y ≤ b}]
                      to y_le_ys, z_in_ys
                  }
              }
            }
          }
        }
      }
    }
  }
end

Prove correctness of msort

First we show that the two lists produced by msort contain the same elements as the input list.

theorem mset_of_msort: all n:Nat. all xs:List<Nat>.
  mset_of(first(msort(n, xs)))  ⨄  mset_of(second(msort(n, xs))) = mset_of(xs)
proof
  induction Nat
  case 0 {
    arbitrary xs:List<Nat>
    definition msort
    switch xs {
      case empty {
        definition {first, second}
        suffices mset_of(empty) ⨄ mset_of(empty) = mset_of(empty)
        definition {mset_of}
        rewrite m_sum_empty[Nat,m_fun(λx{0})].
      }
      case node(x, xs') {
        definition {first, second, mset_of}
        suffices m_one(x) ⨄ mset_of(empty) ⨄ mset_of(xs')
               = m_one(x) ⨄ mset_of(xs')
        definition {mset_of}
        rewrite m_sum_empty[Nat,m_one(x)].
      }
    }
  }
  case suc(n') suppose IH {
    arbitrary xs:List<Nat>
    definition {msort, first, second}
    
    let ys = first(msort(n',xs))
    let ls = second(msort(n',xs))
    rewrite have first(msort(n',xs)) = ys  by definition ys.
    rewrite have second(msort(n',xs)) = ls  by definition ls.
    
    let zs = first(msort(n', ls))
    let ms = second(msort(n', ls))
    rewrite have first(msort(n', ls)) = zs by definition zs.
    rewrite have second(msort(n', ls)) = ms by definition ms.

    equations
          mset_of(merge(length(ys) + length(zs),ys,zs)) ⨄ mset_of(ms)
        = (mset_of(ys) ⨄ mset_of(zs)) ⨄ mset_of(ms)
          by rewrite (apply mset_of_merge[length(ys) + length(zs)][ys,zs] to .).
    ... = mset_of(ys) ⨄ (mset_of(zs) ⨄ mset_of(ms))
          by rewrite m_sum_assoc[Nat, mset_of(ys), mset_of(zs), mset_of(ms)].
    ... = mset_of(ys) ⨄ mset_of(ls)
          by rewrite have mset_of(zs) ⨄ mset_of(ms) = mset_of(ls)
                     by definition {zs, ms} IH[ls].
    ... = mset_of(xs)
          by definition {ys, ls} IH[xs]
  }
end

Next, we prove that the first output list is sorted. We make use of the merge_sorted theorem in this proof.

theorem msort_sorted: all n:Nat. all xs:List<Nat>. 
  sorted(first(msort(n, xs)))
proof
  induction Nat
  case 0 {
    arbitrary xs:List<Nat>
    switch xs {
      case empty {
        definition {msort, first}
        conclude sorted(empty)  by definition sorted.
      }
      case node(x, xs') {
        definition {msort, first}
        conclude sorted(node(x,empty))
            by definition {sorted, sorted, all_elements}.
      }
    }
  }
  case suc(n') suppose IH {
    arbitrary xs:List<Nat>
    let ys = first(msort(n',xs))
    let zs = first(msort(n',second(msort(n',xs))))
    have IH1: sorted(ys)  by definition ys IH[xs]
    have IH2: sorted(zs)  by definition zs IH[second(msort(n',xs))]
    definition {msort, first}
    definition {ys, zs} in
    apply merge_sorted[length(ys) + length(zs)][ys, zs] to IH1, IH2
  }
end

It remains to show that first output of msort is of length min(2ⁿ,length(xs)). Instead of using min, I separated the proof into a couple cases depending on whether 2ⁿ ≤ length(xs). However, I first needed to prove the lengths of the two output lists adds up to the length of the input list.

theorem msort_length: all n:Nat. all xs:List<Nat>.
  length(first(msort(n, xs)))  +  length(second(msort(n, xs))) = length(xs)

The proof of msort_length required a theorem that the length of the output of merge is the sum of the lengths of the inputs.

theorem merge_length: all n:Nat. all xs:List<Nat>, ys:List<Nat>.
  if length(xs) + length(ys) = n
  then length(merge(n, xs, ys)) = n

So in the case when the length of the input list is greater than 2ⁿ, the first output of msort is of length 2ⁿ.

theorem msort_length_less_equal: all n:Nat. all xs:List<Nat>.
  if pow2(n) ≤ length(xs)
  then length(first( msort(n, xs) )) = pow2(n)
proof
  induction Nat
  case 0 {
    arbitrary xs:List<Nat>
    suppose prem
    switch xs {
      case empty suppose xs_mt {
        conclude false
            by definition {pow2, length, operator≤} in
               rewrite xs_mt in prem
      }
      case node(x, xs') suppose xs_xxs {
        definition {msort,first}
        conclude length(node(x,empty)) = pow2(0)
            by definition {length, length, pow2}.
      }
    }
  }
  case suc(n') suppose IH {
    arbitrary xs:List<Nat>
    suppose prem
    have len_xs: pow2(n') + pow2(n') ≤ length(xs)
      by rewrite add_zero[pow2(n')] in
         definition {pow2, operator*, operator*,operator*} in prem
    definition {pow2, msort, first}

    let ys = first(msort(n',xs))
    let ls = second(msort(n',xs))
    have ys_def: first(msort(n',xs)) = ys  by definition ys.
    have ls_def: second(msort(n',xs)) = ls  by definition ls.
    rewrite ys_def | ls_def
    
    let zs = first(msort(n', ls))
    let ms = second(msort(n', ls))
    have zs_def: first(msort(n', ls)) = zs by definition zs.
    have ms_def: second(msort(n', ls)) = ms by definition ms.
    rewrite zs_def | ms_def

    have p2n_le_xs: pow2(n') ≤ length(xs)
      by have p2n_le_2p2n: pow2(n') ≤ pow2(n') + pow2(n')
           by less_equal_add[pow2(n')][pow2(n')]
         apply less_equal_trans[pow2(n')][pow2(n') + pow2(n'), length(xs)]
         to p2n_le_2p2n, len_xs

    have len_ys: length(ys) = pow2(n')
      by rewrite ys_def in apply IH[xs] to p2n_le_xs
      
    have len_ys_ls_eq_xs: length(ys) + length(ls) = length(xs)
      by rewrite ys_def | ls_def in msort_length[n'][xs]

    have p2n_le_ls: pow2(n') ≤ length(ls)
      by have pp_pl: pow2(n') + pow2(n') ≤ pow2(n') + length(ls)
           by rewrite symmetric len_ys_ls_eq_xs | len_ys in len_xs
         apply less_equal_left_cancel[pow2(n')][pow2(n'), length(ls)] to pp_pl
            
    have len_zs: length(zs) = pow2(n')
      by rewrite zs_def in apply IH[ls] to p2n_le_ls

    have len_ys_zs: length(ys) + length(zs) = 2 * pow2(n')
      by rewrite len_ys | len_zs
         definition {operator*,operator*,operator*}
         rewrite add_zero[pow2(n')].

    conclude length(merge(length(ys) + length(zs),ys,zs)) = 2 * pow2(n')
      by rewrite len_ys_zs
         apply merge_length[2 * pow2(n')][ys, zs] to len_ys_zs
  }
end

When the length of the input list is less than 2ⁿ, the length of the first output is the same as the length of the input.

theorem msort_length_less: all n:Nat. all xs:List<Nat>.
  if length(xs) < pow2(n)
  then length(first( msort(n, xs) )) = length(xs)
proof
  induction Nat
  case 0 {
    arbitrary xs:List<Nat>
    suppose prem
    switch xs {
      case empty suppose xs_mt {
        definition {msort, length, first}.
      }
      case node(x, xs') suppose xs_xxs {
        definition {msort,first, length, length}
        have xs_0: length(xs') = 0
          by definition {operator ≤, length, operator<, pow2} in 
             rewrite xs_xxs in prem
        rewrite xs_0.
      }
    }
  }
  case suc(n') suppose IH {
    arbitrary xs:List<Nat>
    suppose prem
    definition{msort, first}

    let ys = first(msort(n',xs))
    let ls = second(msort(n',xs))
    have ys_def: first(msort(n',xs)) = ys  by definition ys.
    have ls_def: second(msort(n',xs)) = ls  by definition ls.
    rewrite ys_def | ls_def
    
    let zs = first(msort(n', ls))
    let ms = second(msort(n', ls))
    have zs_def: first(msort(n', ls)) = zs by definition zs.
    have ms_def: second(msort(n', ls)) = ms by definition ms.
    rewrite zs_def | ms_def

    have xs_le_two_p2n: length(xs) < pow2(n') + pow2(n')
      by rewrite add_zero[pow2(n')] in
         definition {pow2, operator*,operator*,operator*} in prem

    have ys_ls_eq_xs: length(ys) + length(ls) = length(xs)
      by rewrite ys_def | ls_def in msort_length[n'][xs]

    have pn_xs_or_xs_pn: pow2(n') ≤ length(xs) or length(xs) < pow2(n')
      by dichotomy[pow2(n'), length(xs)]
    cases pn_xs_or_xs_pn
    case pn_xs: pow2(n') ≤ length(xs) {
    
      have ys_pn: length(ys) = pow2(n')
          by rewrite ys_def in apply msort_length_less_equal[n'][xs] to pn_xs

      have ls_l_pn: length(ls) < pow2(n')
          by have pn_ls_l_2pn: pow2(n') + length(ls) < pow2(n') + pow2(n')
               by rewrite symmetric ys_ls_eq_xs | ys_pn in xs_le_two_p2n
             apply less_left_cancel[pow2(n'), length(ls), pow2(n')] to pn_ls_l_2pn

      have len_zs: length(zs) = length(ls)
          by rewrite zs_def in apply IH[ls] to ls_l_pn

      equations
        length(merge(length(ys) + length(zs),ys,zs))
            = length(ys) + length(zs)
              by apply merge_length[length(ys) + length(zs)][ys,zs] to .
        ... = length(ys) + length(ls)
              by rewrite len_zs.
        ... = length(xs)
              by ys_ls_eq_xs
    }
    case xs_pn: length(xs) < pow2(n') {
    
      have len_ys: length(ys) = length(xs)
        by rewrite ys_def in apply IH[xs] to xs_pn

      have len_ls: length(ls) = 0
        by apply left_cancel[length(ys)][length(ls), 0] to
           rewrite add_zero[length(ys)] | len_ys
           rewrite len_ys in ys_ls_eq_xs

      have ls_l_pn: length(ls) < pow2(n')
        by rewrite len_ls  pow_positive[n'] 
      
      have len_zs: length(zs) = 0
        by rewrite zs_def | len_ls in apply IH[ls] to ls_l_pn

      equations
        length(merge(length(ys) + length(zs),ys,zs))
          = length(ys) + length(zs)
            by apply merge_length[length(ys) + length(zs)][ys, zs] to .
      ... = length(xs)
            by rewrite len_zs | add_zero[length(ys)] | len_ys.
    }
  }
end

Prove correctness of merge_sort

The proof that merge_sort produces a sorted list is a straightforward corollary of the msort_sorted theorem.

theorem merge_sort_sorted: all xs:List<Nat>.
  sorted(merge_sort(xs))
proof
  arbitrary xs:List<Nat>
  definition merge_sort
  msort_sorted[log(length(xs))][xs]
end

The proof that the contents of the output of merge_sort are the same as the input is a bit more involved. So if we use the definitoin of merge_sort, we then need to show that

mset_of(first(msort(log(length(xs)),xs))) = mset_of(xs)

which means we need to show that all the elements in xs end up in the first output and that there are not any leftovers. Let ys be the first output of msort and ls be the leftovers. The theorem less_equal_pow_log in Log.pf tells us that length(xs) ≤ pow2(log(length(xs))). So in the case where they are equal, we can use the msort_length_less_equal theorem to show that length(ys) = length(xs). In the case where length(xs) is strictly smaller, we use the msort_length_less theorem to prove that length(ys) = length(xs). Finally, we show that the length of ls is zero by use of msort_length and some properties of arithmetic like left_cancel (in Nat.pf).

Here is the proof of mset_of_merge_sort in full.

theorem mset_of_merge_sort: all xs:List<Nat>.
  mset_of(merge_sort(xs)) = mset_of(xs)
proof
  arbitrary xs:List<Nat>
  definition merge_sort
  let n = log(length(xs))
  have n_def: log(length(xs)) = n  by definition n.
  let ys = first(msort(n,xs))
  have ys_def: first(msort(n,xs)) = ys  by definition ys.
  let ls = second(msort(n,xs))
  have ls_def: second(msort(n,xs)) = ls  by definition ls.

  have len_xs: length(xs) ≤ pow2(n)
    by rewrite symmetric n_def
       less_equal_pow_log[length(xs)]
  have len_ys: length(ys) = length(xs)
    by cases apply less_equal_implies_less_or_equal[length(xs)][pow2(n)]
             to len_xs
       case len_xs_less {
         rewrite ys_def in apply msort_length_less[n][xs] to len_xs_less
       }
       case len_xs_equal {
         have pn_le_xs: pow2(n) ≤ length(xs)
           by rewrite len_xs_equal  less_equal_refl[pow2(n)]
         have len_ys_pow2: length(ys) = pow2(n)
           by rewrite symmetric ys_def
              apply msort_length_less_equal[n][xs] to pn_le_xs
         transitive len_ys_pow2 (symmetric len_xs_equal)
       }
  have len_ys_ls_eq_xs: length(ys) + length(ls) = length(xs)
    by rewrite ys_def | ls_def in msort_length[n][xs]
  have len_ls: length(ls) = 0
    by apply left_cancel[length(ys)][length(ls), 0] to
       rewrite add_zero[length(ys)] | len_ys
       rewrite len_ys in len_ys_ls_eq_xs
  have ls_mt: ls = empty
    by apply length_zero_empty[Nat, ls] to len_ls

  have ys_ls_eq_xs: mset_of(ys)  ⨄  mset_of(ls) = mset_of(xs)
    by rewrite ys_def | ls_def in mset_of_msort[n][xs]

  rewrite n_def
  rewrite ys_def
  equations
    mset_of(ys)
        = mset_of(ys)  ⨄  m_fun(λx{0})
          by rewrite m_sum_empty[Nat, mset_of(ys)].
    ... = mset_of(ys)  ⨄  mset_of(ls)
          by rewrite ls_mt definition mset_of.
    ... = mset_of(xs)
          by ys_ls_eq_xs
end

Exercise: merge_length and msort_length

Prove the following theorems.

theorem merge_length: all n:Nat. all xs:List<Nat>, ys:List<Nat>.
  if length(xs) + length(ys) = n
  then length(merge(n, xs, ys)) = n

theorem msort_length: all n:Nat. all xs:List<Nat>.
  length(first(msort(n, xs)))  +  length(second(msort(n, xs))) = length(xs)

Exercise: classic Merge Sort

Test and prove the correctness of the classic definition of merge_sort, which we repeat here.

function msort(Nat, List<Nat>) -> List<Nat> {
  msort(0, xs) = xs
  msort(suc(n'), xs) =
    let p = split(xs)
    merge(msort(n', first(p)), msort(n', second(p)))
}

define merge_sort : fn List<Nat> -> List<Nat>
  = λxs{ msort(log(length(xs)), xs) }

You will need define the split function.