Sunday, June 30, 2024

Merge Sort with Leftovers, Correctly

Merge Sort with Leftovers

This is the fourth blog post in a series about developing correct implementations of basic data structures and algorithms using the Deduce language and proof checker.

In this blog post we study a fast sorting algorithm, Merge Sort. This classic algorithm splits the input list in half, recursively sorts each half, and then merges the two results back into a single sorted list.

The specification of Merge Sort is the same as Insertion Sort.

Specification: The merge_sort(xs) function returns a list that contains the same elements as xs but the elements in the result are in sorted order.

We follow the write-test-prove approach to develop a correct implementation of merge_sort.

Write the merge_sort function

The classic implementation of merge_sort would be something like the following.

function merge_sort(List<Nat>) -> List<Nat> {
  merge_sort(empty) = empty
  merge_sort(node(x,xs')) =
    define p = split(node(x,xs'))
    merge(merge_sort(first(p)), merge_sort(second(p)))
}

Unfortunately, Deduce rejects the above function definition because Deduce uses a very simple restriction to ensure the termination of recursive function, which is that a recursive call may only be made on a part of the input. In this case, the recursive call may only be applied to the sublist xs', not first(p) or second(p).

How can we work around this restriction? There’s an old trick that goes by many names (gas, fuel, etc.), which is to add another parameter of type Nat and use that for termination. Let us use the name msort for the following, and then we define merge_sort in terms of msort.

function msort(Nat, List<Nat>) -> List<Nat> {
  msort(0, xs) = xs
  msort(suc(n'), xs) =
    define p = split(xs)
    merge(msort(n', first(p)), msort(n', second(p)))
}

define merge_sort : fn List<Nat> -> List<Nat>
  = λxs{ msort(log(length(xs)), xs) }

In the above definition of merge_sort, we need to suppply enough gas so that msort won’t prematurely run out. Here we use the logarithm (base 2, rounding up) defined in Log.pf.

This definition of merge_sort and msort is fine, it has O(n log(n)) time complexity, so it is efficient. However, the use of split rubs me the wrong way because it requires traversing half of the input list. The use of split is necessary if one wanted to use parallelism to speed up the code, performing the two recursive calls in parallel. However, we are currently only interested in a single-threaded implementation.

Suppose you just finished baking a pie and intend to eat half now and half tomorrow night. One approach would be to split it in half and then eat one of the halves. Another approach is to just start eating the pie and stop when half of it is gone. That’s the approach that we will take with the next version of msort.

Specification The msort(n,xs) function sorts the first min(2ⁿ,length(xs)) many elements of xs and returns a pair containing (1) the sorted list and (2) the leftovers that were not yet sorted.

function msort(Nat, List<Nat>) -> Pair< List<Nat>, List<Nat> > {
  msort(0, xs) =
    switch xs {
      case empty { pair(empty, empty) }
      case node(x, xs') { pair(node(x, empty), xs') }
    }
  msort(suc(n'), xs) =
    define p1 = msort(n', xs)
    define p2 = msort(n', second(p1))
    define ys = first(p1)
    define zs = first(p2)
    pair(merge(length(ys) + length(zs), ys, zs), second(p2))
}

In the above case for suc(n'), the first recursive call to msort produces the pair p1 that includes a sorted list and the leftovers. We sort the leftovers with the second recursive call to msort. We return (1) the merge of the two sorted sublists and (2) the leftovers from the second recursive call to msort.

With the code for msort complete, we can turn to merge_sort. Similar to the previous version, we involke msort with the input list xs and use the logarithm of list length for the gas. This msort returns a pair, with the sorted results in the first component. The second component of the pair is an empty list because we supplied enough gas.

define merge_sort : fn List<Nat> -> List<Nat>
    = λxs{ first(msort(log(length(xs)), xs)) }

So far, we have neglected the implementation of merge. Here’s its specification.

Specification: The merge(xs,ys) function takes two sorted lists and returns a sorted list that contains just the elements from the two input lists.

Here’s the classic implementation of merge. The idea is to compare the two elements at the front of each list and use the lower of the two as the first element of the output. Then do the recursive call with the two lists, minus the element that was chosen. Again, we use an extra gas parameter to ensure termination. To ensure that we have enough gas, we will choose the sum of the lengths of the two input lists.

function merge(Nat, List<Nat>, List<Nat>) -> List<Nat> {
  merge(0, xs, ys) = empty
  merge(suc(n), xs, ys) =
    switch xs {
      case empty { ys }
      case node(x, xs') {
        switch ys {
          case empty {
            node(x, xs')
          }
          case node(y, ys') {
            if x ≤ y then
              node(x, merge(n, xs', node(y, ys')))
            else
              node(y, merge(n, node(x, xs'), ys'))
          }
        }
     }
   }
}

Test

We have three functions to test, merge, msort and merge_sort.

Test merge

We test that the result of merge is sorted and that it contains all the elements from the two input lists, which we check using count.

define L_1337 = node(1, node(3, node(3, node(7, empty))))
define L_2348 = node(2, node(3, node(4, node(8, empty))))
define L_12333478 = merge(length(L_1337) + length(L_2348), L_1337, L_2348)
assert sorted(L_12333478)
assert all_elements(L_1337 ++ L_2348,
  λx{count(L_1337)(x) + count(L_2348)(x) = count(L_12333478)(x) })

Test msort

In the following tests, we vary the gas from 0 to 3, varying how much of the input list L18 gets sorted in the call to msort. The take(n,xs) function returns the first n elements of xs and drop(n,xs) drops the first n elements of xs and returns the remaining portion of xs.

define L18 = L_1337 ++ L_2348

define p0 = msort(0, L18)
define t0 = take(pow2(0), L18)
define d0 = drop(pow2(0), L18)
assert sorted(first(p0))
assert all_elements(t0, λx{count(t0)(x) = count(first(p0))(x) })
assert all_elements(d0, λx{count(d0)(x) = count(second(p0))(x) })

define p1 = msort(1, L18)
define t1 = take(pow2(1), L18)
define d1 = drop(pow2(1), L18)
assert sorted(first(p1))
assert all_elements(t1, λx{count(t1)(x) = count(first(p1))(x) })
assert all_elements(d1, λx{count(d1)(x) = count(second(p1))(x) })

define p2 = msort(2, L18)
define t2 = take(pow2(2), L18)
define d2 = drop(pow2(2), L18)
assert sorted(first(p2))
assert all_elements(t2, λx{count(t2)(x) = count(first(p2))(x) })
assert all_elements(d2, λx{count(d2)(x) = count(second(p2))(x) })

define p3 = msort(3, L18)
define t3 = take(pow2(3), L18)
define d3 = drop(pow2(3), L18)
assert sorted(first(p3))
assert all_elements(t3, λx{count(t3)(x) = count(first(p3))(x) })
assert all_elements(d3, λx{count(d3)(x) = count(second(p3))(x) })

Test merge_sort

Next we test that merge_sort returns a sorted list that contains the same elements as the input list. For input, we reuse the list L18 from above.

define s_L18 = merge_sort(L18)
assert sorted(s_L18)
assert all_elements(t0, λx{count(L18)(x) = count(s_L18)(x) })

We can bundle several tests, with varying-length inputs, into one assert by using all_elements and interval.

assert all_elements(interval(3, 0),
    λn{ define xs = reverse(interval(n, 0))
        define ls = merge_sort(xs)
        sorted(ls) and
        all_elements(xs, λx{count(xs)(x) = count(ls)(x)})
    })

Prove

Compared to the proof of correctness for insertion_sort, we have considerably more work to do for merge_sort. Instead of two functions, we have three functions to consider: merge, msort, and merge_sort. Furthermore, these functions are more complex than insert and insertion_sort. Nevertheless, we are up to the challenge!

Prove correctness of merge

The specificaiton of merge has two parts, one part saying that the elements of the output must be the elements of the two input lists, and the another part saying that the output must be sorted, provided the two input lists are sorted.

Here is how we state the theorem for the first part.

theorem mset_of_merge: all n:Nat. all xs:List<Nat>, ys:List<Nat>.
  if length(xs) + length(ys) = n
  then mset_of(merge(n, xs, ys)) = mset_of(xs) ⨄ mset_of(ys)

Here is the theorem stating that the output of merge is sort.

theorem merge_sorted: all n:Nat. all xs:List<Nat>, ys:List<Nat>.
  if sorted(xs) and sorted(ys) and
     length(xs) + length(ys) = n
  then sorted(merge(n, xs, ys))

Prove the mset_of_merge theorem

We begin with the proof of mset_of_merge. Because merge(n, xs, ys) is recursive on the natural number n, we proceed by induction on Nat.

  induction Nat
  case 0 {
    arbitrary xs:List<Nat>, ys:List<Nat>
    suppose prem: length(xs) + length(ys) = 0
    ?
  }
  case suc(n') suppose IH {
    ?
  }

In the case for n = 0, we need to prove

  mset_of(merge(0,xs,ys)) = mset_of(xs) ⨄ mset_of(ys)

and merge(0,xs,ys) returns empty, so we need to show that mset_of(xs) ⨄ mset_of(ys) is the empty multiset. From the premise prem, both xs and ys must be empty.

  // <<mset_of_merge_case_zero_xs_ys_empty>> =
  have lxs_lys_z: length(xs) = 0 and length(ys) = 0
    by apply add_to_zero[length(xs)][length(ys)] to prem
  have xs_mt: xs = empty
    by apply length_zero_empty[Nat,xs] to lxs_lys_z
  have ys_mt: ys = empty
    by apply length_zero_empty[Nat,ys] to lxs_lys_z

After rewriting with those equalities and applying the definition of merge and mset_of:

  suffices mset_of(merge(0, empty, empty)) = mset_of(empty) ⨄ mset_of(empty)
      with rewrite xs_mt | ys_mt
  suffices m_fun[Nat](λ{0}) = m_fun[Nat](λ{0}) ⨄ m_fun[Nat](λ{0})
      with definition {merge, mset_of}

it remains to prove m_fun(λ{0}) = m_fun(λ{0}) ⨄ m_fun(λ{0}) (the sum of two empty multisets is the empty multiset), which we prove with the theorem m_sum_empty from MultiSet.pf.

  // <<mset_of_merge_case_zero_conclusion>> =
  symmetric m_sum_empty[Nat, m_fun[Nat](λx{0})]

In the case for n = suc(n'), we need to prove

  mset_of(merge(suc(n'),xs,ys)) = mset_of(xs) ⨄ mset_of(ys)

Looking a the suc clause of merge, there is a switch on xs and then on ys. So our proof will be structured analogously.

  switch xs for merge {
    case empty {
      ?
    }
    case node(x, xs') suppose xs_xxs {
      ?
    }
  }

In the case for xs = empty, we conclude by the definition of mset_of and the fact that combining mset_of(ys) with the empty multiset produces mset_of(ys).

    // <<mset_of_merge_case_suc_empty>> =
    suffices mset_of(ys) = m_fun(λx{0}) ⨄ mset_of(ys)
        with definition {mset_of}
    symmetric empty_m_sum[Nat, mset_of(ys)]

In the case for xs = node(x, xs'), merge performs a switch on ys, so our proof does too.

  switch ys for merge {
    case empty {
      ?
    }
    case node(y, ys') suppose ys_yys {
      ?
    }

The case for ys = empty, is similar to the case for xs = empty. We conclude by use of the definitions of merge and mset_of and the fact that combining mset_of(ys) with the empty multiset produces mset_of(ys).

    // <<mset_of_merge_case_suc_node_empty>> =
    suffices m_one(x) ⨄ mset_of(xs')
           = (m_one(x) ⨄ mset_of(xs')) ⨄ m_fun(λ{0})
        with definition {mset_of}
    rewrite m_sum_empty[Nat, m_one(x) ⨄ mset_of(xs')]

In the case for ys = node(y, ys'), we continue to follow the structure of merge and switch on x ≤ y.

  switch x ≤ y {
    case true suppose xy_true {
      ?
    }
    case false suppose xy_false {
      ?
    }
  }

In the case for (x ≤ y) = true, the goal becomes

  mset_of(node(x, merge(n', xs', node(y, ys')))) 
= mset_of(node(x, xs')) ⨄ mset_of(node(y, ys'))

Which follows from the conclusion of the induction hypothesis (instantiated with xs' and node(y,ys')):

  mset_of(merge(n',xs',node(y,ys')))
= mset_of(xs') ⨄ mset_of(node(y, ys'))

The induction hypothesis is a conditional, so we first must prove its premise as follows.

    // <<mset_of_merge_x_le_y_IH_prem>> =
    have IH_prem: length(xs') + length(node(y,ys')) = n'
      by enable {operator +, operator +,length}
         have suc_len: suc(length(xs')) + suc(length(ys')) = suc(n')
                by rewrite xs_xxs | ys_yys in prem
         injective suc suc_len

We conclude this case with the following equational reasoning, using the induction hypothesis in the second step.

    // <<mset_of_merge_x_le_y_equations>> =
    equations
          mset_of(node(x, merge(n', xs', node(y, ys')))) 
        = m_one(x) ⨄ mset_of(merge(n',xs',node(y,ys')))
            by definition mset_of
    ... = m_one(x) ⨄ (mset_of(xs') ⨄ mset_of(node(y, ys')))
            by rewrite (apply IH[xs', node(y, ys')] to IH_prem)
    ... = m_one(x) ⨄ (mset_of(xs') ⨄ (m_one(y) ⨄ mset_of(ys')))
            by definition mset_of
    ... = (m_one(x) ⨄ mset_of(xs')) ⨄ (m_one(y) ⨄ mset_of(ys'))
            by rewrite m_sum_assoc[Nat, m_one(x), mset_of(xs'),
                                  (m_one(y) ⨄ mset_of(ys'))]
    ... = mset_of(node(x, xs')) ⨄ mset_of(node(y, ys'))
            by definition mset_of

In the case for (x ≤ y) = false, the goal becomes

  mset_of(node(y, merge(n', node(x, xs'), ys'))) 
= mset_of(node(x, xs')) ⨄ mset_of(node(y, ys'))

The conclusion of the induction hypothesis (instantiated with node(x,xs') and ys') is

  mset_of(merge(n',node(x,xs'),ys'))
= mset_of(node(x,xs')) ⨄ mset_of(ys')

so the goal will follow from the fact that multiset sum is associative and commutative.

We first prove the premise of the induction hypothesis.

    have IH_prem: length(node(x,xs')) + length(ys') = n'
      by enable {operator +, operator +, length}
         have suc_len: suc(length(xs')) + suc(length(ys')) = suc(n')
              by rewrite xs_xxs | ys_yys in prem
         injective suc
         rewrite add_suc[length(xs')][length(ys')] in suc_len

Then we proceed with applying the induction hypothesis in the second step, followed by the equational reasoning about multiset sum.

    equations
            mset_of(node(y, merge(n', node(x, xs'), ys')))
          = m_one(y) ⨄ mset_of(merge(n',node(x,xs'),ys'))
              by definition mset_of
      ... = m_one(y) ⨄ mset_of(node(x,xs')) ⨄ mset_of(ys')
              by rewrite (apply IH[node(x,xs'), ys'] to IH_prem)
      ... = m_one(y) ⨄ ((m_one(x) ⨄ mset_of(xs')) ⨄ mset_of(ys'))
              by definition mset_of
      ... = ((m_one(x) ⨄ mset_of(xs')) ⨄ mset_of(ys')) ⨄ m_one(y)
              by m_sum_commutes[Nat, m_one(y), (m_one(x) ⨄ mset_of(xs')) ⨄ mset_of(ys')]
      ... = (m_one(x) ⨄ mset_of(xs')) ⨄ (mset_of(ys') ⨄ m_one(y))
              by m_sum_assoc[Nat, m_one(x) ⨄ mset_of(xs'), mset_of(ys'), m_one(y)]
      ... = (m_one(x) ⨄ mset_of(xs')) ⨄ (m_one(y) ⨄ mset_of(ys'))
              by rewrite m_sum_commutes[Nat, mset_of(ys'), m_one(y)]
      ... = mset_of(node(x, xs')) ⨄ mset_of(node(y, ys'))
              by definition mset_of

Here is the completed proof of mset_of_merge.

theorem mset_of_merge: all n:Nat. all xs:List<Nat>, ys:List<Nat>.
  if length(xs) + length(ys) = n
  then mset_of(merge(n, xs, ys)) = mset_of(xs) ⨄ mset_of(ys)
proof
  induction Nat
  case 0 {
    arbitrary xs:List<Nat>, ys:List<Nat>
    suppose prem: length(xs) + length(ys) = 0
    <<mset_of_merge_case_zero_xs_ys_empty>>
    <<mset_of_merge_case_zero_suffices>>
    <<mset_of_merge_case_zero_conclusion>>
  }
  case suc(n') suppose IH {
    arbitrary xs:List<Nat>, ys:List<Nat>
    suppose prem: length(xs) + length(ys) = suc(n')
    switch xs for merge {
      case empty {
        <<mset_of_merge_case_suc_empty>>
      }
      case node(x, xs') suppose xs_xxs {
        switch ys for merge {
          case empty {
            <<mset_of_merge_case_suc_node_empty>>
          }
          case node(y, ys') suppose ys_yys {
            switch x ≤ y {
              case true suppose xy_true {
                <<mset_of_merge_x_le_y_IH_prem>>
                <<mset_of_merge_x_le_y_equations>>
              }
              case false suppose xy_false {
                <<mset_of_merge_x_g_y_IH_prem>>
                <<mset_of_merge_x_g_y_equations>>
              }
            }
          }
        }
      }
    }
  }
end

The mset_of_merge theorem also holds for sets, using the set_of function. We prove the following set_of_merge theorem as a corollary of mset_of_merge.

theorem set_of_merge: all xs:List<Nat>, ys:List<Nat>.
  set_of(merge(length(xs) + length(ys), xs, ys)) = set_of(xs) ∪ set_of(ys)
proof
  arbitrary xs:List<Nat>, ys:List<Nat>
  equations
    set_of(merge(length(xs) + length(ys), xs, ys))
        = set_of_mset(mset_of(merge(length(xs) + length(ys), xs, ys)))
            by symmetric som_mset_eq_set[Nat]
                             [merge(length(xs) + length(ys), xs, ys)]
    ... = set_of_mset(mset_of(xs) ⨄ mset_of(ys))
            by rewrite mset_of_merge[length(xs) + length(ys)][xs, ys]
    ... = set_of_mset(mset_of(xs)) ∪ set_of_mset(mset_of(ys))
            by som_union[Nat, mset_of(xs), mset_of(ys)]
    ... = set_of(xs) ∪ set_of(ys)
            by rewrite som_mset_eq_set[Nat][xs] | som_mset_eq_set[Nat][ys]
end

Prove the merge_sorted theorem

Next up is the merge_sorted theorem.

theorem merge_sorted: all n:Nat. all xs:List<Nat>, ys:List<Nat>.
  if sorted(xs) and sorted(ys) and length(xs) + length(ys) = n
  then sorted(merge(n, xs, ys))

The structure of the proof will be similar to the one for mset_of_merge, because they both follow the structure of merge. So begin with induction on Nat.

In the case for n = 0, we need to prove sorted(merge(0, xs, ys)). But merge(0, xs, ys) = empty, and sorted(empty) is trivially true. So we conclude the case for n = 0 as follows.

    // <<merge_sorted_case_zero>> =
    arbitrary xs:List<Nat>, ys:List<Nat>
    suppose _
    suffices sorted(empty)  with definition merge
    definition sorted

We move on to the case for n = suc(n') and xs = empty. Here merge returns ys, and we already know that ys is sorted from the premise.

    // <<merge_sorted_case_suc_empty>> =
    conclude sorted(ys) by prem

In the case for xs = node(x, xs') and ys = empty, the merge function returns node(x, xs') (aka. xs), and we already know that xs is sorted from the premise.

    // <<merge_sorted_case_suc_node_empty>> =
    conclude sorted(node(x,xs'))  by rewrite xs_xxs in prem

In the case for ys = node(y, ys') and (x ≤ y) = true, the merge function returns node(x, merge(n',xs',node(y,ys'))). So we need to prove the following.

    suffices sorted(merge(n',xs',node(y,ys'))) and
             all_elements(merge(n',xs',node(y,ys')), λb{x ≤ b}) 
                 with definition sorted

To prove the first, we invoke the induction hypothesis intantiated to xs' and node(y,ys') as follows.

    // <<merge_sorted_IH_xs_yys>> =
    have s_xs: sorted(xs')
      by definition sorted in rewrite xs_xxs in prem
    have s_yys: sorted(node(y,ys'))
      by rewrite ys_yys in prem
    have len_xs_yys: length(xs') + length(node(y,ys')) = n'
      by enable {operator +, operator +, length}
         have sxs: suc(length(xs')) + suc(length(ys')) = suc(n')
            by rewrite xs_xxs | ys_yys in prem
         injective suc sxs
    have IH_xs_yys: sorted(merge(n',xs',node(y,ys')))
      by apply IH[xs',node(y,ys')] to s_xs, s_yys, len_xs_yys

It remains to prove that x is less-or-equal to to all the elements in the rest of the output list:

  all_elements(merge(n',xs',node(y,ys')),λb{x ≤ b})

The theorem all_elements_eq_member in List.pf says

  all_elements(xs,P) = (all x:T. if x ∈ set_of(xs) then P(x))

which combined with the set_of_merge corollary above, simplifies our goal as follows

    // <<x_le_merge_suffices>> =
    suffices (all z:Nat. (if z ∈ set_of(xs') ∪ set_of(node(y, ys')) then x ≤ z))
        with rewrite all_elements_eq_member[Nat, merge(n', xs', node(y,ys')),
                                            λb{x ≤ b}]
                   | symmetric len_xs_yys | set_of_merge[xs',node(y,ys')]
    arbitrary z:Nat
    suppose z_in_xs_yys: z ∈ set_of(xs') ∪ set_of(node(y,ys'))

So we have a few cases to consider and need to prove x ≤ z in each one. Consider the case where z ∈ set_of(xs'). Because node(x, xs') is sorted, we know x is less-or-equal every element of xs':

  // <<x_le_xs>> =
  have x_le_xs: all_elements(xs', λb{x ≤ b})
    by definition sorted in rewrite xs_xxs in prem

so x is less-or-equal to z, being one of the elements in xs'.

  conclude x ≤ z by
    apply all_elements_member[Nat][xs'][z, λb{x ≤ b}]
    to x_le_xs, z_in_xs

Next, consider the case where z ∈ single(y) and therefore y = z. Then we can immediately conclude because x ≤ y.

    have y_z: y = z   by definition {operator ∈, single, rep} in z_sy
    conclude x ≤ z    by rewrite symmetric y_z | xy_true

Finally, consider when z ∈ set_of(ys'). Because node(y,ys') is sorted, we know that y is less-or-equal all elements of ys'.

    have y_le_ys: all_elements(ys', λb{y ≤ b})
      by definition sorted in rewrite ys_yys in prem

Therefore we have y ≤ z. Combined with x ≤ y, we conclude that x ≤ z by transitivity.

    // <<merge_sorted_z_in_ys>> =
    have y_z: y ≤ z
      by apply all_elements_member[Nat][ys'][z,λb{y ≤ b}]
         to y_le_ys, z_in_ys
    have x_y: x ≤ y by rewrite xy_true
    conclude x ≤ z
        by apply less_equal_trans[x][y,z] to x_y, y_z

The last case to consider is for ys = node(y, ys') and (x ≤ y) = false. The reasoning is similar to the case for (x ≤ y) = true, so we skip the detailed explanation and refer the reader to the below proof.

Here’s the completed proof of merge_sorted.

theorem merge_sorted: all n:Nat. all xs:List<Nat>, ys:List<Nat>.
  if sorted(xs) and sorted(ys) and length(xs) + length(ys) = n
  then sorted(merge(n, xs, ys))
proof
  induction Nat
  case 0 {
    <<merge_sorted_case_zero>>
  }
  case suc(n') suppose IH {
    arbitrary xs:List<Nat>, ys:List<Nat>
    suppose prem
    switch xs for merge {
      case empty {
        <<merge_sorted_case_suc_empty>>
      }
      case node(x, xs') suppose xs_xxs {
        switch ys for merge {
          case empty {
            <<merge_sorted_case_suc_node_empty>>
          }
          case node(y, ys') suppose ys_yys {
            <<merge_sorted_IH_xs_yys>>
            <<merge_sorted_x_le_xs>>
            <<merge_sorted_y_le_ys>>
            switch x ≤ y {
              case true suppose xy_true {
                <<merge_sorted_less_equal_suffices>>
                have x_le_merge: all_elements(merge(n',xs',node(y,ys')),λb{x ≤ b}) by
                    <<x_le_merge_suffices>>
                    suffices x ≤ z  by .
                    cases apply member_union[Nat] to z_in_xs_yys
                    case z_in_xs: z ∈ set_of(xs') {
                      <<merge_sorted_z_in_xs>>
                    }
                    case z_in_ys: z ∈ set_of(node(y,ys')) {
                      cases apply member_union[Nat] to definition set_of in z_in_ys
                      case z_sy: z ∈ single(y) {
                        <<merge_sorted_z_in_y>>
                      }
                      case z_in_ys: z ∈ set_of(ys') {
                        <<merge_sorted_z_in_ys>>
                      }
                    }
                IH_xs_yys, x_le_merge
              }
              case false suppose xy_false {
              
                /* Apply the induction hypothesis
                 * to prove sorted(merge(n',node(x,xs'),ys'))
                 */
                have len_xxs_ys: length(node(x,xs')) + length(ys') = n'
                  by enable {operator +, operator +, length}
                     have suc_len: suc(length(xs') + suc(length(ys'))) = suc(n')
                       by rewrite xs_xxs | ys_yys in prem
                     injective suc
                     rewrite add_suc[length(xs')][length(ys')] in suc_len
                have s_xxs: sorted(node(x, xs'))
                  by enable sorted rewrite xs_xxs in prem
                have s_ys: sorted(ys')
                  by definition sorted in rewrite ys_yys in prem
                have IH_xxs_ys: sorted(merge(n',node(x,xs'),ys'))
                  by apply IH[node(x,xs'),ys'] to s_xxs, s_ys, len_xxs_ys

                have not_x_y: not (x ≤ y)
                  by suppose xs rewrite xy_false in xs
                have y_x: y ≤ x
                  by apply less_implies_less_equal[y][x] to
                     (apply not_less_equal_greater[x,y] to not_x_y)
                suffices sorted(merge(n',node(x,xs'),ys')) and
                         all_elements(merge(n',node(x,xs'),ys'),λb{y ≤ b}) 
                    with definition sorted
                have y_le_merge: all_elements(merge(n',node(x,xs'),ys'),λb{y ≤ b}) by
                    suffices (all z:Nat. (if z ∈ set_of(node(x, xs')) ∪ set_of(ys') then y ≤ z))
                        with rewrite all_elements_eq_member[Nat,merge(n',node(x,xs'),ys'),λb{y ≤ b}]
                                   | symmetric len_xxs_ys | set_of_merge[node(x,xs'),ys']
                    arbitrary z:Nat
                    suppose z_in_xxs_ys: z ∈ set_of(node(x,xs')) ∪ set_of(ys')
                    suffices y ≤ z  by.
                    cases apply member_union[Nat] to z_in_xxs_ys
                    case z_in_xxs: z ∈ set_of(node(x,xs')) {
                      have z_in_sx_or_xs: z ∈ single(x) or z ∈ set_of(xs')
                        by apply member_union[Nat] to definition set_of in z_in_xxs
                      cases z_in_sx_or_xs
                      case z_in_sx: z ∈ single(x) {
                        have x_z: x = z  by definition {operator ∈, single, rep} in z_in_sx
                        conclude y ≤ z  by rewrite x_z in y_x
                      }
                      case z_in_xs: z ∈ set_of(xs') {
                        have x_z: x ≤ z
                          by apply all_elements_member[Nat][xs'][z,λb{x ≤ b}]
                             to x_le_xs, z_in_xs
                        conclude y ≤ z 
                           by apply less_equal_trans[y][x,z] to y_x, x_z
                      }
                    }
                    case z_in_ys: z ∈ set_of(ys') {
                      conclude y ≤ z by
                        apply all_elements_member[Nat][ys'][z,λb{y ≤ b}]
                        to y_le_ys, z_in_ys
                    }
                IH_xxs_ys, y_le_merge
              }
            }
          }
        }
      }
    }
  }
end

Prove correctness of msort

First we show that the two lists produced by msort contain the same elements as the input list.

theorem mset_of_msort: all n:Nat. all xs:List<Nat>.
  mset_of(first(msort(n, xs)))  ⨄  mset_of(second(msort(n, xs))) = mset_of(xs)
proof
  induction Nat
  case 0 {
    arbitrary xs:List<Nat>
    switch xs for msort {
      case empty {
        suffices m_fun[Nat](λ{0}) ⨄ m_fun[Nat](λ{0}) = m_fun[Nat](λ{0})
            with definition {first, second, mset_of}
        rewrite m_sum_empty[Nat,m_fun(λx{0})]
      }
      case node(x, xs') {
        suffices (m_one(x) ⨄ m_fun[Nat](λ{0})) ⨄ mset_of(xs')
               = m_one(x) ⨄ mset_of(xs')
            with definition {first, second, mset_of, mset_of}
        rewrite m_sum_empty[Nat,m_one(x)]
      }
    }
  }
  case suc(n') suppose IH {
    arbitrary xs:List<Nat>
    suffices mset_of(merge(length(first(msort(n', xs))) 
                           + length(first(msort(n', second(msort(n', xs))))),
                           first(msort(n', xs)),
                           first(msort(n', second(msort(n', xs)))))) 
             ⨄ mset_of(second(msort(n', second(msort(n', xs))))) 
             = mset_of(xs)
        with definition {msort, first, second}
    define ys = first(msort(n',xs))
    define ls = second(msort(n',xs))
    define zs = first(msort(n', ls))
    define ms = second(msort(n', ls))
    equations
          mset_of(merge(length(ys) + length(zs),ys,zs)) ⨄ mset_of(ms)
        = (mset_of(ys) ⨄ mset_of(zs)) ⨄ mset_of(ms)
          by rewrite (mset_of_merge[length(ys) + length(zs)][ys,zs])
    ... = mset_of(ys) ⨄ (mset_of(zs) ⨄ mset_of(ms))
          by rewrite m_sum_assoc[Nat, mset_of(ys), mset_of(zs), mset_of(ms)]
    ... = mset_of(ys) ⨄ mset_of(ls)
          by rewrite conclude mset_of(zs) ⨄ mset_of(ms) = mset_of(ls)
                     by enable {zs, ms} IH[ls]
    ... = mset_of(xs)
          by enable {ys, ls} IH[xs]
  }
end

Next, we prove that the first output list is sorted. We make use of the merge_sorted theorem in this proof.

theorem msort_sorted: all n:Nat. all xs:List<Nat>. 
  sorted(first(msort(n, xs)))
proof
  induction Nat
  case 0 {
    arbitrary xs:List<Nat>
    switch xs {
      case empty {
        suffices sorted(empty)  with definition {msort, first}
        definition sorted
      }
      case node(x, xs') {
        suffices sorted(node(x,empty))
            with definition {msort, first}
        definition {sorted, sorted, all_elements}
      }
    }
  }
  case suc(n') suppose IH {
    arbitrary xs:List<Nat>
    suffices sorted(merge(length(first(msort(n', xs))) 
                          + length(first(msort(n', second(msort(n', xs))))), 
                          first(msort(n', xs)), 
                          first(msort(n', second(msort(n', xs))))))
        with definition {msort, first}
    define ys = first(msort(n', xs))
    define ls = second(msort(n', xs))
    define zs = first(msort(n', ls))
    have IH1: sorted(ys)  by enable {ys}  IH[xs]
    have IH2: sorted(zs)  by enable {zs}  IH[ls]
    conclude sorted(merge(length(ys) + length(zs), ys, zs))
      by apply merge_sorted[length(ys) + length(zs)][ys, zs] to IH1, IH2
  }
end

It remains to show that first output of msort is of length min(2ⁿ,length(xs)). Instead of using min, I separated the proof into a couple cases depending on whether 2ⁿ ≤ length(xs). However, I first needed to prove the lengths of the two output lists adds up to the length of the input list.

theorem msort_length: all n:Nat. all xs:List<Nat>.
  length(first(msort(n, xs)))  +  length(second(msort(n, xs))) = length(xs)

The proof of msort_length required a theorem that the length of the output of merge is the sum of the lengths of the inputs.

theorem merge_length: all n:Nat. all xs:List<Nat>, ys:List<Nat>.
  if length(xs) + length(ys) = n
  then length(merge(n, xs, ys)) = n

So in the case when the length of the input list is greater than 2ⁿ, the first output of msort is of length 2ⁿ.

theorem msort_length_less_equal: all n:Nat. all xs:List<Nat>.
  if pow2(n) ≤ length(xs)
  then length(first( msort(n, xs) )) = pow2(n)
proof
  induction Nat
  case 0 {
    arbitrary xs:List<Nat>
    suppose prem
    switch xs {
      case empty suppose xs_mt {
        conclude false
            by definition {pow2, length, operator≤} in
               rewrite xs_mt in prem
      }
      case node(x, xs') suppose xs_xxs {
        suffices length(node(x,empty)) = pow2(0)
            with definition {msort,first}
        definition {length, length, pow2, operator+, operator+}
      }
    }
  }
  case suc(n') suppose IH {
    arbitrary xs:List<Nat>
    suppose prem
    have len_xs: pow2(n') + pow2(n') ≤ length(xs)
      by rewrite add_zero[pow2(n')] in
         definition {pow2, operator*, operator*,operator*} in prem
    suffices length(merge(length(first(msort(n', xs))) 
                            + length(first(msort(n', second(msort(n', xs))))), 
                          first(msort(n', xs)), 
                          first(msort(n', second(msort(n', xs))))))
             = 2 * pow2(n')
        with definition {pow2, msort, first}
    define ys = first(msort(n',xs))
    define ls = second(msort(n',xs))
    define zs = first(msort(n', ls))
    define ms = second(msort(n', ls))
    have len_ys: length(ys) = pow2(n') by {
         have p2n_le_xs: pow2(n') ≤ length(xs) by
             have p2n_le_2p2n: pow2(n') ≤ pow2(n') + pow2(n') by
                 less_equal_add[pow2(n')][pow2(n')]
             apply less_equal_trans[pow2(n')][pow2(n') + pow2(n'), length(xs)]
             to p2n_le_2p2n, len_xs
         enable {ys} 
         apply IH[xs] to p2n_le_xs
    }
    have len_zs: length(zs) = pow2(n') by {
         have len_ys_ls_eq_xs: length(ys) + length(ls) = length(xs)
           by enable {ys, ls} msort_length[n'][xs]
         have p2n_le_ls: pow2(n') ≤ length(ls)
           by have pp_pl: pow2(n') + pow2(n') ≤ pow2(n') + length(ls)
                by rewrite symmetric len_ys_ls_eq_xs | len_ys in len_xs
              apply less_equal_left_cancel[pow2(n')][pow2(n'), length(ls)] to pp_pl
         enable {zs} 
         apply IH[ls] to p2n_le_ls
    }
    have len_ys_zs: length(ys) + length(zs) = 2 * pow2(n') by {
      equations
        length(ys) + length(zs) 
          = pow2(n') + pow2(n')    by rewrite len_ys | len_zs
      ... = 2 * pow2(n')           by symmetric two_mult[pow2(n')]
    }
    equations
          length(merge(length(ys) + length(zs), ys, zs))
        = length(merge(2 * pow2(n'), ys, zs))   by rewrite len_ys_zs
    ... = 2 * pow2(n')                          by apply merge_length[2 * pow2(n')][ys, zs] to len_ys_zs
  }
end

When the length of the input list is less than 2ⁿ, the length of the first output is the same as the length of the input.

theorem msort_length_less: all n:Nat. all xs:List<Nat>.
  if length(xs) < pow2(n)
  then length(first( msort(n, xs) )) = length(xs)
proof
  induction Nat
  case 0 {
    arbitrary xs:List<Nat>
    suppose prem
    switch xs {
      case empty suppose xs_mt {
        definition {msort, length, first}
      }
      case node(x, xs') suppose xs_xxs {
        suffices 1 + 0 = 1 + length(xs')
            with definition {msort, first, length, length}
        have xs_0: length(xs') = 0
            by definition {operator ≤, length, operator+, operator+, operator<, 
                           pow2, operator ≤, operator ≤} in 
               rewrite xs_xxs in prem
        rewrite xs_0
      }
    }
  }
  case suc(n') suppose IH {
    arbitrary xs:List<Nat>
    suppose prem
    suffices length(merge(length(first(msort(n', xs))) 
                          + length(first(msort(n', second(msort(n', xs))))), 
                          first(msort(n', xs)), 
                          first(msort(n', second(msort(n', xs))))))
             = length(xs)
        with definition{msort, first}
    define ys = first(msort(n',xs))
    define ls = second(msort(n',xs))
    define zs = first(msort(n', ls))
    define ms = second(msort(n', ls))

    have xs_le_two_p2n: length(xs) < pow2(n') + pow2(n')
      by rewrite add_zero[pow2(n')] in
         definition {pow2, operator*,operator*,operator*} in prem

    have ys_ls_eq_xs: length(ys) + length(ls) = length(xs)
      by enable {ys,ls} msort_length[n'][xs]

    have pn_xs_or_xs_pn: pow2(n') ≤ length(xs) or length(xs) < pow2(n')
      by dichotomy[pow2(n'), length(xs)]
    cases pn_xs_or_xs_pn
    case pn_xs: pow2(n') ≤ length(xs) {
    
      have ys_pn: length(ys) = pow2(n')
          by enable {ys} apply msort_length_less_equal[n'][xs] to pn_xs

      have ls_l_pn: length(ls) < pow2(n')
          by have pn_ls_l_2pn: pow2(n') + length(ls) < pow2(n') + pow2(n')
               by rewrite symmetric ys_ls_eq_xs | ys_pn in xs_le_two_p2n
             apply less_left_cancel[pow2(n'), length(ls), pow2(n')] to pn_ls_l_2pn

      have len_zs: length(zs) = length(ls)
          by enable {zs} apply IH[ls] to ls_l_pn

      equations
        length(merge(length(ys) + length(zs),ys,zs))
            = length(ys) + length(zs)
              by merge_length[length(ys) + length(zs)][ys,zs]
        ... = length(ys) + length(ls)
              by rewrite len_zs
        ... = length(xs)
              by ys_ls_eq_xs
    }
    case xs_pn: length(xs) < pow2(n') {
    
      have len_ys: length(ys) = length(xs)
        by enable {ys} apply IH[xs] to xs_pn

      have len_ls: length(ls) = 0
        by apply left_cancel[length(ys)][length(ls), 0] to
           suffices length(ys) + length(ls) = length(ys)
               by rewrite add_zero[length(ys)]
           rewrite symmetric len_ys in ys_ls_eq_xs

      have ls_l_pn: length(ls) < pow2(n')
        by rewrite symmetric len_ls in pow_positive[n'] 
      
      have len_zs: length(zs) = 0
        by enable {zs} rewrite len_ls in apply IH[ls] to ls_l_pn

      equations
        length(merge(length(ys) + length(zs),ys,zs))
          = length(ys) + length(zs)
            by merge_length[length(ys) + length(zs)][ys, zs]
      ... = length(xs)
            by rewrite len_zs | add_zero[length(ys)] | len_ys
    }
  }
end

Prove correctness of merge_sort

The proof that merge_sort produces a sorted list is a straightforward corollary of the msort_sorted theorem.

theorem merge_sort_sorted: all xs:List<Nat>.
  sorted(merge_sort(xs))
proof
  arbitrary xs:List<Nat>
  suffices sorted(first(msort(log(length(xs)), xs)))
      with definition merge_sort
  msort_sorted[log(length(xs))][xs]
end

The proof that the contents of the output of merge_sort are the same as the input is a bit more involved. So if we use the definitoin of merge_sort, we then need to show that

mset_of(first(msort(log(length(xs)),xs))) = mset_of(xs)

which means we need to show that all the elements in xs end up in the first output and that there are not any leftovers. Let ys be the first output of msort and ls be the leftovers. The theorem less_equal_pow_log in Log.pf tells us that length(xs) ≤ pow2(log(length(xs))). So in the case where they are equal, we can use the msort_length_less_equal theorem to show that length(ys) = length(xs). In the case where length(xs) is strictly smaller, we use the msort_length_less theorem to prove that length(ys) = length(xs). Finally, we show that the length of ls is zero by use of msort_length and some properties of arithmetic like left_cancel (in Nat.pf).

Here is the proof of mset_of_merge_sort in full.

theorem mset_of_merge_sort: all xs:List<Nat>.
  mset_of(merge_sort(xs)) = mset_of(xs)
proof
  arbitrary xs:List<Nat>
  suffices mset_of(first(msort(log(length(xs)), xs))) = mset_of(xs)
      with definition merge_sort
  define n = log(length(xs))
  define ys = first(msort(n,xs))
  define ls = second(msort(n,xs))

  have len_xs: length(xs) ≤ pow2(n)
    by enable {n} less_equal_pow_log[length(xs)]
  have len_ys: length(ys) = length(xs)
    by cases apply less_equal_implies_less_or_equal[length(xs)][pow2(n)]
             to len_xs
       case len_xs_less: length(xs) < pow2(n) {
         enable {ys} apply msort_length_less[n][xs] to len_xs_less
       }
       case len_xs_equal: length(xs) = pow2(n) {
         have pn_le_xs: pow2(n) ≤ length(xs)
           by apply equal_implies_less_equal to symmetric len_xs_equal
         have len_ys_pow2: length(ys) = pow2(n)
           by enable {ys} apply msort_length_less_equal[n][xs] to pn_le_xs
         transitive len_ys_pow2 (symmetric len_xs_equal)
       }
  have len_ys_ls_eq_xs: length(ys) + length(ls) = length(xs)
    by enable {ys, ls} msort_length[n][xs]
  have len_ls: length(ls) = 0
    by apply left_cancel[length(ys)][length(ls), 0] to
       suffices length(ys) + length(ls) = length(ys)
           with rewrite add_zero[length(ys)]
       rewrite symmetric len_ys in len_ys_ls_eq_xs
  have ls_mt: ls = empty
    by apply length_zero_empty[Nat, ls] to len_ls

  have ys_ls_eq_xs: mset_of(ys)  ⨄  mset_of(ls) = mset_of(xs)
    by enable {ys, ls} mset_of_msort[n][xs]

  equations
          mset_of(ys)
        = mset_of(ys)  ⨄  m_fun(λx{0})   by rewrite m_sum_empty[Nat,mset_of(ys)]
    ... = mset_of(ys)  ⨄  mset_of(empty) by definition mset_of
    ... = mset_of(ys)  ⨄  mset_of(ls)    by rewrite ls_mt 
    ... = mset_of(xs)                    by ys_ls_eq_xs
end

Exercise: merge_length and msort_length

Prove the following theorems.

theorem merge_length: all n:Nat. all xs:List<Nat>, ys:List<Nat>.
  if length(xs) + length(ys) = n
  then length(merge(n, xs, ys)) = n

theorem msort_length: all n:Nat. all xs:List<Nat>.
  length(first(msort(n, xs)))  +  length(second(msort(n, xs))) = length(xs)

Exercise: classic Merge Sort

Test and prove the correctness of the classic definition of merge_sort, which we repeat here.

function msort(Nat, List<Nat>) -> List<Nat> {
  msort(0, xs) = xs
  msort(suc(n'), xs) =
    define p = split(xs)
    merge(msort(n', first(p)), msort(n', second(p)))
}

define merge_sort : fn List<Nat> -> List<Nat>
  = λxs{ msort(log(length(xs)), xs) }

You will need define the split function.

Monday, June 17, 2024

Insertion Sort, Correctly

Insertion Sort

This is the third blog post in a series about developing correct implementations of basic data structures and algorithms using the Deduce language and proof checker.

In this blog post we study a simple but slow sorting algorithm, Insertion Sort. (We will study the faster Merge Sort in the next blog post.) Insertion Sort is, roughly speaking, how many people sort the cards in their hand when playing a card game. The basic idea is to take one card at a time and place it into the correct location amongst the cards that you’ve already sorted.

Specification: The insertion_sort(xs) function returns a list that contains the same elements as xs but the elements in the result are in sorted order.

Of course, to make this specification precise, we need to define "sorted". There are several ways to go with this formal definition. Here is one that works well for me. It requires each element in the list to be less-or-equal to all the elements that come after it.

function sorted(List<Nat>) -> bool {
  sorted(empty) = true
  sorted(node(x, xs)) =
    sorted(xs) and all_elements(xs, λy{ x ≤ y })
}

We follow the write-test-prove approach to develop a correct implementation of insertion_sort. We then propose an exercise for the reader.

Write the insertion_sort function

Because insertion_sort operates on the recursive type List, we’ll try to implement insertion_sort as a recursive function.

function insertion_sort(List<Nat>) -> List<Nat> {
  insertion_sort(empty) = ?
  insertion_sort(node(x, xs')) = ?
}

In the case for the empty list, we need to return a list with the same contents, so we better return empty.

function insertion_sort(List<Nat>) -> List<Nat> {
  insertion_sort(empty) = empty
  insertion_sort(node(x, xs')) = ?
}

In the case for node(x, xs'), we can make the recursive call insertion_sort(xs') to sort the rest of the list.

function insertion_sort(List<Nat>) -> List<Nat> {
  insertion_sort(empty) = empty
  insertion_sort(node(x, xs')) = ... insertion_sort(xs') ...
}

But what do we do with the element x? This is where we need to define an auxiliary function that inserts x into the appropriate location within the result of sorting the rest of the list. We’ll choose the name insert for this auxiliary function. Here is the completed code for insertion_sort.

function insertion_sort(List<Nat>) -> List<Nat> {
  insertion_sort(empty) = empty
  insertion_sort(node(x, xs')) = insert(insertion_sort(xs'), x)
}

Of course, we’ll follow the write-test-prove approach to develop the insert function. The first thing we need to do is write down the specification. The specification of insert will play an important role in the proof of correctness of insertion_sort, because we’ll use the correctness theorems about insert in the proof of the correctness theorems about insertion_sort. With this in mind, here’s a specification for insert.

Specification: The insert(xs, y) function takes a sorted list xs and value y as input and returns a sorted list that contains y and the elements of xs.

Next we write the code for insert. This function also has a List as input, so we define yet another recursive function.

function insert(List<Nat>,Nat) -> List<Nat> {
  insert(empty, y) = ?
  insert(node(x, xs), y) = ?
}

For the case empty we must return a list that contains y, so it must be node(y, empty)

function insert(List<Nat>,Nat) -> List<Nat> {
  insert(empty, y) = node(y, empty)
  insert(node(x, xs), y) = ?
}

In the case for node(x, xs'), we need to check whether y should come before x. So we test y ≤ x and if that’s the case, we return node(y, node(x, xs')). Otherwise, y belongs somewhere later in the sequence, so we make the recursive call and return node(x, insert(xs', y)).

function insert(List<Nat>,Nat) -> List<Nat> {
  insert(empty, y) = node(y, empty)
  insert(node(x, xs'), y) =
    if y ≤ x then
      node(y, node(x, xs'))
    else
      node(x, insert(xs', y))
}

Test

This time we have 2 functions to test, insert and insertion_sort. We start with insert because if there are bugs in insert, then it will be confusing to find out about them when testing insertion_sort.

Test insert

Looking at the specification for insert, we need to check whether the resulting list is sorted and we need to check that the resulting list contains the elements from the input and the inserted element. We could use the search function that we developed in the previous blog post to check whether the elements from the input list are in the output list. However, doing that would ignore a subtle issue, which is that there can be one or more occurrences of the same element in the input list, and the output list should have the same number of occurrences. To take this into account, we need a new function to count the number of occurrences of an element.

function count<T>(List<T>) -> fn T->Nat {
  count(empty) = λy{ 0 }
  count(node(x, xs')) = λy{
    if y = x then 
      suc(count(xs')(y))
    else
      count(xs')(y) 
  }
}

Here’s a test that checks whether insert produces a sorted list with the correct count for every element on the input list as well as the inserted element.

define list_1234 = node(1, node(2, node(3, node(4, empty))))
define list_12334 = insert(list_1234, 3)
assert sorted(list_12334)
assert all_elements(node(3, list_1234), λx{
  if x = 3 then
    count(list_12334)(x) = suc(count(list_1234)(x))
  else
    count(list_12334)(x) = count(list_1234)(x)
  })

It’s a good idea to test corner cases, that is, inputs that trigger different code paths through the insert function. As there is a case for the empty list in the code, that’s a good test case to consider.

define list_3 = insert(empty, 3)
assert sorted(list_3)
assert all_elements(node(3, empty), λx{
  if x = 3 then
    count(list_3)(x) = suc(count(empty : List<Nat>)(x))
  else
    count(list_3)(x) = count(empty : List<Nat>)(x)
  })

Ideally we would also test with hundreds of randomly-generated lists. Adding support for random number generation is high on the TODO list for Deduce.

Test insertion_sort

If we refer back to the specification of insertion_sort, we need to check that the output list is sorted and that it contains the same elements as the input list.

define list_8373 = node(8, node(3, node(7, node(3, empty))))
define list_3378 = insertion_sort(list_8373)
assert sorted( list_3378 )
assert all_elements(list_8373, λx{count(list_3378)(x) = count(list_8373)(x)})

Prove

The next step in the process is to prove the correctness of insert and insertion_sort with respect to their specification.

Prove correctness of insert

Our first task is to prove that insert(xs, y) produces a list that contains y and the elements of xs. In our tests, we used the count function to accomplish this. Note that count returns a function fn T->Nat, which is the same thing as a multiset. The file MultiSet.pf defines the MultiSet<T> type and operations on them such as m_one(x) for creating a multiset that only contains x and the operator A ⨄ B for combining two multisets. The file List.pf defines a function mset_of that converts a list into a multiset.

function mset_of<T>(List<T>) -> MultiSet<T> {
  mset_of(empty) = m_fun(λ{0})
  mset_of(node(x, xs)) = m_one(x) ⨄ mset_of(xs)
}

So we express the requirements on the contents of insert(xs, y) as follows: converting insert(xs, y) into a multiset is the same as converting xs into a multiset and then adding y. The proof is relatively straightforward, making use of several theorems about multisets from MultiSet.pf.

theorem insert_contents: all xs:List<Nat>. all y:Nat.
  mset_of(insert(xs,y)) = m_one(y) ⨄ mset_of(xs)
proof
  induction List<Nat>
  case empty {
    arbitrary y:Nat
    conclude mset_of(insert(empty,y)) = m_one(y) ⨄ mset_of(empty)
        by definition {insert, mset_of, mset_of}
  }
  case node(x, xs') suppose IH {
    arbitrary y:Nat
    switch y ≤ x for insert {
      case true suppose yx_true {
        conclude mset_of(node(y,node(x,xs'))) = m_one(y) ⨄ mset_of(node(x,xs'))
            by definition {mset_of, mset_of}
      }
      case false suppose yx_false {
        equations
              mset_of(node(x,insert(xs',y))) 
            = m_one(x) ⨄ mset_of(insert(xs',y))
              by definition mset_of
        ... = m_one(x) ⨄ (m_one(y) ⨄ mset_of(xs'))
              by rewrite IH[y]
        ... = (m_one(x) ⨄ m_one(y)) ⨄ mset_of(xs')
              by rewrite m_sum_assoc<Nat>[m_one(x),m_one(y),mset_of(xs')]
        ... = (m_one(y) ⨄ m_one(x)) ⨄ mset_of(xs')
              by rewrite m_sum_commutes<Nat>[ m_one(x), m_one(y)]
        ... = m_one(y) ⨄ (m_one(x) ⨄ mset_of(xs'))
              by rewrite m_sum_assoc<Nat>[m_one(y),m_one(x),mset_of(xs')]
        ... = m_one(y) ⨄ mset_of(node(x,xs'))
              by definition mset_of
      }
    }
  }
end

Our second task is to prove that insert produces a sorted list, assuming the input list is sorted.

theorem insert_sorted: all xs:List<Nat>. all y:Nat.
  if sorted(xs) then sorted(insert(xs, y))
proof
  ?
end

Because insert is a recursive function, we proceed by induction on its first argument xs.

  induction List<Nat>

The case for xs = empty is a straightforward use of definitions.

    // <<insert_sorted_case_empty>> =
    arbitrary y:Nat
    suppose _
    conclude sorted(insert(empty,y))
        by definition {insert, sorted, sorted, all_elements}

Here’s the beginning of the case for xs = node(x, xs').

  case node(x, xs') suppose IH {
    arbitrary y:Nat
    suppose s_xxs: sorted(node(x,xs'))
    suffices sorted(insert(node(x,xs'),y))  by .
    ?
  }

In the goal we see an opportunity to use the definition of insert. However, insert branches on whether y ≤ x, so we use a switch-for statement to do the same in our proof.

  switch y ≤ x for insert {
    case true suppose yx_true {
      ?
    }
    case false suppose yx_false {
      ?
    }
  }

In the case when y ≤ x is true, we apply the relevant definitions to arive at the following four subgoals.

    suffices sorted(xs') 
         and all_elements(xs', λb{x ≤ b}) 
         and y ≤ x
         and all_elements(xs', λb{y ≤ b})
             with definition {sorted, sorted, all_elements}

The first two of these follows from the premise sorted(node(x,xs')).

  // <<insert_sorted_case_node_s_xs__x_le_xs>> =
  have s_xs: sorted(xs') by definition sorted in s_xxs
  have x_le_xs': all_elements(xs',λb{(x ≤ b)}) by definition sorted in s_xxs

The third is true in the current case.

  // <<insert_sorted_y_le_x>> =
  have y_le_x: y ≤ x by rewrite yx_true

The fourth, which states that y is less-or-equal all the elements in xs' follows transitively from y ≤ x and the that x is less-or-equal all the elements in xs' (x_le_xs') using the theorem all_elements_implies (in List.pf):

theorem all_elements_implies: 
  all T:type. all xs:List<T>. all P: fn T -> bool, Q: fn T -> bool.
  if all_elements(xs,P) and (all z:T. if P(z) then Q(z)) 
  then all_elements(xs,Q)

To satisfy the second premise of all_elements_implies, we use y ≤ x to prove that if x is less than any other element, then so is y.

  // <<insert_sorted_x_le_implies_y_le>> =
  have x_le_implies_y_le: all z:Nat. (if x ≤ z then y ≤ z)
    by arbitrary z:Nat  suppose x_le_z: x ≤ z
       conclude y ≤ z by apply less_equal_trans[y][x,z] to y_le_x , x_le_z

Now we apply all_elements_implies to prove all_elements(xs',λb{(y ≤ b)}).

  // <<insert_sorted_y_le_xs>> =
  have y_le_xs': all_elements(xs',λb{(y ≤ b)})
    by apply all_elements_implies<Nat>[xs']
             [λb{(x ≤ b)} : fn Nat->bool, λb{(y ≤ b)} : fn Nat->bool]
       to x_le_xs', x_le_implies_y_le

and then conclude this case for when y ≤ x.

  // <<insert_sorted_case_node_le_conclusion>> =
  s_xs, x_le_xs', y_le_x, y_le_xs'

Next we turn our attention to the case for when y ≤ x is false. After applying the definition of insert, Deduce tells us that we need to prove.

    // <<insert_sorted_case_node_g_def>> =
    suffices sorted(insert(xs',y)) 
         and all_elements(insert(xs',y),λb{x ≤ b})
             with definition sorted

The first follows from the induction hypothesis. (Though we need to move the proof of s_xs out of the y ≤ x case so that we can use it here.)

  // <<insert_sorted_s_xs_y>> =
  have s_xs'_y: sorted(insert(xs',y)) by apply IH[y] to s_xs

The second requires more thinking. We know that x ≤ y in this case by the following reasoning.

  // <<insert_sorted_x_le_y>> =
  have x_le_y: x ≤ y
      by have not_yx: not (y ≤ x)  by suppose yx rewrite yx_false in yx
         apply not_less_equal_less_equal to not_yx

We have already proved that x is less-or-equal all the elements in xs'. So we know that x is less-or-equal all the element in node(y, xs') by the definition of all_elements.

  // <<insert_sorted_x_le_y_xs>> =
  have x_le_y_xs': all_elements(node(y, xs'),λb{(x ≤ b)})
      by suffices x ≤ y and all_elements(xs', λb{x ≤ b}) 
              with definition all_elements
         x_le_y, x_le_xs'

However, what we need to prove is that x is less-or-equal to insert(xs', y). But the all_elements function shouldn’t care about the ordering of the elements in the list, and indeed there is the following theorem in List.pf:

theorem all_elements_set_of:
  all T:type, xs:List<T>, ys:List<T>, P:fn T -> bool.
  if set_of(xs) = set_of(ys)
  then all_elements(xs, P) = all_elements(ys, P)

So we need to show that set_of(insert(xs',y)) = set_of(node(y,xs')). Thankfully, we already showed that this is true for mset_of in the insert_contents theorem, and multiset equality implies set equality: (also from List.pf)

theorem mset_equal_implies_set_equal: 
  all T:type, xs:List<T>, ys:List<T>.
  if mset_of(xs) = mset_of(ys)
  then set_of(xs) = set_of(ys)

So we use these three theorems to prove the following.

theorem all_elements_insert_node:
  all xs:List<Nat>, x:Nat, P:fn Nat->bool.
  all_elements(insert(xs,x), P) = all_elements(node(x,xs), P)
proof
  arbitrary xs:List<Nat>, x:Nat, P:fn Nat->bool
  have m_xs_x: mset_of(insert(xs, x)) = mset_of(node(x, xs))
      by suffices mset_of(insert(xs, x)) = m_one(x) ⨄ mset_of(xs)
             with definition mset_of
         insert_contents[xs][x]
  have ixsx_xxs: set_of(insert(xs, x)) = set_of(node(x, xs))
     by apply mset_equal_implies_set_equal<Nat>[insert(xs, x), node(x, xs)] 
        to m_xs_x
  apply all_elements_set_of<Nat>[ insert(xs,x), node(x, xs), P]
  to ixsx_xxs
end

We apply this theorem to prove that x is less-or-equal all the elements in insert(xs',y).

  // <<insert_sorted_x_le_xs_y>> =
  have x_le_xs'_y: all_elements(insert(xs',y), λb{x ≤ b})
      by _rewrite all_elements_insert_node[xs',y,λb{x≤b}:fn Nat->bool]
         x_le_y_xs'

Now we have the two facts we need to conclude this final case of proof of insert_sorted.

  // <<insert_sorted_case_node_g_conclusion>> =
  conclude sorted(insert(xs',y)) and
           all_elements(insert(xs',y),λb{x ≤ b})
      by s_xs'_y, x_le_xs'_y

Here is the complete proof of insert_sorted.

theorem insert_sorted: all xs:List<Nat>. all y:Nat.
  if sorted(xs) then sorted(insert(xs, y))
proof
  induction List<Nat>
  case empty {
    <<insert_sorted_case_empty>>
  }
  case node(x, xs') suppose IH {
    arbitrary y:Nat
    suppose s_xxs: sorted(node(x,xs'))
    suffices sorted(insert(node(x,xs'),y))  by .
    <<insert_sorted_case_node_s_xs__x_le_xs>>
    switch y ≤ x for insert {
      case true suppose yx_true {
        suffices sorted(node(y,node(x,xs')))  by .
        <<insert_sorted_case_node_less_defs>>
        <<insert_sorted_y_le_x>>
        <<insert_sorted_x_le_implies_y_le>>
        <<insert_sorted_y_le_xs>>
        <<insert_sorted_case_node_le_conclusion>>
      }
      case false suppose yx_false {
        <<insert_sorted_case_node_g_def>>
        <<insert_sorted_s_xs_y>>
        <<insert_sorted_x_le_y>>
        <<insert_sorted_x_le_y_xs>>
        <<insert_sorted_x_le_xs_y>>
        <<insert_sorted_case_node_g_conclusion>>
      }
    }
  }
end

Prove the correctness of insertion_sort

Referring back at the specification of insertion_sort(xs), we need to prove that (1) it outputs a list that contains the same elements as xs, and (2) the output is sorted.

As we did for insert, we use multisets and mset_of to express the requirement o the contents of the output of insertion_sort.

theorem insertion_sort_contents: all xs:List<Nat>.
  mset_of(insertion_sort(xs)) = mset_of(xs)

The insertion_sort(xs) function is recursive, so we proceed by induction on xs. In the case for xs = empty, we conclude the following using the definitions of insertion_sort and mset_of.

    // <<insertion_sort_contents_case_empty>> =
    conclude mset_of(insertion_sort(empty)) = mset_of(empty)
      by definition {insertion_sort, mset_of}

In the case for xs = node(x, xs'), after applying the definitions of insertion_sort and mset_of, it suffices show that

    // <<insertion_sort_contents_case_node_defs>> =
    suffices mset_of(insert(insertion_sort(xs'),x)) 
           = m_one(x) ⨄ mset_of(xs')
        with definition {insertion_sort, mset_of}

The goal follows from the insert_contents theorem and the induction hypothesis as follows.

  // <<insertion_sort_contents_case_node_equations>> =
  equations
          mset_of(insert(insertion_sort(xs'),x)) 
        = m_one(x) ⨄ mset_of(insertion_sort(xs'))
          by insert_contents[insertion_sort(xs')][x]
    ... = m_one(x) ⨄ mset_of(xs')
          by rewrite IH

Here is the complete proof of insertion_sort_contents.

theorem insertion_sort_contents: all xs:List<Nat>.
  mset_of(insertion_sort(xs)) = mset_of(xs)
proof
  induction List<Nat>
  case empty {
    <<insertion_sort_contents_case_empty>>
  }
  case node(x, xs') suppose IH {
    <<insertion_sort_contents_case_node_defs>>
    <<insertion_sort_contents_case_node_equations>>
  }
end

Finally, we prove that insertion_sort(xs) produces a sorted list. Of course the proof is by induction on xs. The case for empty follows from the relevant definitions. The case for node(x, xs') follows from the insert_sorted theorem and the induction hypothesis.

theorem insertion_sort_sorted: all xs:List<Nat>. 
  sorted( insertion_sort(xs) )
proof
  induction List<Nat>
  case empty {
    conclude sorted(insertion_sort(empty))
        by definition {insertion_sort, sorted}
  }
  case node(x, xs') suppose IH: sorted( insertion_sort(xs') ) {
    suffices sorted(insert(insertion_sort(xs'),x))
        with definition {insertion_sort, sorted}
    apply insert_sorted[insertion_sort(xs')][x] to IH
  }
end

Exercise: tail-recursive variant of insertion_sort

The insertion_sort function uses more computer memory than necessary because it uses one frame on the procedure call stack for every element in the input list. This can be avoided if we instead implement Insertion Sort using a tail-recursive function. That is, as a function that immediately returns after the recursive call. For this exercise, formulate a tail recursive version of insertion_sort, test it, and prove that it is correct.

As a hint, define an auxiliary function isort(xs,ys) that takes a list xs and an already sorted list ys and returns a sorted list that includes the contents of both xs and ys.

function isort(List<Nat>, List<Nat>) -> List<Nat> {
  FILL IN HERE
}

Once you have defined isort, you can implement Insertion Sort as follows.

define insertion_sort2 : fn List<Nat> -> List<Nat>
    = λxs{ isort(xs, empty) }

To prove the correctness of insertion_sort2, prove that the result is sorted

theorem insertion_sort2_sorted: all xs:List<Nat>. 
  sorted( insertion_sort2(xs) )
proof
  ?
end

and prove that the output includes all of the same elements as in the input (the correct number of times).

theorem insertion_sort2_contents: all xs:List<Nat>. 
  mset_of( insertion_sort2(xs) ) = mset_of(xs)
proof
  ?
end

Friday, June 14, 2024

Sequential Search, Correctly

Sequential Search

This is the second blog post in a series about developing correct implementations of basic data structures and algorithms using the Deduce language and proof checker.

In this blog post we’ll study a classic and simple algorithm known as Sequential Search (aka. Linear Search). The basic idea of the algorithm is to look for the location of a particular item within a linked list, traversing the list front to back. Here is the specification of this search function.

Specification: The search(xs, y) function returns a natural number i such that i ≤ length(xs). If i < length(xs), then i is the index of the first occurence of y in the list xs. If i = length(xs), y is not in the list xs.

We follow the write-test-prove approach to develop a correct implementation of search. We then propose two exercises for the reader.

Write the search function

Before diving into the code for search, let us look again at the definition of the List type.

union List<T> {
  empty
  node(T, List<T>)
}

We say that List is a recursive union because one of its constructors has a parameter that is also of the List type (e.g. the second parameter of the node constructor).

In general, when defining a function with a parameter that is a recursive union, first consider making that function a recursive function that pattern-matches on that parameter.

For example, with search, we choose for the List<Nat> to be the first parameter so that we can pattern-match on it as follows.

function search(List<Nat>, Nat) -> Nat {
  search(empty, y) = ?
  search(node(x, xs'), y) = ?
}

Let us consider the case for the empty list. Looking at the specification of search, we need to return 0, the length of the empty list, because y is not in the empty list.

function search(List<Nat>, Nat) -> Nat {
  search(empty, y) = 0
  search(node(x, xs'), y) = ?
}

In the case for node(x, xs'), we can check whether x = y. If so, then we should return 0 because y is at index 0 of node(x, xs') and that is certainly the first occurence of y in node(x, xs').

function search(List<Nat>, Nat) -> Nat {
  search(empty, y) = 0
  search(node(x, xs'), y) =
    if x = y then
      0
    else
      ?
}

If x ≠ y, then we need to search the rest of the list xs' for y. We can make the recursive call search(xs', y), but then we need to decide how to adapt its result to produce a result that makes sense for node(x, xs'). The only way to reason about the result of a recursive call is to use the specification of the function. The specification of search splits into two cases on the result: (1) search(xs', y) < length(xs') and (2) length(xs) ≤ search(xs', y). In case (1), search(xs',y) is returning the index of the first y inside xs'. Because x ≠ y, that location will also be the first y inside node(x, xs'). However, we need to add one to the index to take into account that we’re adding a node to the front of the list. So for case (1), the result should be suc(search(xs', y)). In case (2), search(xs',y) did not find y in xs', so it is returning length(xs'). Because x ≠ y, we need to indicate that y is also not found in node(x, xs'), so we need to return length(node(x, xs')). Thus, we need to add one to the index, so the result should again be suc(search(xs', y)).

Here is the completed code for search.

Test the search function

Focusing on the specification of search, there are several things that we should test. First, we should test whether search always returns a number that is less-or-equal to the length of the list. We can use all_elements and interval to automate the testing over a bunch of values, some of which are in the list and some are not.

define list_1223 = node(1, node(2, node(2, node(3, empty))))

assert all_elements(interval(0, 5),
  λx{ search(list_1223, x) ≤ length(list_1223) })

Most importantly, we should test whether search finds the correct index of the elements in the list. To do that we can make use of nth to lookup the element at a given index.

assert all_elements(list_1223,
  λx{ nth(list_1223, 0)( search(list_1223, x) ) = x })

Next, we should test whether search finds the first occurence. We can do this by iterating over all the indexes and checking that what search returns is an index that is less-than or equal to the current index.

assert all_elements(interval(0, length(list_1223)),
   λi{ search(list_1223, nth(list_1223, 0)(i)) ≤ i })

Finally, we check that search fails gracefully when the value being searched for is not present in the list.

assert search(list_1223, 0) = length(list_1223)
assert search(list_1223, 4) = length(list_1223)

Prove search Correct

We break down the specification of search into four parts and prove four theorems.

Prove search is less-or-equal length

The first part of the specification of search says that the search(xs, y) function returns a natural number i such that i ≤ length(xs). Because search is recursive, we’re going to prove this by induction on its first parameter xs.

theorem search_length: all xs:List<Nat>. all y:Nat.
  search(xs, y) ≤ length(xs)
proof
  induction List<Nat>
  case empty {
    ?
  }
  case node(x, xs') 
    suppose IH: all y:Nat. search(xs',y) ≤ length(xs') 
  {
    ?
  }
end

In the case for xs = empty, Deduce tells us that we need to prove

Goal:
    all y:Nat. search(empty,y) ≤ length(empty)

So we start with arbitrary y:Nat and then conclude using the definitions of search, length, and operator ≤.

    // <<search_length_case_empty>> =
    arbitrary y:Nat
    conclude search(empty,y) ≤ length[Nat](empty)
        by definition {search, length, operator ≤}

In these blog post we use a literate programming tool named Entangled to translate the markdown files into Deduce proof files. Entangled lets us label chunks of proof and then paste them into larger proofs. So that you can see the label names, we include them in comments, as in the <<search_length_case_empty>> label above.

In the case for xs = node(x, xs'), Deduce tells us that we need to prove

Goal:
    all y:Nat. search(node(x,xs'),y) ≤ length(node(x,xs'))

So we start with arbitrary y:Nat and note that the definitions of search has an if-then-else, so we proceed with a switch-for statement.

    arbitrary y:Nat
    switch x = y for search {
      case true {
        ?
      }
      case false {
        ?
      }
    }

In the case for x = y, the goal becomes

0 ≤ length(node(x, xs'))

so we need to use the definition of length and then we can complete the proof using the definition of .

    // <<search_length_case_node_eq>> =
    suffices 0 ≤ 1 + length(xs')  with definition length
    definition operator ≤

In the case for x ≠ y, after applying the definitions of length, , and +, it remains to prove that search(xs', y) ≤ length(xs'). But that is just the induction hypothesis

    // <<search_length_case_node_not_eq>> =
    suffices search(xs', y) ≤ length(xs')
        with definition {length, operator ≤, operator+, operator+}
    IH[y]

Putting all of the pieces together, we have a complete proof of search_length.

theorem search_length: all xs:List<Nat>. all y:Nat.
  search(xs, y) ≤ length(xs)
proof
  induction List<Nat>
  case empty {
    <<search_length_case_empty>>
  }
  case node(x, xs') 
    suppose IH: all y:Nat. search(xs',y) ≤ length(xs') 
  {
    arbitrary y:Nat
    switch x = y for search {
      case true {
        <<search_length_case_node_eq>>
      }
      case false {
        <<search_length_case_node_not_eq>>
      }
    }
  }
end

Prove search(xs, y) finds an occurence of y

The specification of search(xs, y) says that if the result is less-than length(xs), then the result is the index of the first occurence of y in xs. First off, this means that search(xs, y) is indeed an index for y, which we can express using nth as follows.

nth(xs, 0)( search(xs, y) ) = y

So we can formulate the following theorem, which we’ll prove by induction on xs.

theorem search_present: all xs:List<Nat>. all y:Nat.
  if search(xs, y) < length(xs)
  then nth(xs, 0)( search(xs, y) ) = y
proof
  induction List<Nat>
  case empty {
    ?
  }
  case node(x, xs') suppose IH {
    ?
  }
end

In the case for xs = empty, we proceed in a goal-directed way using arbitrary for the all y and then suppose for the if.

    arbitrary y:Nat
    suppose prem: search(empty,y) < length[Nat](empty)
    ?

Then we need to prove

nth(empty, 0)(search(empty, y)) = y

but that looks impossible! So hopefully the premise is also false, which will let us finish this case using the principle of explosion. Indeed, applying all of the relevant definitions to the premise yields false.

    arbitrary y:Nat
    suppose prem: search(empty,y) < length[Nat](empty)
    conclude false by definition {search, length, operator <, operator ≤} 
                      in prem

Moving on to the case for xs = node(x, xs'), we again begin with arbitrary and suppose.

    arbitrary y:Nat
    suppose sxxs_len: search(node(x,xs'),y) < length(node(x,xs'))
    ?

Deduce tells us that we need to prove

Goal:
    nth(node(x,xs'),0)(search(node(x,xs'),y)) = y

We see search applied to a node argument and note that again that the body of search contains an if-then-else, so we proceed with a switch-for statement.

    switch x = y for search {
      case true suppose xy_true {
        ?
      }
      case false suppose xy_false {
        ?
      }
    }

In the case where x = y, Deduce tells us that we need to prove

Goal:
    nth(node(x,xs'),0)(0) = y

We conclude using the definition of nth and the fact that x = y.

    suffices x = y with definition nth
    rewrite xy_true

In the case where x ≠ y, we need to prove

Goal:
    nth(node(x,xs'),0)(suc(search(xs',y))) = y

Now if we apply the definitions of nth and pred, the goal becomes:

    // <<search_present_case_node_nth_pred>> =
    suffices nth(xs', 0)(search(xs', y)) = y
        with definition {nth, pred}

This looks a lot like the conclusion of our induction hypothesis:

Givens:
    ...
    IH: all y:Nat. (if search(xs',y) < length(xs') 
                    then nth(xs',0)(search(xs',y)) = y)

So we just need to prove the premise of the IH, that search(xs',y) < length(xs'). Thankfully, that can be proved from the premise search(node(x,xs'),y) < length(node(x,xs')).

  // <<search_present_IH_premise>> =
    have sxs_len: search(xs',y) < length(xs')
      by enable {search, length, operator <, operator ≤, 
                 operator+, operator+}
         rewrite xy_false in sxxs_len

We conclude by applying the induction hypothesis.

  conclude nth(xs',0)(search(xs',y)) = y
    by apply IH[y] to sxs_len

Here is the the complete proof of search_present.

theorem search_present: all xs:List<Nat>. all y:Nat.
  if search(xs, y) < length(xs)
  then nth(xs, 0)( search(xs, y) ) = y
proof
  induction List<Nat>
  case empty {
    <<search_present_case_empty>>
  }
  case node(x, xs') suppose IH {
    arbitrary y:Nat
    suppose sxxs_len: search(node(x,xs'),y) < length(node(x,xs'))
    switch x = y for search {
      case true suppose xy_true {
        <<search_present_case_node_eq>>
      }
      case false suppose xy_false {
        <<search_present_case_node_nth_pred>>
        <<search_present_IH_premise>>
        <<search_present_apply_IH>>
      }
    }
  }
end

Prove search(xs, y) finds the first occurence of y

Going back to the specification of search(xs, y), it says that if the result is less-than length(xs), then the result is the index of the first occurence of y in xs. To be the first means that the result is smaller than the index of any other occurence of y. We express that in the following theorem.

theorem search_first: all xs:List<Nat>. all y:Nat, i:Nat.
  if search(xs, y) < length(xs) and nth(xs, 0)(i) = y
  then search(xs, y) ≤ i

We proceed by induction on xs. We can handle the case for xs = empty in the same way as in search_present; the premise is false.

    // <<search_first_case_empty>> =
    arbitrary y:Nat, i:Nat
    suppose prem: search(empty,y) < length[Nat](empty) and nth(empty,0)(i) = y
    conclude false by definition {search, length, operator <, operator ≤} 
                      in prem

In the case for xs = node(x, xs'), we proceed in a goal-directed fashion with an arbitrary and suppose.

  case node(x, xs') suppose IH {
    arbitrary y:Nat, i:Nat
    suppose prem: search(node(x,xs'),y) < length(node(x,xs')) 
                  and nth(node(x,xs'),0)(i) = y,
    ?
  }

Deduce response with

Goal:
    search(node(x,xs'),y) ≤ i

We apply the definition of search and switch on x = y with a switch-for statement.

  switch x = y for search {
    case true {
      ?
    }
    case false suppose xs_false {
      ?
    }
  }

In the case where x = y, the result of search is 0, so just need to prove that 0 ≤ i, which follows from the definition of .

    conclude 0 ≤ i   by definition operator ≤

In the case where x ≠ y, we need to prove

Goal:
    suc(search(xs',y)) ≤ i

What do we now about i? The premise nth(node(x,xs'),0)(i) = y tells us that i ≠ 0, which means that i is the successor of some other number i′.

    // <<search_first_case_node_false_1>> =
    have not_iz: not (i = 0)
      by suppose i_z 
         conclude false by rewrite i_z | xy_false in 
                           definition nth in prem
    obtain i' where i_si: i = suc(i') from apply not_zero_suc to not_iz
    suffices suc(search(xs', y)) ≤ suc(i')  with rewrite i_si

Now we can further simplify the goal with the definition of .

    // <<search_first_case_node_false_2>> =
    suffices search(xs', y) ≤ i'   with definition operator≤ 

The goal looks like the conclusion of the induction hypothesis instantiated at i'.

Givens:
    ...
    IH: all y:Nat, i:Nat. (if search(xs',y) < length(xs') and nth(xs',0)(i) = y 
                           then search(xs',y) ≤ i)

So we need to prove the two premises of the IH. They follow from the given prem:

Givens:
    prem: search(node(x,xs'),y) < length(node(x,xs')) 
          and nth(node(x,xs'),0)(i) = y

In particular, the first premise of IH follows from the first conjunct of prem.

    // <<search_first_IH_prem_1>> =
    have IH_prem_1: search(xs',y) < length(xs')
      by enable {search, length, operator <, operator ≤, 
                 operator+, operator+}
         rewrite xy_false in (conjunct 0 of prem)

The second premise of the IH follows from the second conjunct of prem.

    // <<search_first_IH_prem_2>> =
    have IH_prem_2: nth(xs',0)(i') = y
      by enable {nth, pred} rewrite i_si in (conjunct 1 of prem)

We conclude the case for i = suc(i') by applying the induction hypothesis.

    // <<search_first_apply_IH>> =
    apply IH[y,i'] to IH_prem_1, IH_prem_2

Here is the complete proof of search_first.

theorem search_first: all xs:List<Nat>. all y:Nat, i:Nat.
  if search(xs, y) < length(xs) and nth(xs, 0)(i) = y
  then search(xs, y) ≤ i
proof
  induction List<Nat>
  case empty {
    <<search_first_case_empty>>
  }
  case node(x, xs') suppose IH {
    arbitrary y:Nat, i:Nat
    suppose prem: search(node(x,xs'),y) < length(node(x,xs')) 
                  and nth(node(x,xs'),0)(i) = y
    switch x = y for search {
      case true {
        <<search_first_case_node_true>>
      }
      case false suppose xy_false {
        <<search_first_case_node_false_1>>
        <<search_first_case_node_false_2>>
        <<search_first_IH_prem_1>>
        <<search_first_IH_prem_2>>
        <<search_first_apply_IH>>
      }
    }
  }
end

Prove that search fails only when it should

The last sentence in the specification for search(xs, y) says that if i = length(xs), y is not in the list xs. How do we express that y is not in the list? In some sense, that is what search is for, but it would be vacuous to prove a theorem that says search returns length(xs) if search returns lengt(xs). Instead we need an alternative and intuitive way to express membership in a list.

One approach to expressing list membership that works well is to convert the list to a set and then use set membership. The file Set.pf defines the Set type, operations on sets such as memberhsip, union, and intersection. The Set.pf files also proves many theorems about these operations. The following set_of function converts a list into a set.

function set_of<T>(List<T>) -> Set<T> {
  set_of(empty) = ∅
  set_of(node(x, xs)) = single(x) ∪ set_of(xs)
}

We can now express our last correctness theorem for search as follows.

theorem search_absent: all xs:List<Nat>. all y:Nat, d:Nat.
  if search(xs, y) = length(xs)
  then not (y ∈ set_of(xs))

We proceed by induction on xs. In the case for xs = empty, we take the following goal-directed steps

  case empty {
    arbitrary y:Nat, d:Nat
    suppose _
    ?
  }

and Deduce responds with

Goal:
    not y ∈ set_of(empty)

which we prove using the definition of set_of and the empty_no_members theorem from Set.pf.

    // <<search_absent_case_empty>> =
    arbitrary y:Nat, d:Nat
    suppose _
    suffices not (y ∈ ∅) with definition set_of
    empty_no_members[Nat,y]

Turning to the case for xs = node(x, xs'), we take several goal-directed steps.

  case node(x, xs') suppose IH {
    arbitrary y:Nat, d:Nat
    suppose s_xxs_len_xxs: search(node(x,xs'),y) = length(node(x,xs'))
    suffices not (y ∈ single(x) ∪ set_of(xs'))  with definition set_of
    ?
  }

Now we need to prove a not formula:

Goal:
    not (y ∈ single(x) ∪ set_of(xs'))

So we assume y ∈ single(x) ∪ set_of(xs') and then prove false (a contradiction).

  suppose y_in_x_union_xs: y ∈ single(x) ∪ set_of(xs')

The main information we have to work with is the premise s_xxs_len_xxs above, concerning search(node(x,xs'), y). Thinking about the code for search, we know it will branch on whether x = y, so we better switch on that.

  switch x = y {
    case true suppose xy {
      ?
    }
    case false suppose not_xy {
      ?
    }
  }

In the case where x = y, we have search(node(x,xs'),y) = 0 but length(node(x,xs')) is 1 + length(xs'), so we have a contradiction.

    // <<search_absent_case_node_equal>> =
    have xy: x = y by rewrite xy_true
    have s_yxs_len_yxs: search(node(y,xs'),y) = length(node(y,xs'))
        by rewrite xy in s_xxs_len_xxs
    have zero_1_plus: 0 = 1 + length(xs')
        by definition {search, length} in s_yxs_len_yxs
    conclude false  by definition {operator+} in zero_1_plus

In the case where x ≠ y, we can show that y ∈ set_of(xs') and then invoke the induction hypothesis to obtain the contradition. In particular, the premise y_in_x_union_xs gives us y ∈ single(x) or y ∈ set_of(xs'). But x ≠ y implies not (y ∈ single(x)). So it must be that y ∈ set_of(xs') (using or_not from Base.pf).

  // <<search_absent_case_node_notequal_y_in_xs>> =
  have ysx_or_y_xs: y ∈ single(x) or y ∈ set_of(xs')
      by apply member_union[Nat] to y_in_x_union_xs
  have not_ysx: not (y ∈ single(x))
    by suppose ysx
       rewrite xy_false in
       apply single_equal[Nat] to ysx
  have y_xs: y ∈ set_of(xs')
    by apply or_not[y ∈ single(x), y ∈ set_of(xs')] 
       to ysx_or_y_xs, not_ysx

To satisfy the premise of the induction hypothesis, we prove the following.

    // <<search_absent_IH_prem>> =
    have sxs_lxs: search(xs',y) = length(xs')
      by injective suc
         rewrite xy_false in
         definition {search,length,operator+,operator+} in
         s_xxs_len_xxs

So we apply the induction hypothesis to get y ∉ set_of(xs'), which contradicts y ∈ set_of(xs).

  // <<search_absent_apply_IH>> =
  have y_not_xs: not (y ∈ set_of(xs'))
    by apply IH[y,d] to sxs_lxs
  conclude false  by apply y_not_xs to y_xs

Here is the complete proof of search_absent.

theorem search_absent: all xs:List<Nat>. all y:Nat, d:Nat.
  if search(xs, y) = length(xs)
  then not (y ∈ set_of(xs))
proof
  induction List<Nat>
  case empty {
    <<search_absent_case_empty>>
  }
  case node(x, xs') suppose IH {
    arbitrary y:Nat, d:Nat
    suppose s_xxs_len_xxs: search(node(x,xs'),y) = length(node(x,xs'))
    suffices not (y ∈ single(x) ∪ set_of(xs'))  with definition set_of
    suppose y_in_x_union_xs: y ∈ single(x) ∪ set_of(xs')
    switch x = y {
      case true suppose xy_true {
        <<search_absent_case_node_equal>>
      }
      case false suppose xy_false {
        <<search_absent_case_node_notequal_y_in_xs>>
        <<search_absent_IH_prem>>
        <<search_absent_apply_IH>>
      }
    }
  }
end

Exercise search_last

Apply the write-test-prove approach to develop a correct implementation of the search_last(xs, y) function, which is like search(xs, y) except that it finds the last occurence of y in xs instead of the first.

In particular, you need to

  • write a specification for search_last,
  • write the code for search_last,
  • test search_last on diverse inputs, and
  • prove that search_last is correct.
function search_last(List<Nat>, Nat) -> Nat {
    FILL IN HERE
}

Exercise search_if

The search_if(xs, P) function is a generalization of search(xs, y). Instead of searching for the first occurence of element y, the search_if function searches for the location of the first element that satisfied predicate P (i.e. an element y in xs such that P(y) is true). Apply the write-test-prove approach to develop a correct implementation of search_if.

In particular, you need to

  • write a specification for search_if,
  • write the code for search_if,
  • test search_if on diverse inputs, and
  • prove that search_if is correct.
function search_if<T>(List<T>, fn T->bool) -> Nat {
    FILL IN HERE
}