changeset 237:95721959e087

secret_ifs: general cleanup
author Sigurd Meldgaard <stm@daimi.au.dk>
date Tue, 05 Jan 2010 13:44:42 +0100
parents e1442a016352
children 632c18c86fde
files pysmcl/secret_ifs.py
diffstat 1 files changed, 30 insertions(+), 26 deletions(-) [+]
line wrap: on
line diff
--- a/pysmcl/secret_ifs.py	Tue Jan 05 11:13:34 2010 +0100
+++ b/pysmcl/secret_ifs.py	Tue Jan 05 13:44:42 2010 +0100
@@ -8,7 +8,7 @@
 import pysmcl.ast_wrapper as ast
 import pysmcl.secret_annotator
 import pysmcl.setup
-from pysmcl.range_analysis import RangeVisitor
+from pysmcl.range_analysis import RangeVisitor, Range
 
 from pysmcl.util import error
 
@@ -62,10 +62,7 @@
         # transformation triggered:
         self.name_counts = {}
         self.changed = False
-        suffixes = [""]
-        env = []
-        
-        
+
     def make_name(self, base_name):
         """
         Returns a fresh name, as a string.
@@ -93,7 +90,8 @@
                     for v in ast.walk(stm):
                         if not isinstance(v, ast.Name):
                             continue
-                        if (not (isinstance(v.parent, ast.Assign) and v.parent.targets[0] is v)
+                        if (not (isinstance(v.parent, ast.Assign)
+                                 and v.parent.targets[0] is v)
                             and v.id in repl):
                             v.id = repl[v.id]
             elif isinstance(stm, ast.Pass):
@@ -104,15 +102,18 @@
 
     def visit_If(self, node):
         self.generic_visit(node)
+        print ast.get_ancestor(node, ast.stmt).in_values["secret"]
         if(pysmcl.secret_annotator.expr_secret(node.test)):
-            range_visitor = RangeVisitor(pysmcl.setup.Zp.modulus, node.in_values["range"])
-            if not range_visitor.visit(node.test).within(0,1):
-                error("The condition cannot be proven to be in the range (0, 1)", node.test)
+            range_visitor = RangeVisitor(node.in_values["range"])
+            if not range_visitor.visit(node.test).within(Range(0, 1)):
+                error("The condition cannot be "
+                      "proven to be in the range (0, 1)",
+                      node.test)
             self.changed = True
             replacement = []
             r_then = self.do_body(node.body, "_then")
             r_else = self.do_body(node.orelse, "_else") if node.orelse else {}
-           
+
             replacement +=node.body
             if(node.orelse):
                 replacement += node.orelse
@@ -120,15 +121,15 @@
             def loc(n):
                 return ast.copy_location(n, node)
 
-            cp = ast.copy_location
-            cond_load = loc(ast.Name(id = self.make_name("cond"), ctx=ast.Load))
+            cond_load = loc(ast.Name(
+                    id=self.make_name("cond"), ctx=ast.Load))
 
 
             cond_store = copy.copy(cond_load)
             cond_store.ctx = ast.Store
-            
-            cond_comp = loc(ast.Assign(targets = [cond_store],
-                                       value = node.test))
+
+            cond_comp = loc(ast.Assign(targets=[cond_store],
+                                       value=node.test))
 
             cond_comp.synthetic = True
 
@@ -137,21 +138,23 @@
             for id in set(r_then.keys()) | set(r_else.keys()):
                 r1 = r_then[id] if id in r_then else id
                 r2 = r_else[id] if id in r_else else id
-                
+
                 v1 = loc(ast.Name(id = r1, ctx = ast.Load))
                 v2 = loc(ast.Name(id = r2, ctx = ast.Load))
 
-                right_hand =loc(ast.BinOp(op = ast.Mult(),
-                                          left = loc(ast.BinOp(op = ast.Sub(),
-                                                               left = v2,
-                                                               right = v1)),
-                                                     right = cond_load))
+                right_hand = loc(ast.BinOp(
+                        op=ast.Mult(),
+                        left=loc(ast.BinOp(op=ast.Sub(),
+                                             left=v2,
+                                             right=v1)),
+                        right=cond_load))
 
-                combine = loc(ast.Assign(targets = [loc(ast.Name(id=id, ctx=ast.Store))],
-                                         value =
-                                         loc(ast.BinOp(op = ast.Add(),
-                                                       left = v1,
-                                                       right = right_hand))))
+                combine = loc(ast.Assign(
+                        targets=[loc(ast.Name(id=id, ctx=ast.Store))],
+                        value=
+                        loc(ast.BinOp(op=ast.Add(),
+                                      left=v1,
+                                      right=right_hand))))
 
                 # combine.synthetic = False the combination of the two
                 # branches is not synthetic, because it replaces the
@@ -165,6 +168,7 @@
         else:
             return node
 
+
 def lookup_env(env, id):
     for i in reversed(env):
         if id in env: