changeset 1505:60cbf0d030c7

BeDOZa: Changed mul and fullmul to batch style processing.
author Janus Dam Nielsen <janus.nielsen@alexandra.dk>
date Mon, 19 Jul 2010 15:38:13 +0200
parents 129b326c3ff1
children 5d340ba92fff
files viff/bedoza_triple.py viff/test/test_bedoza_triple.py
diffstat 2 files changed, 179 insertions(+), 105 deletions(-) [+]
line wrap: on
line diff
--- a/viff/bedoza_triple.py	Mon Jul 19 10:46:09 2010 +0200
+++ b/viff/bedoza_triple.py	Mon Jul 19 15:38:13 2010 +0200
@@ -19,6 +19,8 @@
     TODO: Explain more.
 """
 
+import itertools
+
 from twisted.internet.defer import Deferred, gatherResults, succeed
 
 from viff.runtime import Runtime, Share, ShareList, gather_shares
@@ -293,7 +295,7 @@
         #     receive c from player i and set 
         #         m^i=Decrypt(c)
     
-    def _mul(self, inx, jnx, ai=None, cj=None):
+    def _mul(self, inx, jnx, ais=None, cjs=None):
         CKIND = 1
         DiKIND = 2
         DjKIND = 3
@@ -301,74 +303,111 @@
         self.runtime.increment_pc()
 
         pc = tuple(self.runtime.program_counter)
-            
+
         deferreds = []
+        zis = []
         if self.runtime.id == inx:
-            u = rand.randint(0, self.u_bound)
-            Ej_u = self.paillier.encrypt(u, jnx)
             Nj_square = self.paillier.get_modulus_square(jnx)
-            c = (pow(cj, ai.value, Nj_square) * Ej_u) % Nj_square
-            self.runtime.protocols[jnx].sendData(pc, CKIND, str(c))
-            zi = self.Zp(-u)
-            di = self.paillier.encrypt(zi.value, inx)
+            cs = []
+            dis = []
+            for ai, cj in zip(ais, cjs):
+                u = rand.randint(0, self.u_bound)
+                Ej_u = self.paillier.encrypt(u, jnx)
+                cs.append( (pow(cj, ai.value, Nj_square) * Ej_u) % Nj_square )
+                zi = self.Zp(-u)
+                zis.append(zi)
+                dis.append(self.paillier.encrypt(zi.value, inx))
+            self.runtime.protocols[jnx].sendData(pc, CKIND, str(cs))
+
             for player_id in self.runtime.players:
-                self.runtime.protocols[player_id].sendData(pc, DiKIND, str(di))
-            zi_deferred = Deferred()
-            zi_deferred.callback(zi)
-            deferreds.append(zi_deferred)
+                self.runtime.protocols[player_id].sendData(pc, DiKIND, str(dis))
 
         if self.runtime.id == jnx:
-            c = Deferred()
-            self.runtime._expect_data(inx, CKIND, c)
-            def decrypt(c, pc):
-                t = self.paillier.decrypt(long(c))
-                zj = self.Zp(t)
-                dj = self.paillier.encrypt(zj.value, jnx)
+            cs = Deferred()
+            self.runtime._expect_data(inx, CKIND, cs)
+            def decrypt(cs, pc, zis):
+                zjs = []
+                djs = []
+                for c in eval(cs):
+                    t = self.paillier.decrypt(c)
+                    zj = self.Zp(t)
+                    zjs.append(zj)
+                    djs.append(self.paillier.encrypt(zj.value, jnx))
                 for player_id in self.runtime.players:
-                    self.runtime.protocols[player_id].sendData(pc, DjKIND, str(dj))
-                return zj 
-            c.addCallback(decrypt, pc)
-            deferreds.append(c)
+                    self.runtime.protocols[player_id].sendData(pc, DjKIND, str(djs))
+                if not zis == []:
+                    return [x + y for x, y in zip(zis, zjs)]
+                else:
+                    return zjs 
+            cs.addCallback(decrypt, pc, zis)
+            deferreds.append(cs)
+        else:
+            zis_deferred = Deferred()
+            zis_deferred.callback(zis)
+            deferreds.append(zis_deferred)
 
-        di = Deferred()
-        self.runtime._expect_data(inx, DiKIND, di)
-        dj = Deferred()
-        self.runtime._expect_data(jnx, DjKIND, dj)
+        dis = Deferred()
+        self.runtime._expect_data(inx, DiKIND, dis)
+        djs = Deferred()
+        self.runtime._expect_data(jnx, DjKIND, djs)
 
-        deferreds.append(di)
-        deferreds.append(dj)
+        deferreds.append(dis)
+        deferreds.append(djs)
         r = gatherResults(deferreds)
-        def wrap(ls, inx, jnx):
-            value = reduce(lambda x, y: x + y, [self.Zp(0)] + ls[0:-2])
+        def wrap((values, dis, djs), inx, jnx):
+            dis = eval(dis)
+            djs = eval(djs)
             n_square_i = self.paillier.get_modulus_square(inx)
             n_square_j = self.paillier.get_modulus_square(jnx)
-            enc_shares = len(self.runtime.players) * [1]
-            enc_shares[inx - 1] = (enc_shares[inx - 1] * long(ls[-2])) % n_square_i
-            enc_shares[jnx - 1] = (enc_shares[jnx - 1] * long(ls[-1])) % n_square_j
             N_squared_list = [self.paillier.get_modulus_square(player_id) for player_id in self.runtime.players]
-            return PartialShareContents(value, enc_shares, N_squared_list)
+            ps = []
+            for v, di, dj in itertools.izip_longest(values, dis, djs, fillvalue=self.Zp(0)):
+                value = v 
+                enc_shares = len(self.runtime.players) * [1]
+                enc_shares[inx - 1] = (enc_shares[inx - 1] * di) % n_square_i
+                enc_shares[jnx - 1] = (enc_shares[jnx - 1] * dj) % n_square_j
+                ps.append(PartialShareContents(value, enc_shares, N_squared_list))
+            return ps
         r.addCallback(wrap, inx, jnx)
         return r
 
     def _full_mul(self, a, b):
         self.runtime.increment_pc()
         
-        def do_full_mul((contents_a, contents_b)):
+        def do_full_mul(shares):
+            """Share content belonging to ai, bi are at:
+            shares[i], shares[len(shares) + i].
+            """
             deferreds = []
+            len_shares = len(shares)
+            a_values = [s.value for s in shares[0:len_shares/2]]
+            b_enc_shares = []
+            for inx in self.runtime.players:              
+                b_enc_shares.append([s.enc_shares[inx - 1] for s in shares[len_shares/2:]])
             for inx in xrange(0, len(self.runtime.players)):
                 for jnx in xrange(0, len(self.runtime.players)):
-                    deferreds.append(self._mul(inx + 1, jnx + 1, contents_a.value, contents_b.enc_shares[jnx]))
-            def compute_share(partialShares):
-                partialShareContents = reduce(lambda x, y: x + y, partialShares)
-                pid = self.runtime.id
-                share = partialShareContents.enc_shares[pid - 1]
-                share = self.paillier.decrypt(share)
-                share = self.Zp(share)
-                return PartialShare(self.runtime, partialShareContents.value, partialShareContents.enc_shares)
+                    deferreds.append(self._mul(inx + 1,
+                                               jnx + 1,
+                                               a_values,
+                                               b_enc_shares[jnx]))
+                        
+            def compute_shares(partialShareContents, len_shares):
+                num_players = len(self.runtime.players)
+                pcs = len(partialShareContents[0]) * [None]
+                for ps in partialShareContents:
+                    for inx in xrange(0, len(ps)):
+                        if pcs[inx] == None:
+                            pcs[inx] = ps[inx]
+                        else:
+                            pcs[inx] += ps[inx]
+                partialShares = [PartialShare(self.runtime,
+                                              p.value,
+                                              p.enc_shares) for p in pcs]
+                return partialShares
             d = gatherResults(deferreds)
-            d.addCallback(compute_share)
+            d.addCallback(compute_shares, len_shares)
             return d
-        s = gatherResults([a, b])
+        s = gatherResults(a + b)
         self.runtime.schedule_callback(s, do_full_mul)
         return s
 
--- a/viff/test/test_bedoza_triple.py	Mon Jul 19 10:46:09 2010 +0200
+++ b/viff/test/test_bedoza_triple.py	Mon Jul 19 15:38:13 2010 +0200
@@ -332,39 +332,44 @@
 
         Zp = GF(p)
 
-        a1 = Zp(6)
+        ais = [Zp(6), Zp(6), Zp(6), Zp(6)]
         b2 = Zp(7)
-        c2 = triple_generator.paillier.encrypt(b2.value, 2)
+        cs = []
+        for ai in ais:
+            cs.append(triple_generator.paillier.encrypt(b2.value, 2))      
         
         if runtime.id == 1:
-            r1 = triple_generator._mul(1, 2, a1, c2)
-            def check1(partialShare):
-                zi = triple_generator.paillier.decrypt(partialShare.enc_shares[0])
-                self.assertEquals(partialShare.value.value, zi)
-                pc = tuple(runtime.program_counter)
-                runtime.protocols[2].sendData(pc, TEXT, str(zi))
+            r1 = triple_generator._mul(1, 2, ais, cs)
+            def check1(partialShares):
+                for partialShare in partialShares:
+                    zi = triple_generator.paillier.decrypt(partialShare.enc_shares[0])
+                    self.assertEquals(partialShare.value.value, zi)
+                    pc = tuple(runtime.program_counter)
+                    runtime.protocols[2].sendData(pc, TEXT, str(zi))
                 return True
             r1.addCallback(check1)
             return r1
         else:
             r1 = triple_generator._mul(1, 2)
-            def check(partialShare):
-                if runtime.id == 2:
-                    zj = triple_generator.paillier.decrypt(partialShare.enc_shares[1])
-                    self.assertEquals(partialShare.value.value, zj)
-                    def check_additivity(zi, zj):
-                        self.assertEquals((Zp(long(zi)) + zj).value, 8)
-                        return None
-                    d = Deferred()
-                    d.addCallback(check_additivity, partialShare.value)
-                    runtime._expect_data(1, TEXT, d)
-                    return d
-                else:
-                    self.assertEquals(partialShare.value, 0)
-                    self.assertNotEquals(partialShare.enc_shares[0], 0)
-                    self.assertNotEquals(partialShare.enc_shares[1], 0)
-                    self.assertEquals(partialShare.enc_shares[2], 1)
-                return True
+            def check(partialShares):
+                deferreds = []
+                for partialShare in partialShares:
+                    if runtime.id == 2:
+                        zj = triple_generator.paillier.decrypt(partialShare.enc_shares[1])
+                        self.assertEquals(partialShare.value.value, zj)
+                        def check_additivity(zi, zj):
+                            self.assertEquals((Zp(long(zi)) + zj).value, 8)
+                            return None
+                        d = Deferred()
+                        d.addCallback(check_additivity, partialShare.value)
+                        runtime._expect_data(1, TEXT, d)
+                        deferreds.append(d)
+                    else:
+                        self.assertEquals(partialShare.value, 0)
+                        self.assertNotEquals(partialShare.enc_shares[0], 0)
+                        self.assertNotEquals(partialShare.enc_shares[1], 0)
+                        self.assertEquals(partialShare.enc_shares[2], 1)
+                return gatherResults(deferreds)
             r1.addCallback(check)
             return r1
 
@@ -376,16 +381,19 @@
 
         Zp = GF(p)
 
-        a1 = Zp(6)
+        ais = [Zp(6), Zp(6), Zp(6), Zp(6)]
         b2 = Zp(7)
-        c2 = triple_generator.paillier.encrypt(b2.value, 2)
+        cs = []
+        for ai in ais:
+            cs.append(triple_generator.paillier.encrypt(b2.value, 2))
         
-        r1 = triple_generator._mul(2, 2, a1, c2)
+        r1 = triple_generator._mul(2, 2, ais, cs)
         def check(partialShareContents):
-            if runtime.id == 2:
-                zi_enc = Zp(triple_generator.paillier.decrypt(partialShareContents.enc_shares[1]))
-                self.assertEquals(zi_enc, partialShareContents.value)
-                self.assertEquals(partialShareContents.value, 8)
+            for partialShareContent in partialShareContents:
+                if runtime.id == 2:
+                    zi_enc = Zp(triple_generator.paillier.decrypt(partialShareContent.enc_shares[1]))
+                    self.assertEquals(zi_enc, partialShareContent.value)
+                    self.assertEquals(partialShareContent.value, 8)
             return True
             
         r1.addCallback(check)
@@ -406,20 +414,36 @@
 
         paillier = triple_generator.paillier
         
-        share_a = partial_share(random, runtime, GF(p), 6, paillier=paillier)
-        share_b = partial_share(random, runtime, GF(p), 7, paillier=paillier)
+        share_as = []
+        share_bs = []      
+        share_as.append(partial_share(random, runtime, GF(p), 6, paillier=paillier))
+        share_bs.append(partial_share(random, runtime, GF(p), 7, paillier=paillier))
+        share_as.append(partial_share(random, runtime, GF(p), 5, paillier=paillier))
+        share_bs.append(partial_share(random, runtime, GF(p), 4, paillier=paillier))
+        share_as.append(partial_share(random, runtime, GF(p), 2, paillier=paillier))
+        share_bs.append(partial_share(random, runtime, GF(p), 3, paillier=paillier))
 
-        share_z = triple_generator._full_mul(share_a, share_b)
-        def check(share):
+
+        share_zs = triple_generator._full_mul(share_as, share_bs)
+        def check(shares):
             def test_sum(ls):
-                vals = ls[0]
-                self.assertEquals(8, Zp(sum(vals)))
-            value = _convolute(runtime, share.value.value)
-            runtime.schedule_callback(gatherResults([value]), test_sum)
-            return True
+                self.assertEquals(8, Zp(sum(ls[0])))
+                self.assertEquals(3, Zp(sum(ls[1])))
+                self.assertEquals(6, Zp(sum(ls[2])))
+            values = []
+            for share in shares:
+                value = _convolute(runtime, share.value.value)
+                values.append(value)
+            d = gatherResults(values)
+            runtime.schedule_callback(d, test_sum)
+            return d
             
-        share_z.addCallback(check)
-        return share_z
+        def indirection(shares):
+            d = gatherResults(shares)
+            d.addCallback(check)
+            return d
+        share_zs.addCallback(indirection)
+        return share_zs
 
     @protocol
     def test_fullmul_encrypted_values_are_the_same_as_the_share(self, runtime):
@@ -431,27 +455,38 @@
         triple_generator = TripleGenerator(runtime, p, random)
 
         paillier = triple_generator.paillier
-        
-        share_a = partial_share(random, runtime, GF(p), 6, paillier=paillier)
-        share_b = partial_share(random, runtime, GF(p), 7, paillier=paillier)
 
-        share_z = triple_generator._full_mul(share_a, share_b)
-        def check(share):
-            def test_enc(enc_shares, value):
-                all_the_same, zi_enc = reduce(lambda x, y: (x[0] and x[1] == y, y), enc_shares, (True, enc_shares[0]))
-                zi_enc = triple_generator.paillier.decrypt(zi_enc)
-                self.assertEquals(value, Zp(zi_enc))
-                return True
+        share_as = []
+        share_bs = []      
+        share_as.append(partial_share(random, runtime, GF(p), 6, paillier=paillier))
+        share_bs.append(partial_share(random, runtime, GF(p), 7, paillier=paillier))
+        share_as.append(partial_share(random, runtime, GF(p), 5, paillier=paillier))
+        share_bs.append(partial_share(random, runtime, GF(p), 4, paillier=paillier))
+        share_as.append(partial_share(random, runtime, GF(p), 2, paillier=paillier))
+        share_bs.append(partial_share(random, runtime, GF(p), 3, paillier=paillier))
+
+        share_zs = triple_generator._full_mul(share_as, share_bs)
+        def check(shares):
             all_enc_shares = []
-            for inx, enc_share in enumerate(share.enc_shares):
-                d = _convolute(runtime, enc_share)
-                if runtime.id == inx + 1:
-                    d.addCallback(test_enc, share.value)
+            for share in shares:
+                def test_enc(enc_shares, value):
+                    all_the_same, zi_enc = reduce(lambda x, y: (x[0] and x[1] == y, y), enc_shares, (True, enc_shares[0]))
+                    zi_enc = triple_generator.paillier.decrypt(zi_enc)
+                    self.assertEquals(value, Zp(zi_enc))
+                    return True
+                for inx, enc_share in enumerate(share.enc_shares):
+                    d = _convolute(runtime, enc_share)
+                    if runtime.id == inx + 1:
+                        d.addCallback(test_enc, share.value)
                 all_enc_shares.append(d)
             return gatherResults(all_enc_shares)
-            
-        share_z.addCallback(check)
-        return share_z
+        
+        def indirection(shares):
+            d = gatherResults(shares)
+            d.addCallback(check)
+            return d
+        share_zs.addCallback(indirection)
+        return share_zs
 
 
 missing_package = None