Mercurial > pysmcl
changeset 219:1683fdd52828
range_analysis: use a class for Ranges, use widening
author | Sigurd Meldgaard <stm@daimi.au.dk> |
---|---|
date | Wed, 23 Dec 2009 14:34:07 +0100 |
parents | 35a38b49b5cd |
children | 01a6d95bf029 |
files | pysmcl/range_analysis.py |
diffstat | 1 files changed, 109 insertions(+), 51 deletions(-) [+] |
line wrap: on
line diff
--- a/pysmcl/range_analysis.py Wed Dec 23 14:33:31 2009 +0100 +++ b/pysmcl/range_analysis.py Wed Dec 23 14:34:07 2009 +0100 @@ -12,6 +12,80 @@ # Switch to True for debug output printed to std. out. debug = False + +class RangeError(Exception): + pass + + +def widen(n): + if(n < 0): + return -widen(-n) + if n < 128: + return n + else: + for i in range(3, 32): + power = 2**i + if n<power: + return power + raise RangeError() + + +class RangeC(object): + + def __init__(self, a, b): + self.a = widen(a) + self.b = widen(b) + + def tuple(self): + """Returns a tuple representation of this range.""" + return (self.a, self.b) + + def combine(self, other): + """Combine the two ranges.""" + if debug: + print " combine_ranges:", self, other + + if isinstance(other, Bottom): + return Bottom() + else: + (l1, h1), (l2, h2) = self.tuple(), other.tuple() + r = Range(min(l1, l2), max(h1, h2)) + + if debug: + print " combined ranges:", r + return r + + def within(self, other): + """ Returns true if the interval of self is completely + covered by that of other. """ + return self.a >= other.a and self.b <= other.b + + def __eq__(self, other): + if isinstance(other, Bottom): + return False + else: + return other.a == self.a and other.b == self.b + + def __repr__(self): + return self.tuple().__repr__() + + def __getitem__(self, i): + if i == 0: + return self.a + elif i == 1: + return self.b + else: + raise IndexError("Ranges can only be accessed" + " in 0, 1 but was: %s" % i) + + +def Range(a, b): + try: + return RangeC(a, b) + except RangeError: + return Bottom() + + class Bottom(object): """The element used to represent intervals that are not between -(p//2) and p//2. @@ -34,12 +108,18 @@ return isinstance(other, Bottom) return False + def combine(self, other): + """Combining with bottom yields bottom""" + return self + def __repr__(self): return "_|_" + def full_range(p): """The range represented modulo p""" - return (-(p//2), p//2) + return Range(-(p//2), p//2) + class RangeAnalysis(object): """The class which defines the range analysis. @@ -61,8 +141,9 @@ function (FunctioDef) """ flow.analyze(function, self.join, self.combine, self.key, - combine_env({'True' : (1,1), 'False' : (0,0)}, - initial_env), + combine_env({'True': Range(1, 1), + 'False': Range(0, 0)}, + initial_env), self.distribute) def join(self, in_nodes): @@ -90,7 +171,8 @@ if self.is_comparison(x.test): a = dict(x.out_values[self.key]) old = a[x.test.left.id] - compared_value = RangeVisitor(self.prime, x.out_values[self.key]).visit(x.test.comparators[0]) + compared_value = RangeVisitor(self.prime, + x.out_values[self.key]).visit(x.test.comparators[0]) a[x.test.left.id] = (min(old[0], compared_value[1]), min(old[1], compared_value[1])) print(a, x.test.left.id) @@ -101,7 +183,6 @@ for child in x.children: child.in_values[self.key] = x.out_values[self.key] - def combine(self, node, env): """The least upper bound of the node and the environment. @@ -118,23 +199,26 @@ if env is None: env = {} + class Visitor(ast.NodeVisitor): """Visitor for Python statements.""" def visit_Assign(self, node): target = node.targets[0] - r = range(prime, node.value, env) + r = find_range(prime, node.value, env) if(isinstance(target, ast.Name)): env[target.id] = r elif(isinstance(target, ast.Subscript)): - env[target.value.id] = combine_range(env[target.value.id], r) + print ast.dump(target) + print target.lineno + env[target.value.id] = env[target.value.id].combine(r) else: raise RuntimeError("Unsupported target of assignment") return env def visit_For(self, node): target = node.target - r = range(prime, node.iter, env) + r = find_range(prime, node.iter, env) if(isinstance(target, ast.Name)): env[target.id] = r else: @@ -148,6 +232,8 @@ for keyword in decorator.keywords: if keyword.arg == 'range': param_range = ast.literal_eval(keyword.value) + for (k, r) in param_range.items(): + param_range[k] = Range(r[0], r[1]) break for arg in node.args.args: if not arg.id in param_range: @@ -184,15 +270,14 @@ return r -def range(prime, node, env): - """range performs the computation of range of the given node. - - node (ast.Node) - env (dict) An environment of variables and their range. +def find_range(prime, node, env): + """range performs the computation of range of the given expression + in node. Assuming that variables have values as in the env dict. """ rangeVisitor = RangeVisitor(prime, env) return rangeVisitor.visit(node) + class RangeVisitor(ast.NodeVisitor): """RangeVisitor is the visitor which actually implements the range computation.""" @@ -206,11 +291,7 @@ self.env = env def visit_Num(self, node): - if node.n > self.prime // 2: - return Bottom() - if node.n < -(self.prime // 2): - return Bottom() - return (node.n, node.n) + return Range(node.n, node.n) def visit_Name(self, node): if node.id in self.env.keys(): @@ -228,8 +309,9 @@ else: r = self.visit(node.elts[0]) for i in node.elts[1:]: - r = combine_range(r, self.visit(node.elts[0])) + r = r.combine(self.visit(node.elts[0])) return r + def visit_BinOp(self, node): # operator = Add | Sub | Mult | Div | Mod | Pow | LShift # | RShift | BitOr | BitXor | BitAnd | FloorDiv @@ -252,32 +334,28 @@ r0 = left[0] // right[1] r1 = left[1] // right[0] else: - raise Exception("Operator not implemented: ", node.op) - if(r0 > self.prime // 2 or r0 < -(self.prime // 2) - or r1 > self.prime // 2 or r1 < -(self.prime // 2)): - return Bottom() - return (r0, r1) - + raise RuntimeError("Operator not implemented: ", node.op) + return Range(r0, r1) def visit_Compare(self, node): # cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn - return (0, 1) + return Range(0, 1) def visit_Call(self, node): # TODO: Check that random and random_bit are bound to the # expected functions if node.func.id == "num_players": - return (setup.num_players, setup.num_players) + return Range(setup.num_players, setup.num_players) if node.func.id == "players": - return (1, setup.num_players) + return Range(1, setup.num_players) if node.func.id == "id": - return (1, setup.nr_of_players) + return Range(1, setup.nr_of_players) if node.func.id == "random": return full_range(self.prime) if node.func.id == "open": return self.visit(node.args[0]) if node.func.id == "random_bit": - return (0, 1) + return Range(0, 1) return full_range(self.prime) @@ -298,7 +376,7 @@ env1_keys = set(env1.keys()) - set(env2.keys()) env2_keys = set(env2.keys()) - set(env1.keys()) for key in common_keys: - env[key] = combine_range(env1[key], env2[key]) + env[key] = env1[key].combine(env2[key]) for key in env1_keys: env[key] = env1[key] for key in env2_keys: @@ -306,23 +384,3 @@ if debug: print " combined env:", env return env - -def combine_range(a, b): - """Combine the two ranges.""" - if debug: - print " combine_ranges:", a, b - - if a == Bottom() or b == Bottom(): - r = Bottom() - else: - (l1, h1), (l2, h2) = a, b - r = (min(l1, l2), max(h1, h2)) - - if debug: - print " combined ranges:", r - return r - -def interval_within(a, b): - """ Returns true if the interval of a is completely covered by - that of b. """ - return a[0] >= b[0] and a[1] <= b[1]