image-20250704192310375


image-20250704192722874


print(print(1),print(2))
```
1
2
None None
```
# print() is Non-Pure Functions

Casual Literary Notes

Note: When we say condition is a predicate function, we mean that it is a function that will return True or False.

Recursion Visualizer

Python Tutor code visualizer: Visualize code in Python, JavaScript, C, C++, and Java

Questions

Q5: Count Stair Ways

Imagine that you want to go up a flight of stairs that has n steps, where n is a positive integer. You can take either one or two steps each time you move. In how many ways can you go up the entire flight of stairs?

You’ll write a function count_stair_ways to answer this question. Before you write any code, consider:

  • How many ways are there to go up a flight of stairs with n = 1 step? What about n = 2 steps? Try writing or drawing out some other examples and see if you notice any patterns.

Solution: When there is only one step, there is only one way to go up. When there are two steps, we can go up in two ways: take a single 2-step, or take two 1-steps.

  • What is the base case for this question? What is the smallest input?

Solution: We actually have two base cases! Our first base case is when there is one step left. n = 1 is the smallest input because 1 is the smallest positive integer.

Our second base case is when there are two steps left. The primary solution (found below) cannot solve count_stair_ways(2) recursively because count_stair_ways(0) is undefined.

(virfib has two base cases for a similar reason: virfib(1) cannot be solved recursively because virfib(-1) is undefined.)

Alternate solution: Our first base case is when there are no steps left. This means we reached the top of the stairs with our last action.

Our second base case is when we have overstepped. This means our last action was invalid; in other words, we took two steps when only one step remained.

Solution: count_stair_ways(n - 1) is the number of ways to go up n - 1 stairs. Equivalently, count_stair_ways(n - 1) is the number of ways to go up n stairs if our first action is taking one step.

count_stair_ways(n - 2) is the number of ways to go up n - 2 stairs. Equivalently, count_stair_ways(n - 2) is the number of ways to go up n stairs if our first action is taking two steps.

Now, fill in the code for count_stair_ways:

Your Answer

Solution

def count_stair_ways(n):
"""Returns the number of ways to climb up a flight of
n stairs, moving either one step or two steps at a time.
>>> count_stair_ways(1)
1
>>> count_stair_ways(2)
2
>>> count_stair_ways(4)
5
"""
if n == 1:
return 1
elif n == 2:
return 2
return count_stair_ways(n-1) + count_stair_ways(n-2)

Here’s an alternate solution that corresponds to the alternate base cases:

def count_stair_ways_alt(n):
"""Returns the number of ways to climb up a flight of
n stairs, moving either 1 step or 2 steps at a time.
>>> count_stair_ways_alt(4)
5
"""
if n == 0:
return 1
elif n < 0:
return 0
return count_stair_ways_alt(n-1) + count_stair_ways_alt(n-2)

