changeset 1527:7dd55d319d2b

BeDOZa: Added size constraints to _mul.
author Janus Dam Nielsen <janus.nielsen@alexandra.dk>
date Thu, 29 Jul 2010 15:56:58 +0200
parents a87fd09f8c38
children e7b2fe7eb753
files viff/bedoza/bedoza_triple.py viff/test/test_bedoza_triple.py
diffstat 2 files changed, 73 insertions(+), 28 deletions(-) [+]
line wrap: on
line diff
--- a/viff/bedoza/bedoza_triple.py	Thu Jul 29 11:31:24 2010 +0200
+++ b/viff/bedoza/bedoza_triple.py	Thu Jul 29 15:56:58 2010 +0200
@@ -158,7 +158,7 @@
         #     receive c from player i and set 
         #         m^i=Decrypt(c)
     
-    def _mul(self, inx, jnx, ais=None, cjs=None):
+    def _mul(self, inx, jnx, n, ais=None, cjs=None):
         """Multiply each of the field elements in *ais* with the
         corresponding encrypted elements in *cjs*.
         
@@ -167,6 +167,21 @@
         CKIND = 1
         DiKIND = 2
         DjKIND = 3
+
+        """The transmission_restraint_constant is the number of
+        encrypted shares we can safely transmit in one call to
+        sendData. The sendData method can only transmit up to
+        65536 bytes.
+        The constant has been imperically determined by running
+        TripleGenerator.generate_triples.
+        TODO: How can we allow a user of the runtime to adjust this
+        constraint at a higher level of abstraction?
+        """
+        transmission_restraint_constant = 425
+
+        number_of_packets = n / transmission_restraint_constant
+        if n % transmission_restraint_constant != 0:
+            number_of_packets += 1
         
         self.runtime.increment_pc()
 
@@ -176,61 +191,86 @@
         zis = []
         if self.runtime.id == inx:
             Nj_square = self.paillier.get_modulus_square(jnx)
-            cs = []
-            dis = []
-            for ai, cj in zip(ais, cjs):
+            all_cs = []
+            all_dis = []
+            for iny, (ai, cj) in enumerate(zip(ais, cjs)):
+                if iny % transmission_restraint_constant == 0:
+                    cs = []
+                    all_cs.append(cs)
+                    dis = []
+                    all_dis.append(dis)
                 u = rand.randint(0, self.u_bound)
                 Ej_u = self.paillier.encrypt(u, jnx)
                 cs.append( (pow(cj, ai.value, Nj_square) * Ej_u) % Nj_square )
                 zi = self.Zp(-u)
                 zis.append(zi)
                 dis.append(self.paillier.encrypt(zi.value, inx))
-            self.runtime.protocols[jnx].sendData(pc, CKIND, str(cs))
+                
+            for cs in all_cs:
+                self.runtime.protocols[jnx].sendData(pc, CKIND, str(cs))
 
-            for player_id in self.runtime.players:
-                self.runtime.protocols[player_id].sendData(pc, DiKIND, str(dis))
+            for dis in all_dis:
+                for player_id in self.runtime.players:
+                    self.runtime.protocols[player_id].sendData(pc, DiKIND, str(dis))
 
         if self.runtime.id == jnx:
-            cs = Deferred()
-            self.runtime._expect_data(inx, CKIND, cs)
-            def decrypt(cs, pc, zis):
+            all_cs = []
+            for _ in xrange(number_of_packets):
+                cs = Deferred()
+                self.runtime._expect_data(inx, CKIND, cs)
+                all_cs.append(cs)
+                
+            def decrypt(all_cs, pc, zis):
                 zjs = []
-                djs = []
-                for c in eval(cs):
+                cs = reduce(lambda x, y: x + eval(y), all_cs, [])
+                all_djs = []
+                for iny, c in enumerate(cs):
+                    if iny % transmission_restraint_constant == 0:
+                        djs = []
+                        all_djs.append(djs)
                     t = self.paillier.decrypt(c)
                     zj = self.Zp(t)
                     zjs.append(zj)
                     djs.append(self.paillier.encrypt(zj.value, jnx))
-                for player_id in self.runtime.players:
-                    self.runtime.protocols[player_id].sendData(pc, DjKIND, str(djs))
+                for djs in all_djs:
+                    for player_id in self.runtime.players:
+                        self.runtime.protocols[player_id].sendData(pc, DjKIND, str(djs))
                 if not zis == []:
                     return [x + y for x, y in zip(zis, zjs)]
                 else:
                     return zjs 
-            cs.addCallback(decrypt, pc, zis)
-            deferreds.append(cs)
+            all_cs_d = gatherResults(all_cs)
+            all_cs_d.addCallback(decrypt, pc, zis)
+            deferreds.append(all_cs_d)
         else:
             zis_deferred = Deferred()
             zis_deferred.callback(zis)
             deferreds.append(zis_deferred)
 
-        dis = Deferred()
-        self.runtime._expect_data(inx, DiKIND, dis)
-        djs = Deferred()
-        self.runtime._expect_data(jnx, DjKIND, djs)
+        all_dis = []
+        for _ in xrange(number_of_packets):
+            dis = Deferred()
+            self.runtime._expect_data(inx, DiKIND, dis)
+            all_dis.append(dis)
+        all_djs = []
+        for _ in xrange(number_of_packets):
+            djs = Deferred()
+            self.runtime._expect_data(jnx, DjKIND, djs)
+            all_djs.append(djs)
 
-        deferreds.append(dis)
-        deferreds.append(djs)
+        deferreds.append(gatherResults(all_dis))
+        deferreds.append(gatherResults(all_djs))
         r = gatherResults(deferreds)
         def wrap((values, dis, djs), inx, jnx):
-            dis = eval(dis)
-            djs = eval(djs)
+            dis = reduce(lambda x, y: x + eval(y), dis, [])
+            djs = reduce(lambda x, y: x + eval(y), djs, [])
             n_square_i = self.paillier.get_modulus_square(inx)
             n_square_j = self.paillier.get_modulus_square(jnx)
             N_squared_list = [self.paillier.get_modulus_square(player_id) for player_id in self.runtime.players]
             ps = []
+            
             for v, di, dj in itertools.izip_longest(values, dis, djs, fillvalue=self.Zp(0)):
-                value = v 
+                value = v
                 enc_shares = len(self.runtime.players) * [1]
                 enc_shares[inx - 1] = (enc_shares[inx - 1] * di) % n_square_i
                 enc_shares[jnx - 1] = (enc_shares[jnx - 1] * dj) % n_square_j
@@ -261,6 +301,7 @@
                 for jnx in xrange(0, len(self.runtime.players)):
                     deferreds.append(self._mul(inx + 1,
                                                jnx + 1,
+                                               len(a_values),
                                                a_values,
                                                b_enc_shares[jnx]))
                         
--- a/viff/test/test_bedoza_triple.py	Thu Jul 29 11:31:24 2010 +0200
+++ b/viff/test/test_bedoza_triple.py	Thu Jul 29 15:56:58 2010 +0200
@@ -466,9 +466,11 @@
         cs = []
         for ai in ais:
             cs.append(triple_generator.paillier.encrypt(b2.value, 2))      
+
+        n = len(ais)
         
         if runtime.id == 1:
-            r1 = triple_generator._mul(1, 2, ais, cs)
+            r1 = triple_generator._mul(1, 2, n, ais, cs)
             def check1(partialShares):
                 for partialShare in partialShares:
                     zi = triple_generator.paillier.decrypt(partialShare.enc_shares[0])
@@ -479,7 +481,7 @@
             r1.addCallback(check1)
             return r1
         else:
-            r1 = triple_generator._mul(1, 2)
+            r1 = triple_generator._mul(1, 2, n)
             def check(partialShares):
                 deferreds = []
                 for partialShare in partialShares:
@@ -515,8 +517,10 @@
         cs = []
         for ai in ais:
             cs.append(triple_generator.paillier.encrypt(b2.value, 2))
+
+        n = len(ais)
         
-        r1 = triple_generator._mul(2, 2, ais, cs)
+        r1 = triple_generator._mul(2, 2, n, ais, cs)
         def check(partialShareContents):
             for partialShareContent in partialShareContents:
                 if runtime.id == 2: