### changeset 219:1683fdd52828

range_analysis: use a class for Ranges, use widening
author Sigurd Meldgaard Wed, 23 Dec 2009 14:34:07 +0100 35a38b49b5cd 01a6d95bf029 pysmcl/range_analysis.py 1 files changed, 109 insertions(+), 51 deletions(-) [+]
line wrap: on
line diff
--- a/pysmcl/range_analysis.py	Wed Dec 23 14:33:31 2009 +0100
+++ b/pysmcl/range_analysis.py	Wed Dec 23 14:34:07 2009 +0100
@@ -12,6 +12,80 @@
# Switch to True for debug output printed to std. out.
debug = False

+
+class RangeError(Exception):
+    pass
+
+
+def widen(n):
+    if(n < 0):
+        return -widen(-n)
+    if n < 128:
+        return n
+    else:
+        for i in range(3, 32):
+            power = 2**i
+            if n<power:
+                return power
+        raise RangeError()
+
+
+class RangeC(object):
+
+    def __init__(self, a, b):
+        self.a = widen(a)
+        self.b = widen(b)
+
+    def tuple(self):
+        """Returns a tuple representation of this range."""
+        return (self.a, self.b)
+
+    def combine(self, other):
+        """Combine the two ranges."""
+        if debug:
+            print "  combine_ranges:", self, other
+
+        if isinstance(other, Bottom):
+            return Bottom()
+        else:
+            (l1, h1), (l2, h2) = self.tuple(), other.tuple()
+            r = Range(min(l1, l2), max(h1, h2))
+
+        if debug:
+            print "  combined ranges:", r
+        return r
+
+    def within(self, other):
+        """ Returns true if the interval of self is completely
+        covered by that of other. """
+        return self.a >= other.a and self.b <= other.b
+
+    def __eq__(self, other):
+        if isinstance(other, Bottom):
+            return False
+        else:
+            return other.a == self.a and other.b == self.b
+
+    def __repr__(self):
+        return self.tuple().__repr__()
+
+    def __getitem__(self, i):
+        if i == 0:
+            return self.a
+        elif i == 1:
+            return self.b
+        else:
+            raise IndexError("Ranges can only be accessed"
+                             " in 0, 1 but was: %s" % i)
+
+
+def Range(a, b):
+    try:
+        return RangeC(a, b)
+    except RangeError:
+        return Bottom()
+
+
class Bottom(object):
"""The element used to represent intervals that are not between
-(p//2) and p//2.
@@ -34,12 +108,18 @@
return isinstance(other, Bottom)
return False

+    def combine(self, other):
+        """Combining with bottom yields bottom"""
+        return self
+
def __repr__(self):
return "_|_"

+
def full_range(p):
"""The range represented modulo p"""
-    return (-(p//2), p//2)
+    return Range(-(p//2), p//2)
+

class RangeAnalysis(object):
"""The class which defines the range analysis.
@@ -61,8 +141,9 @@
function (FunctioDef)
"""
flow.analyze(function, self.join, self.combine, self.key,
-                     combine_env({'True' : (1,1), 'False' : (0,0)},
-                                  initial_env),
+                     combine_env({'True': Range(1, 1),
+                                  'False': Range(0, 0)},
+                                 initial_env),
self.distribute)

def join(self, in_nodes):
@@ -90,7 +171,8 @@
if self.is_comparison(x.test):
a = dict(x.out_values[self.key])
old = a[x.test.left.id]
-                compared_value = RangeVisitor(self.prime, x.out_values[self.key]).visit(x.test.comparators[0])
+                compared_value = RangeVisitor(self.prime,
+                    x.out_values[self.key]).visit(x.test.comparators[0])
a[x.test.left.id] = (min(old[0], compared_value[1]),
min(old[1], compared_value[1]))
print(a, x.test.left.id)
@@ -101,7 +183,6 @@
for child in x.children:
child.in_values[self.key] = x.out_values[self.key]

-
def combine(self, node, env):
"""The least upper bound of the node and the environment.

@@ -118,23 +199,26 @@

if env is None:
env = {}
+
class Visitor(ast.NodeVisitor):
"""Visitor for Python statements."""

def visit_Assign(self, node):
target = node.targets[0]
-                r = range(prime, node.value, env)
+                r = find_range(prime, node.value, env)
if(isinstance(target, ast.Name)):
env[target.id] = r
elif(isinstance(target, ast.Subscript)):
-                    env[target.value.id] = combine_range(env[target.value.id], r)
+                    print ast.dump(target)
+                    print target.lineno
+                    env[target.value.id] = env[target.value.id].combine(r)
else:
raise RuntimeError("Unsupported target of assignment")
return env

def visit_For(self, node):
target = node.target
-                r = range(prime, node.iter, env)
+                r = find_range(prime, node.iter, env)
if(isinstance(target, ast.Name)):
env[target.id] = r
else:
@@ -148,6 +232,8 @@
for keyword in decorator.keywords:
if keyword.arg == 'range':
param_range = ast.literal_eval(keyword.value)
+                            for (k, r) in param_range.items():
+                                param_range[k] = Range(r[0], r[1])
break
for arg in node.args.args:
if not arg.id in param_range:
@@ -184,15 +270,14 @@
return r

-def range(prime, node, env):
-    """range performs the computation of range of the given node.
-
-    node (ast.Node)
-    env (dict) An environment of variables and their range.
+def find_range(prime, node, env):
+    """range performs the computation of range of the given expression
+    in node. Assuming that variables have values as in the env dict.
"""
rangeVisitor = RangeVisitor(prime, env)
return rangeVisitor.visit(node)

+
class RangeVisitor(ast.NodeVisitor):
"""RangeVisitor is the visitor which actually implements the range
computation."""
@@ -206,11 +291,7 @@
self.env = env

def visit_Num(self, node):
-        if node.n > self.prime // 2:
-            return Bottom()
-        if node.n < -(self.prime // 2):
-            return Bottom()
-        return (node.n, node.n)
+        return Range(node.n, node.n)

def visit_Name(self, node):
if node.id in self.env.keys():
@@ -228,8 +309,9 @@
else:
r = self.visit(node.elts[0])
for i in node.elts[1:]:
-                r = combine_range(r, self.visit(node.elts[0]))
+                r = r.combine(self.visit(node.elts[0]))
return r
+
def visit_BinOp(self, node):
# operator = Add | Sub | Mult | Div | Mod | Pow | LShift
#          | RShift | BitOr | BitXor | BitAnd | FloorDiv
@@ -252,32 +334,28 @@
r0 = left[0] // right[1]
r1 = left[1] // right[0]
else:
-            raise Exception("Operator not implemented: ", node.op)
-        if(r0 > self.prime // 2 or r0 < -(self.prime // 2)
-           or r1 > self.prime // 2 or r1 < -(self.prime // 2)):
-            return Bottom()
-        return (r0, r1)
-
+            raise RuntimeError("Operator not implemented: ", node.op)
+        return Range(r0, r1)

def visit_Compare(self, node):
# cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn
-        return (0, 1)
+        return Range(0, 1)

def visit_Call(self, node):
# TODO: Check that random and random_bit are bound to the
# expected functions
if node.func.id == "num_players":
-            return (setup.num_players, setup.num_players)
+            return Range(setup.num_players, setup.num_players)
if node.func.id == "players":
-            return (1, setup.num_players)
+            return Range(1, setup.num_players)
if node.func.id == "id":
-            return (1, setup.nr_of_players)
+            return Range(1, setup.nr_of_players)
if node.func.id == "random":
return full_range(self.prime)
if node.func.id == "open":
return self.visit(node.args[0])
if node.func.id == "random_bit":
-            return (0, 1)
+            return Range(0, 1)
return full_range(self.prime)

@@ -298,7 +376,7 @@
env1_keys = set(env1.keys()) - set(env2.keys())
env2_keys = set(env2.keys()) - set(env1.keys())
for key in common_keys:
-        env[key] = combine_range(env1[key], env2[key])
+        env[key] = env1[key].combine(env2[key])
for key in env1_keys:
env[key] = env1[key]
for key in env2_keys:
@@ -306,23 +384,3 @@
if debug:
print "  combined env:", env
return env
-
-def combine_range(a, b):
-    """Combine the two ranges."""
-    if debug:
-        print "  combine_ranges:", a, b
-
-    if a == Bottom() or b == Bottom():
-        r = Bottom()
-    else:
-        (l1, h1), (l2, h2) = a, b
-        r = (min(l1, l2), max(h1, h2))
-
-    if debug:
-        print "  combined ranges:", r
-    return r
-
-def interval_within(a, b):
-    """ Returns true if the interval of a is completely covered by
-    that of b. """
-    return a[0] >= b[0] and a[1] <= b[1]