This past semester for my programming languages class, we had to implement heap sort in Standard ML. This is my implementation, which I release under the GNU General Public License v3. I split the code into two separate files: shared.ml and heap_sort.ml, where shared.ml contains functions that might be of use with other sorting algorithms and heap_sort.ml contains functions and code specific to heap sort.

Note for students: my professor requested I state that, should you be taking his programming languages class CS 655 at the University of Kentucky and you try to use this code, (1) you’ll get in trouble for using someone else’s work and (2) you have to include the GPL and my copyright notice, which would be a real big hint that it’s not entirely your work. ;)

shared.ml

(* Copyright 2009 Sarah Vessels
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>. *)

(* Define a comparison function for integers that will return -1 if the first
 * element is smaller than the second, 0 if they are equal, and 1 if the first
 * element is larger than the second. *)
fun int_compare (a:int, b) = if a < b then ~1 else if a = b then 0 else 1;

(* Comparison function for reals; see int_compare. *)
fun real_compare (a:real, b) = if a < b then ~1 else if a = b then 0 else 1;

(* Comparison function for chars; see int_compare. *)
fun char_compare (a:char, b) = if a < b then ~1 else if a = b then 0 else 1;

(* Comparison function for bools; see int_compare.  Arbitrarily choosing true to
 * be less than false. *)
fun bool_compare (a:bool, b) = if a=true andalso b=false then ~1 else if a = b then 0 else 1;

(* Comparison function for strings; see int_compare. *)
fun str_compare (a:string, b) = if a < b then ~1 else if a = b then 0 else 1;

(* Using the given comparison function, this returns true if a <= b, false
 * otherwise. *)
fun leq (<=>, a, b) = if (<=> (a, b)) < 1 then true else false;

(* Using the given comparison function, this returns true if a >= b, false
 * otherwise. *)
fun geq (<=>, a, b) = if (<=> (a, b)) > ~1 then true else false;

(* Using the given comparison function, this returns true if a < b, false
 * otherwise. *)
fun lt (<=>, a, b) = if (<=> (a, b)) = ~1 then true else false;

(* Using the given comparison function, this returns true if a > b, false
 * otherwise. *)
fun gt (<=>, a, b) = if (<=> (a, b)) = 1 then true else false;

(* Using the given comparison function, this returns true if a = b, false
 * otherwise. *)
fun eq (<=>, a, b) = if (<=> (a, b)) = 0 then true else false;

