Recursive Functions on Trees

Livecode

Tests

Logistics

Hours and Drill

Nobody came to my hours on Monday. So I’m reminding you that I moved them to Monday so that more were able to make it! If you can’t make it, but need to talk, then reach out.

There is a drill today!

Writing Recursive Functions

Last time we talked about the HTMLTree type we’ll be using to represent trees. Today, we’ll discuss how to write functions that process tree-structured data. Concretely, we’ll use 2 examples that do computation on HTML documents.

A key insight is that we can’t really use nested for loops to do this computation. Nesting 2, 3, 4, … layers deep will only let us explore 2, 3, 4, … levels of the HTML document. So we’ll need to use a different technique, one that fits (and exploits!) the recursive tree structure of the HTMLTrees.

Example 1

Challenge: given an HTMLTree document and a tag as input, return the number of times that tag appears in the document.

How would you approach this problem?

Let’s start by writing some examples in a test file! It’s a good idea to mix the simplest possible examples with some more complex ones. Here are some I wrote to prepare for this lecture:

def test_count_tag():
  assert count_tag(HTMLTree('html', []), 'p') == 0
  assert count_tag(parse('<html></html>'), 'p') == 0 
  assert count_tag(parse('<html></html>'), 'html') == 1
  assert count_tag(parse('<html><p>hello</p><strong><p>world</p></strong></html>'), 'p') == 2

If we didn’t have that final test, our suite of tests would pass an implementation that just looked at the top-most tag!

Now let’s get started implementing. The trick will be delegation: we’ll write a function that handles just the top tag of a tree, and delegates work on its children to other invocations of the same function. Here’s a good way to get started:

def count_tag(doc: HTMLTree, goal: str) -> int:    
    if doc.tag == goal:
        return 1
    else:
        return 0

This isn’t done, but it’s not wrong: we really do want to check the top-level tag and see whether it’s what we’re looking for. And, on super simple examples, this function would work just fine! The trick will be getting it to delegate. Let’s look at an example from our test suite: '<html><p>hello</p><strong><p>world</p></strong></html>'. I’ll omit the low-level memory details, and just draw the structure of the tree:

Note that our helper code turns plain text into HTMLTree objects with their tag field set to 'text', which is why the bottom nodes look the way they do.

If we call the above function on this tree, it will only detect the html tag, but not the others. So we need to do a bit more work: visit the two child subtrees, and add the results to our total.

def count_tag(doc: HTMLTree, goal: str) -> int:
    """return the number of times a particular tag appears in an HTML tree document"""
    count = 0
    if doc.tag = goal:
        count = count + 1
    for child in doc.children: 
        count = count + count_tag(child, goal)
    return count    

As you think about how this function executes, remember that each invocation of count_tag is separate: the count variable is different for each, as is the input. Because we make the input smaller every time we invoke count_tag, this technique will eventually return an answer.

…with comprehensions instead:

With more Python experience, you might start using comprehensions here, instead of a for loop:

def count_tag(doc: HTMLTree, goal: str) -> int:
    """return the number of times a particular tag appears in an HTML tree document"""
    subtotal = sum([count_tag(subdoc, goal) for subdoc in doc.children])
    if doc.tag == goal:
        return subtotal + 1
    return subtotal

Example 2

Challenge: given an HTMLTree document and a tag as input, return a list containing all of the plain text elements of the document.

How would you approach this problem?

As before, let’s start by writing some examples to help us get a feel for the computation we’ll need to do. Here are some that I wrote:

def test_all_text():
  assert all_text(parse('<html></html>')) == []  
  assert all_text(parse('hello')) == ['hello']  
  assert all_text(parse('<html>hello</html>')) == ['hello']    
  assert all_text(parse('<html><p>hello</p><strong><p>world</p></strong></html>')) == ['hello', 'world']

It looks like we are going to need to identify the nodes tagged with text, and somehow accumulate their contents. Like before, we’ll start with a version of the function that only works locally, at the top level:

def all_text(doc: HTMLTree) -> list:
    text = []
    if doc.tag == "text":
        text.append(doc.text)    
    return text

And, again, this works fine on single-element documents! But (also like before) we need to do a bit more work to process the child sub-trees of doc. We’ll delegate the problem to other invocations of the all_text function.

def all_text(doc: Tag) -> list[str]:
    text = []
    if doc.tag == "text":
        text.append(doc.text)
    for child in doc.children:
        text.append(all_text(child)) # uh oh! 
    return text

Note that I’ve marked a line in the above with a comment that says “uh oh!” There’s a mistake being made here: what is it?

Think, then click!

Calling text.append on a value will add whatever that value is as a new element in the list. That is, we’ll get the list all_text(child) added as a single element of the original list, rather than having all the new text elements individually appended to the original list.

Try it out! This sort of error is common when writing this sort of problem, so it will be useful to recognize the behavior.

Here’s a fixed version:

def all_text(doc: Tag) -> List[str]:
    text = []
    if doc.tag == "text":
        text.append(doc.text)
    for child in doc.children:
        text = text + all_text(child)
    return text

We could also have looped over all_text(child) and added each element individually.

…with comprehensions instead:

If you like comprehensions, here’s the (fixed) function written in that style:

def all_text(doc: HTMLTree) -> list[str]:
    if doc.tag == "text":
        return [doc.text]
    return [text for subdoc in doc.children for text in all_text(subdoc)]

Notice that here we manage to get by without creating a new list! Why is that?