changeset 158:db7a0ad210df

secret_ifs: improved handling of secret ifs. Now works on nested ifs.
author Sigurd Meldgaard <stm@daimi.au.dk>
date Mon, 07 Dec 2009 14:20:15 +0100
parents cba610df903a
children 4aba2a1b340e
files pysmcl/ast_wrapper.py pysmcl/secret_ifs.py
diffstat 2 files changed, 113 insertions(+), 78 deletions(-) [+]
line wrap: on
line diff
--- a/pysmcl/ast_wrapper.py	Fri Dec 04 11:25:07 2009 +0100
+++ b/pysmcl/ast_wrapper.py	Mon Dec 07 14:20:15 2009 +0100
@@ -6,14 +6,15 @@
 Ensures that nodes of parse-trees have parent pointers.
 """
 
+def fix_ast_parents(n):
+    for i in ast.iter_child_nodes(n):
+        i.parent = n
+        fix_ast_parents(i)
+
+
 def parse(f):
-    def make_ast_parents(n):
-        for i in ast.iter_child_nodes(n):
-            i.parent = n
-            make_ast_parents(i)
-
     m = ast.parse(f)
-    make_ast_parents(m)
+    fix_ast_parents(m)
     return m
 
 def get_ancestor(node, type):
--- a/pysmcl/secret_ifs.py	Fri Dec 04 11:25:07 2009 +0100
+++ b/pysmcl/secret_ifs.py	Mon Dec 07 14:20:15 2009 +0100
@@ -4,13 +4,13 @@
 assignments
 """
 
-import pysmcl.ast_wrapper as ast
+import copy
 
 # Import PySMCL packages.
+import pysmcl.ast_wrapper as ast
 from pysmcl.util import error
 import pysmcl.secret_annotator
 
-
 class Assignments(ast.NodeVisitor):
 
     def __init__(self):
@@ -31,6 +31,16 @@
     return visitor.assigned
 
 
+def is_syntetic(statement):
+    """Temporary values are not used later, so they do not have to be
+    computed in both paths of a sourounding if.  Therefore we set a
+    .syntetic flag, it can be checked with this function"""
+    try:
+        return statement.syntetic
+    except AttributeError:
+        return False
+
+
 def only_assignments(body):
     for i in body:
         if not isinstance(i, (ast.Assign, ast.Pass)):
@@ -40,89 +50,113 @@
 
 
 class TransformIfs(ast.NodeTransformer):
-    cond_counter = 0
 
     def __init__(self):
-        self.changed = False
+        self.reset()
 
     def reset(self):
+        # As more things can be secret after the transformation, and
+        # we want to find a fixpoint, we keep track of wether the
+        # transformation triggered:
+        self.name_counts = {}
         self.changed = False
+        suffixes = [""]
+        env = []
+        
+        
+    def make_name(self, base_name):
+        """
+        Returns a fresh name, as a string.
+        """
+        a = self.name_counts[base_name] if base_name in self.name_counts else 0
+        name = base_name + str(a)
+        self.name_counts[base_name] = a+1
+        return name
 
     def visit_While(self, node):
-        if(pysmcl.secret_annotator.expr_secret(
-                node.test, node.in_values["secret"])):
+        if(pysmcl.secret_annotator.expr_secret(node.test)):
             error("While is not possible on a secret value", node)
         return node
 
-    def visit_If(self, node):
-        if(pysmcl.secret_annotator.expr_secret(node.test, node.in_values["secret"])):
-            self.changed = True
+    def do_body(self, body, suffix):
+        repl = {}
+        for stm in body:
+            if isinstance(stm, ast.Assign):
+                if not is_syntetic(stm):
+                    new_name = self.make_name(stm.targets[0].id + suffix)
+                    stm.syntetic = True
+                    repl[stm.targets[0].id] = new_name
+                    stm.targets[0].id = new_name
+                    for v in ast.walk(stm):
+                        if not isinstance(v, ast.Name):
+                            continue
+                        if (not (isinstance(v.parent, ast.Assign) and v.parent.targets[0] is v)
+                            and v.id in repl):
+                            v.id = repl[v.id]
+            else:
+                error("Non assignment within secret if", stm)
+        return repl
 
-            only_assignments(node.body)
-            assigned_then = get_assignments(node.body)
-            assigned_else = get_assignments(node.orelse)
-            all_assigned = set(assigned_then.keys() + assigned_else.keys())
-            r = []
-            # Todo: correct gensyms
-            condname_id = "cond%d" % self.cond_counter
+    def visit_If(self, node):
+        self.generic_visit(node)
+        if(pysmcl.secret_annotator.expr_secret(node.test)):   
+            self.changed = True
+            replacement = []
+            r_then = self.do_body(node.body, "_then")
+            r_else = self.do_body(node.orelse, "_else") if node.orelse else {}
+           
+            replacement +=node.body
+            if(node.orelse):
+                replacement += node.orelse
 
