changeset 94:28adf540a9ba

pysmcl/range_analysis: make all tests pass
author Sigurd Meldgaard <stm@daimi.au.dk>
date Mon, 08 Jun 2009 12:29:37 +0200
parents c85711f38392
children f683af7aea18
files pysmcl/range_analysis.py
diffstat 1 files changed, 40 insertions(+), 68 deletions(-) [+]
line wrap: on
line diff
--- a/pysmcl/range_analysis.py	Mon Jun 08 12:29:02 2009 +0200
+++ b/pysmcl/range_analysis.py	Mon Jun 08 12:29:37 2009 +0200
@@ -20,9 +20,9 @@
 debug = False
 
 class Bottom(object):
-    """The element used to represent intervals that are not between 0 and p.
+    """The element used to represent intervals that are not between -(p//2) and p//2.
 
-    e.g. for the expression x = 2-5 yields the range for x (Bottom(), Bottom())
+    e.g. for p=7 the expression x = 2-5 yields the range for x Bottom()
     """
 
     # storage for the instance reference
@@ -43,6 +43,10 @@
     def __repr__(self):
         return "_|_"
 
+def full_range(p):
+    """The range represented modulo p"""
+    return (-(p//2), p//2)
+
 class RangeAnalysis(object):
     """The class which defines the range analysis.
 
@@ -81,7 +85,7 @@
         """The least upper bound of the node and the environment.
 
         node (ast.Node)
-        end (dict) A dictionary with variables and their current range. 
+        end (dict) A dictionary with variables and their current range.
           May be None if the env is uninitialized.
         """
         if debug:
@@ -105,16 +109,12 @@
 
             def visit_FunctionDef(self, node):
                 for arg in node.args.args:
-                    env[arg.id] = (0, self.prime)
+                    env[arg.id] = full_range(self.prime)
                 return env
 
             def visit_If(self, node):
                 return env
 
-            def visit_Expr(self, node):
-                env['_'] = range(self.prime, node.value, env)
-                return env
-
             def visit_While(self, node):
                 return env
 
@@ -150,60 +150,42 @@
         self.env = env
 
     def visit_Num(self, node):
-        if node.n > self.prime:
-            return (Bottom(), Bottom())
+        if node.n > self.prime // 2:
+            return Bottom()
+        if node.n < -(self.prime // 2):
+            return Bottom()
         return (node.n, node.n)
 
     def visit_Name(self, node):
         if node.id in self.env.keys():
             return self.env[node.id]
-        return (0, self.prime)
+        return full_range(self.prime)
 
     def visit_BinOp(self, node):
         # operator = Add | Sub | Mult | Div | Mod | Pow | LShift
         #          | RShift | BitOr | BitXor | BitAnd | FloorDiv
         left = self.visit(node.left)
         right = self.visit(node.right)
+        if left == Bottom() or right == Bottom():
+            return Bottom()
 
         if isinstance(node.op, ast.Add):
-            def liftAdd(a, b):
-                if (isinstance(a, Bottom) or
-                    isinstance(b, Bottom)):
-                    return Bottom()
-                c = a + b
-                if c > self.prime:
-                    return Bottom()
-                return c            
-            r0 = liftAdd(left[0], right[0])
-            r1 = liftAdd(left[1], right[1])
-            return (r0, r1)
+            r0 = left[0]+right[0]
+            r1 = left[1]+right[1]
 
-        if isinstance(node.op, ast.Sub):
-            def liftMinus(a,b):
-                if(isinstance(a, Bottom) or
-                   isinstance(b, Bottom)):
-                    return Bottom()
-                c = a-b
-                if(c < 0):
-                    return Bottom()
-                return c
-            r0 = liftMinus(left[0], right[1])
-            r1 = liftMinus(left[1], right[0])
-            return (r0, r1)
+        elif isinstance(node.op, ast.Sub):
+            r0 = left[0] - right[1]
+            r1 = left[1] - right[0]
+        elif isinstance(node.op, ast.Mult):
+            r0 = left[0] * right[0]
+            r1 = left[1] * right[1]
+        else:
+            raise Exception("Operator not implemented: ", node.op)
+        if(r0 > self.prime // 2 or r0 < -(self.prime // 2)
+           or r1 > self.prime // 2 or r1 < -(self.prime // 2)):
+            return Bottom()
+        return (r0, r1)
 
-        if isinstance(node.op, ast.Mult):
-            def liftMult(a, b):
-                if(isinstance(a, Bottom) or
-                   isinstance(b, Bottom)):
-                    return Bottom()
-                c = a * b
-                if(c > self.prime):
-                    return Bottom()
-                return c
-            r0 = liftMult(left[0], right[0])
-            r1 = liftMult(left[1], right[1])
-            return (r0, r1)
-        raise Exception("Operator not implemented: ", node.op)
 
     def visit_Compare(self, node):
         # cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn
@@ -212,10 +194,10 @@
     def visit_Call(self, node):
         # TODO: Check that random and random_bit are bound to the expected functions
         if node.func.id == "random":
-            return (0, self.prime - 1)
+            return full_range(self.prime)
         if node.func.id == "random_bit":
             return (0, 1)
-        return (0, self.prime)
+        return full_range(self.prime)
 
 class TargetVisitor(ast.NodeVisitor):
     """TargetVisitor update range of target of an assignment in the given environment.
@@ -272,27 +254,17 @@
         print "  combined env:", env
     return env
 
-def combine_range((l1, h1), (l2, h2)):
+def combine_range(a, b):
     """Combine the two ranges."""
     if debug:
-        print "  combine_ranges:", (l1, h1), (l2, h2)
-    if (isinstance(l1, Bottom) or
-        isinstance(l2, Bottom)):
-        r0 = Bottom()
-    else:
-        if l1 < l2:
-            r0 = l1
-        else:
-            r0 = l2
+        print "  combine_ranges:", a, b
 
-    if (isinstance(h1, Bottom) or
-        isinstance(h2, Bottom)):
-        r1 = Bottom()
+    if a == Bottom() or b == Bottom():
+        r = Bottom()
     else:
-        if h1 > h2:
-            r1 = h1
-        else:
-            r1 = h2
+        (l1, h1), (l2, h2) = a, b
+        r = (min(l1, l2), max(h1, h2))
+
     if debug:
-        print "  combined ranges:", (r0, r1)
-    return (r0, r1)
+        print "  combined ranges:", r
+    return r