changeset 4:5f850f623c69

Major refactorings
author Sigurd Meldgaard <stm@daimi.au.dk>
date Thu, 13 Nov 2008 12:55:15 +0100
parents b6dce61b934a
children 3917d9bb1e21
files __init__.py compatibility_check.py flow.py ideal_functionality.py secret_annotator.py secret_ifs.py
diffstat 5 files changed, 276 insertions(+), 123 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/compatibility_check.py	Thu Nov 13 12:55:15 2008 +0100
@@ -0,0 +1,58 @@
+import ast
+
+from util import error
+
+"""
+We only want to treat a small subset of Python in our
+analysis.
+
+Therefore we run it though this Visitor ruling out any
+constructs that we do not treat yet.
+
+In all other phases we assume that the functions treated are confined
+to this subset.
+"""
+
+
+class CompatibilityChecker(ast.NodeVisitor):
+
+    def __init__(self):
+        self.inside_function_def = False
+
+    def visit_BoolOp(self, node):
+        if not len(node.values) == 2:
+            error("Too many arguments", node)
+
+    def visit_Compare(self, node):
+        if len(node.ops)>1:
+            error("Too many arguments", node)
+        if len(node.comparators)>1:
+            error("Too many arguments", node)
+
+    def visit_Lambda(self, node):
+        error("Not implemented", node)
+
+    def visit_IfExp(self, node):
+        error("Not implemented", node)
+
+    def visit_ListComp(self, node):
+        error("Not implemented", node)
+
+    def visit_Yield(self, node):
+        error("Not implemented", node)
+
+    def visit_FunctionDef(self, node):
+        if self.inside_function_def:
+            error("Nested functions not allowed", node)
+        self.inside_function_def = True
+        self.generic_visit(node)
+        self.inside_function_def = False
+
+    def visit_TryExcept(self, node):
+        error("Not implemented", node)
+
+    def visit_TryFinally(self, node):
+        error("Not implemented", node)
+
+    def visit_Delete(self, node):
+        error("Not implemented", node)
--- a/flow.py	Wed Nov 12 18:21:31 2008 +0100
+++ b/flow.py	Thu Nov 13 12:55:15 2008 +0100
@@ -4,65 +4,31 @@
 import secret_ifs
 
 
-
-
-def init_statements(node, initial, combine):
-    """
-    Prepares all statements under node for further analysis
-    """
-    global first
-    first = node.body[0].body[0]
-    first.outsecret = initial
-    for child in ast.walk(node):
-        if(isinstance(child, ast.stmt)):
-            pretty_print.pprint(first)
-            child.children = set()
-            child.parents = set()
-            child.insecret = set()
-            child.outsecret = combine(child, initial)
-
-
-def analyze(node, join, combine, initial):
+def analyze(node, join, combine, key):
     """
     Doing a simple forward analysis using the iterative worklist
-    algorithm
+    algorithm. Parametrized by the functions join and combine, and
+    with the value initial.  The lattice-points are stored in each
+    statement-node in_values and out_values, under the *key*.
     """
     # initialization
     worklist = []
-    init_statements(node, initial, combine)
+
     Flow().flow(node)
     for child in ast.walk(node):
         if(isinstance(child, ast.stmt)):
+            child.in_values[key] = combine(child, set())
             worklist.append(child)
 
-    for child in ast.walk(node):
-        if isinstance(child, ast.stmt):
-            print "hep"
-            pretty_print.pprint(child)
-            print(child.outsecret)
-
     # main cycle
     while len(worklist) != 0:
-        print "-" * 80
         x = worklist.pop()
-        pretty_print.pprint(x)
-        oldout = x.outsecret
-        print "oldout", oldout
-        print "join", join(x.parents)
-        x.outsecret = combine(x, join(x.parents))
-        print x.outsecret
-        if oldout != x.outsecret:
+        oldout = x.out_values[key]
+        x.out_values[key] = combine(x, join(x.parents))
+        if oldout != x.out_values[key]:
             for child in x.children:
