Merge Sort

You just saw two different sorting algorithms: insertion sort and selection sort. We then implemented insertion sort. Recall that insertion sort is $O(n^2)$ in the worst case, but $O(n)$ in the best case.

The best-case performance is great: not any worse than checking whether the list is sorted in the first place. Since many datasets in real life are “mostly” sorted to begin with, this is a great property for a sorting algorithm to have.

The worst-case performance is concerning, though. We saw how quickly $O(n^2)$ can blow up when we implemented distinct. A worst-case $O(n^2)$ runtime wouldn’t bode well for sorting large datasets that are random, or that are mostly in reverse order. And, anyway, haven’t we just seen a data structure that can find “the right place” to put a new element in $O(log_2(n))$ time, rather than $O(n)$ (which is what insertion sort uses in its inner loop)? Surely we can do better…

Python’s Built-in Sorting

Python provides two different built-in sorts:

sorted(lst) # python built-in function, returns new list
lst.sort() # python built-in function, modifies list

These both run the same algorithm, but one modifies the current list and the other returns a new list. The algorithm is called Timsort because it’s named after Tim Peters, who first invented it for Python.

For some reason, I feel a fondness for Timsort.

A (seemingly?) Different Problem

Anyway, here’s a question. Suppose we have 2 sorted lists. How could we combine them to produce a sorted list with the same contents? Here’s an example:

assert merge([2,4,6], [3,5,7]) == [2,3,4,5,6,7]

How could we implement merge efficiently? We could start by just concatenating the lists together and then sorting the result:

def merge(list1, list2):
    return sorted(list1+list2)

This will work, but what’s the runtime? The concatenation (+) will be linear in the sum of the lengths of the lists. Sorting will depend on the underlying sort. If it was insertion sort, then merge would be worst-case quadratic.

Maybe we can do better. After all, we’re not using an important fact: the input lists are both sorted already. Is this fact useful? Think about how you might loop through the input lists to construct the output list.

Think, then click!

The smallest element of the two lists must be the first element of one of the two lists.

Look at first elements. 2 is less than 3. And we know the lists are sorted. We know that 2 is the smallest element in 1 comparison: constant time. So we can loop through the lists like this, looking at only the front of each list in every iteration:

[2,4,6] [3,5,7] --> [] [4,6] [3,5,7] --> [2] [4,6] [5,7] --> [2,3] [6] [5,7] --> [2,3,4] [6] [7] --> [2,3,4,5] [] [7] --> [2,3,4,5,6] [] [] --> [2,3,4,5,6,7]

How long did it take to run this merge? $O(n)$ operations, even in the worst case. So far, this might not appear very useful. But let’s continue anyway.

Implementing merge

Let’s try to implement this idea. But first, we’ll write tests.

assert merge([2,4,6], [3,5,7]) == [2,3,4,5,6,7]
assert merge([], []) == [] 
assert merge([1], []) == [1] 
assert merge([], [1]) == [1] 
assert merge([1], [2]) == [1,2]
assert merge([1,2,3], [4]) == [1,2,3,4]
assert merge([4], [1,2,3]) == [1,2,3,4]

Something I wonder…

Should we have any tests where the input lists are not sorted? Why or why not?

Getting started

Let’s write a skeletal merge function. Our core idea will be keeping two counters simultaneously, that tell us where we are in each of the input lists. We’ll keep looping so long as one of those counters still has list elements to feed into the merge.

To do this, we’ll use a while loop. This will make it easier to keep track of where we are in each list, and move forward in one (not both) of them at a time.

def merge(list1, list2):
    result = [] # eventually holds the merged list
    index1 = 0 # start at beginning of list1
    index2 = 0 # start at beginning of list2
    while index1 < len(list1) or index2 < len(list2):
        pass # ???
    return result

Now what? We need to either take the next element of list or the next element of list2:

def merge(list1, list2):
    result = []
    index1 = 0
    index2 = 0
    while index1 < len(list1) or index2 < len(list2):
        if list1[index1] < list2[index2]:
            result.append(list1[index1])
            index1 += 1
        else:
            result.append(list2[index2])
            index2 += 1
    return result

We’re making progress! But there’s still a problem: eventually we’ll run off the end of one of the lists. (And what if they were different size lists to begin with?)

So we need to add a guard to prevent trying to read past the end of a list. There are a few ways to do this, but we’ll add an or to the first condition:

def merge(list1, list2):
    result = []
    index1 = 0
    index2 = 0
    while index1 < len(list1) or index2 < len(list2): 
        # if we're out of elements in list2 *or* list1 has the smaller element
        if index2 >= len(list2) or list1[index1] < list2[index2]:
            result.append(list1[index1])
            index1 += 1
        else:
            result.append(list2[index2])
            index2 += 1
    return result

Hopefully our tests all pass, now!

(They don’t. What’s missing?)

Think, then click!

Unfortunately, this program isn’t quite right. What happens if it’s list1 that runs out of elements first? In that case, we’ll still end up trying to access list1[index1] and get an error.

This is easiest to fix by just adding another if branch to catch the out-of-elements cases. This adds some code length, yes, but let’s get the program correct before we make it elegant.

def merge(list1, list2):
    result = []
    index1 = 0
    index2 = 0
    while index1 < len(list1) or index2 < len(list2): 
        # list2 is out of elements 
        if index2 >= len(list2):
            result.append(list1[index1])
            index1 += 1
        # list1 is out of elements
        elif index1 >= len(list1):
            result.append(list2[index2])
            index2 += 1
        # list1 has the smaller element
        elif list1[index1] < list2[index2]:
            result.append(list1[index1])
            index1 += 1
        # list2 has the smaller element
        else:
            result.append(list2[index2])
            index2 += 1
    return result