Monday, June 17, 2024

Insertion Sort, Correctly

Insertion Sort

This is the third blog post in a series about developing correct implementations of basic data structures and algorithms using the Deduce language and proof checker.

In this blog post we study a simple but slow sorting algorithm, Insertion Sort. (We will study the faster Merge Sort in the next blog post.) Insertion Sort is, roughly speaking, how many people sort the cards in their hand when playing a card game. The basic idea is to take one card at a time and place it into the correct location amongst the cards that you’ve already sorted.

Specification: The insertion_sort(xs) function returns a list that contains the same elements as xs but the elements in the result are in sorted order.

Of course, to make this specification precise, we need to define "sorted". There are several ways to go with this formal definition. Here is one that works well for me. It requires each element in the list to be less-or-equal to all the elements that come after it.

function sorted(List<Nat>) -> bool {
  sorted(empty) = true
  sorted(node(x, xs)) =
    sorted(xs) and all_elements(xs, λy{ x ≤ y })
}

We follow the write-test-prove approach to develop a correct implementation of insertion_sort. We then propose an exercise for the reader.

Write the insertion_sort function

Because insertion_sort operates on the recursive type List, we’ll try to implement insertion_sort as a recursive function.

function insertion_sort(List<Nat>) -> List<Nat> {
  insertion_sort(empty) = ?
  insertion_sort(node(x, xs')) = ?
}

In the case for the empty list, we need to return a list with the same contents, so we better return empty.

function insertion_sort(List<Nat>) -> List<Nat> {
  insertion_sort(empty) = empty
  insertion_sort(node(x, xs')) = ?
}

In the case for node(x, xs'), we can make the recursive call insertion_sort(xs') to sort the rest of the list.

function insertion_sort(List<Nat>) -> List<Nat> {
  insertion_sort(empty) = empty
  insertion_sort(node(x, xs')) = ... insertion_sort(xs') ...
}

But what do we do with the element x? This is where we need to define an auxiliary function that inserts x into the appropriate location within the result of sorting the rest of the list. We’ll choose the name insert for this auxiliary function. Here is the completed code for insertion_sort.

function insertion_sort(List<Nat>) -> List<Nat> {
  insertion_sort(empty) = empty
  insertion_sort(node(x, xs')) = insert(insertion_sort(xs'), x)
}

Of course, we’ll follow the write-test-prove approach to develop the insert function. The first thing we need to do is write down the specification. The specification of insert will play an important role in the proof of correctness of insertion_sort, because we’ll use the correctness theorems about insert in the proof of the correctness theorems about insertion_sort. With this in mind, here’s a specification for insert.

Specification: The insert(xs, y) function takes a sorted list xs and value y as input and returns a sorted list that contains y and the elements of xs.

Next we write the code for insert. This function also has a List as input, so we define yet another recursive function.

function insert(List<Nat>,Nat) -> List<Nat> {
  insert(empty, y) = ?
  insert(node(x, xs), y) = ?
}

For the case empty we must return a list that contains y, so it must be node(y, empty)

function insert(List<Nat>,Nat) -> List<Nat> {
  insert(empty, y) = node(y, empty)
  insert(node(x, xs), y) = ?
}

In the case for node(x, xs'), we need to check whether y should come before x. So we test y ≤ x and if that’s the case, we return node(y, node(x, xs')). Otherwise, y belongs somewhere later in the sequence, so we make the recursive call and return node(x, insert(xs', y)).

function insert(List<Nat>,Nat) -> List<Nat> {
  insert(empty, y) = node(y, empty)
  insert(node(x, xs'), y) =
    if y ≤ x then
      node(y, node(x, xs'))
    else
      node(x, insert(xs', y))
}

Test

This time we have 2 functions to test, insert and insertion_sort. We start with insert because if there are bugs in insert, then it will be confusing to find out about them when testing insertion_sort.

Test insert

Looking at the specification for insert, we need to check whether the resulting list is sorted and we need to check that the resulting list contains the elements from the input and the inserted element. We could use the search function that we developed in the previous blog post to check whether the elements from the input list are in the output list. However, doing that would ignore a subtle issue, which is that there can be one or more occurrences of the same element in the input list, and the output list should have the same number of occurrences. To take this into account, we need a new function to count the number of occurrences of an element.

function count<T>(List<T>) -> fn T->Nat {
  count(empty) = λy{ 0 }
  count(node(x, xs')) = λy{
    if y = x then 
      suc(count(xs')(y))
    else
      count(xs')(y) 
  }
}

Here’s a test that checks whether insert produces a sorted list with the correct count for every element on the input list as well as the inserted element.

define list_1234 = node(1, node(2, node(3, node(4, empty))))
define list_12334 = insert(list_1234, 3)
assert sorted(list_12334)
assert all_elements(node(3, list_1234), λx{
  if x = 3 then
    count(list_12334)(x) = suc(count(list_1234)(x))
  else
    count(list_12334)(x) = count(list_1234)(x)
  })

It’s a good idea to test corner cases, that is, inputs that trigger different code paths through the insert function. As there is a case for the empty list in the code, that’s a good test case to consider.

define list_3 = insert(empty, 3)
assert sorted(list_3)
assert all_elements(node(3, empty), λx{
  if x = 3 then
    count(list_3)(x) = suc(count(empty : List<Nat>)(x))
  else
    count(list_3)(x) = count(empty : List<Nat>)(x)
  })

Ideally we would also test with hundreds of randomly-generated lists. Adding support for random number generation is high on the TODO list for Deduce.

Test insertion_sort

If we refer back to the specification of insertion_sort, we need to check that the output list is sorted and that it contains the same elements as the input list.

define list_8373 = node(8, node(3, node(7, node(3, empty))))
define list_3378 = insertion_sort(list_8373)
assert sorted( list_3378 )
assert all_elements(list_8373, λx{count(list_3378)(x) = count(list_8373)(x)})

Prove

The next step in the process is to prove the correctness of insert and insertion_sort with respect to their specification.

Prove correctness of insert

Our first task is to prove that insert(xs, y) produces a list that contains y and the elements of xs. In our tests, we used the count function to accomplish this. Note that count returns a function fn T->Nat, which is the same thing as a multiset. The file MultiSet.pf defines the MultiSet<T> type and operations on them such as m_one(x) for creating a multiset that only contains x and the operator A ⨄ B for combining two multisets. The file List.pf defines a function mset_of that converts a list into a multiset.

function mset_of<T>(List<T>) -> MultiSet<T> {
  mset_of(empty) = m_fun(λ{0})
  mset_of(node(x, xs)) = m_one(x) ⨄ mset_of(xs)
}

So we express the requirements on the contents of insert(xs, y) as follows: converting insert(xs, y) into a multiset is the same as converting xs into a multiset and then adding y. The proof is relatively straightforward, making use of several theorems about multisets from MultiSet.pf.

theorem insert_contents: all xs:List<Nat>. all y:Nat.
  mset_of(insert(xs,y)) = m_one(y) ⨄ mset_of(xs)
proof
  induction List<Nat>
  case empty {
    arbitrary y:Nat
    conclude mset_of(insert(empty,y)) = m_one(y) ⨄ mset_of(empty)
        by definition {insert, mset_of, mset_of}.
  }
  case node(x, xs') suppose IH {
    arbitrary y:Nat
    definition insert
    switch y ≤ x {
      case true suppose yx_true {
        conclude mset_of(node(y,node(x,xs'))) = m_one(y) ⨄ mset_of(node(x,xs'))
            by definition {mset_of,mset_of}.
      }
      case false suppose yx_false {
        equations
              mset_of(node(x,insert(xs',y))) 
            = m_one(x) ⨄ mset_of(insert(xs',y))
              by definition mset_of.
        ... = m_one(x) ⨄ (m_one(y) ⨄ mset_of(xs'))
              by rewrite IH[y].
        ... = (m_one(x) ⨄ m_one(y)) ⨄ mset_of(xs')
              by rewrite m_sum_assoc[Nat,m_one(x),m_one(y),mset_of(xs')].
        ... = (m_one(y) ⨄ m_one(x)) ⨄ mset_of(xs')
              by rewrite m_sum_sym[Nat, m_one(x), m_one(y)].
        ... = m_one(y) ⨄ (m_one(x) ⨄ mset_of(xs'))
              by rewrite m_sum_assoc[Nat,m_one(y),m_one(x),mset_of(xs')].
        ... = m_one(y) ⨄ mset_of(node(x,xs'))
              by definition mset_of.
      }
    }
  }
end

Our second task is to prove that insert produces a sorted list, assuming the input list is sorted.

theorem insert_sorted: all xs:List<Nat>. all y:Nat.
  if sorted(xs) then sorted(insert(xs, y))
proof
  ?
end

Because insert is a recursive function, we proceed by induction on its first argument xs.

  induction List<Nat>

The case for xs = empty is a straightforward use of definitions.

  case empty {
    arbitrary y:Nat
    suppose _
    conclude sorted(insert(empty,y))
        by definition {insert, sorted, sorted, all_elements}.
  }

Here’s the beginning of the case for xs = node(x, xs').

  case node(x, xs') suppose IH {
    arbitrary y:Nat
    suppose s_xxs: sorted(node(x,xs'))
    suffices sorted(insert(node(x,xs'),y))
    definition insert
    ?
  }

As we can see in the goal, insert branches on whether y ≤ x.

Goal:
    sorted(if y ≤ x then node(y,node(x,xs')) 
           else node(x,insert(xs',y)))

So our proof also branches on y ≤ x.

  switch y ≤ x {
    case true {
      ?
    }
    case false {
      ?
    }
  }

In the case when y ≤ x is true, the goal simplifies to sorted(node(y,node(x,xs'))). After applying the relevant definitions,

  definition {sorted, sorted, all_elements}

we need to prove

  sorted(xs') and
  all_elements(xs',λb{x ≤ b}) and
  y ≤ x and
  all_elements(xs',λb{y ≤ b})

The first two of these follows from the premise sorted(node(x,xs')).

  have s_xs: sorted(xs') by definition sorted in s_xxs
  have x_le_xs': all_elements(xs',λb{(x ≤ b)}) by definition sorted in s_xxs

The third is true in the current case.

  have y_le_x: y ≤ x by rewrite yx_true.

The fourth, which states that y is less-or-equal all the elements in xs' follows transitively from y ≤ x and the that x is less-or-equal all the elements in xs' (x_le_xs') using the theorem all_elements_implies (in List.pf):

theorem all_elements_implies: 
  all T:type. all xs:List<T>. all P: fn T -> bool, Q: fn T -> bool.
  if all_elements(xs,P) and (all z:T. if P(z) then Q(z)) 
  then all_elements(xs,Q)

To satisfy the second premise of all_elements_implies, we use y ≤ x to prove that if x is less than any other element, then so is y.

  have x_le_implies_y_le: all z:Nat. (if x ≤ z then y ≤ z)
    by arbitrary z:Nat  suppose x_le_z: x ≤ z
       conclude y ≤ z by apply less_equal_trans[y][x,z] to y_le_x , x_le_z

Now we apply all_elements_implies to prove all_elements(xs',λb{(y ≤ b)}) and then conclude this case for when y ≤ x.

  have y_le_xs': all_elements(xs',λb{(y ≤ b)})
    by apply all_elements_implies[Nat][xs']
             [λb{(x ≤ b)} : fn Nat->bool, λb{(y ≤ b)} : fn Nat->bool]
       to x_le_xs', x_le_implies_y_le
  s_xs, x_le_xs', y_le_x, y_le_xs'

Next we turn our attention to the case for when y ≤ x is false. After applying the definition of insert, Deduce tells us that we need to prove.

  sorted(insert(xs',y)) and
  all_elements(insert(xs',y), λb{x ≤ b})

The first follows from the induction hypothesis. (Though we need to move the proof of s_xs out of the y ≤ x case so that we can use it here.)

  have s_xs'_y: sorted(insert(xs',y)) by apply IH[y] to s_xs

The second requires more thinking. We know that x ≤ y in this case and we already proved that x is less-or-equal all the elements in xs'. So we know that

all_elements(node(y, xs'), λb{x ≤ b})

but what we need is

all_elements(insert(xs', y), λb{x ≤ b})

Here are the proofs of what we know so far.

  have x_le_y: x ≤ y
      by have not_yx: not (y ≤ x)  by suppose yx rewrite yx_false in yx
         have x_l_y: x < y   by apply or_not[y ≤ x, x < y] 
                                to dichotomy[y,x], not_yx
         apply less_implies_less_equal[x][y] to x_l_y
  have x_le_y_xs': all_elements(node(y, xs'),λb{(x ≤ b)})
         by definition all_elements  x_le_y, x_le_xs'

Now the all_elements function shouldn’t care about the ordering of the elements in the list, and indeed there is the following theorem in List.pf:

theorem all_elements_set_of:
  all T:type, xs:List<T>, ys:List<T>, P:fn T -> bool.
  if set_of(xs) = set_of(ys)
  then all_elements(xs, P) = all_elements(ys, P)

So we need to show that set_of(insert(xs',y)) = set_of(node(y,xs')). Thankfully, we already showed that this is true for mset_of in the insert_contents theorem, and multiset equality implies set equality: (also from List.pf)

theorem mset_equal_implies_set_equal: 
  all T:type, xs:List<T>, ys:List<T>.
  if mset_of(xs) = mset_of(ys)
  then set_of(xs) = set_of(ys)

So we use these three theorems to prove the following.

theorem all_elements_insert_node:
  all xs:List<Nat>, x:Nat, P:fn Nat->bool.
  all_elements(insert(xs,x), P) = all_elements(node(x,xs), P)
proof
  arbitrary xs:List<Nat>, x:Nat, P:fn Nat->bool
  have m_xs_x: mset_of(insert(xs, x)) = mset_of(node(x, xs))
      by definition mset_of insert_contents[xs][x]
  have ixsx_xxs: set_of(insert(xs, x)) = set_of(node(x, xs))
     by apply mset_equal_implies_set_equal[Nat,insert(xs, x), node(x, xs)] 
        to m_xs_x
  apply all_elements_set_of[Nat, insert(xs,x), node(x, xs), P]
  to ixsx_xxs
end

We apply this theorem to prove that x is less-or-equal all the elements in insert(xs',y) and then we conclude this final case of proof of insert_sorted.

  have x_le_xs'_y: all_elements(insert(xs',y), λb{x ≤ b})
      by rewrite all_elements_insert_node[xs',y,λb{x≤b}:fn Nat->bool]
         x_le_y_xs'
  conclude sorted(insert(xs',y)) and
           all_elements(insert(xs',y),λb{x ≤ b})
      by s_xs'_y, x_le_xs'_y

Here is the complete proof of insert_sorted.

theorem insert_sorted: all xs:List<Nat>. all y:Nat.
  if sorted(xs) then sorted(insert(xs, y))
proof
  induction List<Nat>
  case empty {
    arbitrary y:Nat
    suppose _
    conclude sorted(insert(empty,y))
        by definition {insert, sorted, sorted, all_elements}.
  }
  case node(x, xs') suppose IH {
    arbitrary y:Nat
    suppose s_xxs: sorted(node(x,xs'))
    have s_xs: sorted(xs') by definition sorted in s_xxs
    have x_le_xs': all_elements(xs',λb{(x ≤ b)}) 
        by definition sorted in s_xxs
    suffices sorted(insert(node(x,xs'),y))
    definition insert
    switch y ≤ x {
      case true suppose yx_true {
        suffices sorted(node(y,node(x,xs')))
        definition {sorted, sorted, all_elements}
        have y_le_x: y ≤ x by rewrite yx_true.
        have x_le_implies_y_le: all z:Nat. (if x ≤ z then y ≤ z)
          by arbitrary z:Nat  suppose x_le_z: x ≤ z
             conclude y ≤ z
               by apply less_equal_trans[y][x,z] to y_le_x , x_le_z
        have y_le_xs': all_elements(xs',λb{(y ≤ b)})
          by apply all_elements_implies[Nat][xs']
                  [λb{(x ≤ b)} : fn Nat->bool, λb{(y ≤ b)} : fn Nat->bool]
             to x_le_xs', x_le_implies_y_le
        s_xs, x_le_xs', y_le_x, y_le_xs'
      }
      case false suppose yx_false {
        definition sorted
        have s_xs'_y: sorted(insert(xs',y)) by apply IH[y] to s_xs
        have x_le_y: x ≤ y
            by have not_yx: not (y ≤ x)  by suppose yx rewrite yx_false in yx
               have x_l_y: x < y   by apply or_not[y ≤ x, x < y] 
                                      to dichotomy[y,x], not_yx
               apply less_implies_less_equal[x][y] to x_l_y
        have x_le_y_xs': all_elements(node(y, xs'),λb{(x ≤ b)})
               by definition all_elements  x_le_y, x_le_xs'
        have x_le_xs'_y: all_elements(insert(xs',y), λb{x ≤ b})
            by rewrite all_elements_insert_node[xs',y,λb{x≤b}:fn Nat->bool]
               x_le_y_xs'
        conclude sorted(insert(xs',y)) and
                 all_elements(insert(xs',y),λb{x ≤ b})
            by s_xs'_y, x_le_xs'_y
      }
    }
  }
end

Prove the correctness of insertion_sort

Referring back at the specification of insertion_sort(xs), we need to prove that (1) it outputs a list that contains the same elements as xs, and (2) the output is sorted.

As we did for insert, we use multisets and mset_of to express the requirement o the contents of the output of insertion_sort.

theorem insertion_sort_contents: all xs:List<Nat>.
  mset_of(insertion_sort(xs)) = mset_of(xs)

The insertion_sort(xs) function is recursive, so we proceed by induction on xs. In the case for xs = empty, we conclude the following using the definitions of insertion_sort and mset_of.

  case empty {
    conclude mset_of(insertion_sort(empty)) = mset_of(empty)
      by definition {insertion_sort, mset_of}.
  }

In the case for xs = node(x, xs'), after applying the definitions of insertion_sort and mset_of, we need to show that

mset_of(insert(insertion_sort(xs'),x)) = m_one(x) ⨄ mset_of(xs')

This follows from the insert_contents theorem and the induction hypothesis as follows.

  equations
          mset_of(insert(insertion_sort(xs'),x)) 
        = m_one(x) ⨄ mset_of(insertion_sort(xs'))
          by insert_contents[insertion_sort(xs')][x]
    ... = m_one(x) ⨄ mset_of(xs')
          by rewrite IH.

Here is the complete proof of insertion_sort_contents.

theorem insertion_sort_contents: all xs:List<Nat>.
  mset_of(insertion_sort(xs)) = mset_of(xs)
proof
  induction List<Nat>
  case empty {
    conclude mset_of(insertion_sort(empty)) = mset_of(empty)
      by definition {insertion_sort, mset_of}.
  }
  case node(x, xs') suppose IH {
    definition {insertion_sort, mset_of}
    equations
            mset_of(insert(insertion_sort(xs'),x)) 
          = m_one(x) ⨄ mset_of(insertion_sort(xs'))
            by insert_contents[insertion_sort(xs')][x]
      ... = m_one(x) ⨄ mset_of(xs')
            by rewrite IH.
  }
end

Finally, we prove that insertion_sort(xs) produces a sorted list. Of course the proof is by induction on xs. The case for empty follows from the relevant definitions. The case for node(x, xs') follows from the insert_sorted theorem and the induction hypothesis.

theorem insertion_sort_sorted: all xs:List<Nat>. 
  sorted( insertion_sort(xs) )
proof
  induction List<Nat>
  case empty {
    conclude sorted(insertion_sort(empty))
        by definition {insertion_sort, sorted}.
  }
  case node(x, xs') suppose IH: sorted( insertion_sort(xs') ) {
    definition {insertion_sort, sorted}
    conclude sorted(insert(insertion_sort(xs'),x))
        by apply insert_sorted[insertion_sort(xs')][x]
           to IH
  }
end

Exercise: tail-recursive variant of insertion_sort

The insertion_sort function uses more computer memory than necessary because it uses one frame on the procedure call stack for every element in the input list. This can be avoided if we instead implement Insertion Sort using a tail-recursive function. That is, as a function that immediately returns after the recursive call. For this exercise, formulate a tail recursive version of insertion_sort and prove that it is correct.

As a hint, define an auxiliary function isort(xs,ys) that takes a list xs and an already sorted list ys and returns a sorted list that includes the contents of both xs and ys.

function isort(List<Nat>, List<Nat>) -> List<Nat> {
  FILL IN HERE
}

Once you have defined isort, you can implement Insertion Sort as follows.

define insertion_sort2 : fn List<Nat> -> List<Nat>
    = λxs{ isort(xs, empty) }

Friday, June 14, 2024

Sequential Search, Correctly

Sequential Search

This is the second blog post in a series about developing correct implementations of basic data structures and algorithms using the Deduce language and proof checker.

In this blog post we’ll study a classic and simple algorithm known as Sequential Search (aka. Linear Search). The basic idea of the algorithm is to look for the location of a particular item within a linked list, traversing the list front to back. Here is the specification of this search function.

Specification: The search(xs, y) function returns a natural number i such that i ≤ length(xs). If i < length(xs), then i is the index of the first occurence of y in the list xs. If i = length(xs), y is not in the list xs.

We follow the write-test-prove approach to develop a correct implementation of search. We then propose two exercises for the reader.

Write the search function

Before diving into the code for search, let us look again at the definition of the List type.

union List<T> {
  empty
  node(T, List<T>)
}

We say that List is a recursive union because one of its constructors has a parameter that is also of the List type (e.g. the second parameter of the node constructor).

In general, when defining a function with a parameter that is a recursive union, first consider making that function a recursive function that pattern-matches on that parameter.

For example, with search, we choose for the List<Nat> to be the first parameter so that we can pattern-match on it as follows.

function search(List<Nat>, Nat) -> Nat {
  search(empty, y) = ?
  search(node(x, xs'), y) = ?
}

Let us consider the case for the empty list. Looking at the specification of search, we need to return 0, the length of the empty list, because y is not in the empty list.

function search(List<Nat>, Nat) -> Nat {
  search(empty, y) = 0
  search(node(x, xs'), y) = ?
}

In the case for node(x, xs'), we can check whether x = y. If so, then we should return 0 because y is at index 0 of node(x, xs') and that is certainly the first occurence of y in node(x, xs').

function search(List<Nat>, Nat) -> Nat {
  search(empty, y) = 0
  search(node(x, xs'), y) =
    if x = y then
      0
    else
      ?
}

If x ≠ y, then we need to search the rest of the list xs' for y. We can make the recursive call search(xs', y), but then we need to decide how to adapt its result to produce a result that makes sense for node(x, xs'). The only way to reason about the result of a recursive call is to use the specification of the function. The specification of search splits into two cases on the result: (1) search(xs', y) < length(xs') and (2) length(xs) ≤ search(xs', y). In case (1), search(xs',y) is returning the index of the first y inside xs'. Because x ≠ y, that location will also be the first y inside node(x, xs'). However, we need to add one to the index to take into account that we’re adding a node to the front of the list. So for case (1), the result should be suc(search(xs', y)). In case (2), search(xs',y) did not find y in xs', so it is returning length(xs'). Because x ≠ y, we need to indicate that y is also not found in node(x, xs'), so we need to return length(node(x, xs')). Thus, we need to add one to the index, so the result should again be suc(search(xs', y)).

Here is the completed code for search.

Test the search function

Focusing on the specification of search, there are several things that we should test. First, we should test whether search always returns a number that is less-or-equal to the length of the list. We can use all_elements and interval to automate the testing over a bunch of values, some of which are in the list and some are not.

define list_1223 = node(1, node(2, node(2, node(3, empty))))

assert all_elements(interval(0, 5),
  λx{ search(list_1223, x) ≤ length(list_1223) })

Most importantly, we should test whether search finds the correct index of the elements in the list. To do that we can make use of nth to lookup the element at a given index.

assert all_elements(list_1223,
  λx{ nth(list_1223, 0)( search(list_1223, x) ) = x })

Next, we should test whether search finds the first occurence. We can do this by iterating over all the indexes and checking that what search returns is an index that is less-than or equal to the current index.

assert all_elements(interval(0, length(list_1223)),
   λi{ search(list_1223, nth(list_1223, 0)(i)) ≤ i })

Finally, we check that search fails gracefully when the value being searched for is not present in the list.

assert search(list_1223, 0) = length(list_1223)
assert search(list_1223, 4) = length(list_1223)

Prove search Correct

We break down the specification of search into four parts and prove four theorems.

Prove search is less-or-equal length

The first part of the specification of search says that the search(xs, y) function returns a natural number i such that i ≤ length(xs). Because search is recursive, we’re going to prove this by induction on its first parameter xs.

theorem search_length: all xs:List<Nat>. all y:Nat.
  search(xs, y) ≤ length(xs)
proof
  induction List<Nat>
  case empty {
    ?
  }
  case node(x, xs') 
    suppose IH: all y:Nat. search(xs',y) ≤ length(xs') 
  {
    ?
  }
end

In the case for xs = empty, Deduce tells us that we need to prove

Goal:
    all y:Nat. search(empty,y) ≤ length(empty)

So we start with arbitrary y:Nat and then conclude using the definitions of search, length, and operator ≤.

  case empty {
    arbitrary y:Nat
    conclude search(empty,y) ≤ length(empty)
        by definition {search, length, operator ≤}.
  }

In the case for xs = node(x, xs'), Deduce tells us that we need to prove

Goal:
    all y:Nat. search(node(x,xs'),y) ≤ length(node(x,xs'))

So we start with arbitrary y:Nat and use the definitions of search and length.

  case node(x, xs') 
    suppose IH: all y:Nat. search(xs',y) ≤ length(xs') 
  {
    arbitrary y:Nat
    definition {search, length}
    ?
  }

The goal is transformed to the following, with the body of search expanded on the left of the and the body of length expanded on the right.

Goal:
    if x = y then 0 else suc(search(xs',y)) ≤ suc(length(xs'))

In general, it is a good idea to let the structure of the code direct the structure of your proof. In this case, the code for search is a conditional on x = y, so in our proof we can switch on x = y as follows.

  case node(x, xs') 
    suppose IH: all y:Nat. search(xs',y) ≤ length(xs') 
  {
    arbitrary y:Nat
    definition {search, length, operator ≤}
    switch x = y {
      case true {
        ?
      }
      case false {
        ?
      }
    }
  }

In the case for x = y, the left-hand side of the becomes 0, so we can conclude by the definition of operator ≤.

  case true {
    conclude 0 ≤ suc(length(xs'))  by definition operator ≤.
  }

In the case for x ≠ y, the left-hand side of the becomes suc(search(xs',y)), so we have suc on both side of . Therefore we apply the definition of and it remains to prove the following.

Goal:
    search(xs',y) ≤ length(xs')

We conclude the proof of the false case by using the induction hypothesis

  case false {
    definition operator ≤
    conclude search(xs',y) ≤ length(xs')
      by IH[y]
  }

Putting all of the pieces together, we have a complete proof of search_length.

theorem search_length: all xs:List<Nat>. all y:Nat.
  search(xs, y) ≤ length(xs)
proof
  induction List<Nat>
  case empty {
    arbitrary y:Nat
    conclude search(empty,y) ≤ length(empty)
        by definition {search, length, operator ≤}.
  }
  case node(x, xs') 
    suppose IH: all y:Nat. search(xs',y) ≤ length(xs') 
  {
    arbitrary y:Nat
    definition {search, length}
    switch x = y {
      case true {
        conclude 0 ≤ suc(length(xs'))  by definition operator ≤.
      }
      case false {
        definition operator ≤
        conclude search(xs',y) ≤ length(xs')
          by IH[y]
      }
    }
  }
end

Prove search(xs, y) finds an occurence of y

The specification of search(xs, y) says that if the result is less-than length(xs), then the result is the index of the first occurence of y in xs. First off, this means that search(xs, y) is indeed an index for y, which we can express using nth as follows.

nth(xs, 0)( search(xs, y) ) = y

So we can formulate the following theorem, which we’ll prove by induction on xs.

theorem search_present: all xs:List<Nat>. all y:Nat.
  if search(xs, y) < length(xs)
  then nth(xs, 0)( search(xs, y) ) = y
proof
  induction List<Nat>
  case empty {
    arbitrary y:Nat
    ?
  }
  case node(x, xs') suppose IH {
    arbitrary y:Nat
    ?
  }
end

In the case for xs = empty, Deduce tells us that we need to prove

Goal:
    (if search(empty,y) < length(empty) then nth(empty,0)(search(empty,y)) = y)

Proceeding in a goal-directed way, we suppose the premise but then realize that the premise is false. So we can conclude using the principle of explosion.

  case empty {
    arbitrary y:Nat
    suppose prem: search(empty,y) < length(empty)
    conclude false by definition {search, length, operator <, operator ≤} 
                      in prem
  }

Moving on to the case for xs = node(x, xs'), we again suppose the premise.

  case node(x, xs') suppose IH {
    arbitrary y:Nat
    suppose sxxs_len: search(node(x,xs'),y) < length(node(x,xs'))
    ?
  }

Deduce tells us that we need to prove

Goal:
    nth(node(x,xs'),0)(search(node(x,xs'),y)) = y

We see search applied to a node argument, so we can expand its definition. (We could also expand nth but we postpone doing that for the sake of readability.)

Goal:
    nth(node(x,xs'),0)((if x = y then 0 else suc(search(xs',y)))) = y

Similar to the proof of search_length, we now need to switch on x = y.

    definition {search}
    switch x = y {
      case true suppose xy_true {
        ?
      }
      case false suppose xy_false {
        ?
      }
    }

In the case where x = y, Deduce tells us that we need to prove

Goal:
    nth(node(x,xs'),0)(0) = y

We conclude using the definition of nth and the fact that x = y.

      case true suppose xy_true {
        conclude nth(node(x,xs'),0)(0) = y
          by definition nth rewrite xy_true.
      }

In the case where x ≠ y, we need to prove

Goal:
    nth(node(x,xs'),0)(suc(search(xs',y))) = y

Now if we apply the definitions of nth and pred, the goal becomes.

Goal:
    nth(xs',0)(search(xs',y)) = y

This looks a lot like the conclusion of our induction hypothesis:

Givens:
    ...
    IH: all y:Nat. (if search(xs',y) < length(xs') 
                    then nth(xs',0)(search(xs',y)) = y)

So we just need to prove the premise, that search(xs',y) < length(xs'). Thankfully, that can be proved from the premise search(node(x,xs'),y) < length(node(x,xs')).

  have sxs_len: search(xs',y) < length(xs')
    by enable {search, length, operator <, operator ≤}
       rewrite xy_false in sxxs_len

We conclude by applying the induction hypothesis.

  conclude nth(xs',0)(search(xs',y)) = y
    by apply IH[y] to sxs_len

Here is the the complete proof of search_present.

theorem search_present: all xs:List<Nat>. all y:Nat.
  if search(xs, y) < length(xs)
  then nth(xs, 0)( search(xs, y) ) = y
proof
  induction List<Nat>
  case empty {
    arbitrary y:Nat
    suppose prem: search(empty,y) < length(empty)
    conclude false by definition {search, length, operator <, operator ≤} 
                      in prem
  }
  case node(x, xs') suppose IH {
    arbitrary y:Nat
    suppose sxxs_len: search(node(x,xs'),y) < length(node(x,xs'))
    definition {search}
    switch x = y {
      case true suppose xy_true {
        conclude nth(node(x,xs'),0)(0) = y
          by definition nth rewrite xy_true.
      }
      case false suppose xy_false {
        definition {nth, pred}
        have sxs_len: search(xs',y) < length(xs')
          by enable {search, length, operator <, operator ≤}
             rewrite xy_false in sxxs_len
        conclude nth(xs',0)(search(xs',y)) = y
          by apply IH[y] to sxs_len
      }
    }
  }
end

Prove search(xs, y) finds the first occurence of y

Going back to the specification of search(xs, y), it says that if the result is less-than length(xs), then the result is the index of the first occurence of y in xs. To be the first means that the result is smaller than the index of any other occurence of y. We express that in the following theorem.

theorem search_first: all xs:List<Nat>. all y:Nat, i:Nat.
  if search(xs, y) < length(xs) and nth(xs, 0)(i) = y
  then search(xs, y) ≤ i

We proceed by induction on xs. We can handle the case for xs = empty in the same way as in search_present; the premise is false.

  induction List<Nat>
  case empty {
    arbitrary y:Nat, i:Nat
    suppose prem: search(empty,y) < length(empty) and nth(empty,0)(i) = y
    conclude false by definition {search, length, operator <, operator ≤} 
                      in prem
  }

In the case for xs = node(x, xs'), we proceed in a goal-directed fashion with an arbitrary and suppose.

  case node(x, xs') suppose IH {
    arbitrary y:Nat, i:Nat
    suppose prem: search(node(x,xs'),y) < length(node(x,xs')) 
                  and nth(node(x,xs'),0)(i) = y,
    ?
  }

Deduce response with

Goal:
    search(node(x,xs'),y) ≤ i

We apply the definition of search and then switch on x = y.

  definition search
  switch x = y {
    case true {
      ?
    }
    case false suppose xs_false {
      ?
    }
  }

In the case where x = y, the result of search is 0, so just need to prove that 0 ≤ i, which follows from the definition of .

  case true {
    conclude 0 ≤ i   by definition operator ≤.
  }

In the case where x ≠ y, we need to prove

Goal:
    suc(search(xs',y)) ≤ i

What do we now about i? The premise nth(node(x,xs'),0)(i) = y tells us that i ≠ 0. So we can switch on i and use the principle explosion to handle the case where i = 0.

  case 0 suppose i_z: i = 0 {
    conclude false
      by enable nth rewrite i_z | xy_false in prem
  }

We are left with the case where i = suc(i'). Using the definition of , the goal becomes

Goal:
    search(xs',y) ≤ i'

This looks like the conclusion of the induction hypothesis instantiated at i'.

Givens:
    ...
    IH: all y:Nat, i:Nat. (if search(xs',y) < length(xs') and nth(xs',0)(i) = y then search(xs',y) ≤ i)

So we need to prove the two premises of the IH. They follow from the given prem:

Givens:
    prem: search(node(x,xs'),y) < length(node(x,xs')) and nth(node(x,xs'),0)(i) = y,

In particular, the first premise of IH follows from the first conjunct of prem.

  have sxs_len: search(xs',y) < length(xs')
    by enable {search, length, operator <, operator ≤}
       rewrite xy_false in (conjunct 0 of prem)

The second premise of the IH follows from the second conjunct of prem.

  have nth_i_y: nth(xs',0)(i') = y
    by enable {nth, pred} rewrite i_si in (conjunct 1 of prem)

We conclude the case for i = suc(i') by applying the induction hypothesis.

  conclude search(xs',y) ≤ i'   by apply IH[y,i'] to sxs_len, nth_i_y

Here is the complete proof of search_first.

theorem search_first: all xs:List<Nat>. all y:Nat, i:Nat.
  if search(xs, y) < length(xs) and nth(xs, 0)(i) = y
  then search(xs, y) ≤ i
proof
  induction List<Nat>
  case empty {
    arbitrary y:Nat, i:Nat
    suppose prem: search(empty,y) < length(empty) and nth(empty,0)(i) = y
    conclude false by definition {search, length, operator <, operator ≤} 
                      in prem
  }
  case node(x, xs') suppose IH {
    arbitrary y:Nat, i:Nat
    suppose prem: search(node(x,xs'),y) < length(node(x,xs')) 
                  and nth(node(x,xs'),0)(i) = y
    definition search
    switch x = y {
      case true {
        conclude 0 ≤ i   by definition operator ≤.
      }
      case false suppose xy_false {
        switch i {
          case 0 suppose i_z: i = 0 {
            conclude false
              by enable nth rewrite i_z | xy_false in prem
          }
          case suc(i') suppose i_si: i = suc(i') {
            definition operator ≤
            have sxs_len: search(xs',y) < length(xs')
              by enable {search, length, operator <, operator ≤}
                 rewrite xy_false in (conjunct 0 of prem)
            have nth_i_y: nth(xs',0)(i') = y
              by enable {nth, pred} rewrite i_si in (conjunct 1 of prem)
            conclude search(xs',y) ≤ i'   by apply IH[y,i'] to sxs_len, nth_i_y
          }
        }
      }
    }
  }
end

Prove that search fails only when it should

The last sentence in the specification for search(xs, y) says that if i = length(xs), y is not in the list xs. How do we express that y is not in the list? In some sense, that is what search is for, but it would be vacuous to prove a theorem that says search returns length(xs) if search returns lengt(xs). Instead we need an alternative and intuitive way to express membership in a list.

One approach to expressing list membership that works well is to convert the list to a set and then use set membership. The file Set.pf defines the Set type, operations on sets such as memberhsip, union, and intersection. The Set.pf files also proves many theorems about these operations. The following set_of function converts a list into a set.

function set_of<T>(List<T>) -> Set<T> {
  set_of(empty) = ∅
  set_of(node(x, xs)) = single(x) ∪ set_of(xs)
}

We can now express our last correctness theorem for search as follows.

theorem search_absent: all xs:List<Nat>. all y:Nat, d:Nat.
  if search(xs, y) = length(xs)
  then not (y ∈ set_of(xs))

We proceed by induction on xs. In the case for xs = empty, we take the following goal-directed steps

  case empty {
    arbitrary y:Nat, d:Nat
    suppose _
    ?
  }

and Deduce responds with

Goal:
    (if y ∈ set_of(empty) then false)

which we prove using the empty_no_members theorem from Set.pf.

  conclude not (y ∈ set_of(empty))
      by definition {set_of} empty_no_members[Nat,y]

Turning to the case for xs = node(x, xs'), we take several goal-directed steps.

  case node(x, xs') suppose IH {
    arbitrary y:Nat, d:Nat
    suppose s_xxs_len_xxs: search(node(x,xs'),y) = length(node(x,xs'))
    definition set_of
    ?

We need to show that y is not in node(x, xs'), which amounts to the following.

Goal:
    (if y ∈ single(x) ∪ set_of(xs') then false)

Towards proving a contradiction, we can assume y ∈ single(x) ∪ set_of(xs').

  suppose y_in_x_union_xs: y ∈ single(x) ∪ set_of(xs')

Now the main information we have to work with is the premise s_xxs_len_xxs above, concerning search(node(x,xs'), y). Thinking about the code for search, we know it will branch on whether x = y, so we better switch on that.

  switch x = y {
    case true suppose xy {
      ?
    }
    case false suppose not_xy {
      ?
    }
  }

In the case where x = y, we have search(node(x,xs'),y) = 0 but length(node(x,xs')) is at least 1, so we have a contradition.

  case true suppose xy {
    have s_xxs_0: search(node(x,xs'),y) = 0
        by definition search  rewrite xy.
    have z_len_xxs: 0 = length(node(x,xs'))
        by rewrite s_xxs_0 in s_xxs_len_xxs
    conclude false  by definition length in z_len_xxs
  }

In the case where x ≠ y, we can show that y ∈ set_of(xs') and then invoke the induction hypothesis to obtain the contradition. In patricular, the premise y_in_x_union_xs gives us the following.

  have ysx_or_y_xs: y ∈ single(x) or y ∈ set_of(xs')
      by apply member_union[Nat,y,single(x),set_of(xs')]
         to y_in_x_union_xs

But x ≠ y implies not (y ∈ single(x)).

  have not_ysx: not (y ∈ single(x))
    by suppose ysx
       rewrite xy_false in
       apply single_equal[Nat,x,y] to ysx

So it must be that y ∈ set_of(xs') (using or_not from Base.pf).

  have y_xs: y ∈ set_of(xs')
    by apply or_not[y ∈ single(x), y ∈ set_of(xs')] 
       to ysx_or_y_xs, not_ysx

To satisfy the premise of the induction hypothesis, we prove the following.

  have sxs_lxs: search(xs',y) = length(xs')
    by injective suc
       rewrite xy_false in definition {search,length} in
       s_xxs_len_xxs

So we apply the induction hypothesis to get y ∉ set_of(xs'), which contradicts y ∈ set_of(xs).

  have y_not_xs: not (y ∈ set_of(xs'))
    by apply IH[y,d] to sxs_lxs
  conclude false  by apply y_not_xs to y_xs

Here is the complete proof of search_absent.

theorem search_absent: all xs:List<Nat>. all y:Nat, d:Nat.
  if search(xs, y) = length(xs)
  then not (y ∈ set_of(xs))
proof
  induction List<Nat>
  case empty {
    arbitrary y:Nat, d:Nat
    suppose _
    conclude not (y ∈ set_of(empty))
        by definition {set_of} empty_no_members[Nat,y]
  }
  case node(x, xs') suppose IH {
    arbitrary y:Nat, d:Nat
    suppose s_xxs_len_xxs: search(node(x,xs'),y) = length(node(x,xs'))
    definition set_of
    suppose y_in_x_union_xs: y ∈ single(x) ∪ set_of(xs')
    switch x = y {
      case true suppose xy {
        have s_xxs_0: search(node(x,xs'),y) = 0
            by definition search  rewrite xy.
        have z_len_xxs: 0 = length(node(x,xs'))
            by rewrite s_xxs_0 in s_xxs_len_xxs
        conclude false  by definition length in z_len_xxs
      }
      case false suppose xy_false {
        have ysx_or_y_xs: y ∈ single(x) or y ∈ set_of(xs')
            by apply member_union[Nat,y,single(x),set_of(xs')]
               to y_in_x_union_xs
        have not_ysx: not (y ∈ single(x))
          by suppose ysx
             rewrite xy_false in
             apply single_equal[Nat,x,y] to ysx
        have y_xs: y ∈ set_of(xs')
          by apply or_not[y ∈ single(x), y ∈ set_of(xs')] 
             to ysx_or_y_xs, not_ysx
        have sxs_lxs: search(xs',y) = length(xs')
          by injective suc
             rewrite xy_false in definition {search,length} in
             s_xxs_len_xxs
        have y_not_xs: not (y ∈ set_of(xs'))
          by apply IH[y,d] to sxs_lxs
        conclude false  by apply y_not_xs to y_xs
      }
    }
  }
end

Exercise search_last

Apply the write-test-prove approach to develop a correct implementation of the search_last(xs, y) function, which is like search(xs, y) except that it finds the last occurence of y in xs instead of the first.

In particular, you need to

  • write a specification for search_last,
  • write the code for search_last,
  • test search_last on diverse inputs, and
  • prove that search_last is correct.
function search_last(List<Nat>, Nat) -> Nat {
    FILL IN HERE
}

Exercise search_if

The search_if(xs, P) function is a generalization of search(xs, y). Instead of searching for the first occurence of element y, the search_if function searches for the location of the first element that satisfied predicate P (i.e. an element y in xs such that P(y) is true). Apply the write-test-prove approach to develop a correct implementation of search_if.

In particular, you need to

  • write a specification for search_if,
  • write the code for search_if,
  • test search_if on diverse inputs, and
  • prove that search_if is correct.
function search_if<T>(List<T>, fn T->bool) -> Nat {
    FILL IN HERE
}

Wednesday, June 12, 2024

Data Structures and Algorithms, Correctly

Prelude

This is the first in what I hope to be a sequence of blog posts about (1) data structures and algorithms, (2) an approach to constructing correct code, and (3) achieving a deeper understanding of testing, logic, and proof, all of which are needed for constructing correct code. These blog posts take a functional programming approach to data structures and algorithms because, in that setting, there are software tools that make sure that our proofs about correctness are themselves correct! In particular, these posts will use the Deduce language for writing programs, testing them, and proving theorems. Unlike most functional languages and proof assistants, the syntax of Deduce is meant to be easy to learn for people familiar with languages such as Java or Python. The README.md file in the Deduce github repository provides an introduction to Deduce. We recommend reading that first.

https://github.com/jsiek/deduce/tree/main

These blog posts will cover a limited number of the data structures and algorithms, as the pace will be slower due to the increased focus on correctness. The rough plan is to cover the following topics.

  • Linked Lists (this post)
  • Sequential Search
  • Insertion Sort
  • Merge Sort
  • Binary Trees
  • Binary Search Trees
  • Balanced Binary Search Trees
  • Heaps and Priority Queues

Introduction to Linked Lists

A linked list is a data structure that represents a sequence of elements. Each element is stored inside a node and each node also stores a link to the next node, or to the special empty value that signifies the end of the list. In Deduce we can implement linked lists with the following union type.

union List<T> {
  empty
  node(T, List<T>)
}

For example, the sequence of numbers 1, 2, 3 is represented by the following linked list.

define list_123 : List<Nat> = node(1, node(2, node(3, empty)))

Next we introduce two fundamental operations on linked lists. The first operation is length, which returns the number of elements in a given list. The length of an empty list is 0 and the length of a list that starts with a node is one more than the length of the list starting at the next node.

function length<E>(List<E>) -> Nat {
  length(empty) = 0
  length(node(n, next)) = suc(length(next))
}

Of course, the length of list_123 is 3. We can ask Deduce to check this fact using the assert statement.

assert length(list_123) = 3

The return type of length is Nat which stands for natural number (that is, the non-negative integers). The suc(n) constructor represents 1 + n and is short for successor. The pred(n) function is short for predecessor and computes n - 1, except that pred(0) = 0.

import Nat

The second fundamental operation on linked lists is nth(xs,d)(i), which retrieves the element at position i in the list xs. However, if i is greater or equal to the length of xs, then nth returns the default value d.

function nth<T>(List<T>, T) -> (fn Nat -> T) {
  nth(empty, default) = λi{default}
  nth(node(x, xs), default) = λi{
    if i = 0 then
      x
    else
      nth(xs, default)(pred(i))
  }
}

Here are examples of applying nth to the list 1, 2, 3, using 0 as the default value.

assert nth(list_123, 0)(0) = 1
assert nth(list_123, 0)(1) = 2
assert nth(list_123, 0)(2) = 3
assert nth(list_123, 0)(3) = 0

We have formulated the nth operation in an unusual way. It has two parameters and returns a function of one parameter that returns an element T. We could have instead made nth take three parameters and directly return an element T. We made this design choice because it means we can use nth with several other functions and theorems that work with functions of the type fn Nat -> T.

Correct Software via Write, Test, and Prove

We recommend a three step process to constructing correct software.

  1. Write down the specification and the code for a subcomponent, such as a function,
  2. Test the function on a diverse choice of inputs. If all the tests pass, proceed to step 3, otherwise return to step 1.
  3. Prove that the function is correct with respect to its specification.

We recognize that once step 3 is complete, step 2 is obsolete because a proof of correctness supersedes any amount of testing. However there is a good reason to perform testing even when you are planning to do a proof of correctness. More often than not, your code will have one or more bugs. Testing is a fast way to detect most of the bugs. When you detect a bug, you’ll need to revise the code and then re-run the tests. On the other hand, proving correctness is a much slower way to detect bugs. You will spend a relatively long time to get part-way through a proof and realize that there is no way to finish. You’ll then need to revise the code. But because of the changes in the code, much of the proof will need to change. So you’ll spend a significant amount of time refactoring the parts of the proof that you’ve already completed before continuing on to the new parts. Therefore, to reduce the number of relatively-costly proof attempts, it is a good idea to first spend a relatively short amount of time to test and fix the code.

Example: Intervals

As an example of the write-test-prove approach, we consider the interval function.

Specification: interval(count, start) returns a list of natural numbers of length count, where the element at position i is i + start.

For example, interval(3,5) produces the list 5, 6, 7:

assert interval(3, 5) = node(5, node(6, node(7, empty)))

Write interval

A straightforward way to implement interval in Deduce is to define it as a function that pattern-matches on the count.

function interval(Nat, Nat) -> List<Nat> {
  interval(0, n) = ?
  interval(suc(k), n) = ?
}

For the clause where count = 0, we must return a list of length 0. So our only choice is the empty list.

  interval(0, n) = empty

For the clause where count = suc(k), we must return a list of length suc(k). So it has at least one node.

  interval(suc(k), n) = node(?, ?)

The specification tells us that the element at position 0 of the return value is n + 0 or simply n.

  interval(suc(k), n) = node(n, ?)

The next of this node should be a list of length k that starts with the element n + 1. Thankfully we can construct such a list with a recursive call to interval.

  interval(suc(k), n) = node(n, interval(k, suc(n)))

Putting these pieces together, we have the following complete definition of interval.

function interval(Nat, Nat) -> List<Nat> {
  interval(0, n) = empty
  interval(suc(k), n) = node(n, interval(k, suc(n)))
}

Test interval

Let us test that our definition of interval is behaving the way we expect it to. In general, one should test many variations of each input to a function. Here we test with the values 0, 1 and 2 for the first parameter and 0 and 3 for the second parameter.

assert length(interval(0, 0)) = 0

assert length(interval(1, 0)) = 1
assert nth(interval(1, 0), 7)(0) = 0 + 0

assert length(interval(2, 0)) = 2
assert nth(interval(2, 0), 7)(0) = 0 + 0
assert nth(interval(2, 0), 7)(1) = 1 + 0

assert length(interval(0, 3)) = 0

assert length(interval(1, 3)) = 1
assert nth(interval(1, 3), 7)(0) = 0 + 3

assert length(interval(2, 3)) = 2
assert nth(interval(2, 3), 7)(0) = 0 + 3
assert nth(interval(2, 3), 7)(1) = 1 + 3

Yeah! All of these assert statements execute without error.

We have formulated these assert statements in a subtly different way than above. When we tested the length and nth functions, we wrote assert statements that compared the results to our expected output. Here we have instead written the assert statements based on the specification of interval(count, start). The specification says that the length of the output should be the same as the count parameter. So in the above we wrote assert statements that check whether the length is the same as the count. Furthermore, the specification says that the element at position i of the output is i + start. So we have used the nth function to check, for every position i in the output list, whether the element is i + start.

The benefit of writing tests based on the specification is that it reduces the possibility of discrepancies between the specification and the tests. After all, what it means for a function to be correct is that it behaves according to its specification, not that it passes some ad-hoc tests based on a loose interpretation of the specification.

In general, when a test fails, it often means that either the implementation of the function-under-test is incorrect, or the test itself is incorrect. A careful reading of the function’s specification will help you figure out which is at fault. Unfortunately, it is also possible for the specification to be incorrect! The good thing about the testing approach described here is that it helps to reveal inconsistencies between the specification, the tests, and the implementation.

Prove interval Correct

Once we have finished testing interval we can move on to proving that interval is correct for all inputs. Looking back at the specification of interval, there are two parts. We will prove each part with a separate theorem.

Prove the interval_length theorem

The first part of the specification says that interval(count, start) returns a list of length count. We want to prove that this is true for all possible choices of count and start, so we shall use Deduce’s all formula. Recall that there are two ways to prove an all formula in Deduce: 1) using arbitrary or 2) using induction. When proving a theorem about a recursive function, one typically needs to use induction for the first parameter of the function, in the case count. So our initial plan is to use induction for count and arbitrary for start. Because we are going to use different proof methods for each variable, we need to use a separate all formula for each one, as follows.

theorem interval_length:
  all count:Nat. all start:Nat. length(interval(count, start)) = count
proof
  ?
end

There is also the question of whether all count:Nat should come before or after all start:Nat. It is always safe to first choose the variable for which you’re using induction. If you make the other choice, the induction hypothesis will be weaker, which sometimes is convenient but other times prevents the proof from going through.

Now let us start the proof. We proceed by induction on the count.

theorem interval_length:
  all count:Nat. all start:Nat. length(interval(count, start)) = count
proof
  induction Nat
  case 0 {
    ?
  }
  case suc(count') suppose IH {
    ?
  }
end

In the case for count = 0, Deduce tells us that we need to prove

  all start:Nat. length(interval(0,start)) = 0

As mentioned earlier, we’ll use arbitrary for start.

  case 0 {
    arbitrary start:Nat
    ?
  }

So now we need to prove

  length(interval(0,start)) = 0

Of course, by definition we have interval(0,start) = empty and length(empty) = 0, so we can conclude using those definitions.

  case 0 {
    arbitrary start:Nat
    conclude length(interval(0, start)) = 0
        by definition {interval, length}.
  }

Turning to the case count = suc(count'), Deduce tells us the goal for this case and the induction hypothesis.

incomplete proof
Goal:
    all start:Nat. length(interval(suc(count'),start)) = suc(count')
Givens:
    IH: all start:Nat. length(interval(count',start)) = count'

To improve readability of the proof, I often like to copy the formula for the induction hypothesis and paste it into the suppose as shown below.

  case suc(count') 
    suppose IH: all start:Nat. length(interval(count', start)) = count' 
  {
    ?
  }

For the proof of this case, we again start with arbitrary to handle all start then use the definitions of interval and length.

  case suc(count')
    suppose IH: all start:Nat. length(interval(count', start)) = count'
  {
    arbitrary start:Nat
    definition {interval, length}
    ?
  }

Deduce tells us that we need to prove the following.

suc(length(interval(count',suc(start)))) = suc(count')

Here the induction hypothesis IH comes to the rescue. If we instantiate the all start with suc(start), we get

length(interval(count',suc(start))) = count'

which is just what we need to conclude.

  case suc(count') 
    suppose IH: all start:Nat. length(interval(count', start)) = count' 
  {
    arbitrary start:Nat
    definition {interval, length}
    conclude suc(length(interval(count',suc(start)))) = suc(count')
        by rewrite IH[suc(start)].
  }

Putting the two cases together, we have the following completed proof that the output of interval has the appropriate length.

theorem interval_length:
  all count:Nat. all start:Nat. length(interval(count, start)) = count
proof
  induction Nat
  case 0 {
    arbitrary start:Nat
    conclude length(interval(0, start)) = 0
        by definition {interval, length}.
  }
  case suc(count')
    suppose IH: all start:Nat. length(interval(count', start)) = count' 
  {
    arbitrary start:Nat
    definition {interval, length}
    conclude suc(length(interval(count',suc(start)))) = suc(count')
        by rewrite IH[suc(start)].
  }
end

Prove the interval_nth theorem

The second part of the specification of interval says that the element at position i of the output is i + start. Of course, there is no element at position i if i is too big, so our theorem needs to be conditional, with the premise i < count.

theorem interval_nth: all count:Nat. all start:Nat, d:Nat, i:Nat.
  if i < count
  then nth(interval(count, start), d)(i) = i + start
proof
   ?
end

Because this proof is about a recursive function whose first parameter is of type Nat, we proceed by induction on Nat.

  induction Nat
  case 0 {
    ?
  }
  case suc(count') suppose IH {
    ?
  }

In the case count = 0, Deduce tells us that we need to prove

all start:Nat, d:Nat, i:Nat.
    if i < 0 then nth(interval(0,start),d)(i) = i + start

So we can start the proof of this case with arbitrary and suppose, then use the definitions of interval and nth.

  case 0 {
    arbitrary start:Nat, d:Nat, i:Nat
    suppose i_l_z: i < 0
    definition {interval, nth}
    ?
  }

Deduce responds with

Goal:
    d = i + start
Givens:
    i_l_z: i < 0

Now we are in a strange situation. The goal seems rather difficult to prove because we don’t know anything about start and d. The givens (aka. assumptions) are also strange. How can the natural number i be less than 0? Of course it cannot. Thus, i < 0 implies false and then we can use the principle of explosion, which states that false implies anything, to prove that d = i + start.

  case 0 {
    arbitrary start:Nat, d:Nat, i:Nat
    suppose i_l_z: i < 0
    definition {interval, nth}
    conclude false  by definition {operator <, operator ≤} in i_l_z
  }

Next we turn to the case for count = suc(count'). Deduce tells us the formula for the induction hypothesis, so we paste that into the suppose IH. Looking at the goal formula, we begin the proof with arbitrary, suppose, and use the definitions of interval and nth.

  case suc(count') 
    suppose IH: all start:Nat, d:Nat, i:Nat. 
        if i < count' then nth(interval(count',start),d)(i) = i + start
  {
    arbitrary start:Nat, d:Nat, i:Nat
    suppose i_l_sc: i < suc(count')
    definition {interval, nth}
    ?
  }

Deduce responds with the following.

Goal:
    if i = 0 then start
    else nth(interval(count',suc(start)),d)(pred(i)) = i + start

What we’re seeing here is that the nth function uses an if-then-else with i = 0 as the condition. So to reason about this goal, we need to break our proof down into two cases, when i = 0 and i ≠ 0. One convenient way to do that in Deduce is with a switch.

  switch i {
    case 0 {
      ?
    }
    case suc(i') suppose i_sc: i = suc(i') {
      ?
    }
  }

Let us proceed with the case for i = 0. Deduce simplifies the goal and responds with

Goal:
    start = 0 + start

which follows directly from the definition of addition.

  case 0 {
    conclude start = 0 + start   by definition operator +.
  }

In the case for i = suc(i'), Deduce tells us that we need to prove

  nth(interval(count',suc(start)),d)(pred(suc(i'))) = suc(i') + start

This looks quite similar to the induction hypothesis instantiated with suc(start), d, and i':

  if i' < count' 
  then nth(interval(count',suc(start)),d)(i') = i' + suc(start)

One difference is pred(suc(i')) versus i', but they are equal by the definition of pred.

  case suc(i') suppose i_sc: i = suc(i') {
    definition pred
    ?
  }

Deduce responds with

Goal:
    nth(interval(count',suc(start)),d)(i') = suc(i') + start

So if we use the induction hypothesis, then we will just need to prove that i' + suc(start) = suc(i') + start, which is certainly true and will just require a little reasoning about addition. But to use the induction hypothesis, we need to prove that i' < count'. This follows from the givens i_l_sc: i < suc(count') and i_sc: i = suc(i') and the definitions of < and .

  case suc(i') suppose i_sc: i = suc(i') {
    definition pred
    have i_l_cnt: i' < count'  by enable {operator <, operator ≤}
                                  rewrite i_sc in i_l_sc
    ?
  }

Now we can complete the proof of this case by linking together a few equations, starting with the induction hypothesis, then using the add_suc theorem from Nat.pf (which states that m + suc(n) = suc(m + n)), and finally using the definition of addition (which states that suc(n) + m = suc(n + m)).

  equations
    nth(interval(count',suc(start)),d)(i') 
        = i' + suc(start)        by apply IH[suc(start), d, i'] to i_l_cnt
    ... = suc(i' + start)        by add_suc[i'][start]
    ... = suc(i') + start        by definition operator +.

Putting together all these pieces, we have the following complete proof of the interval_nth theorem. At this point we know that the interval function is 100% correct!

theorem interval_nth: all count:Nat. all start:Nat, d:Nat, i:Nat.
  if i < count
  then nth(interval(count, start), d)(i) = i + start
proof
  induction Nat
  case 0 {
    arbitrary start:Nat, d:Nat, i:Nat
    suppose i_l_z: i < 0
    definition {interval, nth}
    conclude false  by definition {operator <, operator ≤} in i_l_z
  }
  case suc(count') 
    suppose IH: all start:Nat, d:Nat, i:Nat. 
        if i < count' then nth(interval(count',start),d)(i) = i + start
  {
    arbitrary start:Nat, d:Nat, i:Nat
    suppose i_l_sc: i < suc(count')
    definition {interval, nth}
    switch i {
      case 0 {
        conclude start = 0 + start   by definition operator +.
      }
      case suc(i') suppose i_sc: i = suc(i') {
        definition pred
        have i_l_cnt: i' < count'  by enable {operator <, operator ≤}
                                      rewrite i_sc in i_l_sc
        equations
          nth(interval(count',suc(start)),d)(i') 
              = i' + suc(start)    by apply IH[suc(start), d, i'] to i_l_cnt
          ... = suc(i' + start)    by add_suc[i'][start]
          ... = suc(i') + start    by definition operator +.
      }
    }
  }
end

Exercise: Define Append

Create a function named append that satisfies the following specification.

Specification append combines two lists into a single list. The elements of the output list must be ordered in a way that 1) the elements from the first input list come before the elements of the second list, and 2) the ordering of elements must preserve the internal ordering of each input.

function append<E>(List<E>, List<E>) -> List<E> {
  FILL IN HERE
}

Exercise: Test Append

Write assert statements to test the append function that you have defined. Formulate the assertions to closely match the above specification of above. Refer to the assertions that we wrote above to test interval to see an example of how to write the tests.

More Automation in Tests

An added benefit of formulating the assertions based on the specification is that it enables us to automate our testing. In the following code we append the list 1, 2, 3 with 4, 5 and then check the resulting list using only two assert statements. The first assert checks whether the front part of the result matches the first input list and the second assert checks whether the back part of the result matches the second input list. We make use of another function named all_elements that we describe next.

define list_45 : List<Nat> = node(4, node(5, empty))
define list_1_5 = append(list_123, list_45)
assert all_elements(interval(3, 0),
                    λi{ nth(list_1_5, 0)(i) = nth(list_123,0)(i) })
assert all_elements(interval(2, 0),
                    λi{ nth(list_1_5, 0)(3 + i) = nth(list_45,0)(i) })

The all_elements function takes a list and a function and checks whether applying the function to every element of the list always produces true.

function all_elements<T>(List<T>, fn (T) -> bool) -> bool {
  all_elements(empty, P) = true
  all_elements(node(x, xs'), P) = P(x) and all_elements(xs', P)
}

Going a step further, we can adapt the tests to apply to longer lists by automating the creation of the input lists. Here we increase the combined size to 20 elements. We could go with longer lists, but Deduce currently has a slow interpreter, so the assertions would take a long time (e.g., a minute for 100 elements).

define num_elts = 20
define first_elts = 12
define second_elts = 8
define first_list = interval(first_elts,1)
define second_list = interval(second_elts, first_elts + 1)
define output_list = append(first_list, second_list)
assert all_elements(interval(first_elts, 0), 
          λi{ nth(output_list, 0)(i) = nth(first_list,0)(i) })
assert all_elements(interval(second_elts, 0),
          λi{ nth(output_list, 0)(first_elts + i) = nth(second_list,0)(i) })

Exercise: Prove that Append is Correct

Prove that append satisfies its specification on all possible inputs. First, we need to translate the specification into a Deduce formula. We can do this by generalizing the above assertions. Instead of using specific lists and specific indices, we use all formulas to talk about all possible lists and indices. Also, for convenience, we split up correctness into two theorems, one about the first input list xs and the other about the second input list ys. We recommend that your proofs use induction on List<T>.

theorem nth_append_front:
  all T:type. all xs:List<T>. all ys:List<T>, i:Nat, d:T.
  if i < length(xs)
  then nth(append(xs, ys), d)(i) = nth(xs, d)(i)
proof
  FILL IN HERE
end

theorem nth_append_back: 
  all T:type. all xs:List<T>. all ys:List<T>, i:Nat, d:T.
  nth(append(xs, ys), d)(length(xs) + i) = nth(ys, d)(i)
proof
  FILL IN HERE
end