-                print "Adding:"
-                pretty_print.pprint(child)
                 worklist.append(child)
-                child.insecret = x.outsecret
-    print "-" * 80
-    for child in ast.walk(node):
-        if isinstance(child, ast.stmt):
-            print "hep"
-            pretty_print.pprint(child)
-            print(child.outsecret)
+                child.in_values[key] = x.out_values[key]
 
 
 class Flow():
@@ -71,9 +37,11 @@
         self.to_loop_exit = set()
         self.to_function_exit = set()
 
-    def flow(self, module):
-        function = module.body[0]
-        current = self.flow_body(function.body, )
+    def flow(self, function):
+        #function.body.append(ast.Pass)
+        self.flow_body(function.body)
+        # Todo - edges from function returns...
+        # Neccesary?
 
     def make_edge(self, from_nodes, to_node):
         for from_node in from_nodes:
@@ -90,7 +58,7 @@
             if isinstance(stm, ast.If):
                 out_of_then = self.flow_body(stm.body, current)
                 out_of_else = self.flow_body(stm.orelse, current)
-                current = out_of_then + out_of_else
+                current = out_of_then | out_of_else
             elif isinstance(stm, ast.While):
                 # TODO, does not express the condition being evaluated twice
                 # We have to rewrite the expression to make it right.
@@ -98,59 +66,15 @@
                 out_of_while = self.flow_body(stm.body, current)
                 self.make_edge(out_of_while, stm)
                 current = before | out_of_while | self.to_loop_exit
-                self.to_loop_exit = []
+                self.to_loop_exit = set()
             elif isinstance(stm, ast.Break):
                 self.to_loop_exit = self.to_loop_exit | set([stm])
                 return set()
             elif isinstance(stm, ast.Return):
                 self.to_loop_exit = self.to_loop_exit | set([stm])
                 return set()
-            elif isinstance(stm, ast.Yield):
-                assert False, "Not supported"
-            elif isinstance(stm, ast.FunctionDef):
-                assert False, "Not supported"
-            elif isinstance(stm, ast.TryExcept):
-                assert False, "Not supported"
-            elif isinstance(stm, ast.TryFinally):
-                assert False, "Not supported"
             else:
                 current = set([stm])
         return current
 
 
