changeset 38:bd168f288c5b

Implemented parts of the range analysis, and corresponding tests.
author Janus Dam Nielsen <janus.nielsen@alexandra.dk>
date Tue, 26 May 2009 11:37:23 +0200
parents d574dfc3470a
children 95c5e9e0c5f8
files pysmcl/range_analysis.py pysmcl/test/unit/test_rangeanalysis.py
diffstat 2 files changed, 229 insertions(+), 86 deletions(-) [+]
line wrap: on
line diff
--- a/pysmcl/range_analysis.py	Tue May 26 11:36:41 2009 +0200
+++ b/pysmcl/range_analysis.py	Tue May 26 11:37:23 2009 +0200
@@ -1,14 +1,155 @@
 
+import ast
+
+from pysmcl import flow
 
 class Bottom(object):
-    pass
+    
+    def __eq__(self, other):
+        if other is not None:
+            return isinstance(other, Bottom)
+        return False
 
 class RangeAnalysis(object):
 
     result = None
+    key = "range"
 
     def __init__(self, prime):
+        self.prime = prime
+
+    def apply(self, function):
+        flow.analyze(function, self.join, self.combine, self.key, lambda : {})
+
+    def join(self, in_nodes):
+        print "Join"
+        env = {}
+        for n in in_nodes:
+            print "Joining: %s" % n
+            env = combine_env(env, n.out_values[self.key])
+        return env
+
+    def combine(self, node, env):
+        print "========= Combining ==========="
+        print "Node, env", node, env
+
+#         env = {}
+#         if set is []:
+#             env = set[0]
+        class Visitor(ast.NodeVisitor):
+
+            def __init__(self, prime):
+                self.prime = prime
+
+            def visit_Assign(self, node):
+                r = range(self.prime, node.value, env)
+                targetVisitor = TargetVisitor(env, r)
+                for target in node.targets:
+                    targetVisitor.visit(target)
+                return env
+           
+            def visit_FunctionDef(self, node):
+                for arg in node.args.args:
+                    env[arg.id] = (0, self.prime)
+                return env
+
+            def visit_Expr(self, node):
+                env['_'] = range(self.prime, node.value, env)
+                return env
+
+#             def generic_visit(self, node):
+#                 new_env = super(Visitor, self).generic_visit(node)
+# #                print "Generic: ", node, new_env
+#                 return new_env
+
+        r = Visitor(self.prime).visit(node)
+        print "Result of combination", r
+        print ""
+        return r
+
+
+def range(prime, node, env):
+    rangeVisitor = RangeVisitor(prime, env)
+    return rangeVisitor.visit(node)        
+
+class RangeVisitor(ast.NodeVisitor):
+
+    def __init__(self, prime, env):
+        self.prime = prime
+        self.env = env
+
+    def visit_Num(self, node):
+        return (node.n, node.n)
+
+    def visit_Name(self, node):
+        if node.id in self.env.keys():
+            return self.env[node.id]
+        return (0, self.prime)
+
+    def visit_BinOp(self, node):
+#         operator = Add | Sub | Mult | Div | Mod | Pow | LShift 
+#                  | RShift | BitOr | BitXor | BitAnd | FloorDiv
+        left = self.visit(node.left)
+        right = self.visit(node.right)
+        if isinstance(node.op, ast.Add):
+            r0 = left[0] + right[0]
+            r1 = left[1] + right[1]
+            if r0 >= self.prime:
+                r0 = Bottom()
+            if r1 >= self.prime:
+                r1 = Bottom()
+            return (r0, r1)
+        raise Exception("Operator not implemented.") 
+
+class TargetVisitor(ast.NodeVisitor):
+    """The following expression can appear in assignment context """                         
+
+    def __init__(self, env, range):
+        self.env = env
+        self.range = range
+
+    def visit_Attribute(self, value, attr, ctx):
         pass
 
-    def visit(self, node):
+    def visit_Subscript(self, value, slice, ctx):
+        pass
+
+    def visit_Name(self, id):
+#         print "id: ", id.id, self.range, self.env
+        if id in self.env.keys():
+            self.env[id.id] = combine_range(self.env[id.id], self.range)
+        else:
+            self.env[id.id] = self.range
+#        print "range:", self.result
+
+    def visit_List(self, elts, ctx):
+        pass
+
+    def visit_Tuple(self, elts, ctx):
         pass
