Monday, June 17, 2024

Insertion Sort, Correctly

Insertion Sort

This is the third blog post in a series about developing correct implementations of basic data structures and algorithms using the Deduce language and proof checker.

In this blog post we study a simple but slow sorting algorithm, Insertion Sort. (We will study the faster Merge Sort in the next blog post.) Insertion Sort is, roughly speaking, how many people sort the cards in their hand when playing a card game. The basic idea is to take one card at a time and place it into the correct location amongst the cards that you’ve already sorted.

Specification: The insertion_sort(xs) function returns a list that contains the same elements as xs but the elements in the result are in sorted order.

Of course, to make this specification precise, we need to define "sorted". There are several ways to go with this formal definition. Here is one that works well for me. It requires each element in the list to be less-or-equal to all the elements that come after it.

function sorted(List<Nat>) -> bool {
  sorted(empty) = true
  sorted(node(x, xs)) =
    sorted(xs) and all_elements(xs, λy{ x ≤ y })
}

We follow the write-test-prove approach to develop a correct implementation of insertion_sort. We then propose an exercise for the reader.

Write the insertion_sort function

Because insertion_sort operates on the recursive type List, we’ll try to implement insertion_sort as a recursive function.

function insertion_sort(List<Nat>) -> List<Nat> {
  insertion_sort(empty) = ?
  insertion_sort(node(x, xs')) = ?
}

In the case for the empty list, we need to return a list with the same contents, so we better return empty.

function insertion_sort(List<Nat>) -> List<Nat> {
  insertion_sort(empty) = empty
  insertion_sort(node(x, xs')) = ?
}

In the case for node(x, xs'), we can make the recursive call insertion_sort(xs') to sort the rest of the list.

function insertion_sort(List<Nat>) -> List<Nat> {
  insertion_sort(empty) = empty
  insertion_sort(node(x, xs')) = ... insertion_sort(xs') ...
}

But what do we do with the element x? This is where we need to define an auxiliary function that inserts x into the appropriate location within the result of sorting the rest of the list. We’ll choose the name insert for this auxiliary function. Here is the completed code for insertion_sort.

function insertion_sort(List<Nat>) -> List<Nat> {
  insertion_sort(empty) = empty
  insertion_sort(node(x, xs')) = insert(insertion_sort(xs'), x)
}

Of course, we’ll follow the write-test-prove approach to develop the insert function. The first thing we need to do is write down the specification. The specification of insert will play an important role in the proof of correctness of insertion_sort, because we’ll use the correctness theorems about insert in the proof of the correctness theorems about insertion_sort. With this in mind, here’s a specification for insert.

Specification: The insert(xs, y) function takes a sorted list xs and value y as input and returns a sorted list that contains y and the elements of xs.

Next we write the code for insert. This function also has a List as input, so we define yet another recursive function.

function insert(List<Nat>,Nat) -> List<Nat> {
  insert(empty, y) = ?
  insert(node(x, xs), y) = ?
}

For the case empty we must return a list that contains y, so it must be node(y, empty)

function insert(List<Nat>,Nat) -> List<Nat> {
  insert(empty, y) = node(y, empty)
  insert(node(x, xs), y) = ?
}

In the case for node(x, xs'), we need to check whether y should come before x. So we test y ≤ x and if that’s the case, we return node(y, node(x, xs')). Otherwise, y belongs somewhere later in the sequence, so we make the recursive call and return node(x, insert(xs', y)).

function insert(List<Nat>,Nat) -> List<Nat> {
  insert(empty, y) = node(y, empty)
  insert(node(x, xs'), y) =
    if y ≤ x then
      node(y, node(x, xs'))
    else
      node(x, insert(xs', y))
}

Test

This time we have 2 functions to test, insert and insertion_sort. We start with insert because if there are bugs in insert, then it will be confusing to find out about them when testing insertion_sort.

Test insert

Looking at the specification for insert, we need to check whether the resulting list is sorted and we need to check that the resulting list contains the elements from the input and the inserted element. We could use the search function that we developed in the previous blog post to check whether the elements from the input list are in the output list. However, doing that would ignore a subtle issue, which is that there can be one or more occurrences of the same element in the input list, and the output list should have the same number of occurrences. To take this into account, we need a new function to count the number of occurrences of an element.

function count<T>(List<T>) -> fn T->Nat {
  count(empty) = λy{ 0 }
  count(node(x, xs')) = λy{
    if y = x then 
      suc(count(xs')(y))
    else
      count(xs')(y) 
  }
}

Here’s a test that checks whether insert produces a sorted list with the correct count for every element on the input list as well as the inserted element.

define list_1234 = node(1, node(2, node(3, node(4, empty))))
define list_12334 = insert(list_1234, 3)
assert sorted(list_12334)
assert all_elements(node(3, list_1234), λx{
  if x = 3 then
    count(list_12334)(x) = suc(count(list_1234)(x))
  else
    count(list_12334)(x) = count(list_1234)(x)
  })

It’s a good idea to test corner cases, that is, inputs that trigger different code paths through the insert function. As there is a case for the empty list in the code, that’s a good test case to consider.

define list_3 = insert(empty, 3)
assert sorted(list_3)
assert all_elements(node(3, empty), λx{
  if x = 3 then
    count(list_3)(x) = suc(count(empty : List<Nat>)(x))
  else
    count(list_3)(x) = count(empty : List<Nat>)(x)
  })

Ideally we would also test with hundreds of randomly-generated lists. Adding support for random number generation is high on the TODO list for Deduce.

Test insertion_sort

If we refer back to the specification of insertion_sort, we need to check that the output list is sorted and that it contains the same elements as the input list.

define list_8373 = node(8, node(3, node(7, node(3, empty))))
define list_3378 = insertion_sort(list_8373)
assert sorted( list_3378 )
assert all_elements(list_8373, λx{count(list_3378)(x) = count(list_8373)(x)})

Prove

The next step in the process is to prove the correctness of insert and insertion_sort with respect to their specification.

Prove correctness of insert

Our first task is to prove that insert(xs, y) produces a list that contains y and the elements of xs. In our tests, we used the count function to accomplish this. Note that count returns a function fn T->Nat, which is the same thing as a multiset. The file MultiSet.pf defines the MultiSet<T> type and operations on them such as m_one(x) for creating a multiset that only contains x and the operator A ⨄ B for combining two multisets. The file List.pf defines a function mset_of that converts a list into a multiset.

function mset_of<T>(List<T>) -> MultiSet<T> {
  mset_of(empty) = m_fun(λ{0})
  mset_of(node(x, xs)) = m_one(x) ⨄ mset_of(xs)
}

So we express the requirements on the contents of insert(xs, y) as follows: converting insert(xs, y) into a multiset is the same as converting xs into a multiset and then adding y. The proof is relatively straightforward, making use of several theorems about multisets from MultiSet.pf.

theorem insert_contents: all xs:List<Nat>. all y:Nat.
  mset_of(insert(xs,y)) = m_one(y) ⨄ mset_of(xs)
proof
  induction List<Nat>
  case empty {
    arbitrary y:Nat
    conclude mset_of(insert(empty,y)) = m_one(y) ⨄ mset_of(empty)
        by definition {insert, mset_of, mset_of}
  }
  case node(x, xs') suppose IH {
    arbitrary y:Nat
    switch y ≤ x for insert {
      case true suppose yx_true {
        conclude mset_of(node(y,node(x,xs'))) = m_one(y) ⨄ mset_of(node(x,xs'))
            by definition {mset_of, mset_of}
      }
      case false suppose yx_false {
        equations
              mset_of(node(x,insert(xs',y))) 
            = m_one(x) ⨄ mset_of(insert(xs',y))
              by definition mset_of
        ... = m_one(x) ⨄ (m_one(y) ⨄ mset_of(xs'))
              by rewrite IH[y]
        ... = (m_one(x) ⨄ m_one(y)) ⨄ mset_of(xs')
              by rewrite m_sum_assoc<Nat>[m_one(x),m_one(y),mset_of(xs')]
        ... = (m_one(y) ⨄ m_one(x)) ⨄ mset_of(xs')
              by rewrite m_sum_commutes<Nat>[ m_one(x), m_one(y)]
        ... = m_one(y) ⨄ (m_one(x) ⨄ mset_of(xs'))
              by rewrite m_sum_assoc<Nat>[m_one(y),m_one(x),mset_of(xs')]
        ... = m_one(y) ⨄ mset_of(node(x,xs'))
              by definition mset_of
      }
    }
  }
end

Our second task is to prove that insert produces a sorted list, assuming the input list is sorted.

theorem insert_sorted: all xs:List<Nat>. all y:Nat.
  if sorted(xs) then sorted(insert(xs, y))
proof
  ?
end

Because insert is a recursive function, we proceed by induction on its first argument xs.

  induction List<Nat>

The case for xs = empty is a straightforward use of definitions.

    // <<insert_sorted_case_empty>> =
    arbitrary y:Nat
    suppose _
    conclude sorted(insert(empty,y))
        by definition {insert, sorted, sorted, all_elements}

Here’s the beginning of the case for xs = node(x, xs').

  case node(x, xs') suppose IH {
    arbitrary y:Nat
    suppose s_xxs: sorted(node(x,xs'))
    suffices sorted(insert(node(x,xs'),y))  by .
    ?
  }

In the goal we see an opportunity to use the definition of insert. However, insert branches on whether y ≤ x, so we use a switch-for statement to do the same in our proof.

  switch y ≤ x for insert {
    case true suppose yx_true {
      ?
    }
    case false suppose yx_false {
      ?
    }
  }

In the case when y ≤ x is true, we apply the relevant definitions to arive at the following four subgoals.

    suffices sorted(xs') 
         and all_elements(xs', λb{x ≤ b}) 
         and y ≤ x
         and all_elements(xs', λb{y ≤ b})
             with definition {sorted, sorted, all_elements}

The first two of these follows from the premise sorted(node(x,xs')).

  // <<insert_sorted_case_node_s_xs__x_le_xs>> =
  have s_xs: sorted(xs') by definition sorted in s_xxs
  have x_le_xs': all_elements(xs',λb{(x ≤ b)}) by definition sorted in s_xxs

The third is true in the current case.

  // <<insert_sorted_y_le_x>> =
  have y_le_x: y ≤ x by rewrite yx_true

The fourth, which states that y is less-or-equal all the elements in xs' follows transitively from y ≤ x and the that x is less-or-equal all the elements in xs' (x_le_xs') using the theorem all_elements_implies (in List.pf):

theorem all_elements_implies: 
  all T:type. all xs:List<T>. all P: fn T -> bool, Q: fn T -> bool.
  if all_elements(xs,P) and (all z:T. if P(z) then Q(z)) 
  then all_elements(xs,Q)

To satisfy the second premise of all_elements_implies, we use y ≤ x to prove that if x is less than any other element, then so is y.

  // <<insert_sorted_x_le_implies_y_le>> =
  have x_le_implies_y_le: all z:Nat. (if x ≤ z then y ≤ z)
    by arbitrary z:Nat  suppose x_le_z: x ≤ z
       conclude y ≤ z by apply less_equal_trans[y][x,z] to y_le_x , x_le_z

Now we apply all_elements_implies to prove all_elements(xs',λb{(y ≤ b)}).

  // <<insert_sorted_y_le_xs>> =
  have y_le_xs': all_elements(xs',λb{(y ≤ b)})
    by apply all_elements_implies<Nat>[xs']
             [λb{(x ≤ b)} : fn Nat->bool, λb{(y ≤ b)} : fn Nat->bool]
       to x_le_xs', x_le_implies_y_le

and then conclude this case for when y ≤ x.

  // <<insert_sorted_case_node_le_conclusion>> =
  s_xs, x_le_xs', y_le_x, y_le_xs'

Next we turn our attention to the case for when y ≤ x is false. After applying the definition of insert, Deduce tells us that we need to prove.

    // <<insert_sorted_case_node_g_def>> =
    suffices sorted(insert(xs',y)) 
         and all_elements(insert(xs',y),λb{x ≤ b})
             with definition sorted

The first follows from the induction hypothesis. (Though we need to move the proof of s_xs out of the y ≤ x case so that we can use it here.)

  // <<insert_sorted_s_xs_y>> =
  have s_xs'_y: sorted(insert(xs',y)) by apply IH[y] to s_xs

The second requires more thinking. We know that x ≤ y in this case by the following reasoning.

  // <<insert_sorted_x_le_y>> =
  have x_le_y: x ≤ y
      by have not_yx: not (y ≤ x)  by suppose yx rewrite yx_false in yx
         apply not_less_equal_less_equal to not_yx

We have already proved that x is less-or-equal all the elements in xs'. So we know that x is less-or-equal all the element in node(y, xs') by the definition of all_elements.

  // <<insert_sorted_x_le_y_xs>> =
  have x_le_y_xs': all_elements(node(y, xs'),λb{(x ≤ b)})
      by suffices x ≤ y and all_elements(xs', λb{x ≤ b}) 
              with definition all_elements
         x_le_y, x_le_xs'

However, what we need to prove is that x is less-or-equal to insert(xs', y). But the all_elements function shouldn’t care about the ordering of the elements in the list, and indeed there is the following theorem in List.pf:

theorem all_elements_set_of:
  all T:type, xs:List<T>, ys:List<T>, P:fn T -> bool.
  if set_of(xs) = set_of(ys)
  then all_elements(xs, P) = all_elements(ys, P)

So we need to show that set_of(insert(xs',y)) = set_of(node(y,xs')). Thankfully, we already showed that this is true for mset_of in the insert_contents theorem, and multiset equality implies set equality: (also from List.pf)

theorem mset_equal_implies_set_equal: 
  all T:type, xs:List<T>, ys:List<T>.
  if mset_of(xs) = mset_of(ys)
  then set_of(xs) = set_of(ys)

So we use these three theorems to prove the following.

theorem all_elements_insert_node:
  all xs:List<Nat>, x:Nat, P:fn Nat->bool.
  all_elements(insert(xs,x), P) = all_elements(node(x,xs), P)
proof
  arbitrary xs:List<Nat>, x:Nat, P:fn Nat->bool
  have m_xs_x: mset_of(insert(xs, x)) = mset_of(node(x, xs))
      by suffices mset_of(insert(xs, x)) = m_one(x) ⨄ mset_of(xs)
             with definition mset_of
         insert_contents[xs][x]
  have ixsx_xxs: set_of(insert(xs, x)) = set_of(node(x, xs))
     by apply mset_equal_implies_set_equal<Nat>[insert(xs, x), node(x, xs)] 
        to m_xs_x
  apply all_elements_set_of<Nat>[ insert(xs,x), node(x, xs), P]
  to ixsx_xxs
end

We apply this theorem to prove that x is less-or-equal all the elements in insert(xs',y).

  // <<insert_sorted_x_le_xs_y>> =
  have x_le_xs'_y: all_elements(insert(xs',y), λb{x ≤ b})
      by _rewrite all_elements_insert_node[xs',y,λb{x≤b}:fn Nat->bool]
         x_le_y_xs'

Now we have the two facts we need to conclude this final case of proof of insert_sorted.

  // <<insert_sorted_case_node_g_conclusion>> =
  conclude sorted(insert(xs',y)) and
           all_elements(insert(xs',y),λb{x ≤ b})
      by s_xs'_y, x_le_xs'_y

Here is the complete proof of insert_sorted.

theorem insert_sorted: all xs:List<Nat>. all y:Nat.
  if sorted(xs) then sorted(insert(xs, y))
proof
  induction List<Nat>
  case empty {
    <<insert_sorted_case_empty>>
  }
  case node(x, xs') suppose IH {
    arbitrary y:Nat
    suppose s_xxs: sorted(node(x,xs'))
    suffices sorted(insert(node(x,xs'),y))  by .
    <<insert_sorted_case_node_s_xs__x_le_xs>>
    switch y ≤ x for insert {
      case true suppose yx_true {
        suffices sorted(node(y,node(x,xs')))  by .
        <<insert_sorted_case_node_less_defs>>
        <<insert_sorted_y_le_x>>
        <<insert_sorted_x_le_implies_y_le>>
        <<insert_sorted_y_le_xs>>
        <<insert_sorted_case_node_le_conclusion>>
      }
      case false suppose yx_false {
        <<insert_sorted_case_node_g_def>>
        <<insert_sorted_s_xs_y>>
        <<insert_sorted_x_le_y>>
        <<insert_sorted_x_le_y_xs>>
        <<insert_sorted_x_le_xs_y>>
        <<insert_sorted_case_node_g_conclusion>>
      }
    }
  }
end

Prove the correctness of insertion_sort

Referring back at the specification of insertion_sort(xs), we need to prove that (1) it outputs a list that contains the same elements as xs, and (2) the output is sorted.

As we did for insert, we use multisets and mset_of to express the requirement o the contents of the output of insertion_sort.

theorem insertion_sort_contents: all xs:List<Nat>.
  mset_of(insertion_sort(xs)) = mset_of(xs)

The insertion_sort(xs) function is recursive, so we proceed by induction on xs. In the case for xs = empty, we conclude the following using the definitions of insertion_sort and mset_of.

    // <<insertion_sort_contents_case_empty>> =
    conclude mset_of(insertion_sort(empty)) = mset_of(empty)
      by definition {insertion_sort, mset_of}

In the case for xs = node(x, xs'), after applying the definitions of insertion_sort and mset_of, it suffices show that

    // <<insertion_sort_contents_case_node_defs>> =
    suffices mset_of(insert(insertion_sort(xs'),x)) 
           = m_one(x) ⨄ mset_of(xs')
        with definition {insertion_sort, mset_of}

The goal follows from the insert_contents theorem and the induction hypothesis as follows.

  // <<insertion_sort_contents_case_node_equations>> =
  equations
          mset_of(insert(insertion_sort(xs'),x)) 
        = m_one(x) ⨄ mset_of(insertion_sort(xs'))
          by insert_contents[insertion_sort(xs')][x]
    ... = m_one(x) ⨄ mset_of(xs')
          by rewrite IH

Here is the complete proof of insertion_sort_contents.

theorem insertion_sort_contents: all xs:List<Nat>.
  mset_of(insertion_sort(xs)) = mset_of(xs)
proof
  induction List<Nat>
  case empty {
    <<insertion_sort_contents_case_empty>>
  }
  case node(x, xs') suppose IH {
    <<insertion_sort_contents_case_node_defs>>
    <<insertion_sort_contents_case_node_equations>>
  }
end

Finally, we prove that insertion_sort(xs) produces a sorted list. Of course the proof is by induction on xs. The case for empty follows from the relevant definitions. The case for node(x, xs') follows from the insert_sorted theorem and the induction hypothesis.

theorem insertion_sort_sorted: all xs:List<Nat>. 
  sorted( insertion_sort(xs) )
proof
  induction List<Nat>
  case empty {
    conclude sorted(insertion_sort(empty))
        by definition {insertion_sort, sorted}
  }
  case node(x, xs') suppose IH: sorted( insertion_sort(xs') ) {
    suffices sorted(insert(insertion_sort(xs'),x))
        with definition {insertion_sort, sorted}
    apply insert_sorted[insertion_sort(xs')][x] to IH
  }
end

Exercise: tail-recursive variant of insertion_sort

The insertion_sort function uses more computer memory than necessary because it uses one frame on the procedure call stack for every element in the input list. This can be avoided if we instead implement Insertion Sort using a tail-recursive function. That is, as a function that immediately returns after the recursive call. For this exercise, formulate a tail recursive version of insertion_sort, test it, and prove that it is correct.

As a hint, define an auxiliary function isort(xs,ys) that takes a list xs and an already sorted list ys and returns a sorted list that includes the contents of both xs and ys.

function isort(List<Nat>, List<Nat>) -> List<Nat> {
  FILL IN HERE
}

Once you have defined isort, you can implement Insertion Sort as follows.

define insertion_sort2 : fn List<Nat> -> List<Nat>
    = λxs{ isort(xs, empty) }

To prove the correctness of insertion_sort2, prove that the result is sorted

theorem insertion_sort2_sorted: all xs:List<Nat>. 
  sorted( insertion_sort2(xs) )
proof
  ?
end

and prove that the output includes all of the same elements as in the input (the correct number of times).

theorem insertion_sort2_contents: all xs:List<Nat>. 
  mset_of( insertion_sort2(xs) ) = mset_of(xs)
proof
  ?
end

No comments:

Post a Comment