-def ideal_functionality(f):
-    source_str = inspect.getsource(f)
-    source_ast = ast.parse(source_str)
-    FunctionFinder().visit(source_ast)
-    pretty_print.PrettyPrinter().visit(source_ast)
-    arguments = set(i.id for i in source_ast.body[0].args.args)
-    analyze(source_ast, secret_join, secret_combine, arguments)
-#    t = secret_ifs.TransformIfs().visit(source_ast)
-#    pretty_print.PrettyPrinter().visit(t)
-
-"""
-def a(a, b, c):
-    a = 2
-    if c:
-        b = 2
-        while True:
-            d = 3
-            if a:
-                break
-            else:
-                pass
-            if b:
-                return
-            else:
-                pass
-    else:
-        print("hej")
-"""
-@ideal_functionality
-def a(b, c):
-    d = b + 2
-    e = b + 2
-    f = 3
-    e = 2
-    f = e + 4
-#flow_graph.to_dot()
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/ideal_functionality.py	Thu Nov 13 12:55:15 2008 +0100
@@ -0,0 +1,75 @@
+import inspect
+import ast
+from collections import defaultdict
+
+import pretty_print
+import secret_annotator
+import secret_ifs
+import compatibility_check
+
+
+def init_statements(node):
+    """
+    Prepares all statements under node for further analysis.
+    """
+    for child in ast.walk(node):
+        if(isinstance(child, ast.stmt)):
+            # children and parents are used for flow-graph edges
+            child.children = set()
+            child.parents = set()
+            # in_values and out_values are used for storing
+            # lattice-points in for static analyses using the monotone
+            # framework.
+            child.in_values = defaultdict(set)
+            child.out_values = defaultdict(set)
+
+
+def ideal_functionality(f):
+    source_str = inspect.getsource(f)
+    source_ast = ast.parse(source_str) # Returns a module
+    function_ast = source_ast.body[0] # We want only the function
+    init_statements(function_ast)
+    # We don't want recursive applications of this decorator
+    #function_ast.decorator_list.remove[0]
+    compatibility_check.CompatibilityChecker().visit(function_ast)
+    secret_annotator.secret_analysis(function_ast)
+    t = secret_ifs.TransformIfs().visit(function_ast)
+    pretty_print.PrettyPrinter().visit(t)
+
+
+"""
+def a(a, b, c):
+    a = 2
+    if c:
+        b = 2
+        while True:
+            d = 3
+            if a:
+                break
+            else:
+                pass
+            if b:
+                return
+            else:
+                pass
+    else:
+        print("hej")
+"""
+
+
+@ideal_functionality
+def fun(b, c):
+    d = b + 2
+    if d:
+        e = 2
+    a = 2
+    mark_secret(a)
+    if a:
+        e = 4
+    else:
+        e = 8
+    f = 3
+    if f:
+        e = 2
+    f = e + 4
+#flow_graph.to_dot()
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/secret_annotator.py	Thu Nov 13 12:55:15 2008 +0100
@@ -0,0 +1,80 @@
+import util
+import ast
+import flow
+
+
+def expr_secret(exp, secret_variables):
+    """
+    Returns True if the expression exp should be considered secret
+    given that the variables in *secret_variables* are secret.
+    """
+    if(isinstance(exp, ast.BoolOp)):
+        return (expr_secret(exp.values[0], secret_variables)
+                 or expr_secret(exp.values[1], secret_variables))
+    elif(isinstance(exp, ast.BinOp)):
+        return (expr_secret(exp.left, secret_variables)
+                 or expr_secret(exp.right, secret_variables))
+    elif(isinstance(exp, ast.Compare)):
+        return (expr_secret(exp.left, secret_variables)
+                or expr_secret(exp.comparators[0], secret_variables))
+    elif(isinstance(exp, ast.UnaryOp)):
+            return (expr_secret(exp.operand, secret_variables))
+    elif(isinstance(exp, ast.Dict)):
+        return any([expr_secret(i, secret_variables) for i in exp.values])
+    elif(isinstance(exp, ast.Name)):
+        return exp.id in secret_variables
+    elif(isinstance(exp, ast.Num)):
+        return False
+    elif(isinstance(exp, ast.Call)):
+        if(not (isinstance(exp.func, ast.Name)
+           and eval(exp.func.id).secret)):
+            if(any([expr_secret(i, secret_variables) for i in exp.args])):
+                util.error("Call of non-secret value with secret argument",
+                           exp)
+            else:
+                return False
+        else:
+            return True # For now
+    else:
+        assert False, "Not implemented of type %s" % type(exp)
+
+
+def secret_join(in_nodes):
+    """This is a may-analysis, so take the union"""
+    r = set()
+    for i in in_nodes:
+        r|=i.out_values["secret"]
+    return r
+
+
+def secret_combine(node, ins):
+    """
+    If this is an assignment, check if the computed expression is
+    secret, if that is the case, the assigned value is secret as
+    well. Otherwise it becomes non-secret.
+
+    All other statements just pass their value through.
+    """
+    ins
+    if(getattr(node, "imported_secrets", False)):
+        ins = ins | node.imported_secrets
+    if(isinstance(node, ast.Assign)):
+        if(expr_secret(node.value, ins)):
+            ins = ins | set([node.targets[0].id])
+        else:
+            print "fr", ins
+            ins = ins - set([node.targets[0].id])
+            print "efter", ins
+    if(isinstance(node, ast.Expr) and
+       isinstance(node.value, ast.Call) and
+       isinstance(node.value.func, ast.Name) and
+       node.value.func.id == "mark_secret"):
+        # Todo - handle non-names...
+        ins = ins | set([i.id for i in node.value.args])
+    return r
+
+
+def secret_analysis(function):
+    arguments = set(i.id for i in function.args.args)
+    function.body[0].imported_secrets = arguments
+    flow.analyze(function, secret_join, secret_combine, "secret")
--- a/secret_ifs.py	Wed Nov 12 18:21:31 2008 +0100
+++ b/secret_ifs.py	Thu Nov 13 12:55:15 2008 +0100
@@ -1,11 +1,13 @@
 """
-Defines an ast-visitor *TransformIfs* finding *if*s branching on
+Defines an ast-transformer *TransformIfs* finding *if*s branching on
 secret values, and transforming them into equivalent series of
 assignments
 """
 
 import ast
