changeset 599:46bd23052806

Introduce a program counter per connection. This change does not really change anything, but is a first step towards resolving Issue 22. The next step is to make sendData and _expect_data increment just one program counter.
author Martin Geisler <mg@daimi.au.dk>
date Mon, 24 Mar 2008 01:04:29 +0100
parents 79411fe0da54
children 9b04a55123de e5273143bf1f c0251598d6a3
files viff/runtime.py viff/test/test_basic_runtime.py
diffstat 2 files changed, 100 insertions(+), 94 deletions(-) [+]
line wrap: on
line diff
--- a/viff/runtime.py	Sat Mar 22 15:57:14 2008 +0100
+++ b/viff/runtime.py	Mon Mar 24 01:04:29 2008 +0100
@@ -222,6 +222,51 @@
         #: deferred data.
         self.incoming_data = {}
 
+        #: Program counter for this connection.
+        #:
+        #: Whenever a share is sent over the network, it must be
+        #: uniquely identified so that the receiving player known what
+        #: operation the share is a result of. This is done by
+        #: associating a X{program counter} with each operation.
+        #:
+        #: Keeping the program counter synchronized between all
+        #: players ought to be easy, but because of the asynchronous
+        #: nature of network protocols, all players might not reach
+        #: the same parts of the program at the same time.
+        #:
+        #: Consider two players M{A} and M{B} who are both waiting on
+        #: the variables C{a} and C{b}. Callbacks have been added to
+        #: C{a} and C{b}, and the question is what program counter the
+        #: callbacks should use when sending data out over the
+        #: network.
+        #:
+        #: Let M{A} receive input for C{a} and then for C{b} a little
+        #: later, and let M{B} receive the inputs in reversed order so
+        #: that the input for C{b} arrives first. The goal is to keep
+        #: the program counters synchronized so that program counter
+        #: M{x} refers to the same operation on all players. Because
+        #: the inputs arrive in different order at different players,
+        #: incrementing a simple global counter is not enough.
+        #:
+        #: Instead, a I{tree} is made, which follows the tree of
+        #: execution. At the top level the program counter starts at
+        #: C{[0]}. At the next operation it becomes C{[1]}, and so on.
+        #: If a callback is scheduled (see L{callback}) at program
+        #: counter C{[x, y, z]}, any calls it makes will be numbered
+        #: C{[x, y, z, 1]}, then C{[x, y, z, 2]}, and so on.
+        #:
+        #: Maintaining such a tree of program counters ensures that
+        #: different parts of the program execution never reuses the
+        #: same program counter for different variables.
+        #:
+        #: The L{increment_pc} decorator is responsible for
+        #: dynamically building the tree as the execution unfolds and
+        #: L{Runtime.schedule_callback} is responsible for scheduling
+        #: callbacks with the correct program counter.
+        #:
+        #: @type: C{list} of integers.
+        self.program_counter = [0]
+
     def connectionMade(self):
         #print "Transport:", self.transport
         self.sendString(str(self.factory.runtime.id))
@@ -268,26 +313,18 @@
             # TODO: marshal.loads can raise EOFError, ValueError, and
             # TypeError. They should be handled somehow.
 
-    def sendData(self, program_counter, data_type, data):
-        send_data = (program_counter, data_type, data)
+    def sendData(self, data_type, data):
+        pc = tuple(self.program_counter)
+        send_data = (pc, data_type, data)
         self.sendString(marshal.dumps(send_data))
 
-    def sendShare(self, program_counter, share):
+    def sendShare(self, share):
         """Send a share.
 
-        The program counter and the share are marshalled and sent to
-        the peer.
-
-        @param program_counter: the program counter associated with
-        the share.
-
-        @return: C{self} so that C{sendShare} can be used as a
-        callback.
+        The share is marshalled and sent to the peer along with the
+        current program counter for this connection.
         """
-        #println("Sending to id=%d: program_counter=%s, share=%s",
-        #        self.id, program_counter, share)
-
-        self.sendData(program_counter, "share", share.value)
+        self.sendData("share", share.value)
 
     def loseConnection(self):
         """Disconnect this protocol instance."""
@@ -326,12 +363,14 @@
     @wrapper(method)
     def inc_pc_wrapper(self, *args, **kwargs):
         try:
-            self.program_counter[-1] += 1
-            self.program_counter.append(0)
+            for protocol in self.protocols.itervalues():
+                protocol.program_counter[-1] += 1
+                protocol.program_counter.append(0)
             #println("Calling %s: %s", method.func_name, self.program_counter)
             return method(self, *args, **kwargs)
         finally:
-            self.program_counter.pop()
+            for protocol in self.protocols.itervalues():
+                protocol.program_counter.pop()
     return inc_pc_wrapper
 
 
@@ -391,51 +430,6 @@
             from twisted.internet import defer
             defer.setDebugging(True)
 
