viff

changeset 1527:7dd55d319d2b

BeDOZa: Added size constraints to _mul.
author Janus Dam Nielsen <janus.nielsen@alexandra.dk>
date Thu, 29 Jul 2010 15:56:58 +0200
parents a87fd09f8c38
children e7b2fe7eb753
files viff/bedoza/bedoza_triple.py viff/test/test_bedoza_triple.py
diffstat 2 files changed, 73 insertions(+), 28 deletions(-) [+]
line diff
     1.1 --- a/viff/bedoza/bedoza_triple.py	Thu Jul 29 11:31:24 2010 +0200
     1.2 +++ b/viff/bedoza/bedoza_triple.py	Thu Jul 29 15:56:58 2010 +0200
     1.3 @@ -158,7 +158,7 @@
     1.4          #     receive c from player i and set 
     1.5          #         m^i=Decrypt(c)
     1.6      
     1.7 -    def _mul(self, inx, jnx, ais=None, cjs=None):
     1.8 +    def _mul(self, inx, jnx, n, ais=None, cjs=None):
     1.9          """Multiply each of the field elements in *ais* with the
    1.10          corresponding encrypted elements in *cjs*.
    1.11          
    1.12 @@ -167,6 +167,21 @@
    1.13          CKIND = 1
    1.14          DiKIND = 2
    1.15          DjKIND = 3
    1.16 +
    1.17 +        """The transmission_restraint_constant is the number of
    1.18 +        encrypted shares we can safely transmit in one call to
    1.19 +        sendData. The sendData method can only transmit up to
    1.20 +        65536 bytes.
    1.21 +        The constant has been imperically determined by running
    1.22 +        TripleGenerator.generate_triples.
    1.23 +        TODO: How can we allow a user of the runtime to adjust this
    1.24 +        constraint at a higher level of abstraction?
    1.25 +        """
    1.26 +        transmission_restraint_constant = 425
    1.27 +
    1.28 +        number_of_packets = n / transmission_restraint_constant
    1.29 +        if n % transmission_restraint_constant != 0:
    1.30 +            number_of_packets += 1
    1.31          
    1.32          self.runtime.increment_pc()
    1.33  
    1.34 @@ -176,61 +191,86 @@
    1.35          zis = []
    1.36          if self.runtime.id == inx:
    1.37              Nj_square = self.paillier.get_modulus_square(jnx)
    1.38 -            cs = []
    1.39 -            dis = []
    1.40 -            for ai, cj in zip(ais, cjs):
    1.41 +            all_cs = []
    1.42 +            all_dis = []
    1.43 +            for iny, (ai, cj) in enumerate(zip(ais, cjs)):
    1.44 +                if iny % transmission_restraint_constant == 0:
    1.45 +                    cs = []
    1.46 +                    all_cs.append(cs)
    1.47 +                    dis = []
    1.48 +                    all_dis.append(dis)
    1.49                  u = rand.randint(0, self.u_bound)
    1.50                  Ej_u = self.paillier.encrypt(u, jnx)
    1.51                  cs.append( (pow(cj, ai.value, Nj_square) * Ej_u) % Nj_square )
    1.52                  zi = self.Zp(-u)
    1.53                  zis.append(zi)
    1.54                  dis.append(self.paillier.encrypt(zi.value, inx))
    1.55 -            self.runtime.protocols[jnx].sendData(pc, CKIND, str(cs))
    1.56 +                
    1.57 +            for cs in all_cs:
    1.58 +                self.runtime.protocols[jnx].sendData(pc, CKIND, str(cs))
    1.59  
    1.60 -            for player_id in self.runtime.players:
    1.61 -                self.runtime.protocols[player_id].sendData(pc, DiKIND, str(dis))
    1.62 +            for dis in all_dis:
    1.63 +                for player_id in self.runtime.players:
    1.64 +                    self.runtime.protocols[player_id].sendData(pc, DiKIND, str(dis))
    1.65  
    1.66          if self.runtime.id == jnx:
    1.67 -            cs = Deferred()
    1.68 -            self.runtime._expect_data(inx, CKIND, cs)
    1.69 -            def decrypt(cs, pc, zis):
    1.70 +            all_cs = []
    1.71 +            for _ in xrange(number_of_packets):
    1.72 +                cs = Deferred()
    1.73 +                self.runtime._expect_data(inx, CKIND, cs)
    1.74 +                all_cs.append(cs)
    1.75 +                
    1.76 +            def decrypt(all_cs, pc, zis):
    1.77                  zjs = []
    1.78 -                djs = []
    1.79 -                for c in eval(cs):
    1.80 +                cs = reduce(lambda x, y: x + eval(y), all_cs, [])
    1.81 +                all_djs = []
    1.82 +                for iny, c in enumerate(cs):
    1.83 +                    if iny % transmission_restraint_constant == 0:
    1.84 +                        djs = []
    1.85 +                        all_djs.append(djs)
    1.86                      t = self.paillier.decrypt(c)
    1.87                      zj = self.Zp(t)
    1.88                      zjs.append(zj)
    1.89                      djs.append(self.paillier.encrypt(zj.value, jnx))
    1.90 -                for player_id in self.runtime.players:
    1.91 -                    self.runtime.protocols[player_id].sendData(pc, DjKIND, str(djs))
    1.92 +                for djs in all_djs:
    1.93 +                    for player_id in self.runtime.players:
    1.94 +                        self.runtime.protocols[player_id].sendData(pc, DjKIND, str(djs))
    1.95                  if not zis == []:
    1.96                      return [x + y for x, y in zip(zis, zjs)]
    1.97                  else:
    1.98                      return zjs 
    1.99 -            cs.addCallback(decrypt, pc, zis)
   1.100 -            deferreds.append(cs)
   1.101 +            all_cs_d = gatherResults(all_cs)
   1.102 +            all_cs_d.addCallback(decrypt, pc, zis)
   1.103 +            deferreds.append(all_cs_d)
   1.104          else:
   1.105              zis_deferred = Deferred()
   1.106              zis_deferred.callback(zis)
   1.107              deferreds.append(zis_deferred)
   1.108  
   1.109 -        dis = Deferred()
   1.110 -        self.runtime._expect_data(inx, DiKIND, dis)
   1.111 -        djs = Deferred()
   1.112 -        self.runtime._expect_data(jnx, DjKIND, djs)
   1.113 +        all_dis = []
   1.114 +        for _ in xrange(number_of_packets):
   1.115 +            dis = Deferred()
   1.116 +            self.runtime._expect_data(inx, DiKIND, dis)
   1.117 +            all_dis.append(dis)
   1.118 +        all_djs = []
   1.119 +        for _ in xrange(number_of_packets):
   1.120 +            djs = Deferred()
   1.121 +            self.runtime._expect_data(jnx, DjKIND, djs)
   1.122 +            all_djs.append(djs)
   1.123  
   1.124 -        deferreds.append(dis)
   1.125 -        deferreds.append(djs)
   1.126 +        deferreds.append(gatherResults(all_dis))
   1.127 +        deferreds.append(gatherResults(all_djs))
   1.128          r = gatherResults(deferreds)
   1.129          def wrap((values, dis, djs), inx, jnx):
   1.130 -            dis = eval(dis)
   1.131 -            djs = eval(djs)
   1.132 +            dis = reduce(lambda x, y: x + eval(y), dis, [])
   1.133 +            djs = reduce(lambda x, y: x + eval(y), djs, [])
   1.134              n_square_i = self.paillier.get_modulus_square(inx)
   1.135              n_square_j = self.paillier.get_modulus_square(jnx)
   1.136              N_squared_list = [self.paillier.get_modulus_square(player_id) for player_id in self.runtime.players]
   1.137              ps = []
   1.138 +            
   1.139              for v, di, dj in itertools.izip_longest(values, dis, djs, fillvalue=self.Zp(0)):
   1.140 -                value = v 
   1.141 +                value = v
   1.142                  enc_shares = len(self.runtime.players) * [1]
   1.143                  enc_shares[inx - 1] = (enc_shares[inx - 1] * di) % n_square_i
   1.144                  enc_shares[jnx - 1] = (enc_shares[jnx - 1] * dj) % n_square_j
   1.145 @@ -261,6 +301,7 @@
   1.146                  for jnx in xrange(0, len(self.runtime.players)):
   1.147                      deferreds.append(self._mul(inx + 1,
   1.148                                                 jnx + 1,
   1.149 +                                               len(a_values),
   1.150                                                 a_values,
   1.151                                                 b_enc_shares[jnx]))
   1.152                          
     2.1 --- a/viff/test/test_bedoza_triple.py	Thu Jul 29 11:31:24 2010 +0200
     2.2 +++ b/viff/test/test_bedoza_triple.py	Thu Jul 29 15:56:58 2010 +0200
     2.3 @@ -466,9 +466,11 @@
     2.4          cs = []
     2.5          for ai in ais:
     2.6              cs.append(triple_generator.paillier.encrypt(b2.value, 2))      
     2.7 +
     2.8 +        n = len(ais)
     2.9          
    2.10          if runtime.id == 1:
    2.11 -            r1 = triple_generator._mul(1, 2, ais, cs)
    2.12 +            r1 = triple_generator._mul(1, 2, n, ais, cs)
    2.13              def check1(partialShares):
    2.14                  for partialShare in partialShares:
    2.15                      zi = triple_generator.paillier.decrypt(partialShare.enc_shares[0])
    2.16 @@ -479,7 +481,7 @@
    2.17              r1.addCallback(check1)
    2.18              return r1
    2.19          else:
    2.20 -            r1 = triple_generator._mul(1, 2)
    2.21 +            r1 = triple_generator._mul(1, 2, n)
    2.22              def check(partialShares):
    2.23                  deferreds = []
    2.24                  for partialShare in partialShares:
    2.25 @@ -515,8 +517,10 @@
    2.26          cs = []
    2.27          for ai in ais:
    2.28              cs.append(triple_generator.paillier.encrypt(b2.value, 2))
    2.29 +
    2.30 +        n = len(ais)
    2.31          
    2.32 -        r1 = triple_generator._mul(2, 2, ais, cs)
    2.33 +        r1 = triple_generator._mul(2, 2, n, ais, cs)
    2.34          def check(partialShareContents):
    2.35              for partialShareContent in partialShareContents:
    2.36                  if runtime.id == 2: