import ast
import grader
from grader.utils import read_code
__all__ = ["template_test", "compare_trees", "is_underscore", "is_wildcard", "pprint_ast", "dump"]
[docs]def dump(T):
" maps nested list of asts to strings "
if isinstance(T, ast.AST):
return pprint_ast(T)
elif type(T) in [list, tuple]:
type_ = type(T)
return type_(map(dump, T))
else:
return T
[docs]def is_underscore(node):
" Function matching a ____ expression. Used to allow to be filled with any expression. "
return isinstance(node, ast.Name) and len(node.id) > 2 and all(x == "_" for x in node.id)
[docs]def is_wildcard(node):
" Function matching a ... expression (used as a wildcard expression) "
return isinstance(node, ast.Expr) and type(node.value) == ast.Ellipsis
def matching_ast_lists(original_list, compared_list, i=0, j=0):
" Returns if the ast lists are identical or not (allowing for wildcards) "
# terrible performance, but enough for a POC
if i == len(original_list) and j == len(compared_list):
return []
if j == len(compared_list):
if all(is_wildcard(x) for x in original_list[i:]):
return []
return [(original_list, compared_list)]
if i == len(original_list):
return [(original_list, compared_list)]
original, compared = original_list[i], compared_list[j]
if is_wildcard(original):
end_wildcard_result = matching_ast_lists(original_list, compared_list, i + 1, j)
if end_wildcard_result == []:
return []
# wildcard matches
return matching_ast_lists(original_list, compared_list, i, j + 1)
else:
return compare_trees(original, compared) + \
matching_ast_lists(original_list, compared_list, i + 1, j + 1)
[docs]def compare_trees(original_tree, compared_tree):
""" Compares two AST trees with each other.
Returns a list of differences, pairs of (expected, got).
In addition to doing straight-forward comparison, it allows
for two kinds of wildcard expressions:
1) ____ expressions which can be filled with a single expression/statement.
2) ... expressions which can be filled with any number of
valid expressions/statements. Used in bodies.
"""
type_differs = lambda a, b: type(a) != type(b)
value_type_in = lambda a, b, _set: any(isinstance(a, x) for x in _set)
if isinstance(original_tree, ast.AST) and isinstance(compared_tree, ast.AST):
# Blanks shown by ____
if is_underscore(original_tree):
return []
# different types of expressions? fail
if type(original_tree) != type(compared_tree):
return [(original_tree, compared_tree)]
result = []
# look over fields
for field_name, _ in ast.iter_fields(original_tree):
expected_new = getattr(original_tree, field_name, None)
got_new = getattr(compared_tree, field_name, None)
# if we should check recursively, do so
if value_type_in(expected_new, got_new, [list, ast.AST]):
result.extend(compare_trees(expected_new, got_new))
elif type_differs(expected_new, got_new) or expected_new != got_new:
result.append((original_tree, compared_tree))
return result
elif (isinstance(original_tree, list) and
isinstance(compared_tree, list)):
return matching_ast_lists(original_tree, compared_tree)
return [(original_tree, compared_tree)]
def load_ast_from_file(file_path):
source = read_code(file_path)
return ast.parse(source)
[docs]def template_test(template_file=None, template_code=None,
description="Program should match the template."):
if template_file is not None:
template_code = read_code(template_file)
assert template_code is not None
template_tree = ast.parse(template_code)
@grader.test
@grader.set_description(description)
@grader.expose_ast
def _inner(m, AST):
result = compare_trees(template_tree, AST)
if result:
m.log(pprint_ast(AST))
m.log(pprint_ast(template_tree))
m.log(dump(result))
assert result == [], (
"Program code does not match template.\n\nTemplate code:\n{}"
.format(template_code)
)
return _inner
def next_in(str, char_set):
for i, ch in enumerate(str):
if ch in char_set:
yield i, ch
def pair_with_next(iterator):
prev = None
for current in iterator:
if prev is not None:
yield prev, current
prev = current
yield current, None
def traverse(tree_string, indent=4):
cut_offs = list(next_in(tree_string, '([])'))
level = 0
block = " " * indent
at = 0
skipNext = False
for (i, ch), N in pair_with_next(cut_offs):
yield tree_string[at:i]
at = i+1
if skipNext:
skipNext = False
yield "\n" + block * level
continue
if ch in '([':
yield ch
if N and N[0] == i+1 and N[1] in "])":
yield N[1]
skipNext = True
else:
level += 1
yield "\n" + block * level
if ch in '])':
level -= 1
yield "\n" + block * level + ch
if not (N and N[0] == i + 1 and N[1] in "])"):
yield "\n" + block * level
[docs]def pprint_ast(tree):
return "".join(traverse(ast.dump(tree)))