changeset 2:b8efe2734cc6

Lavet egentlig flow-analyse, mangler oprydning
author Sigurd Meldgaard <stm@daimi.au.dk>
date Wed, 12 Nov 2008 15:12:11 +0100
parents 10b59d5be10e
children b6dce61b934a
files flow.py
diffstat 1 files changed, 187 insertions(+), 53 deletions(-) [+]
line wrap: on
line diff
--- a/flow.py	Tue Nov 11 11:02:11 2008 +0100
+++ b/flow.py	Wed Nov 12 15:12:11 2008 +0100
@@ -54,52 +54,184 @@
         print "}"
 
 
-def flow(module):
-    function = module.body[0]
-    return flow_body(function.body, "def %s:" % function.name)
+def init_statements(node, initial, combine):
+    """
+    Prepares all statements under node for further analysis
+    """
+    global first
+    first = node.body[0].body[0]
+    first.outsecret = initial
+    for child in ast.walk(node):
+        if(isinstance(child, ast.stmt)):
+            pretty_print.pprint(first)
+            child.children = set()
+            child.parents = set()
+            child.insecret = set()
+            child.outsecret = combine(child, initial)
+
+
+def analyze(node, join, combine, initial):
+    """
+    Doing a simple forward analysis using the iterative worklist
+    algorithm
+    """
+    # initialization
+    worklist = []
+    init_statements(node, initial, combine)
+    Flow().flow(node)
+    for child in ast.walk(node):
+        if(isinstance(child, ast.stmt)):
+            worklist.append(child)
+
+    for child in ast.walk(node):
+        if isinstance(child, ast.stmt):
+            print "hep"
+            pretty_print.pprint(child)
+            print(child.outsecret)
+
+    # main cycle
+    while len(worklist) != 0:
+        print "-" * 80
+        x = worklist.pop()
+        pretty_print.pprint(x)
+        oldout = x.outsecret
+        print "oldout", oldout
+        print "join", join(x.parents)
+        x.outsecret = combine(x, join(x.parents))
+        print x.outsecret
+        if oldout != x.outsecret:
+            for child in x.children:
+                print "Adding:"
+                pretty_print.pprint(child)
+                worklist.append(child)
+                child.insecret = x.outsecret
+    print "-" * 80
+    for child in ast.walk(node):
+        if isinstance(child, ast.stmt):
+            print "hep"
+            pretty_print.pprint(child)
+            print(child.outsecret)
 
 
-def flow_body(body, msg="", exit_function=None,
-              exit_loop=None):
-    entry = Node("enter %s" % msg)
-    ex = Node("exit %s" % msg)
-    if not exit_function:
-        exit_function = ex
-    current = entry
-    for stm in body:
-        if isinstance(stm, ast.If):
-            body_g = flow_body(stm.body,
-                               msg="if(%s)" % pretty_print.expr_string(stm.test),
-                               exit_function=exit_function,
-                               exit_loop=exit_loop)
-            else_g = flow_body(stm.orelse,
-                               msg="if not(%s)" % pretty_print.expr_string(stm.test),
-                               exit_function=exit_function,
-                               exit_loop=exit_loop)
-            current.out.append(body_g.entry)
-            current.out.append(else_g.entry)
-            current = Node("endif(%s)" % pretty_print.expr_string(stm.test))
-            body_g.ex.out.append(current)
-            else_g.ex.out.append(current)
-        elif isinstance(stm, ast.While):
-            loop_done = Node()
-            body_g = flow_body(stm.body,
-                               "while(%s)" % pretty_print.expr_string(stm.test),
-                               exit_function=exit_function,
-                               exit_loop=ex)
-            body_g.ex.out.append(body_g.entry)
-            current.out.append(body_g.ex)
-            current = body_g.ex
+def expr_secret(exp, secret_variables):
+    "Returns True if the expression exp should be considered secret"
+    if(isinstance(exp, ast.BoolOp)):
+        assert len(exp.values) == 2, "Not implemented"
+        return (expr_secret(exp.values[0], secret_variables)
+                 or expr_secret(exp.values[1], secret_variables))
+    elif(isinstance(exp, ast.BinOp)):
+        return (expr_secret(exp.left, secret_variables)
+                 or expr_secret(exp.right, secret_variables))
+    elif(isinstance(exp, ast.Compare)):
+        if len(exp.ops)>1:
+            assert False, "Not implemented"
+        if len(exp.comparators)>1:
+            assert False, "Not implemented"
+            return (expr_secret(exp.left, secret_variables)
+                 or expr_secret(exp.comparators[0], secret_variables))
+    elif(isinstance(exp, ast.UnaryOp)):
+            return (expr_secret(exp.operand, secret_variables))
+    elif(isinstance(exp, ast.Lambda)):
+        assert False, "Not implemented"
+    elif(isinstance(exp, ast.IfExp)):
+        assert False, "Not implemented"
+    elif(isinstance(exp, ast.Dict)):
+        return any([expr_secret(i, secret_variables) for i in exp.values])
+    elif(isinstance(exp, ast.ListComp)):
+        assert False, "Not implemented"
+    elif(isinstance(exp, ast.Name)):
+        return exp.id in secret_variables
+    elif(isinstance(exp, ast.Num)):
+        return False
+    elif(isinstance(exp, ast.Yield)):
+        assert False, "Not implemented"
+    elif(isinstance(exp, ast.Call)):
+        if(not (isinstance(exp.func, ast.Name)
+           and eval(exp.func.id).secret)):
+            if(any([expr_secret(i, secret_variables) for i in exp.args])):
+                util.error("Call of non-secret value with secret argument")
+            else:
+                return False
+        else:
+            return True # For now
+    else:
+        assert False, "Not implemented of type %s" % type(exp)
+
+
+def secret_join(in_nodes):
+    """This is a may-analysis, so take the union"""
+    r = set()
+    for i in in_nodes:
+        r|=i.outsecret
+    return r
+
 