+
+def combine_env(env1, env2):
+    env = {}
+    common_keys = set(env1.keys()) & set(env2.keys())
+    env1_keys = set(env1.keys()) - set(env2.keys())
+    env2_keys = set(env2.keys()) - set(env1.keys())
+    for key in common_keys:
+        env[key] = combine_range(env1[key], env2[key])
+    for key in env1_keys:
+        env[key] = env1[key]
+    for key in env2_keys:
+        env[key] = env2[key]
+    return env
+
+def combine_range((l1, h1), (l2, h2)):
+    r = (Buttom(),Buttom())
+    if l1 < l2:
+        r[0] = l1
+    else:
+        r[0] = l2
+
+    if h1 > h2:
+        r[1] = h1
+    else:
+        r[1] = h2    
+    return r
--- a/pysmcl/test/unit/test_rangeanalysis.py	Tue May 26 11:36:41 2009 +0200
+++ b/pysmcl/test/unit/test_rangeanalysis.py	Tue May 26 11:37:23 2009 +0200
@@ -14,117 +14,119 @@
         prog = parse("def f():\n\ty=0\n")
         init_statements(prog)
         range_analysis = RangeAnalysis(p)
-        range_analysis.visit(prog.body[0])
-        r = range_analysis.result
-        self.assertEquals(r, {'y': (0, 0)})
+        range_analysis.apply(prog.body[0])
+#         for node in prog.body[0].body:
+#             print "Result...:", node.out_values["range"]
+
+        self.assertEquals(prog.body[0].body[0].out_values["range"], {'y': (0, 0)})
 
     def test_range_unknown_range(self):
         p = 7
         prog = parse("def f(x):\n\tx\n")
         init_statements(prog)
         range_analysis = RangeAnalysis(p)
-        range_analysis.visit(prog.body[0])
-        r = range_analysis.result
-        self.assertEquals(r, {'x': (0, p)})
+        range_analysis.apply(prog.body[0])
+        r = prog.body[0].body[0].out_values["range"]
+        self.assertEquals(r, {'x': (0, p), '_': (0, p)})
 
     def test_range_add(self):
         p = 7
         prog = parse("def f(x):\n\tx=3\n\ty=2\n\tz=x+y\n")
         init_statements(prog)
         range_analysis = RangeAnalysis(p)
-        range_analysis.visit(prog.body[0])
-        r = range_analysis.result
-        self.assertEquals(r, {'z': (5, 5)})
+        range_analysis.apply(prog.body[0])
+        r = prog.body[0].body[2].out_values["range"]
+        self.assertEquals(r, {'y': (2, 2), 'x': (3, 3), 'z': (5, 5)})
 
     def test_range_add_wrap(self):
         p = 7
         prog = parse("def f(x):\n\tx=6\n\ty=2\n\tz=x+y\n")
         init_statements(prog)
         range_analysis = RangeAnalysis(p)
-        range_analysis.visit(prog.body[0])
-        r = range_analysis.result
-        self.assertEquals(r, {'z': (Bottom(), Bottom())})
+        range_analysis.apply(prog.body[0])
+        r = prog.body[0].body[2].out_values["range"]
+        self.assertEquals(r, {'y': (2, 2), 'x': (6, 6), 'z': (Bottom(), Bottom())})
 
-    def test_range_mul(self):
-        p = 7
-        prog = parse("def f(x):\n\tx=3\n\ty=2\n\tz=x*y\n")
-        init_statements(prog)
-        range_analysis = RangeAnalysis(p)
-        range_analysis.visit(prog.body[0])
-        r = range_analysis.result
-        self.assertEquals(r, {'z': (6, 6)})
+#     def test_range_mul(self):
+#         p = 7
+#         prog = parse("def f(x):\n\tx=3\n\ty=2\n\tz=x*y\n")
+#         init_statements(prog)
+#         range_analysis = RangeAnalysis(p)
+#         range_analysis.visit(prog.body[0])
+#         r = range_analysis.result
+#         self.assertEquals(r, {'z': (6, 6)})
 
-    def test_range_mul_wrap(self):
-        p = 7
-        prog = parse("def f(x):\n\tx=3\n\ty=3\n\tz=x*y\n")
-        init_statements(prog)
-        range_analysis = RangeAnalysis(p)
-        range_analysis.visit(prog.body[0])
-        r = range_analysis.result
-        self.assertEquals(r, {'z': (Bottom(), Bottom())})
+#     def test_range_mul_wrap(self):
+#         p = 7
+#         prog = parse("def f(x):\n\tx=3\n\ty=3\n\tz=x*y\n")
+#         init_statements(prog)
+#         range_analysis = RangeAnalysis(p)
+#         range_analysis.visit(prog.body[0])
+#         r = range_analysis.result
+#         self.assertEquals(r, {'z': (Bottom(), Bottom())})
 
-    def test_range_minus(self):
-        p = 7
-        prog = parse("def f(x):\n\tx=3\n\ty=2\n\tz=x-y\n")
-        init_statements(prog)
-        range_analysis = RangeAnalysis(p)
-        range_analysis.visit(prog.body[0])
-        r = range_analysis.result
-        self.assertEquals(r, {'z': (1, 1)})
+#     def test_range_minus(self):
+#         p = 7
+#         prog = parse("def f(x):\n\tx=3\n\ty=2\n\tz=x-y\n")
+#         init_statements(prog)
+#         range_analysis = RangeAnalysis(p)
+#         range_analysis.visit(prog.body[0])
+#         r = range_analysis.result
+#         self.assertEquals(r, {'z': (1, 1)})
 