-from static.util import error
+
+from util import error
+import secret_annotator
 
 
 class Assignments(ast.NodeVisitor):
@@ -33,7 +35,8 @@
         if not (isinstance(i, ast.Assign) or
                 isinstance(i, ast.Pass)):
             error("Error,\
- inside secret ifs, we only allow assignments, but found %s" % type(i),
+ inside secret if statements, we only allow \
+assignments, but found a %s." % type(i),
                   i)
 
 
@@ -53,31 +56,44 @@
     """
     cond_counter = 0
 
+    def __init__(self):
+        self.changed = False
+
+    def reset(self):
+        self.changed = False
+
     def visit_If(self, node):
-        #if node.test is secret:
-        only_assignments(node.body)
-        assigned_then = get_assignments(node.body)
-        assigned_else = get_assignments(node.orelse)
-        all_assigned = set(assigned_then.keys() + assigned_else.keys())
-        r = []
-        condname = ast.Name(id="cond%d" % self.cond_counter)
-        self.cond_counter += 1
-        r.append(ast.Assign(targets=[condname],
-                            value=node.test))
-        for i in all_assigned:
-            then_value = assigned_then.get(i, ast.Name(id=i))
-            else_value = assigned_else.get(i, ast.Name(id=i))
-            r.append(ast.Assign(targets=[ast.Name(id=i)], value=
-                                ast.BinOp(left=
-                                          ast.BinOp(left=condname,
-                                                    op=ast.Mult(),
-                                                    right=then_value),
-                                          op=ast.Add(),
-                                          right=
-                                          ast.BinOp(left=
-                                                    ast.BinOp(left=ast.Num(1),
-                                                              op=ast.Sub(),
-                                                              right=condname),
-                                                    op=ast.Mult(),
-                                                    right=else_value))))
-        return r # A list of statements will be merged into the list
+        print "-"*80, node.in_values["secret"]
+        if(secret_annotator.expr_secret(node.test, node.in_values["secret"])):
+            self.changed = True
+
+            only_assignments(node.body)
+            assigned_then = get_assignments(node.body)
+            assigned_else = get_assignments(node.orelse)
+            all_assigned = set(assigned_then.keys() + assigned_else.keys())
+            r = []
+            condname = ast.Name(id="cond%d" % self.cond_counter)
+            self.cond_counter += 1
+            r.append(ast.Assign(targets=[condname],
+                                value=node.test))
+            for i in all_assigned:
+                then_value = assigned_then.get(i, ast.Name(id=i))
+                else_value = assigned_else.get(i, ast.Name(id=i))
+                r.append(ast.Assign(targets=[ast.Name(id=i)], value=
+                                    ast.BinOp(left=
+                                              ast.BinOp(left=condname,
+                                                        op=ast.Mult(),
+                                                        right=then_value),
+                                              op=ast.Add(),
+                                              right=
+                                              ast.BinOp(left=
+                                                        ast.BinOp(left=
+                                                                  ast.Num(1),
+                                                                  op=ast.Sub(),
+                                                                  right=
+                                                                  condname),
+                                                        op=ast.Mult(),
+                                                        right=else_value))))
+            return r # A list of statements will be merged into the list
+        else:
+            return node