changeset 0:3107bd7ca4af

Initial import
author Sigurd Meldgaard <stm@daimi.au.dk>
date Tue, 11 Nov 2008 10:12:17 +0100
parents
children 10b59d5be10e
files flow.py pretty_print.py secret_ifs.py test.py util.py
diffstat 5 files changed, 600 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/flow.py	Tue Nov 11 10:12:17 2008 +0100
@@ -0,0 +1,138 @@
+import inspect
+import ast
+import pretty_print
+import secret_ifs
+
+
+class FunctionFinder(ast.NodeVisitor):
+
+    def visit_FunctionDef(self, node):
+        self.generic_visit(node)
+        print node.name
+
+
+class Node:
+    count = 0
+
+    def __init__(self, msg = ""):
+        self.out=[]
+        Node.count += 1
+        self.nr = str(Node.count)
+        self.msg = msg
+
+
+class Graph:
+
+    def __init__(self, entry, ex):
+        self.entry = entry
+        self.ex = ex
+
+    def to_dot(self):
+        taken = set()
+        print "digraph G {"
+        print "  in -> %s" % self.entry.nr
+        print '  in [shape = plaintext, label=""]'
+        print "  %s -> out" % self.ex.nr
+        print '  out [shape = plaintext, label=""]'
+        stack = [self.entry]
+        while(True):
+            while(True):
+                if len(stack)==0:
+                    current = None
+                    break
+                current = stack.pop()
+                if current not in taken:
+                    break
+            if not current:
+                break
+            print '    %s [label="%s"]' % (current.nr, current.msg)
+            taken.add(current)
+            for i in current.out:
+                print "    %s->%s;" % (current.nr, i.nr)
+
+            stack += current.out # list concatenation
+        print "}"
+
+
+def flow(module):
+    function = module.body[0]
+    return flow_body(function.body, "def %s:" % function.name)
+
+
+def flow_body(body, msg="", exit_function=None,
+              exit_loop=None):
+    entry = Node("enter %s" % msg)
+    ex = Node("exit %s" % msg)
+    if not exit_function:
+        exit_function = ex
+    current = entry
+    for stm in body:
+        if isinstance(stm, ast.If):
+            body_g = flow_body(stm.body,
+                               msg="if(%s)" % pretty_print.expr_string(stm.test),
+                               exit_function=exit_function,
+                               exit_loop=exit_loop)
+            else_g = flow_body(stm.orelse,
+                               msg="if not(%s)" % pretty_print.expr_string(stm.test),
+                               exit_function=exit_function,
+                               exit_loop=exit_loop)
+            current.out.append(body_g.entry)
+            current.out.append(else_g.entry)
+            current = Node("endif(%s)" % pretty_print.expr_string(stm.test))
+            body_g.ex.out.append(current)
+            else_g.ex.out.append(current)
+        elif isinstance(stm, ast.While):
+            loop_done = Node()
+            body_g = flow_body(stm.body,
+                               "while(%s)" % pretty_print.expr_string(stm.test),
+                               exit_function=exit_function,
+                               exit_loop=ex)
+            body_g.ex.out.append(body_g.entry)
+            current.out.append(body_g.ex)
+            current = body_g.ex
+
+        elif isinstance(stm, ast.Break):
+#            print current.msg
+            current.out.append(exit_loop)
+            return Graph(entry, Node("Dummy"))
+        elif isinstance(stm, ast.Return):
+            current.out.append(exit_function)
+            return Graph(entry, Node("Dummy"))
+    current.out.append(ex)
+    return Graph(entry, ex)
+
+
+def ideal_functionality(f):
+    source_str = inspect.getsource(f)
+    source_ast = ast.parse(source_str)
+    FunctionFinder().visit(source_ast)
+    pretty_print.PrettyPrinter().visit(source_ast)
+    t = secret_ifs.TransformIfs().visit(source_ast)
+    pretty_print.PrettyPrinter().visit(t)
+
+
+#flow(test)
+flow_graph = flow(ast.parse("""
+def a(a, b, c):
+    a = 2
+    if c:
+        b = 2
+        while True:
+            d = 3
+            if a:
+                break
+            else:
+                pass
+            if b:
+                return
+            else:
+                pass
+    else:
+        print("hej")
+"""))
+"""
+def a(b, c):
+    while(b):
+        break
+"""
+flow_graph.to_dot()
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pretty_print.py	Tue Nov 11 10:12:17 2008 +0100
@@ -0,0 +1,339 @@
+"""
+Pretty printing of the Python built-in ast-type.
+
+The goal is to print a program with the exact same semantics as the
+program parsed into the ast.
+"""
+
+import ast
+
+
+def arguments_string(arguments):
+    """ Returns a string representing the argument list, surrounded by
+    parenthesis.
+    """
+    if(arguments.vararg):
+        assert False, "Not implemented"
+    if(arguments.kwarg):
+        assert False, "Not implemented"
+    args = ", ".join([i.id for i in arguments.args])
+    return "(" + args + ")"
+
+
+def op_string(op):
+    """ Returns a string representing the operator *op*.
+    """
+    ops = {ast.Add: "+",
+           ast.Sub: "-",
+           ast.Mult: "*",
+           ast.Div: "/",
+           ast.Mod: "%",
+           ast.Pow: "**",
+           ast.LShift: "<<",
+           ast.RShift: ">>",
+           ast.BitOr: "|",
+           ast.BitXor: "^",
+           ast.BitAnd: "&",
+           ast.FloorDiv: "//",
+           ast.Invert: "~",
+           ast.Not: "not",
+           ast.UAdd: "+",
+           ast.USub: "-",
+           ast.Eq: "==",
+           ast.NotEq: "!=",
+           ast.Lt: "<",
+           ast.LtE: "<=",
+           ast.Gt: ">",
+           ast.GtE: ">=",
+           ast.Is: "is",
+           ast.IsNot: "is not",
+           ast.In: "in",
+           ast.NotIn: "not in",
+           ast.And: "and",
+           ast.Or: "or",
+           }
+    assert op.__class__ in ops, "Not implemented %s" % op
+    return ops[op.__class__]
+
+
+def op_precedence(op):
+    """ Returns the operator precedence of *op*
+    Made after the reference found here:
+    http://www.ibiblio.org/g2swap/byteofpython/read/operator-precedence.html
+    """
+
+    ops = {
+        ast.Or: 2,
+        ast.And: 3,
+        ast.Not: 4,
+        ast.In: 5,
+        ast.NotIn: 5,
+        ast.Is: 6,
+        ast.IsNot: 6,
+        ast.NotEq: 7,
+        ast.Lt: 7,
+        ast.LtE: 7,
+        ast.Gt: 7,
+        ast.GtE: 7,
+        ast.Eq: 7,
+        ast.BitOr: 8,
+        ast.BitXor: 9,
+        ast.BitAnd: 10,
+        ast.LShift: 11,
+        ast.RShift: 11,
+        ast.Add: 12,
+        ast.Sub: 12,
+        ast.Mult: 13,
+        ast.Div: 13,
+        ast.Mod: 13,
+        ast.FloorDiv: 13,
+        ast.UAdd: 14,
+        ast.USub: 14,
+        ast.Invert: 15,
+        ast.Pow: 16,
+           }
+    return ops[op.__class__]
+
+
+def expr_string(exp, prec=1):
+    """Returns a string representation of the expression.
+    *prec* is the precedence of the surrounding of the expression,
+    and determines if parenthesis are neccesary.
+    """
+    if hasattr(exp, "op"):
+        my_precedence = op_precedence(exp.op)
+    if(isinstance(exp, ast.BoolOp)):
+        assert len(exp.values) == 2, "Not implemented"
+
+        r = "%s %s %s" % \
+                 (expr_string(exp.values[0], my_precedence),
+                  op_string(exp.op),
+                  expr_string(exp.values[1]), my_precedence),
+    elif(isinstance(exp, ast.BinOp)):
+        r = "%s %s %s" % \
+                    (expr_string(exp.left, my_precedence),
+                     op_string(exp.op),
+                     expr_string(exp.right, my_precedence))
+    elif(isinstance(exp, ast.Compare)):
+        my_precedence = op_precedence(exp.ops[0])
+        if len(exp.ops)>1:
+            assert False, "Not implemented"
+        if len(exp.comparators)>1:
+            assert False, "Not implemented"
+        r = "%s %s %s" % \
+            (expr_string(exp.left), op_string(exp.ops[0]),
+             expr_string(exp.comparators[0]))
+    elif(isinstance(exp, ast.UnaryOp)):
+        r = " %s (%s)" % \
+            (op_string(exp.op), expr_string(exp.operand),
+             op_precedence(exp.op))
+    elif(isinstance(exp, ast.Lambda)):
+        my_precedence = 1
+        r = "lambda%s : %s" % \
+            (arguments_string(exp.args),
+             expr_string(exp.body, my_precedence))
+    elif(isinstance(exp, ast.IfExp)):
+        assert False, "Not implemented"
+    elif(isinstance(exp, ast.Dict)):
+        my_precedence = 25
+        r = "{"+" ".join(["%s : %s" % (expr_string(i),
+                                       expr_string(j)) for (i, j) in
+                          zip(exp.keys, exp.values)])+"}"
+    elif(isinstance(exp, ast.ListComp)):
+        assert False, "Not implemented"
+    elif(isinstance(exp, ast.Name)):
+        my_precedence = 100
+        r = exp.id
+    elif(isinstance(exp, ast.Num)):
+        my_precedence = 100
+        r = str(exp.n)
+    elif(isinstance(exp, ast.Yield)):
+        my_precedence = 22 # Binds like a function call, Is this correct?
+        r = "yield(%s)" % expr_string(exp.value)
+    elif(isinstance(exp, ast.Call)):
+        my_precedence = 22
+        args_list = []
+        if len(exp.args)>0:
+            args_list.append(", ".join([expr_string(i, 1)
+                                        for i in exp.args]))
+        if len(exp.keywords)>0:
+            args_list.append(", ".join([i.arg +
+                                        "=" + expr_string(i.value, 1)
+                                        for i in exp.keywords]))
+        if exp.starargs:
+            args_list.append("*"+expr_string(exp.starargs, 1))
+        if exp.kwargs:
+            args_list.append("**"+expr_string(exp.kwargs, 1))
+        r = "%s(%s)" % \
+            (expr_string(exp.func, my_precedence),
+             ", ".join(args_list))
+    elif(isinstance(exp, dict)):
+        print exp
+        assert False, "Type error"
+    else:
+        assert False, "Not implemented of type %s" % type(exp)
+
+    if my_precedence <= prec:
+        return "(" + r + ")"
+    else:
+        return r
+    # TODO:
+#             | GeneratorExp(expr elt, comprehension* generators)
+#             | Repr(expr value)
+#             | Num(object n) -- a number as a PyObject.
+#             | Str(string s) -- need to specify raw, unicode, etc?
+#             -- other literals? bools?
+#
+#             -- the following expression can appear in assignment context
+#             | Attribute(expr value, identifier attr, expr_context ctx)
+#             | Subscript(expr value, slice slice, expr_context ctx)
+#             | Name(identifier id, expr_context ctx)
+#             | List(expr* elts, expr_context ctx)
+#             | Tuple(expr* elts, expr_context ctx)
+#
+#              attributes (int lineno, int col_offset)
+#
+#        expr_context = Load | Store | Del | AugLoad | AugStore | Param
+#
+#        slice = Ellipsis | Slice(expr? lower, expr? upper, expr? step)
+#              | ExtSlice(slice* dims)
+#              | Index(expr value)
+#
+#
+#        comprehension = (expr target, expr iter, expr* ifs)
+#
+#        -- not sure what to call the first argument for raise and except
+#        excepthandler = ExceptHandler(expr? type, expr? name, stmt* body)
+#                        attributes (int lineno, int col_offset)
+#
+#        arguments = (expr* args, identifier? vararg,
+#                    identifier? kwarg, expr* defaults)
+#
+#        -- keyword arguments supplied to call
+#        keyword = (identifier arg, expr value)
+#
+#        -- import name with optional 'as' alias.
+#        alias = (identifier name, identifier? asname)
+
+
+class PrettyPrinter(ast.NodeVisitor):
+    indent = 0
+
+    def print_body(self, body):
+        self.indent += 1
+        for i in body:
+            self.visit(i)
+        self.indent -= 1
+
+    def print_indented(self, str):
+        print (" "*(self.indent*4))+str
+
+    def visit_ClassDef(self, node):
+        for i in node.decorator_list:
+            self.print_indented("@%s" % expr_string(i))
+        self.print_indented("class %s(%s)" % (node.name,
+                                              [i.name for i in node.bases]))
+        self.print_body(node.body)
+
+    def visit_Return(self, node):
+        self.print_indented("return %s" % expr_string(node.value))
+
+    def visit_Delete(self, node):
+        assert False, "Not implemented"
+
+    def visit_Assign(self, node):
+        assert len(node.targets) == 1, "Not implemented"
+        self.print_indented("%s = %s" % (expr_string(node.targets[0]),
+                                         expr_string(node.value)))
+
+    def visit_AugAssign(self, node):
+        self.print_indented("%s %s= %s" % (expr_string(node.target),
+                                           op_string(node.op),
+                                           expr_string(node.value)))
+
+    def visit_Print(self, node):
+        if node.dest or not node.nl:
+            assert False, "Not implemented"
+        else:
+            self.print_indented("print %s" % \
+                               (", ".join([expr_string(i)
+                                           for i in node.values])))
+
+    def visit_For(self, node):
+        self.print_indented("for %s in %s:" % (expr_string(node.target),
+                                              expr_string(node.iter)))
+        self.print_body(node.body)
+        if(len(node.orelse)>0):
+            self.print_indented("else:")
+            self.print_body(node.orelse)
+
+    def visit_While(self, node):
+        self.print_indented("while(%s):" % expr_string(node.test))
+        self.print_body(node.body)
+        if(len(node.orelse)>0):
+            self.print_indented("else:")
+            self.print_body(node.orelse)
+
+    def visit_If(self, node):
+        self.print_indented("if(%s):" % expr_string(node.test))
+        self.print_body(node.body)
+        if(len(node.orelse)>0):
+            self.print_indented("else:")
+            self.print_body(node.orelse)
+
+    def visit_With(self, node):
+        assert False, "Not implemented"
+
+    def visit_Raise(self, node):
+        assert False, "Not implemented"
+
+    def visit_TryExcept(self, node):
+        assert False, "Not implemented"
+
+    def visit_TryFinally(self, node):
+        assert False, "Not implemented"
+
+    def visit_Assert(self, node):
+        if(node.msg):
+            self.print_indented("assert %s : %s", (expr_string(node.test),
+                                                   expr_string(node.msg)))
+        else:
+            self.print_indented("assert %s", expr_string(node.test))
+
+    def visit_Import(self, node):
+        assert False, "Not implemented"
+
+    def visit_ImportFrom(self, node):
+        assert False, "Not implemented"
+
+    def visit_Exec(self, node):
+        assert False, "Not implemented"
+
+    def visit_Global(self, node):
+        self.print_indented("global " + ", ".join(node.names))
+
+    def visit_Expr(self, node):
+        self.print_indented(expr_string(node.value))
+
+    def visit_Pass(self, node):
+        self.print_indented("pass")
+
+    def visit_Break(self, node):
+        self.print_indented("break")
+
+    def visit_Continue(self, node):
+        self.print_indented("continue")
+
+    def visit_FunctionDef(self, node):
+        for i in node.decorator_list:
+            self.print_indented("@%s" % expr_string(i))
+        self.print_indented("def %s%s:" % (node.name,
+                                             arguments_string(node.args)))
+        self.print_body(node.body)
+
+
+def pprint(module):
+    """Pretty prints the module represented by the ast node in
+    ***module***
+    """
+    PrettyPrinter().visit(module)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/secret_ifs.py	Tue Nov 11 10:12:17 2008 +0100
@@ -0,0 +1,82 @@
+"""
+Defines an ast-visitor *TransformIfs* finding *if*s branching on
+secret values, and transforming them into equivalent series of
+assignments
+"""
+
+import ast
+from util import error
+
+
+class Assignments(ast.NodeVisitor):
+
+    def __init__(self):
+        self.assigned={}
+
+    def visit_Assign(self, node):
+        varname = node.targets[0].id
+        if varname in self.assigned:
+            error("Error: cannot assign several times"+
+                  "to \"%s\" in a secret if" % varname, node)
+        ## TODO only single-target assignment
+        self.assigned[varname] = node.value
+
+
+def get_assignments(body):
+    visitor = Assignments()
+    [visitor.visit(i) for i in body]
+    return visitor.assigned
+
+
+def only_assignments(body):
+    for i in body:
+        if not (isinstance(i, ast.Assign) or
+                isinstance(i, ast.Pass)):
+            error("Error,\
+ inside secret ifs, we only allow assignments, but found %s" % type(i),
+                  i)
+
+
+class TransformIfs(ast.NodeTransformer):
+    """
+    >>> from ast import *
+    >>> from pretty_print import *
+    >>> prog = parse("def f(x):\n\tif(x):\n\t\ta=1\n\telse:\n\t\ta=2")
+    >>> TransfromIfs().visit(prog)
+    >>> pprint(prog)
+    def f(x):
+        cond0 = x
+        a = cond0 * 1 + (1 - cond0) * 2
+    """
+    cond_counter = 0
+
+    def visit_If(self, node):
+        #if node.test is secret:
+        only_assignments(node.body)
+        assigned_then = get_assignments(node.body)
+        assigned_else = get_assignments(node.orelse)
+        print assigned_then
+        print assigned_else
+        all_assigned = set(assigned_then.keys() + assigned_else.keys())
+        r = []
+        condname = ast.Name(id="cond%d" % self.cond_counter)
+        self.cond_counter += 1
+        r.append(ast.Assign(targets=[condname],
+                            value=node.test))
+        for i in all_assigned:
+            then_value = assigned_then.get(i, ast.Name(id=i))
+            else_value = assigned_else.get(i, ast.Name(id=i))
+            r.append(ast.Assign(targets=[ast.Name(id=i)], value=
+                                ast.BinOp(left=
+                                          ast.BinOp(left=condname,
+                                                    op=ast.Mult(),
+                                                    right=then_value),
+                                          op=ast.Add(),
+                                          right=
+                                          ast.BinOp(left=
+                                                    ast.BinOp(left=ast.Num(1),
+                                                              op=ast.Sub(),
+                                                              right=condname),
+                                                    op=ast.Mult(),
+                                                    right=else_value))))
+        return r # A list of statements will be merged into the list
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test.py	Tue Nov 11 10:12:17 2008 +0100
@@ -0,0 +1,30 @@
+import flow
+
+
+def blah(f):
+    return f
+
+
+def test(x, y, z):
+
+    @blah
+    def inner(y):
+        print x+y
+        if(x):
+            k = 9
+        else:
+            k = 7
+            b = y
+        y = 2
+    x += 2
+    r = 3
+    return r
+
+#@flow
+def a(c):
+    y = 4
+    if(c>c):
+        pass
+    else:
+        y = 5
+        y = 3
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/util.py	Tue Nov 11 10:12:17 2008 +0100
@@ -0,0 +1,11 @@
+""" Utilities for the static analysis tool
+"""
+
+
+def error(msg, loc=None):
+    """Print the error message.
+    *loc* is an ast-node at the given location."""
+    loc_tuple = (loc.lineno, loc.col_offset)
+    if loc:
+        print "at %s" % (loc_tuple, ),
+    print msg