-        #: Current program counter.
-        #:
-        #: Whenever a share is sent over the network, it must be
-        #: uniquely identified so that the receiving player known what
-        #: operation the share is a result of. This is done by
-        #: associating a X{program counter} with each operation.
-        #:
-        #: Keeping the program counter synchronized between all
-        #: players ought to be easy, but because of the asynchronous
-        #: nature of network protocols, all players might not reach
-        #: the same parts of the program at the same time.
-        #:
-        #: Consider two players M{A} and M{B} who are both waiting on
-        #: the variables C{a} and C{b}. Callbacks have been added to
-        #: C{a} and C{b}, and the question is what program counter the
-        #: callbacks should use when sending data out over the
-        #: network.
-        #:
-        #: Let M{A} receive input for C{a} and then for C{b} a little
-        #: later, and let M{B} receive the inputs in reversed order so
-        #: that the input for C{b} arrives first. The goal is to keep
-        #: the program counters synchronized so that program counter
-        #: M{x} refers to the same operation on all players. Because
-        #: the inputs arrive in different order at different players,
-        #: incrementing a simple global counter is not enough.
-        #:
-        #: Instead, a I{tree} is made, which follows the tree of
-        #: execution. At the top level the program counter starts at
-        #: C{[0]}. At the next operation it becomes C{[1]}, and so on.
-        #: If a callback is scheduled (see L{callback}) at program
-        #: counter C{[x, y, z]}, any calls it makes will be numbered
-        #: C{[x, y, z, 1]}, then C{[x, y, z, 2]}, and so on.
-        #:
-        #: Maintaining such a tree of program counters ensures that
-        #: different parts of the program execution never reuses the
-        #: same program counter for different variables.
-        #:
-        #: The L{increment_pc} decorator is responsible for
-        #: dynamically building the tree as the execution unfolds and
-        #: L{callback} is responsible for scheduling callbacks with
-        #: the correct program counter.
-        #:
-        #: @type: C{list} of integers.
-        self.program_counter = [0]
-
         #: Connections to the other players.
         #:
         #: @type: C{dict} from Player ID to L{ShareExchanger} objects.
@@ -512,19 +506,27 @@
         # program counter. Simply decorating callback with increase_pc
         # does not seem to work (the multiplication benchmark hangs).
         # This should be fixed.
-        saved_pc = self.program_counter[:]
+
+        def get_pcs():
+            return [(protocol, protocol.program_counter[:]) for protocol
+                    in self.protocols.itervalues()]
+        def set_pcs(pcs):
+            for protocol, pc in pcs:
+                protocol.program_counter = pc
+
+        saved_pcs = get_pcs()
         #println("Saved PC: %s for %s", saved_pc, func.func_name)
 
         @wrapper(func)
         def callback_wrapper(*args, **kwargs):
             """Wrapper for a callback which ensures a correct PC."""
             try:
-                current_pc = self.program_counter
-                self.program_counter = saved_pc
+                current_pcs = get_pcs()
+                set_pcs(saved_pcs)
                 #println("Callback PC: %s", self.program_counter)
                 return func(*args, **kwargs)
             finally:
-                self.program_counter = current_pc
+                set_pcs(current_pcs)
 
         #println("Adding %s to %s", func.func_name, deferred)
         deferred.addCallback(callback_wrapper, *args, **kwargs)
@@ -541,7 +543,7 @@
         assert peer_id != self.id, "Do not expect data from yourself!"
         # Convert self.program_counter to a hashable value in order to
         # use it as a key in self.protocols[peer_id].incoming_data.
-        pc = tuple(self.program_counter)
+        pc = tuple(self.protocols[peer_id].program_counter)
         key = (pc, data_type)
 
         data = self.protocols[peer_id].incoming_data.pop(key, None)
@@ -564,8 +566,7 @@
             return Share(self, field_element.field, field_element)
         else:
             share = self._expect_share(id, field_element.field)
-            pc = tuple(self.program_counter)
-            self.protocols[id].sendShare(pc, field_element)
+            self.protocols[id].sendShare(field_element)
             return share
 
     def _expect_share(self, peer_id, field):
@@ -639,8 +640,7 @@
             # Send share to all receivers.
             for id in receivers:
                 if id != self.id:
-                    pc = tuple(self.program_counter)
-                    self.protocols[id].sendShare(pc, share)
+                    self.protocols[id].sendShare(share)
             # Receive and recombine shares if this player is a receiver.
             if self.id in receivers:
                 deferreds = []
@@ -767,7 +767,8 @@
         n = self.num_players
 
         # Key used for PRSS.
-        key = tuple(self.program_counter)
+        key = tuple([tuple(p.program_counter) for p
+                     in self.protocols.itervalues()])
 
         # The shares for which we have all the keys.
         all_shares = []
@@ -786,10 +787,9 @@
             correction = element - shared
             # if this player is inputter then broadcast correction value
             # TODO: more efficient broadcast?
-            pc = tuple(self.program_counter)
             for id in self.players:
                 if self.id != id:
-                    self.protocols[id].sendShare(pc, correction)
+                    self.protocols[id].sendShare(correction)
 
         # Receive correction value from inputters and compute share.
         result = []
@@ -823,7 +823,8 @@
             modulus = field.modulus
 
         # Key used for PRSS.