-            condname_store = ast.Name(id=condname_id,
-                                      lineno=node.test.lineno,
-                                      col_offset=node.test.col_offset,
-                                      ctx=ast.Store())
-            condname_load = ast.Name(id=condname_id,
-                                      lineno=node.test.lineno,
-                                      col_offset=node.test.col_offset,
-                                      ctx=ast.Load())
+            def loc(n):
+                return ast.copy_location(n, node)
+
+            cp = ast.copy_location
+            cond_load = loc(ast.Name(id = self.make_name("cond"), ctx=ast.Load))
 
-            self.cond_counter += 1
-            r.append(ast.Assign(targets=[condname_store],
-                                value=node.test,
-                                lineno=node.lineno,
-                                col_offset=node.col_offset))
-            for i in all_assigned:
-                then_value = assigned_then.get(i, ast.Name(id=i,
-                                                           lineno=node.lineno,
-                                                           col_offset=node.col_offset,
-                                                           ctx=ast.Load()))
-                else_value = assigned_else.get(i, ast.Name(id=i,
-                                                           lineno=node.lineno,
-                                                           col_offset=node.col_offset,
-                                                           ctx=ast.Load()))
-                r.append(ast.Assign(targets=[ast.Name(id=i,
-                                                      lineno=node.lineno,
-                                                      col_offset=node.col_offset,
-                                                      ctx=ast.Store())],
-                                    value=ast.BinOp(left=
-                                                    ast.BinOp(left=condname_load,
-                                                              op=ast.Mult(lineno=node.lineno,
-                                                                          col_offset=node.col_offset),
-                                                              right=then_value,
-                                                              lineno=node.lineno,
-                                                              col_offset=node.col_offset),
-                                                    op=ast.Add(lineno=node.lineno,
-                                                               col_offset=node.col_offset),
-                                                    right=
-                                                    ast.BinOp(left=
-                                                              ast.BinOp(left=
-                                                                        ast.Num(1,
-                                                                                lineno=node.lineno,
-                                                                                col_offset=node.col_offset),
-                                                                        op=ast.Sub(lineno=node.lineno,
-                                                                                col_offset=node.col_offset),
-                                                                        right=condname_load,
-                                                                        lineno=node.lineno,
-                                                                        col_offset=node.col_offset),
-                                                              op=ast.Mult(lineno=node.lineno,
-                                                                          col_offset=node.col_offset),
-                                                              right=else_value,
-                                                              lineno=node.lineno,
-                                                              col_offset=node.col_offset),
-                                                    lineno=node.lineno,
-                                                    col_offset=node.col_offset),
-                                    lineno=node.lineno,
-                                    col_offset=node.col_offset))
+
+            cond_store = copy.copy(cond_load)
+            cond_store.ctx = ast.Store
+            
+            cond_comp = loc(ast.Assign(targets = [cond_store],
+                                       value = node.test))
+
+            cond_comp.syntetic = True
+
+            replacement.append(cond_comp)
+
+            for id in set(r_then.keys()) | set(r_else.keys()):
+                r1 = r_then[id] if id in r_then else id
+                r2 = r_else[id] if id in r_else else id
+                
+                v1 = loc(ast.Name(id = r1, ctx = ast.Load))
+                v2 = loc(ast.Name(id = r2, ctx = ast.Load))
 
-            return r # A list of statements will be merged into the list
+                combine = loc(ast.Assign(targets = [loc(ast.Name(id=id, ctx=ast.Store))],
+                                         value =
+                                         loc(ast.BinOp(op = ast.Add(),
+                                                       left = v1,
+                                                       right = loc(ast.BinOp(op = ast.Mult(),
+                                                                             left = ast.BinOp(op = ast.Sub(),
+                                                                                              left = v2,
+                                                                                              right = v1),
+                                                                             right = cond_load))))))
+
+                # combine.syntetic = False the combination of the two
+                # branches is not syntetic, because it replaces the
+                # real assignments
+
+                replacement.append(combine)
+            for s in replacement:
+                s.parent = node.parent
+                ast.fix_ast_parents(s)
+            return replacement
         else:
             return node
+
+def lookup_env(env, id):
+    for i in reversed(env):
+        if id in env:
+            return env[id]
+    return None