These notes do not exactly match the lecture capture, because we merged some OOP and recursive functions content.
Recursive Functions on Trees
Logistics
Writing Recursive Functions
Last time we talked about the HTMLTree
type we’ll be using to represent trees. Today, we’ll discuss how to write functions that process tree-structured data. Concretely, we’ll use 2 examples that do computation on HTML documents.
A key insight is that we can’t really use nested for
loops to do this computation. Nesting 2, 3, 4, … layers deep will only let us explore 2, 3, 4, … levels of the HTML document. So we’ll need to use a different technique, one that fits (and exploits!) the recursive tree structure of the HTMLTree
s. The data are recursive, so recursion is a natural thing to try first.
Example 1
Challenge: given an HTMLTree
document and a tag as input, return the number of times that tag appears in the document.
How would you approach this problem?
Let’s start by writing some examples in a test file! It’s a good idea to mix the simplest possible examples with some more complex ones. Here are some I wrote to prepare for this lecture:
def test_count_tag():
assert count_tag(HTMLTree('html', []), 'p') == 0
assert count_tag(parse('<html></html>'), 'p') == 0
assert count_tag(parse('<html></html>'), 'html') == 1
assert count_tag(parse('<html><p>hello</p><strong><p>world</p></strong></html>'), 'p') == 2
If we didn’t have that final test, our suite of tests would pass an implementation that just looked at the top-most tag!
Now let’s get started implementing. The trick will be delegation: we’ll write a function that handles just the top tag of a tree, and delegates work on its children to other invocations of the same function. Here’s a good way to get started:
def count_tag(doc: HTMLTree, goal: str) -> int:
if doc.tag == goal:
return 1
else:
return 0
This isn’t done, but it’s not wrong: we really do want to check the top-level tag and see whether it’s what we’re looking for. And, on super simple examples, this function would work just fine! The trick will be getting it to delegate. Let’s look at an example from our test suite: '<html><p>hello</p><strong><p>world</p></strong></html>'
. I’ll omit the low-level memory details, and just draw the structure of the tree:
Note that our helper code turns plain text into HTMLTree
objects with their tag
field set to 'text'
, which is why the bottom nodes look the way they do.
If we call the above function on this tree, it will only detect the html
tag, but not the others. So we need to do a bit more work: visit the two child subtrees, and add the results to our total.
def count_tag(doc: HTMLTree, goal: str) -> int:
"""return the number of times a particular tag appears in an HTML tree document"""
count = 0
if doc.tag = goal:
count = count + 1
for child in doc.children:
count = count + count_tag(child, goal)
return count
As you think about how this function executes, remember that each invocation of count_tag
is separate: the count
variable is different for each, as is the input. Because we make the input smaller every time we invoke count_tag
, this technique will eventually return an answer.
…with comprehensions instead:
With more Python experience, you might start using comprehensions here, instead of a for
loop:
def count_tag(doc: HTMLTree, goal: str) -> int:
"""return the number of times a particular tag appears in an HTML tree document"""
subtotal = sum([count_tag(subdoc, goal) for subdoc in doc.children])
if doc.tag == goal:
return subtotal + 1
return subtotal
Example 2
Challenge: given an HTMLTree
document and a tag as input, return a list containing all of the plain text
elements of the document.
How would you approach this problem?
As before, let’s start by writing some examples to help us get a feel for the computation we’ll need to do. Here are some that I wrote:
def test_all_text():
assert all_text(parse('<html></html>')) == []
assert all_text(parse('hello')) == ['hello']
assert all_text(parse('<html>hello</html>')) == ['hello']
assert all_text(parse('<html><p>hello</p><strong><p>world</p></strong></html>')) == ['hello', 'world']
It looks like we are going to need to identify the nodes tagged with text
, and somehow accumulate their contents. Like before, we’ll start with a version of the function that only works locally, at the top level:
def all_text(doc: HTMLTree) -> list:
text = []
if doc.tag == "text":
text.append(doc.text)
return text
And, again, this works fine on single-element documents! But (also like before) we need to do a bit more work to process the child sub-trees of doc
. We’ll delegate the problem to other invocations of the all_text
function.
def all_text(doc: Tag) -> list[str]:
text = []
if doc.tag == "text":
text.append(doc.text)
for child in doc.children:
text.append(all_text(child)) # uh oh!
return text
Note that I’ve marked a line in the above with a comment that says “uh oh!” There’s a mistake being made here: what is it?
Think, then click!
Calling text.append
on a value will add whatever that value is as a new element in the list. That is, we’ll get the list all_text(child)
added as a single element of the original list, rather than having all the new text elements individually appended to the original list.
Try it out! This sort of error is common when writing this sort of problem, so it will be useful to recognize the behavior.
Here’s a fixed version:
def all_text(doc: Tag) -> List[str]:
text = []
if doc.tag == "text":
text.append(doc.text)
for child in doc.children:
text = text + all_text(child)
return text
We could also have looped over all_text(child)
and added each element individually.
…with comprehensions instead:
If you like comprehensions, here’s the (fixed) function written in that style:
def all_text(doc: HTMLTree) -> list[str]:
if doc.tag == "text":
return [doc.text]
return [text for subdoc in doc.children for text in all_text(subdoc)]
Notice that here we manage to get by without creating a new list! Why is that?