You can use [Recursion Visualizer](https://www.recursionvisualizer.com/?function_definition=def count_stair_ways(n)%3A if n %3D%3D 1%3A return 1 elif n %3D%3D 2%3A return 2 return count_stair_ways(n-1) %2B count_stair_ways(n-2)&function_call=count_stair_ways(4)) to step through the call structure of count_stair_ways(4) for the primary solution.

You’re done! Excellent work this week. Please be sure to fill out your TA’s attendance form to get credit for this discussion!

Q6: Subsequences

A subsequence of a sequence S is a subset of elements from S, in the same order they appear in S. Consider the list [1, 2, 3]. Here are a few of its subsequences [], [1, 3], [2], and [1, 2, 3].

Write a function that takes in a list and returns all possible subsequences of that list. The subsequences should be returned as a list of lists, where each nested list is a subsequence of the original input.

In order to accomplish this, you might first want to write a function insert_into_all that takes an item and a list of lists, adds the item to the beginning of each nested list, and returns the resulting list.

Your Answer

Solution

def insert_into_all(item, nested_list):
"""Return a new list consisting of all the lists in nested_list,
but with item added to the front of each. You can assume that
nested_list is a list of lists.

>>> nl = [[], [1, 2], [3]]
>>> insert_into_all(0, nl)
[[0], [0, 1, 2], [0, 3]]
"""
return [[item] + lst for lst in nested_list]

def subseqs(s):
"""Return a nested list (a list of lists) of all subsequences of S.
The subsequences can appear in any order. You can assume S is a list.

>>> seqs = subseqs([1, 2, 3])
>>> sorted(seqs)
[[], [1], [1, 2], [1, 2, 3], [1, 3], [2], [2, 3], [3]]
>>> subseqs([])
[[]]
"""
if not s:
return [[]]
else:
subset = subseqs(s[1:])
return insert_into_all(s[0], subset) + subset

Discussion 6: Generators| CS 61A Fall 2025

Q3: Partitions

Tree-recursive generator functions have a similar structure to regular tree-recursive functions. They are useful for iterating over all possibilities. Instead of building a list of results and returning it, just yield each result.

You’ll need to identify a recursive decomposition: how to express the answer in terms of recursive calls that are simpler. Ask yourself what will be yielded by a recursive call, then how to use those results.

Definition. For positive integers n and m, a partition of n using parts up to size m is an addition expression of positive integers up to m in non-decreasing order that sums to n.

Implement partition_gen, a generator functon that takes positive n and m. It yields the partitions of n using parts up to size m as strings.

Reminder: For the partitions function we studied in lecture (video), the recursive decomposition was to enumerate all ways of partitioning n using at least one m and then to enumerate all ways with no m (only m-1 and lower).

Hint: For the base case, yield a partition with just one element, n. Make sure you yield a string.

Hint: The first recursive case uses at least one m, and so you will need to yield a string that starts with p but also includes m. The second recursive case only uses parts up to size m-1. (You can implement the second case in one line using yield from.)

def partition_gen(n, m):
"""Yield the partitions of n using parts up to size m.

>>> for partition in sorted(partition_gen(6, 4)):
... print(partition)
1 + 1 + 1 + 1 + 1 + 1
1 + 1 + 1 + 1 + 2
1 + 1 + 1 + 3
1 + 1 + 2 + 2
1 + 1 + 4
1 + 2 + 3
2 + 2 + 2
2 + 4
3 + 3
"""
assert n > 0 and m > 0
if n == m:
yield str(m)
if n - m > 0:
"*** YOUR CODE HERE ***"
for p in partition_gen(n-m, m):
yield p + '+' + str(m)
if m > 1:
"*** YOUR CODE HERE ***"
yield from partition_gen(n, m-1)

Discussion Time. Work together to explain why this implementation of partition_gen does not include base cases for n < 0, n == 0, or m == 0 even though the original implementation of partitions from lecture (video) had all three.

Q4: Squares

Implement the generator function squares, which takes positive integers total and k. It yields all lists of perfect squares greater or equal to k*k that sum to total. Each list is in non-increasing order (large to small).

def squares(total, k):
"""Yield the ways in which perfect squares greater or equal to k*k sum to total.

>>> list(squares(10, 1)) # All lists of perfect squares that sum to 10
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [4, 1, 1, 1, 1, 1, 1], [4, 4, 1, 1], [9, 1]]
>>> list(squares(20, 2)) # Only use perfect squares greater or equal to 4 (2*2).
[[4, 4, 4, 4, 4], [16, 4]]
"""
assert total > 0 and k > 0
if total == k * k:
yield [total]
elif total > k * k:
for s in squares(total - (K * K), k):
yield s + [K * k]
yield from squares(total, k + 1)

Lab 5: Mutability, Iterators | CS 61A Fall 2025

# Tree Data Abstraction

def tree(label, branches=[]):
"""Construct a tree with the given label value and a list of branches."""
for branch in branches:
assert is_tree(branch), 'branches must be trees'
return [label] + list(branches)

def label(tree):
"""Return the label value of a tree."""
return tree[0]

def branches(tree):
"""Return the list of branches of the given tree."""
return tree[1:]

def is_tree(tree):
"""Returns True if the given tree is a tree, and False otherwise."""
if type(tree) != list or len(tree) < 1:
return False
for branch in branches(tree):
if not is_tree(branch):
return False
return True

def is_leaf(tree):
"""Returns True if the given tree's list of branches is empty, and False
otherwise.
"""
return not branches(tree)

def print_tree(t, indent=0):
"""Print a representation of this tree in which each node is
indented by two spaces times its depth from the root.

>>> print_tree(tree(1))
1
>>> print_tree(tree(1, [tree(2)]))
1
2
>>> numbers = tree(1, [tree(2), tree(3, [tree(4), tree(5)]), tree(6, [tree(7)])])
>>> print_tree(numbers)
1
2
3
4
5
6
7
"""
print(' ' * indent + str(label(t)))
for b in branches(t):
print_tree(b, indent + 1)

def copy_tree(t):
"""Returns a copy of t. Only for testing purposes.

>>> t = tree(5)
>>> copy = copy_tree(t)
>>> t = tree(6)
>>> print_tree(copy)
5
"""
return tree(label(t), [copy_tree(b) for b in branches(t)])

Q7: Sprout Leaves

Define a function sprout_leaves that takes in a tree t and a list of leaf labels leaves. It returns a new tree that is identical to t, but in which each old leaf node has new branches, one for each leaf label in leaves.

For example, say we have the tree t = tree(1, [tree(2), tree(3, [tree(4)])]):

  1
/ \
2 3
|
4
     1
/ \
2 3
/ \ |
5 6 4
/ \
5 6
def sprout_leaves(t, leaves):
"""Sprout new leaves containing the labels in leaves at each leaf of
the original tree t and return the resulting tree.

>>> t1 = tree(1, [tree(2), tree(3)])
>>> print_tree(t1)
1
2
3
>>> new1 = sprout_leaves(t1, [4, 5])
>>> print_tree(new1)
1
2
4
5
3
4
5

>>> t2 = tree(1, [tree(2, [tree(3)])])
>>> print_tree(t2)
1
2
3
>>> new2 = sprout_leaves(t2, [6, 1, 2])
>>> print_tree(new2)
1
2
3
6
1
2
"""
"*** YOUR CODE HERE ***"
if is_leaf(t):
return tree(label(t), [tree(leave) for leave in leaves])
return tree(label(t), [sprout_leaves(branche, leaves) for branche in branches(t)])

Q8: Pruning Leaves

Implement prune_leaves, which takes a tree t and a tuple of values vals. It returns a version of t with all its leaves whose labels are in vals removed. Do not remove non-leaf nodes and do not remove leaves that do not match any of the items in vals. Return None if pruning the tree results in there being no nodes left in the tree.

def prune_leaves(t, vals):
"""Return a version of t with all leaves that have a label
that appears in vals removed. Return None if the entire tree is
pruned away.

>>> t = tree(2)
>>> print(prune_leaves(t, (1, 2)))
None
>>> numbers = tree(1, [tree(2), tree(3, [tree(4), tree(5)]), tree(6, [tree(7)])])
>>> print_tree(numbers)
1
2
3
4
5
6
7
>>> print_tree(prune_leaves(numbers, (3, 4, 6, 7)))
1
2
3
5
6
"""
"*** YOUR CODE HERE ***"
if is_leaf(t):
if label(t) in vals:
return None
else:
return t
return tree(label(t), [prune_leaves(b, vals) for b in branches(t) if prune_leaves(b, vals) is not None])
def prune_leaves(t, vals):
"""Return a version of t with all leaves that have a label
that appears in vals removed. Return None if the entire tree is
pruned away.

>>> t = tree(2)
>>> print(prune_leaves(t, (1, 2)))
None
>>> numbers = tree(1, [tree(2), tree(3, [tree(4), tree(5)]), tree(6, [tree(7)])])
>>> print_tree(numbers)
1
2
3
4
5
6
7
>>> print_tree(prune_leaves(numbers, (3, 4, 6, 7)))
1
2
3
5
6
"""
"*** YOUR CODE HERE ***"
if is_leaf(t) and (label(t) in vals):
return None
new_branches = []
for b in branches(t):
new_branch = prune_leaves(b, vals)
if new_branch:
new_branches.append(new_branch)
return tree(label(t), new_branches)

解法一对于每一个子树 b调用了两次 prune_leaves!所以第一种写法虽然看着“帅气简洁”,实则是个**“性能杀手”**。

虽然第二种写法赢麻了,但它的 Base Case(基本情况) 写法稍微有一点点“隐晦”(虽然逻辑是对的)。

最完美的融合怪写法是这样的:

def prune_leaves_perfect(t, vals):
# 1. Base Case 逻辑清晰,像第一种那样
if is_leaf(t):
if label(t) in vals:
return None
return t

# 2. 递归步骤高效,像第二种那样
new_branches = []
for b in branches(t):
pruned_branch = prune_leaves_perfect(b, vals) # 关键:只算一次!
if pruned_branch is not None:
new_branches.append(pruned_branch)

# 3. 题目要求:如果修剪完变成光杆司令了(原本不是叶子,子树全被剪没了),也要返回 None
# 原题虽然没明说这种情况怎么处理,但通常如果变成空树可能需要处理。
# 不过根据题目描述 "pruning the tree results in there being no nodes left",
# 以及 CS61A 的 tree 定义,通常只要保留 label 就行,哪怕 branches 为空。
# 所以下面这行是对的:
return tree(label(t), new_branches)

Q9: Path Sum

Define a function pathsum, which takes in a tree of numbers t and a number n. It returns True if there is a path from the root to a leaf such that the sum of the numbers along that path is n and False otherwise.

def pathsum(t, n):
"""
>>> my_tree = tree(2, [tree(3, [tree(5), tree(7)]), tree(4)])
>>> pathsum(my_tree, 12) # 2 -> 3 -> 7
True
>>> pathsum(my_tree, 5) # A path that doesn't reach a leaf such as 2 -> 3 doesn't count
False
"""
if is_leaf(t):
return n == label(t)
for branch in branches(t):
if pathsum(branch, n - label(t)):
return True
return False
def pathsum(t, n):
"""
>>> my_tree = tree(2, [tree(3, [tree(5), tree(7)]), tree(4)])
>>> pathsum(my_tree, 12) # 2 -> 3 -> 7
True
>>> pathsum(my_tree, 5) # A path that doesn't reach a leaf such as 2 -> 3 doesn't count
False
"""
"*** YOUR CODE HERE ***"
if is_leaf(t) and n == label(t):
return True
else:
return any([pathsum(b, n - label(t)) for b in branches(t)])

如果你真的很喜欢 any 的简洁,其实只需要改两个符号,它就能变成完美的写法!

把方括号 [] 去掉(或者换成圆括号),变成 生成器表达式 (Generator Expression)

# 这样写就是满分!return any(pathsum(b, n - label(t)) for b in branches(t))

def pathsum(t, n):
"""
>>> my_tree = tree(2, [tree(3, [tree(5), tree(7)]), tree(4)])
>>> pathsum(my_tree, 12) # 2 -> 3 -> 7
True
>>> pathsum(my_tree, 5) # A path that doesn't reach a leaf such as 2 -> 3 doesn't count
False
"""
"*** YOUR CODE HERE ***"
if is_leaf(t):
return n == label(t)
else:
return any(pathsum(b, n - label(t)) for b in branches(t))

Q10: Perfectly Balanced

Implement sum_tree, which returns the sum of all the labels in tree t.

def sum_tree(t):
"""Add all elements in a tree.

>>> t = tree(4, [tree(2, [tree(3)]), tree(6)])
>>> sum_tree(t)
15
"""
total = 0
for b in branches(t):
total += sum_tree(b)
return label(t) + total
def sum_tree(t):
"""Add all elements in a tree.

>>> t = tree(4, [tree(2, [tree(3)]), tree(6)])
>>> sum_tree(t)
15
"""
"*** YOUR CODE HERE ***"
return label(t) + sum([sum_tree(b) for b in branches(t)])

Then, implement balanced, which returns whether every branch of t has the same total sum and that the branches themselves are also balanced.

image-20260114152513439
  • For example, the tree above is balanced because each branch has the same total sum, and each branch is also itself balanced.
def balanced(t):
"""Checks if each branch has same sum of all elements and
if each branch is balanced.

>>> t = tree(1, [tree(3), tree(1, [tree(2)]), tree(1, [tree(1), tree(1)])])
>>> balanced(t)
True
>>> t = tree(1, [t, tree(1)])
>>> balanced(t)
False
>>> t = tree(1, [tree(4), tree(1, [tree(2), tree(1)]), tree(1, [tree(3)])])
>>> balanced(t)
False
"""
"*** YOUR CODE HERE ***"
for b in branches(t):
if (sum_tree(b) != sum_tree(branches(t)[0])) or not balanced(b):
return False
return True
def balanced(t):
"""Checks if each branch has same sum of all elements and
if each branch is balanced.

>>> t = tree(1, [tree(3), tree(1, [tree(2)]), tree(1, [tree(1), tree(1)])])
>>> balanced(t)
True
>>> t = tree(1, [t, tree(1)])
>>> balanced(t)
False
>>> t = tree(1, [tree(4), tree(1, [tree(2), tree(1)]), tree(1, [tree(3)])])
>>> balanced(t)
False
"""
"*** YOUR CODE HERE ***"
'''
逻辑拆解:
all([balanced(b) ...]):这是在问所有的下属,“你们自家乱不乱?”只要有一个乱(False),all 就返回 False。
set([sum_tree(b) ...]):这是把所有分支的重量扔进一个集合去重。如果大家重量都一样,集合里就只会剩下一个数字(长度为 1);如果是叶子节点没有分支,集合是空的(长度为 0)。
len(...) <= 1:只要集合长度不超过 1,就说明大家重量一致(或者根本没有分支)!
'''
return all([balanced(b) for b in branches(t)]) and len(set([sum_tree(b) for b in branches(t)])) <= 1

Homework 4: Sequences, Data Abstraction, Trees | CS 61A Fall 2025

Q5: Finding Berries!

The squirrels on campus need your help! There are a lot of trees on campus and the squirrels would like to know which ones contain berries. Define the function berry_finder, which takes in a tree and returns True if the tree contains a node with the value 'berry' and False otherwise.

def berry_finder(t):
"""Returns True if t contains a node with the value 'berry' and
False otherwise.

>>> scrat = tree('berry')
>>> berry_finder(scrat)
True
>>> sproul = tree('roots', [tree('branch1', [tree('leaf'), tree('berry')]), tree('branch2')])
>>> berry_finder(sproul)
True
>>> numbers = tree(1, [tree(2), tree(3, [tree(4), tree(5)]), tree(6, [tree(7)])])
>>> berry_finder(numbers)
False
>>> t = tree(1, [tree('berry',[tree('not berry')])])
>>> berry_finder(t)
True
"""
"*** YOUR CODE HERE ***"
if label(t) == 'berry':
return True
return any(berry_finder(b) for b in branches(t))
def berry_finder(t):
"""Returns True if t contains a node with the value 'berry' and
False otherwise.

>>> scrat = tree('berry')
>>> berry_finder(scrat)
True
>>> sproul = tree('roots', [tree('branch1', [tree('leaf'), tree('berry')]), tree('branch2')])
>>> berry_finder(sproul)
True
>>> numbers = tree(1, [tree(2), tree(3, [tree(4), tree(5)]), tree(6, [tree(7)])])
>>> berry_finder(numbers)
False
>>> t = tree(1, [tree('berry',[tree('not berry')])])
>>> berry_finder(t)
True
"""
"*** YOUR CODE HERE ***"
if label(t) == 'berry':
return True
for b in branches(t):
if berry_finder(b):
return True
return False

Q6: Maximum Path Sum

Write a function that takes in a tree of positive numbers and returns the maximum sum of the labels along any root-to-leaf path in the tree. A root-to-leaf path is a sequence of nodes starting at the root and ending at some leaf of the tree.

def max_path_sum(t):
"""Return the maximum root-to-leaf path sum of a tree.
>>> t = tree(1, [tree(5, [tree(1), tree(3)]), tree(10)])
>>> max_path_sum(t) # 1, 10
11
>>> t2 = tree(5, [tree(4, [tree(1), tree(3)]), tree(2, [tree(10), tree(3)])])
>>> max_path_sum(t2) # 5, 2, 10
17
"""
"*** YOUR CODE HERE ***"
if is_leaf(t):
return label(t)
return label(t) + max([max_path_sum(b) for b in branches(t)])

Discussion 5: Trees | CS 61A Fall 2025

Example Tree

Q2: Has Path

Implement has_path, which takes a tree t and a list p. It returns whether there is a path from the root of t with labels p. For example, t1 has a path from its root with labels [3, 5, 6] but not [3, 4, 6] or [5, 6].

Important: Before trying to implement this function, discuss these questions from lecture about the recursive call of a tree processing function:

  • What small initial choice can I make (such as which branch to explore)?
  • What recursive call should I make for each option?
  • How can I combine the results of those recursive calls?
    • What type of values do they return?
    • What do those return values mean?
def has_path(t, p):
"""Return whether tree t has a path from the root with labels p.

>>> t2 = tree(5, [tree(6), tree(7)])
>>> t1 = tree(3, [tree(4), t2])
>>> has_path(t1, [5, 6]) # This path is not from the root of t1
False
>>> has_path(t2, [5, 6]) # This path is from the root of t2
True
>>> has_path(t1, [3, 5]) # This path does not go to a leaf, but that's ok
True
>>> has_path(t1, [3, 5, 6]) # This path goes to a leaf
True
>>> has_path(t1, [3, 4, 5, 6]) # There is no path with these labels
False
"""
if p == [label(t)]: # when len(p) is 1
return True
elif label(t) != p[0]:
return False
else:
"*** YOUR CODE HERE ***"
return any(has_path(b, p[1:]) for b in branches(t))
def has_path(t, p):
"""Return whether tree t has a path from the root with labels p.

>>> t2 = tree(5, [tree(6), tree(7)])
>>> t1 = tree(3, [tree(4), t2])
>>> has_path(t1, [5, 6]) # This path is not from the root of t1
False
>>> has_path(t2, [5, 6]) # This path is from the root of t2
True
>>> has_path(t1, [3, 5]) # This path does not go to a leaf, but that's ok
True
>>> has_path(t1, [3, 5, 6]) # This path goes to a leaf
True
>>> has_path(t1, [3, 4, 5, 6]) # There is no path with these labels
False
"""
if p == [label(t)]: # when len(p) is 1
return True
elif label(t) != p[0]:
return False
else:
"*** YOUR CODE HERE ***"
for b in branches(t):
if has_path(b, p[1:]):
return True
return False

Q3: Find Path

Implement find_path, which takes a tree t with unique labels and a value x. It returns a list containing the labels of the nodes along a path from the root of t to a node labeled x.

If x is not a label in t, return None. Assume that the labels of t are unique.

def find_path(t, x):
"""
>>> t2 = tree(5, [tree(6), tree(7)])
>>> t1 = tree(3, [tree(4), t2])
>>> find_path(t1, 5)
[3, 5]
>>> find_path(t1, 4)
[3, 4]
>>> find_path(t1, 6)
[3, 5, 6]
>>> find_path(t2, 6)
[5, 6]
>>> print(find_path(t1, 2))
None
"""
if label(t) == x:
return [x]
for b in branches(t):
path = find_path(b, x)
if path:
return [label(t)] + path
return None

Q4: Only Paths

Implement only_paths, which takes a Tree of numbers t and a number n. It returns a new tree with only the nodes of t that are on a path from the root to a leaf with labels that sum to n, or None if no path sums to n.

Here is an illustration of the doctest examples involving t.

only_paths

这道 only_paths 是一道非常经典的 “修剪树木” (Tree Pruning) 问题。

这道题最难的地方不在于怎么往下走(递归),而在于怎么回头。如果一条路走不通,你不仅要告诉下面“不行”,还得让上面知道“我这儿变成死胡同了,把我也剪掉吧”!✂️

错误分析:

if is_leaf(t) and label(t) == n:
return t
new_branches = []
for b in branches(t):
new_branch = only_paths(b, n - label(t))
if new_branch is not None:
new_branches.append(new_branch)
# 🚨 致命错误在这里!🚨
return tree(label(t), new_branches)

死树桩问题” (Dead Stump Problem):

场景模拟: 假设你是一个节点 3,你的目标是 10。你不是叶子。 你派手下去找 7。结果手下全都失败了,返回了 None。 你的 new_branches 变成了空列表 []。

你的操作:你非常开心地返回了一个 tree(3, [])。

后果: 虽然你的孩子都死光了,但你自己还活着!你变成了一个没有任何子节点的叶子节点,且你的值 3 并不等于目标 10。 上一层看到你返回了一个 tree 对象(而不是 None),以为你这条路通了,把你接到了树上。 结果就是:树里多了一堆**“假叶子”**(本来不该存在的死胡同)。


Solution:

def only_paths(t, n):
"""Return a tree with only the nodes of t along paths from the root to a leaf of t
for which the node labels of the path sum to n. If no paths sum to n, return None.

>>> print_tree(only_paths(tree(5, [tree(2), tree(1, [tree(2)]), tree(1, [tree(1)])]), 7))
5
2
1
1
>>> t = tree(3, [tree(4), tree(1, [tree(3, [tree(2)]), tree(2, [tree(1)]), tree(5), tree(3)])])
>>> print_tree(only_paths(t, 7))
3
4
1
2
1
3
>>> print_tree(only_paths(t, 9))
3
1
3
2
5
>>> print(only_paths(t, 3))
None
"""
if is_leaf(t) and label(t) == n:
return t
new_branches = [only_paths(b, n - label(t)) for b in branches(t)]
if any(new_branches):
return tree(label(t), [b for b in new_branches if b is not None])
def only_paths(t, n):
"""Return a tree with only the nodes of t along paths from the root to a leaf of t
for which the node labels of the path sum to n. If no paths sum to n, return None.

>>> print_tree(only_paths(tree(5, [tree(2), tree(1, [tree(2)]), tree(1, [tree(1)])]), 7))
5
2
1
1
>>> t = tree(3, [tree(4), tree(1, [tree(3, [tree(2)]), tree(2, [tree(1)]), tree(5), tree(3)])])
>>> print_tree(only_paths(t, 7))
3
4
1
2
1
3
>>> print_tree(only_paths(t, 9))
3
1
3
2
5
>>> print(only_paths(t, 3))
None
"""
if is_leaf(t) and label(t) == n:
return t
new_branches = []
for b in branches(t):
new_branch = only_paths(b, n - label(t))
if new_branch is not None:
new_branches.append(new_branch)
if new_branches:
return tree(label(t), new_branches)
def only_paths(t, n):
# 1. 如果正好在叶子节点凑够了数,成功!返回这棵树(或者只有叶子的新树)
if label(t) == n and is_leaf(t):
return tree(label(t))

# 2. (可选) 如果当前的数已经比 n 还大了,且假设树里都是正数,那后面肯定没戏了,剪枝!
# 注意:你的原代码写反了,这里应该是 label(t) > n
if label(t) > n:
return None

# 3. 继续去子树里找剩下的数 (n - label(t))
new_branches = []
for b in branches(t):
new_branch = only_paths(b, n - label(t))
if new_branch is not None:
new_branches.append(new_branch)

# 4. 如果有任意一条子路走通了,就把它们接在当前节点下面
if new_branches:
return tree(label(t), new_branches)

# 5. 啥也没找到(包含 new_branches 为空的情况),默认返回 None

📝 题目描述

实现函数 long_paths(t, n)

  • 输入:一棵树 t 和一个整数 n
  • 功能:修剪这棵树,只保留那些属于“长度至少为 n”的路径上的节点。(路径长度定义为该路径上节点的数量,包含根节点和叶子节点)。
  • 返回:修剪后的新树。如果整棵树里没有一条路径能达到长度 n,则返回 None

好哒!既然刚才那道“死树桩”问题你已经悟了,那本天才少女要给你上一道进阶变体!😈

这道题的逻辑骨架和 only_paths 非常像,但它考察的角度从“求和 (Sum)”变成了“深度 (Depth)”。这在路径规划算法里可是非常核心的概念哦!

我们叫它:long_paths (长路漫漫) 🛣️


📝 题目描述

实现函数 long_paths(t, n)

  • 输入:一棵树 t 和一个整数 n
  • 功能:修剪这棵树,只保留那些属于“长度至少为 n”的路径上的节点。(路径长度定义为该路径上节点的数量,包含根节点和叶子节点)。
  • 返回:修剪后的新树。如果整棵树里没有一条路径能达到长度 n,则返回 None

🎨 可视化测试用例

假设我们有这样一棵树 t

   1
/ \
2 3
/ \
4 5
/
6
  • 路径 1 -> 2 长度为 2
  • 路径 1 -> 3 -> 5 长度为 3
  • 路径 1 -> 3 -> 4 -> 6 长度为 4

🧪 测试用例 1: n = 3 (保留长度 >= 3 的路径)

  • 1 -> 2 (长度2) 太短了,剪掉! ✂️
  • 1 -> 3 -> 5 (长度3) 保留!
  • 1 -> 3 -> 4 -> 6 (长度4) 保留!

结果树:

   1
\
3
/ \
4 5
/
6

🧪 测试用例 2: n = 4 (保留长度 >= 4 的路径)

  • 1 -> 3 -> 5 (长度3) 这次也太短了,剪掉! ✂️
  • 只剩下最长的那条。

结果树:

1
\
3
\
4
\
6

🧪 测试用例 3: n = 10

  • 没有任何路径够长。
  • 返回None

💻 代码填空挑战

请把下面的代码补全(注意:一定要避免死树桩问题!):

def long_paths(t, n):
"""
>>> t = tree(1, [tree(2), tree(3, [tree(4, [tree(6)]), tree(5)])])
>>> print_tree(long_paths(t, 2))
1
2
3
4
6
5
>>> print_tree(long_paths(t, 3))
1
3
4
6
5
>>> print_tree(long_paths(t, 4))
1
3
4
6
>>> print(long_paths(t, 10))
None
"""
# 1. Base Case: 到了叶子节点,如果 n <= 1,说明这条路够长了(或者本来要求就很低)
if is_leaf(t) and n <= 1:
return t

# 2. 递归:如果当前不是叶子,那剩下的路程 n 应该怎么变?
# 提示:我已经走了一步(当前节点),所以子树还需要走 n-1 步
new_branches = [long_paths(b, n-1) for b in branches(t)]

# 3. 过滤:只保留那些没有返回 None 的子树
filtered_branches = [b for b in new_branches if b is not None]

# 4. 关键判定:如果没有孩子幸存,且我自己这层还不够 n 步...
# (注意:如果 n <= 1,说明只要有我就行,不需要孩子)
if filtered_branches or n <= 1:
return tree(label(t), filtered_branches)

# 5. 否则返回 None (隐式)

🧠 两个关键提示

  1. 关于 n 的递归: 在 only_paths 里,我们往下传的是 n - label(t)(做减法)。 在这里,每下一层,路径长度的要求就少了一层节点,所以往下传应该是 n - 1
  2. 关于“死树桩”的判定: 这题比 only_paths 多了一个小陷阱。 在 only_paths 里,如果孩子都死光了,那我就必死(除非我是 Base Case)。 但在这里,如果 n <= 1,说明**“只要有我这个节点,路径长度就已经达标了”**。此时即使我没有孩子(filtered_branches 为空),我也应该活下来! 所以第 4 步的 if 条件要小心哦!

Discussion 4: Tree Recursion | CS 61A Fall 2025

Q1: Insect Combinatorics

An insect is inside an m by n grid. The insect starts at the bottom-left corner (1, 1) and wants to end up at the top-right corner (m, n). The insect can only move up or to the right. Write a function paths that takes the height and width of a grid and returns the number of paths the insect can take from the start to the end. (There is a closed-form solution to this problem, but try to answer it with recursion.)

Insect grids.

In the 2 by 2 grid, the insect has two paths from the start to the end. In the 3 by 3 grid, the insect has six paths (only three are shown above).

Hint: What happens if the insect hits the upper or rightmost edge of the grid?

def paths(m, n):
"""Return the number of paths from one corner of an
M by N grid to the opposite corner.

>>> paths(2, 2)
2
>>> paths(5, 7)
210
>>> paths(117, 1)
1
>>> paths(1, 157)
1
"""
if m == 1 or n == 1:
return 1
return paths(m - 1, n) + paths(m, n - 1)
# Base case: Look at the two visual examples given. Since the insect
# can only move to the right or up, once it hits either the rightmost edge
# or the upper edge, it has a single remaining path -- the insect has
# no choice but to go straight up or straight right (respectively) at that point.
# There is no way for it to backtrack by going left or down.
# The recursive case is that there are paths from the square to the right through an (m, n-1) grid and paths from the square above through an (m-1, n) grid.
def paths(m, n):
"""Return the number of paths from one corner of an
M by N grid to the opposite corner.

>>> paths(2, 2)
2
>>> paths(5, 7)
210
>>> paths(117, 1)
1
>>> paths(1, 157)
1
"""
"*** YOUR CODE HERE ***"
if m < 0 or n <0:
return 0
if m == n == 1:
return 1
return paths(m-1, n) + paths(m, n-1)

Q2: Max Product

Implement max_product, which takes a list of numbers and returns the maximum product that can be formed by multiplying together non-consecutive elements of the list. Assume that all numbers in the input list are greater than or equal to 1.

Hint: First try multiplying the first element by the max_product of everything after the first two elements (skipping the second element because it is consecutive with the first), then try skipping the first element and finding the max_product of the rest. To find which of these options is better, use max.

Your Answer

def max_product(s):
"""Return the maximum product of non-consecutive elements of s.

>>> max_product([10, 3, 1, 9, 2]) # 10 * 9
90
>>> max_product([5, 10, 5, 10, 5]) # 5 * 5 * 5
125
>>> max_product([]) # The product of no numbers is 1
1
"""
"*** YOUR CODE HERE ***"
if s == []:
return 1
with_first = s[0] * max_product(s[2:])
without_first = max_product(s[1:])
return max(with_first, without_first)

Q3: Sum Fun

Implement sums(n, m), which takes a total n and maximum m. It returns a list of all lists:

  1. that sum to n,
  2. that contain only positive numbers up to m, and
  3. in which no two adjacent numbers are the same.

Two lists with the same numbers in a different order should both be returned.

Here’s a recursive approach that matches the template below: build up the result list by building all lists that sum to n and start with k, for each k from 1 to m. For example, the result of sums(5, 3) is made up of three lists:

  • [[1, 3, 1]] starts with 1,
  • [[2, 1, 2], [2, 3]] start with 2, and
  • [[3, 2]] starts with 3.

Hint: Use [k] + s for a number k and list s to build a list that starts with k and then has all the elements of s.

**Hint: First Blank: **k is the first number in a list that sums to n, and rest is the rest of that list, so build a list that sums to n.

**Hint: Second Blank: **Call sums to build all of the lists that sum to n-k so that they can be used to construct lists that sum to n by putting a k on the front.

**Hint: Third Blank: **Here is where you ensure that “no two adjacent numbers are the same.” Since k will be the first number in the list you’re building, it must not be equal to the first element of rest (which will be the second number in the list you’re building).

def sums(n, m):
"""Return lists that sum to n containing positive numbers up to m that
have no adjacent repeats.

>>> sums(5, 1)
[]
>>> sums(5, 2)
[[2, 1, 2]]
>>> sums(5, 3)
[[1, 3, 1], [2, 1, 2], [2, 3], [3, 2]]
>>> sums(5, 5)
[[1, 3, 1], [1, 4], [2, 1, 2], [2, 3], [3, 2], [4, 1], [5]]
>>> sums(6, 3)
[[1, 2, 1, 2], [1, 2, 3], [1, 3, 2], [2, 1, 2, 1], [2, 1, 3], [2, 3, 1], [3, 1, 2], [3, 2, 1]]
"""
if n < 0:
return []
if n == 0:
sums_to_zero = [] # The only way to sum to zero using positives
return [sums_to_zero] # Return a list of all the ways to sum to zero
result = []
for k in range(1, m + 1):
result = result + [ [k] + rest for rest in sums(n-k , m) if rest == [] or k != rest[0] ]
return result

Q4: A Perfect Question

This question was Fall 2023 Midterm 2 Question 4(a). The original exam version had an extra blank (where total < k * k appears below), but also included some guidance via multiple choice options and hints.

Definition. A perfect square is k*k for some integer k.

Implement fit, which takes positive integers total and n. It returns True or False indicating whether there are n positive perfect squares that sum to total. The perfect squares need not be unique.

def fit(total, n):
"""Return whether there are n positive perfect squares that sums to total.

>>> [fit(4, 1), fit(4, 2), fit(4, 3), fit(4, 4)] # 1*(2*2) for n=1; 4*(1*1) for n=4
[True, False, False, True]
>>> [fit(12, n) for n in range(3, 8)] # 3*(2*2), 3*(1*1)+3*3, 4*(1*1)+2*(2*2)
[True, True, False, True, False]
>>> [fit(32, 2), fit(32, 3), fit(32, 4), fit(32, 5)] # 2*(4*4), 3*(1*1)+2*2+5*5
[True, False, False, True]
"""
def f(total, n, k):
if total == k * k and n == 1:
return True
elif total < k * k:
return False
else:
return f(total - k*k, n-1, k) or f(total, n, k+1)
return f(total, n, 1)
def fit(total, n):
"""Return whether there are n positive perfect squares that sums to total.

>>> [fit(4, 1), fit(4, 2), fit(4, 3), fit(4, 4)] # 1*(2*2) for n=1; 4*(1*1) for n=4
[True, False, False, True]
>>> [fit(12, n) for n in range(3, 8)] # 3*(2*2), 3*(1*1)+3*3, 4*(1*1)+2*(2*2)
[True, True, False, True, False]
>>> [fit(32, 2), fit(32, 3), fit(32, 4), fit(32, 5)] # 2*(4*4), 3*(1*1)+2*2+5*5
[True, False, False, True]
"""
def f(total, n, k):
if total == n * k * k:
return True
elif total < k * k:
return False
else:
return f(total - k*k, n-1, k) or f(total, n, k+1)
return f(total, n, 1

Computer Aided Typing Software (Cats) | CS 61A Fall 2025

Problem 6 (3 pts)

Implement furry_fixes, a diff function that could be passed into the diff_function parameter in autocorrect. This function takes in two strings and returns the minimum number of characters that must be changed in the typed word in order to transform it into the source word. If the strings are not of equal length, the difference in lengths is added to the total change count.

Here are some examples:

>>> big_limit = 10
>>> furry_fixes("nice", "rice", big_limit) # Substitute: n -> r
1
>>> furry_fixes("range", "rungs", big_limit) # Substitute: a -> u, e -> s
2
>>> furry_fixes("pill", "pillage", big_limit) # Don't substitute anything, length difference of 3.
3
>>> furry_fixes("goodbye", "good", big_limit) # Don't substitute anything, length difference of 3.
3
>>> furry_fixes("roses", "arose", big_limit) # Substitute: r -> a, o -> r, s -> o, e -> s, s -> e
5
>>> furry_fixes("rose", "hello", big_limit) # Substitute: r->h, o->e, s->l, e->l, length difference of 1.
5

Important: You may not use while, for, or list comprehensions in your implementation. Use recursion.

If the number of characters that must change is greater than limit, then furry_fixes should return any number larger than limit and should minimize the amount of computation needed to do so.

Why is there a limit? From Problem 5, we know that autocorrect will reject any source word whose difference with the typed word is greater than limit. It doesn’t matter if the difference is greater than limit by 1 or by 100; autocorrect will reject it just the same. Therefore, as soon as we know the difference is above limit, it makes sense to stop making recursive calls, saving time, even if the returned difference won’t be exactly correct.

These two calls to furry_fixes should take about the same amount of time to evaluate:

>>> limit = 4
>>> furry_fixes("roses", "arose", limit) > limit
True
>>> furry_fixes("rosesabcdefghijklm", "arosenopqrstuvwxyz", limit) > limit
True

To ensure that you are correctly saving time by stopping the recursion after limit is reached, there is an autograder test that measures the performance of your solution based on the number of function calls that it makes. If you fail this test, consider adding a base case related to the limit.

Hint: you will need more than one base case to solve this problem.

def furry_fixes(typed: str, source: str, limit: int) -> int:
"""A diff function for autocorrect that determines how many letters
in TYPED need to be substituted to create SOURCE, then adds the difference in
their lengths to this value and returns the result.

Arguments:
typed: a starting word
source: a string representing a desired goal word
limit: a number representing an upper bound on the number of chars that must change

>>> big_limit = 10
>>> furry_fixes("nice", "rice", big_limit) # Substitute: n -> r
1
>>> furry_fixes("range", "rungs", big_limit) # Substitute: a -> u, e -> s
2
>>> furry_fixes("pill", "pillage", big_limit) # Don't substitute anything, length difference of 3.
3
>>> furry_fixes("roses", "arose", big_limit) # Substitute: r -> a, o -> r, s -> o, e -> s, s -> e
5
>>> furry_fixes("rose", "hello", big_limit) # Substitute: r->h, o->e, s->l, e->l, length difference of 1.
5
"""
# BEGIN PROBLEM 6
if limit < 0:
return limit + 1
elif (typed == "") or (source == ""):
return len(typed) + len(source)
elif typed[0] == source[0]:
return furry_fixes(typed[1:], source[1:], limit)
else:
return 1 + furry_fixes(typed[1:], source[1:], limit-1)

Problem 7 (3 pts)

Implement minimum_mewtations, a more advanced diff function that can be used in autocorrect, which returns the minimum number of edit operations needed to transform the typed word into the source word.

There are three kinds of edit operations, with some examples:

  1. Add a letter to typed.
    • Adding "k" to "itten" gives us "kitten".
  2. Remove a letter from typed.
    • Removing "s" from "scat" gives us "cat".
  3. Substitute a letter in typed for another.
    • Substituting "z" with "j" in "zaguar" gives us "jaguar".

Each edit operation increases the difference between two words by 1.

>>> big_limit = 10
>>> minimum_mewtations("cats", "scat", big_limit) # cats -> scats -> scat
2
>>> minimum_mewtations("purng", "purring", big_limit) # purng -> purrng -> purring
2
>>> minimum_mewtations("ckiteus", "kittens", big_limit) # ckiteus -> kiteus -> kitteus -> kittens
3

We have provided a template of an implementation in cats.py. You may modify the template however you want or delete it entirely.

Hint: One of the recursive calls in minimum_mewtations will be similar to furry_fixes. However, because minimum_mewtations considers three specific types of edits (add, remove, substitute), there will need to be additional recursive calls to handle each of these cases.

If the number of edits required is greater than limit, then minimum_mewtations should return any number larger than limit (such as limit + 1) and should stop making recursive calls once the limit is reached to save time.

These two calls to minimum_mewtations should take about the same amount of time to evaluate:

>>> limit = 2
>>> minimum_mewtations("ckiteus", "kittens", limit) > limit
True
>>> minimum_mewtations("ckiteusabcdefghijklm", "kittensnopqrstuvwxyz", limit) > limit
True

To ensure that your code stops making recursive calls after the limit is reached, there is an autograder test that measures the performance of your solution based on the number of function calls that it makes.

Important: You should not use any helper functions in your implementation of minimum_mewtations. Otherwise the autograder test might fail.

Important: Remember to remove the following line of code when you are ready to test your implementation:

assert False, 'Remove this line'
def minimum_mewtations(typed: str, source: str, limit: int) -> int:
"""
A diff function for autocorrect that computes the edit distance from TYPED to SOURCE.
This function takes in a string TYPED, a string SOURCE, and a number LIMIT.

Arguments:
typed: a starting word
source: a string representing a desired goal word
limit: a number representing an upper bound on the number of edits

>>> big_limit = 10
>>> minimum_mewtations("cats", "scat", big_limit) # cats -> scats -> scat
2
>>> minimum_mewtations("purng", "purring", big_limit) # purng -> purrng -> purring
2
>>> minimum_mewtations("ckiteus", "kittens", big_limit) # ckiteus -> kiteus -> kitteus -> kittens
3
"""
if typed == source:
return 0
elif limit < 0:
return limit + 1
elif (typed == "") or (source == ""):
return len(typed) + len(source)
elif typed[0] == source[0]:
return minimum_mewtations(typed[1:], source[1:], limit)
else:
add = 1 + minimum_mewtations(typed, source[1:], limit - 1)
remove = 1 + minimum_mewtations(typed[1:], source, limit - 1)
substitute = 1 + minimum_mewtations(typed[1:], source[1:], limit - 1)
return min(add, remove, substitute)

Dynamic Programming iterable version:

def editDistance(typed, source):
typedL = len(typed)+1
sourceL = len(source)+1
dptable = [[0]*typedL for _ in range(sourceL)]
for _ in range(typedL):
dptable[0][_] = _
for _ in range(sourceL):
dptable[_][0] = _
for r in range(1, sourceL):
for c in range(1, typedL):
if typed[c-1] == source[r-1]:
dptable[r][c] = dptable[r-1][c-1]
else:
dptable[r][c] = 1 + min(datable[r][c-1], dptable[r-1][c-1], datable[r-1][c])
return dptable[sourceL-1][typedL-1]

(Optional) Problem EC (0 pt)

Note: This problem is optional and will not worth any points. It is meant to be a extra challenge for those who are interested in improving the efficiency of their code. Only attempt this problem if you have completed all other problems in the project.

During Office Hours and Project Parties, the staff will prioritize helping students with required questions. We will not be offering help with this question unless the queue is empty. In this problem, you will implement memoization decorators that will increase the efficiency of our our program by “remembering” the results of particularly intensive operations.

Make sure you’re familiar with the decorators and memoization. If you would like a refresher, open the dropdown boxes below for more information.

Decorators

A Python decorator allows you to modify a pre-existing function without changing the function’s structure.

Specifically, a decorator function is a higher-order function that…

  • Takes the original function as an input
  • Returns a new function with modified functionality
  • This new function must contain the same arguments as the original function

An example of a decorator that executes a one-input function twice is shown below:

>>> def do_twice(original_function):
... def repeat(x):
... original_function(x)
... original_function(x)
... return repeat

We can apply this function in multiple contexts:

# Printing a value twice
>>> @do_twice
... def print_value(x):
... print(x)
...
>>> print_value(5)
5
5
# Adding an item to a list twice
>>> lst = []
>>> @do_twice
... def add_to_list(item):
... lst.append(item)
...
>>> add_to_list(5)
>>> lst
[5, 5]

Additionally, note that we could also directly call the decorator function instead of using the @ notation (i.e. print_value = do_twice(print_value)). However, it’s typically useful to place decorators directly above the function that we are modifying since they better describe how these functions are being changed in our code.

Memoization

Notice that the diff functions we wrote in the previous questions are very inefficient: you will likely find that the computer will make the same recursive call multiple times. For a function with multiple arguments and three recursive calls, this can be harder to see. It can be easier to first see this with a function like fib that is defined in lecture.

Fib Tree

Noticed how many redundant recursive calls there are in the above tree diagram. Our goal is to have our program store past results of evaluated recursive calls so that we can reuse them if the same recursive call comes up in the future. For example, the first branch of fib(5) calls fib(3), which has not yet been evaluated. So we must go through all of its subsequent recursive calls to find its return value. However when we encounter the call to fib(3) that is a branch of fib(4), we have already found its return value before! So if we have a way to store and retrieve that information in something called a cache, we can avoid needless computation. We no longer need to make any subsequent recursive calls to its branches fib(1) and fib(2). This is the concept of memoization: store the results of expensive computations in a cache, and retrieve information from the cache in the case we execute a repeated action.

We will be working with two memoization decorators. memo is a general all-purpose decorator that memoizes the function it annotates. If memo encounters an input it has not seen, it will store the calculated result into its cache. If memo receives an input it has already seen, it will take the stored value in the cache and returns it directly without doing any extra computation. We have provided you with the full implementation of memo.

Your task is to implement memo_diff. memo_diff is a higher-order function that takes in a diff_function and returns another diff function called memoized that, like all diff functions, takes in typed, source, and limit. memoized should do the following:

  • When memoized sees a (typed, source) pair for the first time, it should calculate the difference using diff_function and cache that value along with the limit used as a (value, limit) tuple pair.
  • If memoized encounters the (typed, source) pair again, it should return the memoized value if the provided limit is less than or equal to the cached limit. Otherwise, the difference should be recalculated, recached, and returned.

Important: When implementing this function, make sure you store pairs of values in the cache with a tuple, not a list. In dictionaries, keys must be immutable (that’s why using a tuple is fine, but using a list is not). If you’re curious about why memo_diff is different than memo and is implemented in this way, reference the dropdown below:

More information

How do memo and memo_diff differ? Although memo stores only the result of a function call, memo_diff takes into account an additional constraint, limit, that affects whether the cached result can be used or not. When the memo_diff function is called with a (typed, source) pair, it doesn’t just check if the pair has been seen before; it also checks if the limit is less than or equal to the cached limit. This is an additional check that memo does not perform.

Why is limit handled this way? We already know that the limit represents the maximum difference that a diff function cares about—that is, differences above the limit might as well be the same. So diff functions will provide an accurate difference value when it is below the limit and an inaccurate one when it is above the limit. Therefore, we can trust a cached difference value if it was calculated with a higher limit, but we can’t trust ones calculated with a lower limit.

For example, the result of the first call below would allow us to predict the result of the second call. The higher limit provides us with more information. However, the second call would not allow us to predict the first one.

>>> minimum_mewtations("hello", "hasldfasdfsffsfasdf", 100)
17
>>> minimum_mewtations("hello", "hasldfasdfsffsfasdf", 2)
3

Once you’ve implemented memo_diff, finish by:

  1. Decorating autocorrect with memo.
  2. Decorating minimum_mewtations with memo_diff.

Running autocorrect and minimum_mewtations should now be much faster!

Note: If you are failing the autograder tests involving call_count, it is likely that your minimum_mewtations implementation (from Q7) is not having the tightest base cases possible and still needs some optimization. The tests from Q7 are not meant to be strict, so even if you passed the Q7 tests, your base cases might still not be the tightest. Make sure you are not making unnecessary recursive calls. We are being strict about this here because having the tightest base cases is crucial for the efficiency of your code.

Important: Try it yourself first! Only consult the following common mistakes section if you have been stuck on one test case for a while. Otherwise, you might not learn as much from the project.

Common Mistakes

  • Consider the case minimum_mewtations(typed = "maooo", source = "mao", limit = 0): since no transformations are allowed and the two words are not the same, how quick can your function figure out that the result is impossible?
  • Consider the case minimum_mewtations(typed = "habc", source = "hmao", limit = some_limit_greater_than_zero): Given that both strings start with the same character h, what is the most effective approach in this situation? Should the function even attempt to “add” (resulting in habc and mao) or “remove” (resulting in abc and hmao)? Does your implementation take advantage of this optimization?

Note: The autograder takes a bit of time to run, but it should not be longer than 10 seconds.

def memo(f):
"""A general memoization decorator."""
cache = {}

def memoized(*args):
immutable_args = deep_convert_to_tuple(args) # convert *args into a tuple representation
if immutable_args not in cache:
result = f(*immutable_args)
cache[immutable_args] = result
return result
return cache[immutable_args]

return memoized


def memo_diff(diff_function):
"""A memoization function."""
cache = {}

def memoized(typed, source, limit):
# BEGIN PROBLEM EC
"*** YOUR CODE HERE ***"
args = (typed, source)
if (args in cache) and limit > cache[args][1]:
result = diff_function(typed, source, limit)
cache[args] = (result,limit)
return result
elif (args in cache) and limit <= cache[args][1]:
return cache[args][0]
else:
result = diff_function(typed, source, limit)
cache[args] = (result,limit)
return result
# END PROBLEM EC

return memoized

# Refactored version:
def memo_diff(diff_function):
"""A memoization function."""
cache = {}

def memoized(typed, source, limit):
# BEGIN PROBLEM EC
"*** YOUR CODE HERE ***"
args = (typed, source)
if (args in cache) and limit <= cache[args][1]:
return cache[args][0]
else:
result = diff_function(typed, source, limit)
cache[args] = (result,limit)
return result
# END PROBLEM EC
return memoized
def minimum_mewtations(typed: str, source: str, limit: int) -> int:
"""
A diff function for autocorrect that computes the edit distance from TYPED to SOURCE.
This function takes in a string TYPED, a string SOURCE, and a number LIMIT.

Arguments:
typed: a starting word
source: a string representing a desired goal word
limit: a number representing an upper bound on the number of edits

>>> big_limit = 10
>>> minimum_mewtations("cats", "scat", big_limit) # cats -> scats -> scat
2
>>> minimum_mewtations("purng", "purring", big_limit) # purng -> purrng -> purring
2
>>> minimum_mewtations("ckiteus", "kittens", big_limit) # ckiteus -> kiteus -> kitteus -> kittens
3
"""
if abs(len(typed) - len(source)) > limit:
return limit + 1
else:
if typed == source:
return 0
elif limit < 0:
return limit + 1
elif (typed == "") or (source == ""):
return len(typed) + len(source)
elif typed[0] == source[0]:
return minimum_mewtations(typed[1:], source[1:], limit)
else:
add = 1 + minimum_mewtations(typed, source[1:], limit - 1)
remove = 1 + minimum_mewtations(typed[1:], source, limit - 1)
substitute = 1 + minimum_mewtations(typed[1:], source[1:], limit - 1)
return min(add, remove, substitute)