-        elif isinstance(stm, ast.Break):
-#            print current.msg
-            current.out.append(exit_loop)
-            return Graph(entry, Node("Dummy"))
-        elif isinstance(stm, ast.Return):
-            current.out.append(exit_function)
-            return Graph(entry, Node("Dummy"))
-    current.out.append(ex)
-    return Graph(entry, ex)
+def secret_combine(node, ins):
+    r = ins
+    global first
+    if(node is first): # Todo: Hack
+        r = node.outsecret
+    if(isinstance(node, ast.Assign)):
+        if(expr_secret(node.value, ins)):
+            r = r | set([node.targets[0].id])
+        else:
+            print "fr", r
+            r = r - set([node.targets[0].id])
+            print "efter", r
+    return r
+
+
+class Flow():
+
+    def __init__(self):
+        self.to_loop_exit = set()
+        self.to_function_exit = set()
+
+    def flow(self, module):
+        function = module.body[0]
+        current = self.flow_body(function.body, )
+
+    def make_edge(self, from_nodes, to_node):
+        for from_node in from_nodes:
+            from_node.children = from_node.children | set([to_node])
+            to_node.parents = to_node.parents | set([from_node])
+
+    def flow_body(self, body=None,
+                  entry=None):
+        if not entry:
+            entry = set()
+        current = entry
+        for stm in body:
+            self.make_edge(current, stm)
+            if isinstance(stm, ast.If):
+                out_of_then = self.flow_body(stm.body, current)
+                out_of_else = self.flow_body(stm.orelse, current)
+                current = out_of_then + out_of_else
+            elif isinstance(stm, ast.While):
+                # TODO, does not express the condition being evaluated twice
+                # We have to rewrite the expression to make it right.
+                before = current
+                out_of_while = self.flow_body(stm.body, current)
+                self.make_edge(out_of_while, stm)
+                current = before | out_of_while | self.to_loop_exit
+                self.to_loop_exit = []
+            elif isinstance(stm, ast.Break):
+                self.to_loop_exit = self.to_loop_exit | set([stm])
+                return set()
+            elif isinstance(stm, ast.Return):
+                self.to_loop_exit = self.to_loop_exit | set([stm])
+                return set()
+            elif isinstance(stm, ast.Yield):
+                assert False, "Not supported"
+            elif isinstance(stm, ast.FunctionDef):
+                assert False, "Not supported"
+            elif isinstance(stm, ast.TryExcept):
+                assert False, "Not supported"
+            elif isinstance(stm, ast.TryFinally):
+                assert False, "Not supported"
+            else:
+                current = set([stm])
+        return current
 
 
 def ideal_functionality(f):
@@ -107,12 +239,12 @@
     source_ast = ast.parse(source_str)
     FunctionFinder().visit(source_ast)
     pretty_print.PrettyPrinter().visit(source_ast)
-    t = secret_ifs.TransformIfs().visit(source_ast)
-    pretty_print.PrettyPrinter().visit(t)
+    arguments = set(i.id for i in source_ast.body[0].args.args)
+    analyze(source_ast, secret_join, secret_combine, arguments)
+#    t = secret_ifs.TransformIfs().visit(source_ast)
+#    pretty_print.PrettyPrinter().visit(t)
 
-
-#flow(test)
-flow_graph = flow(ast.parse("""
+"""
 def a(a, b, c):
     a = 2
     if c:
@@ -129,10 +261,12 @@
                 pass
     else:
         print("hej")
-"""))
 """
+@ideal_functionality
 def a(b, c):
-    while(b):
-        break
-"""
-flow_graph.to_dot()
+    d = b + 2
+    e = b + 2
+    f = 3
+    e = 2
+    f = e + 4
+#flow_graph.to_dot()