Merge Sort
You just saw two different sorting algorithms: insertion sort and selection sort. We then implemented insertion sort. Recall that insertion sort is $O(n^2)$ in the worst case, but $O(n)$ in the best case. Both depend on how the inner loop is structured.
- If the inner loop goes backwards, the best case is on a sorted list, because the program will do one comparison for each element (the element vs. the largest member of the sorted prefix), discover that the element is in its proper place already, and move on.
- In the worst case, the inner loop of comparisons and swaps continues until the element is placed at the beginning of the list. So, if the inner loop goes backwards, the worst case is a reverse-sorted list.
The best-case performance is great: not any worse than checking whether the list is sorted in the first place. Since many datasets in real life are partially sorted to begin with, this is a great property for a sorting algorithm to have.
The worst-case performance is concerning, though. We saw how quickly $O(n^2)$ can blow up when we implemented distinct
. A worst-case $O(n^2)$ runtime wouldn’t bode well for sorting large datasets that are random, or $O(log_2(n))$that are mostly in reverse order. And, anyway, haven’t we just seen a data structure (Binary Search Trees) that can find the right place to put a new element in $O(log_2(n))$ time, rather than $O(n)$ (which is what insertion sort uses in its inner loop)? Surely we can do better.
One option would be to just use a BST in place of the inner loop, writing what amounts to:
- Put all elements of the input list into a BST. (Think about how to make sure the resulting BST is balanced! It’s possible.)
- Do an in-order tree traversal to extract the values in order.
What’s the worst-case runtime now? We’re running a worst-case operation $n$ times to insert, and then the traversal itself is linear. So we end up with worst-case $n \times O(log_2(n))$. But what about the best case? Is it still $O(n)$?
No. If we use a BST in place of the inner loop, the algorithm can no longer get “lucky” and stop early. We’ll spend time propotional to $O(log_2(n))$ for each element (although there’s some imprecision in how I’m phrasing that, since the number of elements in the tree would grow to $n$ only gradually as we added to it).
Still, though, we now know that a worst-case $n \times O(log_2(n))$ algorithm exists for sorting, so we’re getting faster in the limit.
How does Python sort?
Python provides two different built-in sorts:
sorted(lst) # Python built-in function, returns new list
lst.sort() # Python built-in function, modifies list
These both run the same algorithm, but one modifies the current list and the other returns a new list. The algorithm is called Timsort because it’s named after Tim Peters, who first invented it for Python. Timsort is complicated, because it combines ideas from insertion sort with a different algorithm that we’ll learn about today.
A (seemingly?) Different Problem
Anyway, here’s a question. Suppose we have 2 sorted lists. How could we combine them to produce a sorted list with the same contents? Here’s an example:
assert merge([2,4,6], [3,5,7]) == [2,3,4,5,6,7]
How could we implement merge
efficiently? We could start by just concatenating the lists together and then sorting the result:
def merge(list1, list2):
return sorted(list1+list2)
This will work, but what’s the runtime? The concatenation (+
) will be linear in the sum of the lengths of the lists. The call to sorted
will depend on the runtime of the underlying sort. If it was insertion sort, then merge
would be worst-case quadratic.
Maybe we can do better. After all, we’re not using an important fact: the input lists are both sorted already. Is this fact useful? Think about how you might loop through the input lists to construct the output list.
Think, then click!
The smallest element of the two lists must be the first element of one of the two lists.
Look at first elements. 2 is less than 3. And we know the lists are sorted. We know that 2 is the smallest element in 1 comparison: constant time. So we can loop through the lists like this, looking at only the front of each list in every iteration:
[2,4,6] [3,5,7] --> []
[4,6] [3,5,7] --> [2]
[4,6] [5,7] --> [2,3]
[6] [5,7] --> [2,3,4]
[6] [7] --> [2,3,4,5]
[] [7] --> [2,3,4,5,6]
[] [] --> [2,3,4,5,6,7]
How long did it take to run this merge? $O(n)$ operations, even in the worst case. So far, this might not appear very useful. But let’s continue anyway.
Implementing merge
Let’s try to implement this idea. But first, we’ll write tests.
assert merge([2,4,6], [3,5,7]) == [2,3,4,5,6,7]
assert merge([], []) == []
assert merge([1], []) == [1]
assert merge([], [1]) == [1]
assert merge([1], [2]) == [1,2]
assert merge([1,2,3], [4]) == [1,2,3,4]
assert merge([4], [1,2,3]) == [1,2,3,4]
Something I wonder…
Should we have any tests where the input lists are not sorted? Why or why not?
I would argue that it depends. On the face of it, we shouldn’t, since merge
is only supposed to work on sorted input lists. So unsorted inputs are out of scope for the function; it has no obligation to merge them correctly. A common approach is to just say that we make no guarantees.
But if our documentation for merge
happens to say what merge
will do if given invalid input (e.g., “if given an input list that is unsorted, merge
will raise a ValueError
) we would write a test to confirm that the promised behavior does indeed occur. So it’s all about what’s expected by someone using the function.
Getting started
Let’s write a skeletal merge
function. Our core idea will be keeping two counters simultaneously, that tell us where we are in each of the input lists. We’ll keep looping so long as one of those counters still has list elements to feed into the merge.
To do this, we’ll use a while
loop. This will make it easier to keep track of where we are in each list, and move forward in one (not both) of them at a time.
def merge(list1, list2):
result = [] # eventually holds the merged list
index1 = 0 # start at beginning of list1
index2 = 0 # start at beginning of list2
while index1 < len(list1) or index2 < len(list2):
pass # ???
return result
Now what? We need to either take the next element of list
or the next element of list2
:
def merge(list1, list2):
result = []
index1 = 0
index2 = 0
while index1 < len(list1) or index2 < len(list2):
if list1[index1] < list2[index2]:
result.append(list1[index1])
index1 += 1
else:
result.append(list2[index2])
index2 += 1
return result
We’re making progress! But there’s still a problem: eventually we’ll run off the end of one of the lists. (And what if they were different size lists to begin with?) So we need to add a guard to prevent trying to read past the end of a list. There are a few ways to do this, but we’ll add an or
to the first condition:
def merge(list1, list2):
result = []
index1 = 0
index2 = 0
while index1 < len(list1) or index2 < len(list2):
# if we're out of elements in list2 *or* list1 has the smaller element
if index2 >= len(list2) or list1[index1] < list2[index2]:
result.append(list1[index1])
index1 += 1
else:
result.append(list2[index2])
index2 += 1
return result
Hopefully our tests all pass, now!
(They don’t. What’s missing?)
Think, then click!
What happens if it’s list1
that runs out of elements first? In that case, we’ll still end up trying to access list1[index1]
and get an error.
This is easiest to fix by just adding another if
branch to catch the out-of-elements cases. This adds some code length, yes, but let’s get the program correct before we make it elegant.
def merge(list1, list2):
result = []
index1 = 0
index2 = 0
while index1 < len(list1) or index2 < len(list2):
# list2 is out of elements
if index2 >= len(list2):
result.append(list1[index1])
index1 += 1
# list1 is out of elements
elif index1 >= len(list1):
result.append(list2[index2])
index2 += 1
# list1 has the smaller element
elif list1[index1] < list2[index2]:
result.append(list1[index1])
index1 += 1
# list2 has the smaller element
else:
result.append(list2[index2])
index2 += 1
return result
Next time we’ll see how to turn this tiny helper into a sorting algorithm.