# SPDX-FileCopyrightText: 2009 Fermi Research Alliance, LLC
# SPDX-License-Identifier: Apache-2.0
# unparse.py imported from astunparse: https://github.com/simonpercivall/astunparse
# Original Python 2/3 file relying on six transformed in pure Python 3
# Minimal changes applied
# Alternatives would be to pip install the package (adds six requirement) or
# create a simplified function by extracting only the AST nodes used
"""Usage (for testing, will do roundtrips on the files):
unparse.py <path to source files>
unparse.py --testdir <path to source files directories>
"""
import ast
import os
import sys
import tokenize
from io import StringIO
# Large float and imaginary literals get turned into infinities in the AST.
# We unparse those infinities to INFSTR.
INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
[docs]
def interleave(inter, f, seq):
"""Call f on each item in seq, calling inter() in between."""
seq = iter(seq)
try:
f(next(seq))
except StopIteration:
pass
else:
for x in seq:
inter()
f(x)
[docs]
class Unparser:
"""Methods in this class recursively traverse an AST and
output source code for the abstract syntax; original formatting
is disregarded."""
def __init__(self, tree, file=sys.stdout):
"""Unparser(tree, file=sys.stdout) -> None.
Print the source for tree to file."""
self.f = file
self.future_imports = []
self._indent = 0
self.dispatch(tree)
print("", file=self.f)
self.f.flush()
[docs]
def fill(self, text=""):
"Indent a piece of text, according to the current indentation level"
self.f.write("\n" + " " * self._indent + text)
[docs]
def write(self, text):
"Append a piece of text to the current line."
self.f.write(str(text))
[docs]
def enter(self):
"Print ':', and increase the indentation."
self.write(":")
self._indent += 1
[docs]
def leave(self):
"Decrease the indentation level."
self._indent -= 1
[docs]
def dispatch(self, tree):
"Dispatcher function, dispatching tree type T to method _T."
if isinstance(tree, list):
for t in tree:
self.dispatch(t)
return
meth = getattr(self, "_" + tree.__class__.__name__)
meth(tree)
############### Unparsing methods ######################
# There should be one method per concrete grammar type #
# Constructors should be grouped by sum type. Ideally, #
# this would follow the order in the grammar, but #
# currently doesn't. #
########################################################
[docs]
def _Module(self, tree):
for stmt in tree.body:
self.dispatch(stmt)
[docs]
def _Interactive(self, tree):
for stmt in tree.body:
self.dispatch(stmt)
[docs]
def _Expression(self, tree):
self.dispatch(tree.body)
# stmt
[docs]
def _Expr(self, tree):
self.fill()
self.dispatch(tree.value)
[docs]
def _NamedExpr(self, tree):
self.write("(")
self.dispatch(tree.target)
self.write(" := ")
self.dispatch(tree.value)
self.write(")")
[docs]
def _Import(self, t):
self.fill("import ")
interleave(lambda: self.write(", "), self.dispatch, t.names)
[docs]
def _ImportFrom(self, t):
# A from __future__ import may affect unparsing, so record it.
if t.module and t.module == "__future__":
self.future_imports.extend(n.name for n in t.names)
self.fill("from ")
self.write("." * t.level)
if t.module:
self.write(t.module)
self.write(" import ")
interleave(lambda: self.write(", "), self.dispatch, t.names)
[docs]
def _Assign(self, t):
self.fill()
for target in t.targets:
self.dispatch(target)
self.write(" = ")
self.dispatch(t.value)
[docs]
def _AugAssign(self, t):
self.fill()
self.dispatch(t.target)
self.write(" " + self.binop[t.op.__class__.__name__] + "= ")
self.dispatch(t.value)
[docs]
def _AnnAssign(self, t):
self.fill()
if not t.simple and isinstance(t.target, ast.Name):
self.write("(")
self.dispatch(t.target)
if not t.simple and isinstance(t.target, ast.Name):
self.write(")")
self.write(": ")
self.dispatch(t.annotation)
if t.value:
self.write(" = ")
self.dispatch(t.value)
[docs]
def _Return(self, t):
self.fill("return")
if t.value:
self.write(" ")
self.dispatch(t.value)
[docs]
def _Pass(self, t):
self.fill("pass")
[docs]
def _Break(self, t):
self.fill("break")
[docs]
def _Continue(self, t):
self.fill("continue")
[docs]
def _Delete(self, t):
self.fill("del ")
interleave(lambda: self.write(", "), self.dispatch, t.targets)
[docs]
def _Assert(self, t):
self.fill("assert ")
self.dispatch(t.test)
if t.msg:
self.write(", ")
self.dispatch(t.msg)
[docs]
def _Exec(self, t):
self.fill("exec ")
self.dispatch(t.body)
if t.globals:
self.write(" in ")
self.dispatch(t.globals)
if t.locals:
self.write(", ")
self.dispatch(t.locals)
[docs]
def _Print(self, t):
self.fill("print ")
do_comma = False
if t.dest:
self.write(">>")
self.dispatch(t.dest)
do_comma = True
for e in t.values:
if do_comma:
self.write(", ")
else:
do_comma = True
self.dispatch(e)
if not t.nl:
self.write(",")
[docs]
def _Global(self, t):
self.fill("global ")
interleave(lambda: self.write(", "), self.write, t.names)
[docs]
def _Nonlocal(self, t):
self.fill("nonlocal ")
interleave(lambda: self.write(", "), self.write, t.names)
[docs]
def _Await(self, t):
self.write("(")
self.write("await")
if t.value:
self.write(" ")
self.dispatch(t.value)
self.write(")")
[docs]
def _Yield(self, t):
self.write("(")
self.write("yield")
if t.value:
self.write(" ")
self.dispatch(t.value)
self.write(")")
[docs]
def _YieldFrom(self, t):
self.write("(")
self.write("yield from")
if t.value:
self.write(" ")
self.dispatch(t.value)
self.write(")")
[docs]
def _Raise(self, t):
self.fill("raise")
if not t.exc:
assert not t.cause
return
self.write(" ")
self.dispatch(t.exc)
if t.cause:
self.write(" from ")
self.dispatch(t.cause)
[docs]
def _Try(self, t):
self.fill("try")
self.enter()
self.dispatch(t.body)
self.leave()
for ex in t.handlers:
self.dispatch(ex)
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
if t.finalbody:
self.fill("finally")
self.enter()
self.dispatch(t.finalbody)
self.leave()
# Removed _TryExcept and _TryFinally that are only in Python 2
# Python 3 has only _Try including both the except and finally clauses
[docs]
def _ExceptHandler(self, t):
self.fill("except")
if t.type:
self.write(" ")
self.dispatch(t.type)
if t.name:
self.write(" as ")
self.write(t.name)
self.enter()
self.dispatch(t.body)
self.leave()
[docs]
def _ClassDef(self, t):
self.write("\n")
for deco in t.decorator_list:
self.fill("@")
self.dispatch(deco)
self.fill("class " + t.name)
self.write("(")
comma = False
for e in t.bases:
if comma:
self.write(", ")
else:
comma = True
self.dispatch(e)
for e in t.keywords:
if comma:
self.write(", ")
else:
comma = True
self.dispatch(e)
if sys.version_info[:2] < (3, 5):
if t.starargs:
if comma:
self.write(", ")
else:
comma = True
self.write("*")
self.dispatch(t.starargs)
if t.kwargs:
if comma:
self.write(", ")
else:
comma = True
self.write("**")
self.dispatch(t.kwargs)
self.write(")")
self.enter()
self.dispatch(t.body)
self.leave()
[docs]
def _FunctionDef(self, t):
self.__FunctionDef_helper(t, "def")
[docs]
def _AsyncFunctionDef(self, t):
self.__FunctionDef_helper(t, "async def")
def __FunctionDef_helper(self, t, fill_suffix):
self.write("\n")
for deco in t.decorator_list:
self.fill("@")
self.dispatch(deco)
def_str = fill_suffix + " " + t.name + "("
self.fill(def_str)
self.dispatch(t.args)
self.write(")")
if getattr(t, "returns", False):
self.write(" -> ")
self.dispatch(t.returns)
self.enter()
self.dispatch(t.body)
self.leave()
[docs]
def _For(self, t):
self.__For_helper("for ", t)
[docs]
def _AsyncFor(self, t):
self.__For_helper("async for ", t)
def __For_helper(self, fill, t):
self.fill(fill)
self.dispatch(t.target)
self.write(" in ")
self.dispatch(t.iter)
self.enter()
self.dispatch(t.body)
self.leave()
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
[docs]
def _If(self, t):
self.fill("if ")
self.dispatch(t.test)
self.enter()
self.dispatch(t.body)
self.leave()
# collapse nested ifs into equivalent elifs.
while t.orelse and len(t.orelse) == 1 and isinstance(t.orelse[0], ast.If):
t = t.orelse[0]
self.fill("elif ")
self.dispatch(t.test)
self.enter()
self.dispatch(t.body)
self.leave()
# final else
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
[docs]
def _While(self, t):
self.fill("while ")
self.dispatch(t.test)
self.enter()
self.dispatch(t.body)
self.leave()
if t.orelse:
self.fill("else")
self.enter()
self.dispatch(t.orelse)
self.leave()
[docs]
def _generic_With(self, t, async_=False):
self.fill("async with " if async_ else "with ")
if hasattr(t, "items"):
interleave(lambda: self.write(", "), self.dispatch, t.items)
else:
self.dispatch(t.context_expr)
if t.optional_vars:
self.write(" as ")
self.dispatch(t.optional_vars)
self.enter()
self.dispatch(t.body)
self.leave()
[docs]
def _With(self, t):
self._generic_With(t)
[docs]
def _AsyncWith(self, t):
self._generic_With(t, async_=True)
# expr
[docs]
def _Bytes(self, t):
self.write(repr(t.s))
[docs]
def _Str(self, tree):
self.write(repr(tree.s))
[docs]
def _JoinedStr(self, t):
# JoinedStr(expr* values)
self.write("f")
string = StringIO()
self._fstring_JoinedStr(t, string.write)
# Deviation from `unparse.py`: Try to find an unused quote.
# This change is made to handle _very_ complex f-strings.
v = string.getvalue()
if "\n" in v or "\r" in v:
quote_types = ["'''", '"""']
else:
quote_types = ["'", '"', '"""', "'''"]
for quote_type in quote_types:
if quote_type not in v:
v = "{quote_type}{v}{quote_type}".format(quote_type=quote_type, v=v)
break
else:
v = repr(v)
self.write(v)
[docs]
def _fstring_JoinedStr(self, t, write):
for value in t.values:
meth = getattr(self, "_fstring_" + type(value).__name__)
meth(value, write)
[docs]
def _fstring_Str(self, t, write):
value = t.s.replace("{", "{{").replace("}", "}}")
write(value)
[docs]
def _fstring_Constant(self, t, write):
assert isinstance(t.value, str)
value = t.value.replace("{", "{{").replace("}", "}}")
write(value)
[docs]
def _Name(self, t):
self.write(t.id)
[docs]
def _NameConstant(self, t):
self.write(repr(t.value))
[docs]
def _Repr(self, t):
self.write("`")
self.dispatch(t.value)
self.write("`")
[docs]
def _write_constant(self, value):
if isinstance(value, (float, complex)):
# Substitute overflowing decimal literal for AST infinities.
self.write(repr(value).replace("inf", INFSTR))
else:
self.write(repr(value))
[docs]
def _Constant(self, t):
value = t.value
if isinstance(value, tuple):
self.write("(")
if len(value) == 1:
self._write_constant(value[0])
self.write(",")
else:
interleave(lambda: self.write(", "), self._write_constant, value)
self.write(")")
elif value is Ellipsis: # instead of `...` for Py2 compatibility
self.write("...")
else:
if t.kind == "u":
self.write("u")
self._write_constant(t.value)
[docs]
def _Num(self, t):
repr_n = repr(t.n)
self.write(repr_n.replace("inf", INFSTR))
[docs]
def _List(self, t):
self.write("[")
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write("]")
[docs]
def _ListComp(self, t):
self.write("[")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write("]")
[docs]
def _GeneratorExp(self, t):
self.write("(")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write(")")
[docs]
def _SetComp(self, t):
self.write("{")
self.dispatch(t.elt)
for gen in t.generators:
self.dispatch(gen)
self.write("}")
[docs]
def _DictComp(self, t):
self.write("{")
self.dispatch(t.key)
self.write(": ")
self.dispatch(t.value)
for gen in t.generators:
self.dispatch(gen)
self.write("}")
[docs]
def _comprehension(self, t):
if getattr(t, "is_async", False):
self.write(" async for ")
else:
self.write(" for ")
self.dispatch(t.target)
self.write(" in ")
self.dispatch(t.iter)
for if_clause in t.ifs:
self.write(" if ")
self.dispatch(if_clause)
[docs]
def _IfExp(self, t):
self.write("(")
self.dispatch(t.body)
self.write(" if ")
self.dispatch(t.test)
self.write(" else ")
self.dispatch(t.orelse)
self.write(")")
[docs]
def _Set(self, t):
assert t.elts # should be at least one element
self.write("{")
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write("}")
[docs]
def _Dict(self, t):
self.write("{")
def write_key_value_pair(k, v):
self.dispatch(k)
self.write(": ")
self.dispatch(v)
def write_item(item):
k, v = item
if k is None:
# for dictionary unpacking operator in dicts {**{'y': 2}}
# see PEP 448 for details
self.write("**")
self.dispatch(v)
else:
write_key_value_pair(k, v)
interleave(lambda: self.write(", "), write_item, zip(t.keys, t.values))
self.write("}")
[docs]
def _Tuple(self, t):
self.write("(")
if len(t.elts) == 1:
elt = t.elts[0]
self.dispatch(elt)
self.write(",")
else:
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write(")")
unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
[docs]
def _UnaryOp(self, t):
self.write("(")
self.write(self.unop[t.op.__class__.__name__])
self.write(" ")
self.dispatch(t.operand)
self.write(")")
binop = {
"Add": "+",
"Sub": "-",
"Mult": "*",
"MatMult": "@",
"Div": "/",
"Mod": "%",
"LShift": "<<",
"RShift": ">>",
"BitOr": "|",
"BitXor": "^",
"BitAnd": "&",
"FloorDiv": "//",
"Pow": "**",
}
[docs]
def _BinOp(self, t):
self.write("(")
self.dispatch(t.left)
self.write(" " + self.binop[t.op.__class__.__name__] + " ")
self.dispatch(t.right)
self.write(")")
cmpops = {
"Eq": "==",
"NotEq": "!=",
"Lt": "<",
"LtE": "<=",
"Gt": ">",
"GtE": ">=",
"Is": "is",
"IsNot": "is not",
"In": "in",
"NotIn": "not in",
}
[docs]
def _Compare(self, t):
self.write("(")
self.dispatch(t.left)
for o, e in zip(t.ops, t.comparators):
self.write(" " + self.cmpops[o.__class__.__name__] + " ")
self.dispatch(e)
self.write(")")
boolops = {ast.And: "and", ast.Or: "or"}
[docs]
def _BoolOp(self, t):
self.write("(")
s = " %s " % self.boolops[t.op.__class__]
interleave(lambda: self.write(s), self.dispatch, t.values)
self.write(")")
[docs]
def _Attribute(self, t):
self.dispatch(t.value)
# Special case: 3.__abs__() is a syntax error, so if t.value
# is an integer literal then we need to either parenthesize
# it or add an extra space to get 3 .__abs__().
if isinstance(t.value, getattr(ast, "Constant", getattr(ast, "Num", None))) and isinstance(t.value.n, int):
self.write(" ")
self.write(".")
self.write(t.attr)
[docs]
def _Call(self, t):
self.dispatch(t.func)
self.write("(")
comma = False
for e in t.args:
if comma:
self.write(", ")
else:
comma = True
self.dispatch(e)
for e in t.keywords:
if comma:
self.write(", ")
else:
comma = True
self.dispatch(e)
if sys.version_info[:2] < (3, 5):
if t.starargs:
if comma:
self.write(", ")
else:
comma = True
self.write("*")
self.dispatch(t.starargs)
if t.kwargs:
if comma:
self.write(", ")
else:
comma = True
self.write("**")
self.dispatch(t.kwargs)
self.write(")")
[docs]
def _Subscript(self, t):
self.dispatch(t.value)
self.write("[")
self.dispatch(t.slice)
self.write("]")
[docs]
def _Starred(self, t):
self.write("*")
self.dispatch(t.value)
# slice
[docs]
def _Ellipsis(self, t):
self.write("...")
[docs]
def _Index(self, t):
self.dispatch(t.value)
[docs]
def _Slice(self, t):
if t.lower:
self.dispatch(t.lower)
self.write(":")
if t.upper:
self.dispatch(t.upper)
if t.step:
self.write(":")
self.dispatch(t.step)
[docs]
def _ExtSlice(self, t):
interleave(lambda: self.write(", "), self.dispatch, t.dims)
# argument
[docs]
def _arg(self, t):
self.write(t.arg)
if t.annotation:
self.write(": ")
self.dispatch(t.annotation)
# others
[docs]
def _arguments(self, t):
first = True
# normal arguments
all_args = getattr(t, "posonlyargs", []) + t.args
defaults = [None] * (len(all_args) - len(t.defaults)) + t.defaults
for index, elements in enumerate(zip(all_args, defaults), 1):
a, d = elements
if first:
first = False
else:
self.write(", ")
self.dispatch(a)
if d:
self.write("=")
self.dispatch(d)
if index == len(getattr(t, "posonlyargs", ())):
self.write(", /")
# varargs, or bare '*' if no varargs but keyword-only arguments present
if t.vararg or getattr(t, "kwonlyargs", False):
if first:
first = False
else:
self.write(", ")
self.write("*")
if t.vararg:
if hasattr(t.vararg, "arg"):
self.write(t.vararg.arg)
if t.vararg.annotation:
self.write(": ")
self.dispatch(t.vararg.annotation)
else:
self.write(t.vararg)
if getattr(t, "varargannotation", None):
self.write(": ")
self.dispatch(t.varargannotation)
# keyword-only arguments
if getattr(t, "kwonlyargs", False):
for a, d in zip(t.kwonlyargs, t.kw_defaults):
if first:
first = False
else:
self.write(", ")
self.dispatch(a),
if d:
self.write("=")
self.dispatch(d)
# kwargs
if t.kwarg:
if first:
first = False
else:
self.write(", ")
if hasattr(t.kwarg, "arg"):
self.write("**" + t.kwarg.arg)
if t.kwarg.annotation:
self.write(": ")
self.dispatch(t.kwarg.annotation)
else:
self.write("**" + t.kwarg)
if getattr(t, "kwargannotation", None):
self.write(": ")
self.dispatch(t.kwargannotation)
[docs]
def _keyword(self, t):
if t.arg is None:
# starting from Python 3.5 this denotes a kwargs part of the invocation
self.write("**")
else:
self.write(t.arg)
self.write("=")
self.dispatch(t.value)
[docs]
def _Lambda(self, t):
self.write("(")
self.write("lambda ")
self.dispatch(t.args)
self.write(": ")
self.dispatch(t.body)
self.write(")")
[docs]
def _alias(self, t):
self.write(t.name)
if t.asname:
self.write(" as " + t.asname)
[docs]
def _withitem(self, t):
self.dispatch(t.context_expr)
if t.optional_vars:
self.write(" as ")
self.dispatch(t.optional_vars)
[docs]
def roundtrip(filename, output=sys.stdout):
with open(filename, "rb") as pyfile:
encoding = tokenize.detect_encoding(pyfile.readline)[0]
with open(filename, encoding=encoding) as pyfile:
source = pyfile.read()
tree = compile(source, filename, "exec", ast.PyCF_ONLY_AST, dont_inherit=True)
Unparser(tree, output)
[docs]
def testdir(a):
try:
names = [n for n in os.listdir(a) if n.endswith(".py")]
except OSError:
print("Directory not readable: %s" % a, file=sys.stderr)
else:
for n in names:
fullname = os.path.join(a, n)
if os.path.isfile(fullname):
output = StringIO()
print("Testing %s" % fullname)
try:
roundtrip(fullname, output)
except Exception as e:
print(" Failed to compile, exception is %s" % repr(e))
elif os.path.isdir(fullname):
testdir(fullname)
[docs]
def main(args):
if args[0] == "--testdir":
for a in args[1:]:
testdir(a)
else:
for a in args:
roundtrip(a)
if __name__ == "__main__":
main(sys.argv[1:])