-        prss_key = tuple(self.program_counter)
+        prss_key = tuple([tuple(p.program_counter) for p
+                          in self.protocols.itervalues()])
         prfs = self.players[self.id].prfs(modulus)
         share = prss(self.num_players, self.id, field, prfs, prss_key)
 
@@ -877,14 +878,13 @@
         results = []
         for peer_id in inputters:
             if peer_id == self.id:
-                pc = tuple(self.program_counter)
                 shares = shamir.share(field(number), self.threshold,
                                       self.num_players)
                 for other_id, share in shares:
                     if other_id.value == self.id:
                         results.append(Share(self, share.field, share))
                     else:
-                        self.protocols[other_id.value].sendShare(pc, share)
+                        self.protocols[other_id.value].sendShare(share)
             else:
                 results.append(self._expect_share(peer_id, field))
 
@@ -909,7 +909,8 @@
         """
 
         result = Deferred()
-        pc = tuple(self.program_counter)
+        pc = tuple([tuple(p.program_counter) for p
+                    in self.protocols.itervalues()])
         n = self.num_players
         t = self.threshold
 
--- a/viff/test/test_basic_runtime.py	Sat Mar 22 15:57:14 2008 +0100
+++ b/viff/test/test_basic_runtime.py	Mon Mar 24 01:04:29 2008 +0100
@@ -26,9 +26,14 @@
 class ProgramCounterTest(RuntimeTestCase):
     """Program counter tests."""
 
+    def assert_pc(self, runtime, pc):
+        """Assert that all protocols has a given program counter."""
+        for p in runtime.protocols.itervalues():
+            self.assertEquals(p.program_counter, pc)
+
     @protocol
     def test_initial_value(self, runtime):
-        self.assertEquals(runtime.program_counter, [0])
+        self.assert_pc(runtime, [0])
 
     @protocol
     def test_simple_operation(self, runtime):
@@ -37,9 +42,9 @@
         Each call should increment the program counter by one.
         """
         runtime.synchronize()
-        self.assertEquals(runtime.program_counter, [1])
+        self.assert_pc(runtime, [1])
         runtime.synchronize()
-        self.assertEquals(runtime.program_counter, [2])
+        self.assert_pc(runtime, [2])
 
     @protocol
     def test_complex_operation(self, runtime):
@@ -51,9 +56,9 @@
         # Exclusive-or is calculated as x + y - 2 * x * y, so add,
         # sub, and mul are called.
         runtime.xor(self.Zp(0), self.Zp(1))
-        self.assertEquals(runtime.program_counter, [1])
+        self.assert_pc(runtime, [1])
         runtime.xor(self.Zp(0), self.Zp(1))
-        self.assertEquals(runtime.program_counter, [2])
+        self.assert_pc(runtime, [2])
 
     @protocol
     def test_callback(self, runtime):
@@ -64,13 +69,13 @@
         """
 
         def verify_program_counter(_):
-            self.assertEquals(runtime.program_counter, [0])
+            self.assert_pc(runtime, [0])
 
         d = Deferred()
         runtime.schedule_callback(d, verify_program_counter)
 
         runtime.synchronize()
-        self.assertEquals(runtime.program_counter, [1])
+        self.assert_pc(runtime, [1])
 
         # Now trigger verify_program_counter.
         d.callback(None)
@@ -87,26 +92,26 @@
             # First top-level call, so first entry is 1. No calls to
             # other methods decorated with increment_pc has been made,
             # so the second entry is 0.
-            self.assertEquals(runtime.program_counter, [1, 0])
+            self.assert_pc(runtime, [1, 0])
             method_b(runtime, 1)
 
-            self.assertEquals(runtime.program_counter, [1, 1])
+            self.assert_pc(runtime, [1, 1])
             method_b(runtime, 2)
 
             # At this point two sub-calls has been made:
-            self.assertEquals(runtime.program_counter, [1, 2])
+            self.assert_pc(runtime, [1, 2])
 
         @increment_pc
         def method_b(runtime, count):
             # This method is called twice from method_a:
-            self.assertEquals(runtime.program_counter, [1, count, 0])
+            self.assert_pc(runtime, [1, count, 0])
 
         # Zero top-level calls:
-        self.assertEquals(runtime.program_counter, [0])
+        self.assert_pc(runtime, [0])
         method_a(runtime)
 
         # One top-level call:
-        self.assertEquals(runtime.program_counter, [1])
+        self.assert_pc(runtime, [1])
 
     @protocol
     def test_multiple_callbacks(self, runtime):
@@ -115,11 +120,11 @@
         d2 = Deferred()
 
         def verify_program_counter(_, count):
-            self.assertEquals(runtime.program_counter, [1, count, 0])
+            self.assert_pc(runtime, [1, count, 0])
 
         @increment_pc
         def method_a(runtime):
-            self.assertEquals(runtime.program_counter, [1, 0])
+            self.assert_pc(runtime, [1, 0])
 
             runtime.schedule_callback(d1, verify_program_counter, 1)
             runtime.schedule_callback(d2, verify_program_counter, 2)