changeset 219:1683fdd52828

range_analysis: use a class for Ranges, use widening
author Sigurd Meldgaard <stm@daimi.au.dk>
date Wed, 23 Dec 2009 14:34:07 +0100
parents 35a38b49b5cd
children 01a6d95bf029
files pysmcl/range_analysis.py
diffstat 1 files changed, 109 insertions(+), 51 deletions(-) [+]
line wrap: on
line diff
--- a/pysmcl/range_analysis.py	Wed Dec 23 14:33:31 2009 +0100
+++ b/pysmcl/range_analysis.py	Wed Dec 23 14:34:07 2009 +0100
@@ -12,6 +12,80 @@
 # Switch to True for debug output printed to std. out.
 debug = False
 
+
+class RangeError(Exception):
+    pass
+
+
+def widen(n):
+    if(n < 0):
+        return -widen(-n)
+    if n < 128:
+        return n
+    else:
+        for i in range(3, 32):
+            power = 2**i
+            if n<power:
+                return power
+        raise RangeError()
+
+
+class RangeC(object):
+
+    def __init__(self, a, b):
+        self.a = widen(a)
+        self.b = widen(b)
+
+    def tuple(self):
+        """Returns a tuple representation of this range."""
+        return (self.a, self.b)
+
+    def combine(self, other):
+        """Combine the two ranges."""
+        if debug:
+            print "  combine_ranges:", self, other
+
+        if isinstance(other, Bottom):
+            return Bottom()
+        else:
+            (l1, h1), (l2, h2) = self.tuple(), other.tuple()
+            r = Range(min(l1, l2), max(h1, h2))
+
+        if debug:
+            print "  combined ranges:", r
+        return r
+
+    def within(self, other):
+        """ Returns true if the interval of self is completely
+        covered by that of other. """
+        return self.a >= other.a and self.b <= other.b
+
+    def __eq__(self, other):
+        if isinstance(other, Bottom):
+            return False
+        else:
+            return other.a == self.a and other.b == self.b
+
+    def __repr__(self):
+        return self.tuple().__repr__()
+
+    def __getitem__(self, i):
+        if i == 0:
+            return self.a
+        elif i == 1:
+            return self.b
+        else:
+            raise IndexError("Ranges can only be accessed"
+                             " in 0, 1 but was: %s" % i)
+
+
+def Range(a, b):
+    try:
+        return RangeC(a, b)
+    except RangeError:
+        return Bottom()
+
+
 class Bottom(object):
     """The element used to represent intervals that are not between
     -(p//2) and p//2.
@@ -34,12 +108,18 @@
             return isinstance(other, Bottom)
         return False
 
+    def combine(self, other):
+        """Combining with bottom yields bottom"""
+        return self
+
     def __repr__(self):
         return "_|_"
 
+
 def full_range(p):
     """The range represented modulo p"""
-    return (-(p//2), p//2)
+    return Range(-(p//2), p//2)
+
 
 class RangeAnalysis(object):
     """The class which defines the range analysis.
@@ -61,8 +141,9 @@
         function (FunctioDef)
         """
         flow.analyze(function, self.join, self.combine, self.key,
-                     combine_env({'True' : (1,1), 'False' : (0,0)},
-                                  initial_env),
+                     combine_env({'True': Range(1, 1),
+                                  'False': Range(0, 0)},
+                                 initial_env),
                      self.distribute)
 
     def join(self, in_nodes):
@@ -90,7 +171,8 @@
             if self.is_comparison(x.test):
                 a = dict(x.out_values[self.key])
                 old = a[x.test.left.id]
-                compared_value = RangeVisitor(self.prime, x.out_values[self.key]).visit(x.test.comparators[0])
+                compared_value = RangeVisitor(self.prime,
+                    x.out_values[self.key]).visit(x.test.comparators[0])
                 a[x.test.left.id] = (min(old[0], compared_value[1]),
                                          min(old[1], compared_value[1]))
                 print(a, x.test.left.id)
@@ -101,7 +183,6 @@
             for child in x.children:
                 child.in_values[self.key] = x.out_values[self.key]
 
-
     def combine(self, node, env):
         """The least upper bound of the node and the environment.
 
@@ -118,23 +199,26 @@
 
         if env is None:
             env = {}
+
         class Visitor(ast.NodeVisitor):
             """Visitor for Python statements."""
 
             def visit_Assign(self, node):
                 target = node.targets[0]
-                r = range(prime, node.value, env)
+                r = find_range(prime, node.value, env)
                 if(isinstance(target, ast.Name)):
                     env[target.id] = r
                 elif(isinstance(target, ast.Subscript)):