-    def test_range_minus_wrap(self):
-        p = 7
-        prog = parse("def f(x):\n\tx=3\n\ty=4\n\tz=x-y\n")
-        init_statements(prog)
-        range_analysis = RangeAnalysis(p)
-        range_analysis.visit(prog.body[0])
-        r = range_analysis.result
-        self.assertEquals(r, {'z': (Bottom(), Bottom())})
+#     def test_range_minus_wrap(self):
+#         p = 7
+#         prog = parse("def f(x):\n\tx=3\n\ty=4\n\tz=x-y\n")
+#         init_statements(prog)
+#         range_analysis = RangeAnalysis(p)
+#         range_analysis.visit(prog.body[0])
+#         r = range_analysis.result
+#         self.assertEquals(r, {'z': (Bottom(), Bottom())})
 
-    def test_range_equals(self):
-        p = 7
-        prog = parse("def f(x):\n\tx=3\n\ty=2\n\tz=x==y\n")
-        init_statements(prog)
-        range_analysis = RangeAnalysis(p)
-        range_analysis.visit(prog.body[0])
-        r = range_analysis.result
-        self.assertEquals(r, {'z': (0, 1)})
+#     def test_range_equals(self):
+#         p = 7
+#         prog = parse("def f(x):\n\tx=3\n\ty=2\n\tz=x==y\n")
+#         init_statements(prog)
+#         range_analysis = RangeAnalysis(p)
+#         range_analysis.visit(prog.body[0])
+#         r = range_analysis.result
+#         self.assertEquals(r, {'z': (0, 1)})
 
-    def test_range_leg(self):
-        p = 7
-        prog = parse("def f(x):\n\tx=3\n\ty=2\n\tz=x<=y\n")
-        init_statements(prog)
-        range_analysis = RangeAnalysis(p)
-        range_analysis.visit(prog.body[0])
-        r = range_analysis.result
-        self.assertEquals(r, {'z': (0, 1)})
+#     def test_range_leg(self):
+#         p = 7
+#         prog = parse("def f(x):\n\tx=3\n\ty=2\n\tz=x<=y\n")
+#         init_statements(prog)
+#         range_analysis = RangeAnalysis(p)
+#         range_analysis.visit(prog.body[0])
+#         r = range_analysis.result
+#         self.assertEquals(r, {'z': (0, 1)})
 
-    def test_range_geg(self):
-        p = 7
-        prog = parse("def f(x):\n\tx=3\n\ty=2\n\tz=x>=y\n")
-        init_statements(prog)
-        range_analysis = RangeAnalysis(p)
-        range_analysis.visit(prog.body[0])
-        r = range_analysis.result
-        self.assertEquals(r, {'z': (0, 1)})
+#     def test_range_geg(self):
+#         p = 7
+#         prog = parse("def f(x):\n\tx=3\n\ty=2\n\tz=x>=y\n")
+#         init_statements(prog)
+#         range_analysis = RangeAnalysis(p)
+#         range_analysis.visit(prog.body[0])
+#         r = range_analysis.result
+#         self.assertEquals(r, {'z': (0, 1)})
 
-    def test_range_random(self):
-        p = 7
-        prog = parse("def f(x):\n\tx=random()\n")
-        init_statements(prog)
-        range_analysis = RangeAnalysis(p)
-        range_analysis.visit(prog.body[0])
-        r = range_analysis.result
-        self.assertEquals(r, {'x': (0, p-1)})
+#     def test_range_random(self):
+#         p = 7
+#         prog = parse("def f(x):\n\tx=random()\n")
+#         init_statements(prog)
+#         range_analysis = RangeAnalysis(p)
+#         range_analysis.visit(prog.body[0])
+#         r = range_analysis.result
+#         self.assertEquals(r, {'x': (0, p-1)})
 
-    def test_range_random_minus_bit(self):
-        p = 7
-        prog = parse("def f(x):\n\tx=random()\n\ty=x-bit()\n")
-        init_statements(prog)
-        range_analysis = RangeAnalysis(p)
-        range_analysis.visit(prog.body[0])
-        r = range_analysis.result
-        self.assertEquals(r, {'y': (0, 1)})
+#     def test_range_random_minus_bit(self):
+#         p = 7
+#         prog = parse("def f(x):\n\tx=random()\n\ty=x-bit()\n")
+#         init_statements(prog)
+#         range_analysis = RangeAnalysis(p)
+#         range_analysis.visit(prog.body[0])
+#         r = range_analysis.result
+#         self.assertEquals(r, {'y': (0, 1)})
 
 if __name__ == '__main__':
     unittest.main()