(* Returns a sub-list of the given list from indices i to j, inclusive. *)
fun sub_list (arr, i, j) =
  let
    (* The last valid index in the given list *)
    val last_index = (length arr) - 1
  in
    (* Ensure the given indices are valid in the given list. *)
    if i >= 0 andalso j >= 0 andalso i <= last_index andalso j <= last_index then
      let
        (* The first j+1 elements of the list *)
        val first_els = List.take (arr, j+1)
      in
        if i < j then
          (* Remove the first i elements of the first j+1 elements and return
           * what's left *)
          List.drop (first_els, i)
        else if i = j then
          (* Single-element list *)
          [List.nth (arr, i)]
        else
          (* i > j, so call this method again with the smaller element passed as
           * the first index. *)
          sub_list (arr, j, i)
      end
    else
      (* Invalid indices given, return an empty list/nil. *)
      []
  end;

(* Given a list arr and indices i and j, this will swap the values in arr found
 * at the indices i and j *)
fun swap (arr, i, j) =
  let
    (* Get a list of the first elements in the original list, up to the ith
     * element *)
    val first_els = List.take (arr, i)

    (* Get the ith element *)
    val i_el = List.nth (arr, i)

    (* Get the jth element *)
    val j_el = List.nth (arr, j)

    (* Get a list of the last elements in the original list, starting at the
     * (j+1)th element *)
    val last_els = List.drop (arr, j+1)

    (* Get the last valid index in the given list *)
    val last_index = (length arr) - 1
  in
    if i > last_index orelse i < 0 orelse j > last_index orelse j < 0 then
      (* Error, so just return the original array *)
      arr
    else if i = j-1 then
      (* No elements between i and j, so the swap is easy *)
      (first_els) @ [j_el] @ [i_el] @ (last_els)
    else if i < j then
      let
        (* Get a list of the elements strictly between indices i and j in
         * the original array *)
        val between_els = sub_list (arr, i+1, j-1)
      in
        (first_els) @ [j_el] @ between_els @ [i_el] @ (last_els)
      end
    else if i = j then
      (* Same indices given, no swap necessary *)
      arr
    else
      (* i > j, so call this method with the lower index first *)
      swap (arr, j, i)
  end;

(* Given a list, an index, and a value, this will return the same list with the
 * value at the given index set to be the given value. *)
fun set_value (arr, index, value) =
  let
    val size = length arr
  in
    (* If the given index is invalid, just return the original list,
     * unmodified. *)
    if index < 0 orelse index >= size then
      arr
    else
      (* Append the value to the end of the original list, then use swap to put
       * it at the desired index, then return only as many elements from the
       * new list as there were in the original list. *)
      List.take (swap (arr @ [value], index, size), size)
  end;

heap_sort.ml

(* Copyright 2009 Sarah Vessels
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>. *)
use "shared.ml";

(* Given i, this will return the index of i's left child. *)
val left_index = fn i => 2*i;

(* Given i, this will return the value at i's left child. *)
fun left_child (arr, i) = List.nth (arr, left_index i);

(* Given i, this will return the index of i's right child. *)
val right_index = fn i => (2*i) + 1;

(* Given i, this will return the value at i's right child. *)
fun right_child (arr, i) = List.nth (arr, right_index i);

(* Given a list arr, an index i, and the size of the heap contained in the
 * given list, this will returns the entire list, modified s.t. the heap rooted
 * at i is a max heap.  This also requires a comparison function that works with
 * the type of elements found in the given list. *)
fun max_heapify_w_size (compare, arr, i, heap_size) =
  let
    val l = left_index i  (* Index of i's left child *)
    val r = right_index i (* Index of i's right child *)

    (* Value at root i *)
    val value = List.nth (arr, i)
  in
    if (
      l <= heap_size andalso
      gt (compare, left_child (arr, i), value) andalso
      (
        r > heap_size orelse
        leq (compare, right_child (arr, i), left_child (arr, i))
      )
    ) then
      (* Largest is left child, so swap left child and root i, then max heapify
       * at the left child index. *)
      max_heapify_w_size (compare, swap (arr, i, l), l, heap_size)
    else if (
      r <= heap_size andalso
      gt (compare, right_child (arr, i), value)
    ) then
      (* Largest is right child, so swap right child and root i, then max
       * heapify at the right child index. *)
      max_heapify_w_size (compare, swap (arr, i, r), r, heap_size)
    else
      (* Largest is root i, no change necessary, just return the list
       * unchanged. *)
      arr
  end;

(* Convenience function for max heapifying a heap in the given list rooted at
 * the given i with a fixed heap size (namely, one less than the length of the
 * given list).  Requires a comparison function for comparing elements in the
 * given list. *)
fun max_heapify (compare, arr, i) =
  max_heapify_w_size (compare, arr, i, (length arr) - 1);

(* Recursive method for turning an entire given list into a max heap.  Requires
 * the initial root for max_heapify be given as a parameter.  Requires a
 * comparison function for comparing elements in the given list. *)
fun build_max_heap_rec (compare, arr, 0) = max_heapify (compare, arr, 0)
  | build_max_heap_rec (compare, arr, i) =
      build_max_heap_rec (compare, max_heapify (compare, arr, i), i-1);

(* Convenience function for turning an entire given list into a max heap.  Makes
 * use of the build_max_heap_rec function by passing in an initial root to
 * max_heapify.  Requires a comparison function to compare elements in the given
 * list. *)
fun build_max_heap (compare, arr) =
  let
    val heap_size = (length arr) - 1
    val initial_i = heap_size div 2
  in
    build_max_heap_rec (compare, arr, initial_i)
  end;

(* A recursive heap_sort implementation taking a list, an initial index, and a
 * heap size for the given list.  Requires a comparison function to compare
 * elements in the given list. *)
fun heap_sort_rec (compare, arr, 1, heap_size) =
    (* Max heapify the list at root 0 after swapping the values at indices 0 and
     * 1 in the given list. *)
    max_heapify_w_size (compare, swap (arr, 0, 1), 0, heap_size-1)
  | heap_sort_rec (compare, arr, i, heap_size) =
    let
      (* Decrement the heap size we're looking at by 1 each time *)
      val dec_heap_size = heap_size-1
    in
      (* Swap the values at indices 0 and i, then recursively heap_sort the
       * modified list by looking at the next value for i and looking at a
       * smaller heap *)
      heap_sort_rec (compare, max_heapify_w_size (compare, swap (arr, 0, i),
        0, dec_heap_size), i-1, dec_heap_size)
    end;

(* Convenience function for sorting a given list with heap_sort.  Makes use of
 * heap_sort_rec by passing in an initial index and initial heap size.
 * Initially converts the given list into a max heap, which is necessary for
 * heap_sort.  Requires a comparison function for comparing the elements in the
 * given list.
 *)
fun heap_sort (compare, arr) =
  let
    val heap_size = (length arr) - 1
    val max_heap = build_max_heap (compare, arr)
  in
    heap_sort_rec (compare, max_heap, heap_size, heap_size)
  end;

(* Example unsorted lists and the results of heap-sorting them.  I don't print
 * these out because mosml shows the results when you run this file. *)
print "Sorting integers:\n";
val unsorted_list = [16, 4, 10, 14, 7, 9, 3, 2, 8, 1];
val sorted = heap_sort (int_compare, unsorted_list);

val unsorted_list = [5, 4, 3, 2, 1, 0, ~1, ~2, ~3, ~4, ~5];
val sorted = heap_sort (int_compare, unsorted_list);

print "Sorting strings:\n";
val unsorted_list = ["my", "mother", "told", "me", "to", "pick", "the", "very"]
val sorted = heap_sort (str_compare, unsorted_list);

val unsorted_list = ["best", "one", "and", "you", "are", "not", "not", "it"]
val sorted = heap_sort (str_compare, unsorted_list);

print "Sorting Booleans:\n";
val unsorted_list = [false, true];
val sorted = heap_sort (bool_compare, unsorted_list);

print "Sorting characters:\n";
val unsorted_list = [#"z", #"k", #"b", #"c", #"f"];
val sorted = heap_sort (char_compare, unsorted_list);

print "Sorting reals:\n";
val unsorted_list = [1.3, ~0.225, 3.1415, 87.5];
val sorted = heap_sort (real_compare, unsorted_list);