-                    env[target.value.id] = combine_range(env[target.value.id], r)
+                    print ast.dump(target)
+                    print target.lineno
+                    env[target.value.id] = env[target.value.id].combine(r)
                 else:
                     raise RuntimeError("Unsupported target of assignment")
                 return env
 
             def visit_For(self, node):
                 target = node.target
-                r = range(prime, node.iter, env)
+                r = find_range(prime, node.iter, env)
                 if(isinstance(target, ast.Name)):
                     env[target.id] = r
                 else:
@@ -148,6 +232,8 @@
                     for keyword in decorator.keywords:
                         if keyword.arg == 'range':
                             param_range = ast.literal_eval(keyword.value)
+                            for (k, r) in param_range.items():
+                                param_range[k] = Range(r[0], r[1])
                             break
                 for arg in node.args.args:
                     if not arg.id in param_range:
@@ -184,15 +270,14 @@
         return r
 
 
-def range(prime, node, env):
-    """range performs the computation of range of the given node.
-
-    node (ast.Node)
-    env (dict) An environment of variables and their range.
+def find_range(prime, node, env):
+    """range performs the computation of range of the given expression
+    in node. Assuming that variables have values as in the env dict.
     """
     rangeVisitor = RangeVisitor(prime, env)
     return rangeVisitor.visit(node)
 
+
 class RangeVisitor(ast.NodeVisitor):
     """RangeVisitor is the visitor which actually implements the range
     computation."""
@@ -206,11 +291,7 @@
         self.env = env
 
     def visit_Num(self, node):
-        if node.n > self.prime // 2:
-            return Bottom()
-        if node.n < -(self.prime // 2):
-            return Bottom()
-        return (node.n, node.n)
+        return Range(node.n, node.n)
 
     def visit_Name(self, node):
         if node.id in self.env.keys():
@@ -228,8 +309,9 @@
         else:
             r = self.visit(node.elts[0])
             for i in node.elts[1:]:
-                r = combine_range(r, self.visit(node.elts[0]))
+                r = r.combine(self.visit(node.elts[0]))
             return r
+
     def visit_BinOp(self, node):
         # operator = Add | Sub | Mult | Div | Mod | Pow | LShift
         #          | RShift | BitOr | BitXor | BitAnd | FloorDiv
@@ -252,32 +334,28 @@
             r0 = left[0] // right[1]
             r1 = left[1] // right[0]
         else:
-            raise Exception("Operator not implemented: ", node.op)
-        if(r0 > self.prime // 2 or r0 < -(self.prime // 2)
-           or r1 > self.prime // 2 or r1 < -(self.prime // 2)):
-            return Bottom()
-        return (r0, r1)
-
+            raise RuntimeError("Operator not implemented: ", node.op)
+        return Range(r0, r1)
 
     def visit_Compare(self, node):
         # cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn
-        return (0, 1)
+        return Range(0, 1)
 
     def visit_Call(self, node):
         # TODO: Check that random and random_bit are bound to the
         # expected functions
         if node.func.id == "num_players":
-            return (setup.num_players, setup.num_players)
+            return Range(setup.num_players, setup.num_players)
         if node.func.id == "players":
-            return (1, setup.num_players)
+            return Range(1, setup.num_players)
         if node.func.id == "id":
-            return (1, setup.nr_of_players)
+            return Range(1, setup.nr_of_players)
         if node.func.id == "random":
             return full_range(self.prime)
         if node.func.id == "open":
             return self.visit(node.args[0])
         if node.func.id == "random_bit":
-            return (0, 1)
+            return Range(0, 1)
         return full_range(self.prime)
 
 
@@ -298,7 +376,7 @@
     env1_keys = set(env1.keys()) - set(env2.keys())
     env2_keys = set(env2.keys()) - set(env1.keys())
     for key in common_keys:
-        env[key] = combine_range(env1[key], env2[key])
+        env[key] = env1[key].combine(env2[key])
     for key in env1_keys:
         env[key] = env1[key]
     for key in env2_keys:
@@ -306,23 +384,3 @@
     if debug:
         print "  combined env:", env
     return env
-
-def combine_range(a, b):
-    """Combine the two ranges."""
-    if debug:
-        print "  combine_ranges:", a, b
-
-    if a == Bottom() or b == Bottom():
-        r = Bottom()
-    else:
-        (l1, h1), (l2, h2) = a, b
-        r = (min(l1, l2), max(h1, h2))
-
-    if debug:
-        print "  combined ranges:", r
-    return r
-
-def interval_within(a, b):
-    """ Returns true if the interval of a is completely covered by
-    that of b. """
-    return a[0] >= b[0] and a[1] <= b[1]