viff

changeset 599:46bd23052806

Introduce a program counter per connection. This change does not really change anything, but is a first step towards resolving Issue 22. The next step is to make sendData and _expect_data increment just one program counter.
author Martin Geisler <mg@daimi.au.dk>
date Mon, 24 Mar 2008 01:04:29 +0100
parents 79411fe0da54
children 9b04a55123de e5273143bf1f c0251598d6a3
files viff/runtime.py viff/test/test_basic_runtime.py
diffstat 2 files changed, 100 insertions(+), 94 deletions(-) [+]
line diff
     1.1 --- a/viff/runtime.py	Sat Mar 22 15:57:14 2008 +0100
     1.2 +++ b/viff/runtime.py	Mon Mar 24 01:04:29 2008 +0100
     1.3 @@ -222,6 +222,51 @@
     1.4          #: deferred data.
     1.5          self.incoming_data = {}
     1.6  
     1.7 +        #: Program counter for this connection.
     1.8 +        #:
     1.9 +        #: Whenever a share is sent over the network, it must be
    1.10 +        #: uniquely identified so that the receiving player known what
    1.11 +        #: operation the share is a result of. This is done by
    1.12 +        #: associating a X{program counter} with each operation.
    1.13 +        #:
    1.14 +        #: Keeping the program counter synchronized between all
    1.15 +        #: players ought to be easy, but because of the asynchronous
    1.16 +        #: nature of network protocols, all players might not reach
    1.17 +        #: the same parts of the program at the same time.
    1.18 +        #:
    1.19 +        #: Consider two players M{A} and M{B} who are both waiting on
    1.20 +        #: the variables C{a} and C{b}. Callbacks have been added to
    1.21 +        #: C{a} and C{b}, and the question is what program counter the
    1.22 +        #: callbacks should use when sending data out over the
    1.23 +        #: network.
    1.24 +        #:
    1.25 +        #: Let M{A} receive input for C{a} and then for C{b} a little
    1.26 +        #: later, and let M{B} receive the inputs in reversed order so
    1.27 +        #: that the input for C{b} arrives first. The goal is to keep
    1.28 +        #: the program counters synchronized so that program counter
    1.29 +        #: M{x} refers to the same operation on all players. Because
    1.30 +        #: the inputs arrive in different order at different players,
    1.31 +        #: incrementing a simple global counter is not enough.
    1.32 +        #:
    1.33 +        #: Instead, a I{tree} is made, which follows the tree of
    1.34 +        #: execution. At the top level the program counter starts at
    1.35 +        #: C{[0]}. At the next operation it becomes C{[1]}, and so on.
    1.36 +        #: If a callback is scheduled (see L{callback}) at program
    1.37 +        #: counter C{[x, y, z]}, any calls it makes will be numbered
    1.38 +        #: C{[x, y, z, 1]}, then C{[x, y, z, 2]}, and so on.
    1.39 +        #:
    1.40 +        #: Maintaining such a tree of program counters ensures that
    1.41 +        #: different parts of the program execution never reuses the
    1.42 +        #: same program counter for different variables.
    1.43 +        #:
    1.44 +        #: The L{increment_pc} decorator is responsible for
    1.45 +        #: dynamically building the tree as the execution unfolds and
    1.46 +        #: L{Runtime.schedule_callback} is responsible for scheduling
    1.47 +        #: callbacks with the correct program counter.
    1.48 +        #:
    1.49 +        #: @type: C{list} of integers.
    1.50 +        self.program_counter = [0]
    1.51 +
    1.52      def connectionMade(self):
    1.53          #print "Transport:", self.transport
    1.54          self.sendString(str(self.factory.runtime.id))
    1.55 @@ -268,26 +313,18 @@
    1.56              # TODO: marshal.loads can raise EOFError, ValueError, and
    1.57              # TypeError. They should be handled somehow.
    1.58  
    1.59 -    def sendData(self, program_counter, data_type, data):
    1.60 -        send_data = (program_counter, data_type, data)
    1.61 +    def sendData(self, data_type, data):
    1.62 +        pc = tuple(self.program_counter)
    1.63 +        send_data = (pc, data_type, data)
    1.64          self.sendString(marshal.dumps(send_data))
    1.65  
    1.66 -    def sendShare(self, program_counter, share):
    1.67 +    def sendShare(self, share):
    1.68          """Send a share.
    1.69  
    1.70 -        The program counter and the share are marshalled and sent to
    1.71 -        the peer.
    1.72 -
    1.73 -        @param program_counter: the program counter associated with
    1.74 -        the share.
    1.75 -
    1.76 -        @return: C{self} so that C{sendShare} can be used as a
    1.77 -        callback.
    1.78 +        The share is marshalled and sent to the peer along with the
    1.79 +        current program counter for this connection.
    1.80          """
    1.81 -        #println("Sending to id=%d: program_counter=%s, share=%s",
    1.82 -        #        self.id, program_counter, share)
    1.83 -
    1.84 -        self.sendData(program_counter, "share", share.value)
    1.85 +        self.sendData("share", share.value)
    1.86  
    1.87      def loseConnection(self):
    1.88          """Disconnect this protocol instance."""
    1.89 @@ -326,12 +363,14 @@
    1.90      @wrapper(method)
    1.91      def inc_pc_wrapper(self, *args, **kwargs):
    1.92          try:
    1.93 -            self.program_counter[-1] += 1
    1.94 -            self.program_counter.append(0)
    1.95 +            for protocol in self.protocols.itervalues():
    1.96 +                protocol.program_counter[-1] += 1
    1.97 +                protocol.program_counter.append(0)
    1.98              #println("Calling %s: %s", method.func_name, self.program_counter)
    1.99              return method(self, *args, **kwargs)
   1.100          finally:
   1.101 -            self.program_counter.pop()
   1.102 +            for protocol in self.protocols.itervalues():
   1.103 +                protocol.program_counter.pop()
   1.104      return inc_pc_wrapper
   1.105  
   1.106  
   1.107 @@ -391,51 +430,6 @@
   1.108              from twisted.internet import defer
   1.109              defer.setDebugging(True)
   1.110  
   1.111 -        #: Current program counter.
   1.112 -        #:
   1.113 -        #: Whenever a share is sent over the network, it must be
   1.114 -        #: uniquely identified so that the receiving player known what
   1.115 -        #: operation the share is a result of. This is done by
   1.116 -        #: associating a X{program counter} with each operation.
   1.117 -        #:
   1.118 -        #: Keeping the program counter synchronized between all
   1.119 -        #: players ought to be easy, but because of the asynchronous
   1.120 -        #: nature of network protocols, all players might not reach
   1.121 -        #: the same parts of the program at the same time.
   1.122 -        #:
   1.123 -        #: Consider two players M{A} and M{B} who are both waiting on
   1.124 -        #: the variables C{a} and C{b}. Callbacks have been added to
   1.125 -        #: C{a} and C{b}, and the question is what program counter the
   1.126 -        #: callbacks should use when sending data out over the
   1.127 -        #: network.
   1.128 -        #:
   1.129 -        #: Let M{A} receive input for C{a} and then for C{b} a little
   1.130 -        #: later, and let M{B} receive the inputs in reversed order so
   1.131 -        #: that the input for C{b} arrives first. The goal is to keep
   1.132 -        #: the program counters synchronized so that program counter
   1.133 -        #: M{x} refers to the same operation on all players. Because
   1.134 -        #: the inputs arrive in different order at different players,
   1.135 -        #: incrementing a simple global counter is not enough.
   1.136 -        #:
   1.137 -        #: Instead, a I{tree} is made, which follows the tree of
   1.138 -        #: execution. At the top level the program counter starts at
   1.139 -        #: C{[0]}. At the next operation it becomes C{[1]}, and so on.
   1.140 -        #: If a callback is scheduled (see L{callback}) at program
   1.141 -        #: counter C{[x, y, z]}, any calls it makes will be numbered
   1.142 -        #: C{[x, y, z, 1]}, then C{[x, y, z, 2]}, and so on.
   1.143 -        #:
   1.144 -        #: Maintaining such a tree of program counters ensures that
   1.145 -        #: different parts of the program execution never reuses the
   1.146 -        #: same program counter for different variables.
   1.147 -        #:
   1.148 -        #: The L{increment_pc} decorator is responsible for
   1.149 -        #: dynamically building the tree as the execution unfolds and
   1.150 -        #: L{callback} is responsible for scheduling callbacks with
   1.151 -        #: the correct program counter.
   1.152 -        #:
   1.153 -        #: @type: C{list} of integers.
   1.154 -        self.program_counter = [0]
   1.155 -
   1.156          #: Connections to the other players.
   1.157          #:
   1.158          #: @type: C{dict} from Player ID to L{ShareExchanger} objects.
   1.159 @@ -512,19 +506,27 @@
   1.160          # program counter. Simply decorating callback with increase_pc
   1.161          # does not seem to work (the multiplication benchmark hangs).
   1.162          # This should be fixed.
   1.163 -        saved_pc = self.program_counter[:]
   1.164 +
   1.165 +        def get_pcs():
   1.166 +            return [(protocol, protocol.program_counter[:]) for protocol
   1.167 +                    in self.protocols.itervalues()]
   1.168 +        def set_pcs(pcs):
   1.169 +            for protocol, pc in pcs:
   1.170 +                protocol.program_counter = pc
   1.171 +
   1.172 +        saved_pcs = get_pcs()
   1.173          #println("Saved PC: %s for %s", saved_pc, func.func_name)
   1.174  
   1.175          @wrapper(func)
   1.176          def callback_wrapper(*args, **kwargs):
   1.177              """Wrapper for a callback which ensures a correct PC."""
   1.178              try:
   1.179 -                current_pc = self.program_counter
   1.180 -                self.program_counter = saved_pc
   1.181 +                current_pcs = get_pcs()
   1.182 +                set_pcs(saved_pcs)
   1.183                  #println("Callback PC: %s", self.program_counter)
   1.184                  return func(*args, **kwargs)
   1.185              finally:
   1.186 -                self.program_counter = current_pc
   1.187 +                set_pcs(current_pcs)
   1.188  
   1.189          #println("Adding %s to %s", func.func_name, deferred)
   1.190          deferred.addCallback(callback_wrapper, *args, **kwargs)
   1.191 @@ -541,7 +543,7 @@
   1.192          assert peer_id != self.id, "Do not expect data from yourself!"
   1.193          # Convert self.program_counter to a hashable value in order to
   1.194          # use it as a key in self.protocols[peer_id].incoming_data.
   1.195 -        pc = tuple(self.program_counter)
   1.196 +        pc = tuple(self.protocols[peer_id].program_counter)
   1.197          key = (pc, data_type)
   1.198  
   1.199          data = self.protocols[peer_id].incoming_data.pop(key, None)
   1.200 @@ -564,8 +566,7 @@
   1.201              return Share(self, field_element.field, field_element)
   1.202          else:
   1.203              share = self._expect_share(id, field_element.field)
   1.204 -            pc = tuple(self.program_counter)
   1.205 -            self.protocols[id].sendShare(pc, field_element)
   1.206 +            self.protocols[id].sendShare(field_element)
   1.207              return share
   1.208  
   1.209      def _expect_share(self, peer_id, field):
   1.210 @@ -639,8 +640,7 @@
   1.211              # Send share to all receivers.
   1.212              for id in receivers:
   1.213                  if id != self.id:
   1.214 -                    pc = tuple(self.program_counter)
   1.215 -                    self.protocols[id].sendShare(pc, share)
   1.216 +                    self.protocols[id].sendShare(share)
   1.217              # Receive and recombine shares if this player is a receiver.
   1.218              if self.id in receivers:
   1.219                  deferreds = []
   1.220 @@ -767,7 +767,8 @@
   1.221          n = self.num_players
   1.222  
   1.223          # Key used for PRSS.
   1.224 -        key = tuple(self.program_counter)
   1.225 +        key = tuple([tuple(p.program_counter) for p
   1.226 +                     in self.protocols.itervalues()])
   1.227  
   1.228          # The shares for which we have all the keys.
   1.229          all_shares = []
   1.230 @@ -786,10 +787,9 @@
   1.231              correction = element - shared
   1.232              # if this player is inputter then broadcast correction value
   1.233              # TODO: more efficient broadcast?
   1.234 -            pc = tuple(self.program_counter)
   1.235              for id in self.players:
   1.236                  if self.id != id:
   1.237 -                    self.protocols[id].sendShare(pc, correction)
   1.238 +                    self.protocols[id].sendShare(correction)
   1.239  
   1.240          # Receive correction value from inputters and compute share.
   1.241          result = []
   1.242 @@ -823,7 +823,8 @@
   1.243              modulus = field.modulus
   1.244  
   1.245          # Key used for PRSS.
   1.246 -        prss_key = tuple(self.program_counter)
   1.247 +        prss_key = tuple([tuple(p.program_counter) for p
   1.248 +                          in self.protocols.itervalues()])
   1.249          prfs = self.players[self.id].prfs(modulus)
   1.250          share = prss(self.num_players, self.id, field, prfs, prss_key)
   1.251  
   1.252 @@ -877,14 +878,13 @@
   1.253          results = []
   1.254          for peer_id in inputters:
   1.255              if peer_id == self.id:
   1.256 -                pc = tuple(self.program_counter)
   1.257                  shares = shamir.share(field(number), self.threshold,
   1.258                                        self.num_players)
   1.259                  for other_id, share in shares:
   1.260                      if other_id.value == self.id:
   1.261                          results.append(Share(self, share.field, share))
   1.262                      else:
   1.263 -                        self.protocols[other_id.value].sendShare(pc, share)
   1.264 +                        self.protocols[other_id.value].sendShare(share)
   1.265              else:
   1.266                  results.append(self._expect_share(peer_id, field))
   1.267  
   1.268 @@ -909,7 +909,8 @@
   1.269          """
   1.270  
   1.271          result = Deferred()
   1.272 -        pc = tuple(self.program_counter)
   1.273 +        pc = tuple([tuple(p.program_counter) for p
   1.274 +                    in self.protocols.itervalues()])
   1.275          n = self.num_players
   1.276          t = self.threshold
   1.277  
     2.1 --- a/viff/test/test_basic_runtime.py	Sat Mar 22 15:57:14 2008 +0100
     2.2 +++ b/viff/test/test_basic_runtime.py	Mon Mar 24 01:04:29 2008 +0100
     2.3 @@ -26,9 +26,14 @@
     2.4  class ProgramCounterTest(RuntimeTestCase):
     2.5      """Program counter tests."""
     2.6  
     2.7 +    def assert_pc(self, runtime, pc):
     2.8 +        """Assert that all protocols has a given program counter."""
     2.9 +        for p in runtime.protocols.itervalues():
    2.10 +            self.assertEquals(p.program_counter, pc)
    2.11 +
    2.12      @protocol
    2.13      def test_initial_value(self, runtime):
    2.14 -        self.assertEquals(runtime.program_counter, [0])
    2.15 +        self.assert_pc(runtime, [0])
    2.16  
    2.17      @protocol
    2.18      def test_simple_operation(self, runtime):
    2.19 @@ -37,9 +42,9 @@
    2.20          Each call should increment the program counter by one.
    2.21          """
    2.22          runtime.synchronize()
    2.23 -        self.assertEquals(runtime.program_counter, [1])
    2.24 +        self.assert_pc(runtime, [1])
    2.25          runtime.synchronize()
    2.26 -        self.assertEquals(runtime.program_counter, [2])
    2.27 +        self.assert_pc(runtime, [2])
    2.28  
    2.29      @protocol
    2.30      def test_complex_operation(self, runtime):
    2.31 @@ -51,9 +56,9 @@
    2.32          # Exclusive-or is calculated as x + y - 2 * x * y, so add,
    2.33          # sub, and mul are called.
    2.34          runtime.xor(self.Zp(0), self.Zp(1))
    2.35 -        self.assertEquals(runtime.program_counter, [1])
    2.36 +        self.assert_pc(runtime, [1])
    2.37          runtime.xor(self.Zp(0), self.Zp(1))
    2.38 -        self.assertEquals(runtime.program_counter, [2])
    2.39 +        self.assert_pc(runtime, [2])
    2.40  
    2.41      @protocol
    2.42      def test_callback(self, runtime):
    2.43 @@ -64,13 +69,13 @@
    2.44          """
    2.45  
    2.46          def verify_program_counter(_):
    2.47 -            self.assertEquals(runtime.program_counter, [0])
    2.48 +            self.assert_pc(runtime, [0])
    2.49  
    2.50          d = Deferred()
    2.51          runtime.schedule_callback(d, verify_program_counter)
    2.52  
    2.53          runtime.synchronize()
    2.54 -        self.assertEquals(runtime.program_counter, [1])
    2.55 +        self.assert_pc(runtime, [1])
    2.56  
    2.57          # Now trigger verify_program_counter.
    2.58          d.callback(None)
    2.59 @@ -87,26 +92,26 @@
    2.60              # First top-level call, so first entry is 1. No calls to
    2.61              # other methods decorated with increment_pc has been made,
    2.62              # so the second entry is 0.
    2.63 -            self.assertEquals(runtime.program_counter, [1, 0])
    2.64 +            self.assert_pc(runtime, [1, 0])
    2.65              method_b(runtime, 1)
    2.66  
    2.67 -            self.assertEquals(runtime.program_counter, [1, 1])
    2.68 +            self.assert_pc(runtime, [1, 1])
    2.69              method_b(runtime, 2)
    2.70  
    2.71              # At this point two sub-calls has been made:
    2.72 -            self.assertEquals(runtime.program_counter, [1, 2])
    2.73 +            self.assert_pc(runtime, [1, 2])
    2.74  
    2.75          @increment_pc
    2.76          def method_b(runtime, count):
    2.77              # This method is called twice from method_a:
    2.78 -            self.assertEquals(runtime.program_counter, [1, count, 0])
    2.79 +            self.assert_pc(runtime, [1, count, 0])
    2.80  
    2.81          # Zero top-level calls:
    2.82 -        self.assertEquals(runtime.program_counter, [0])
    2.83 +        self.assert_pc(runtime, [0])
    2.84          method_a(runtime)
    2.85  
    2.86          # One top-level call:
    2.87 -        self.assertEquals(runtime.program_counter, [1])
    2.88 +        self.assert_pc(runtime, [1])
    2.89  
    2.90      @protocol
    2.91      def test_multiple_callbacks(self, runtime):
    2.92 @@ -115,11 +120,11 @@
    2.93          d2 = Deferred()
    2.94  
    2.95          def verify_program_counter(_, count):
    2.96 -            self.assertEquals(runtime.program_counter, [1, count, 0])
    2.97 +            self.assert_pc(runtime, [1, count, 0])
    2.98  
    2.99          @increment_pc
   2.100          def method_a(runtime):
   2.101 -            self.assertEquals(runtime.program_counter, [1, 0])
   2.102 +            self.assert_pc(runtime, [1, 0])
   2.103  
   2.104              runtime.schedule_callback(d1, verify_program_counter, 1)
   2.105              runtime.schedule_callback(d2, verify_program_counter, 2)