viff

changeset 1505:60cbf0d030c7

BeDOZa: Changed mul and fullmul to batch style processing.
author Janus Dam Nielsen <janus.nielsen@alexandra.dk>
date Mon, 19 Jul 2010 15:38:13 +0200
parents 129b326c3ff1
children 5d340ba92fff
files viff/bedoza_triple.py viff/test/test_bedoza_triple.py
diffstat 2 files changed, 179 insertions(+), 105 deletions(-) [+]
line diff
     1.1 --- a/viff/bedoza_triple.py	Mon Jul 19 10:46:09 2010 +0200
     1.2 +++ b/viff/bedoza_triple.py	Mon Jul 19 15:38:13 2010 +0200
     1.3 @@ -19,6 +19,8 @@
     1.4      TODO: Explain more.
     1.5  """
     1.6  
     1.7 +import itertools
     1.8 +
     1.9  from twisted.internet.defer import Deferred, gatherResults, succeed
    1.10  
    1.11  from viff.runtime import Runtime, Share, ShareList, gather_shares
    1.12 @@ -293,7 +295,7 @@
    1.13          #     receive c from player i and set 
    1.14          #         m^i=Decrypt(c)
    1.15      
    1.16 -    def _mul(self, inx, jnx, ai=None, cj=None):
    1.17 +    def _mul(self, inx, jnx, ais=None, cjs=None):
    1.18          CKIND = 1
    1.19          DiKIND = 2
    1.20          DjKIND = 3
    1.21 @@ -301,74 +303,111 @@
    1.22          self.runtime.increment_pc()
    1.23  
    1.24          pc = tuple(self.runtime.program_counter)
    1.25 -            
    1.26 +
    1.27          deferreds = []
    1.28 +        zis = []
    1.29          if self.runtime.id == inx:
    1.30 -            u = rand.randint(0, self.u_bound)
    1.31 -            Ej_u = self.paillier.encrypt(u, jnx)
    1.32              Nj_square = self.paillier.get_modulus_square(jnx)
    1.33 -            c = (pow(cj, ai.value, Nj_square) * Ej_u) % Nj_square
    1.34 -            self.runtime.protocols[jnx].sendData(pc, CKIND, str(c))
    1.35 -            zi = self.Zp(-u)
    1.36 -            di = self.paillier.encrypt(zi.value, inx)
    1.37 +            cs = []
    1.38 +            dis = []
    1.39 +            for ai, cj in zip(ais, cjs):
    1.40 +                u = rand.randint(0, self.u_bound)
    1.41 +                Ej_u = self.paillier.encrypt(u, jnx)
    1.42 +                cs.append( (pow(cj, ai.value, Nj_square) * Ej_u) % Nj_square )
    1.43 +                zi = self.Zp(-u)
    1.44 +                zis.append(zi)
    1.45 +                dis.append(self.paillier.encrypt(zi.value, inx))
    1.46 +            self.runtime.protocols[jnx].sendData(pc, CKIND, str(cs))
    1.47 +
    1.48              for player_id in self.runtime.players:
    1.49 -                self.runtime.protocols[player_id].sendData(pc, DiKIND, str(di))
    1.50 -            zi_deferred = Deferred()
    1.51 -            zi_deferred.callback(zi)
    1.52 -            deferreds.append(zi_deferred)
    1.53 +                self.runtime.protocols[player_id].sendData(pc, DiKIND, str(dis))
    1.54  
    1.55          if self.runtime.id == jnx:
    1.56 -            c = Deferred()
    1.57 -            self.runtime._expect_data(inx, CKIND, c)
    1.58 -            def decrypt(c, pc):
    1.59 -                t = self.paillier.decrypt(long(c))
    1.60 -                zj = self.Zp(t)
    1.61 -                dj = self.paillier.encrypt(zj.value, jnx)
    1.62 +            cs = Deferred()
    1.63 +            self.runtime._expect_data(inx, CKIND, cs)
    1.64 +            def decrypt(cs, pc, zis):
    1.65 +                zjs = []
    1.66 +                djs = []
    1.67 +                for c in eval(cs):
    1.68 +                    t = self.paillier.decrypt(c)
    1.69 +                    zj = self.Zp(t)
    1.70 +                    zjs.append(zj)
    1.71 +                    djs.append(self.paillier.encrypt(zj.value, jnx))
    1.72                  for player_id in self.runtime.players:
    1.73 -                    self.runtime.protocols[player_id].sendData(pc, DjKIND, str(dj))
    1.74 -                return zj 
    1.75 -            c.addCallback(decrypt, pc)
    1.76 -            deferreds.append(c)
    1.77 +                    self.runtime.protocols[player_id].sendData(pc, DjKIND, str(djs))
    1.78 +                if not zis == []:
    1.79 +                    return [x + y for x, y in zip(zis, zjs)]
    1.80 +                else:
    1.81 +                    return zjs 
    1.82 +            cs.addCallback(decrypt, pc, zis)
    1.83 +            deferreds.append(cs)
    1.84 +        else:
    1.85 +            zis_deferred = Deferred()
    1.86 +            zis_deferred.callback(zis)
    1.87 +            deferreds.append(zis_deferred)
    1.88  
    1.89 -        di = Deferred()
    1.90 -        self.runtime._expect_data(inx, DiKIND, di)
    1.91 -        dj = Deferred()
    1.92 -        self.runtime._expect_data(jnx, DjKIND, dj)
    1.93 +        dis = Deferred()
    1.94 +        self.runtime._expect_data(inx, DiKIND, dis)
    1.95 +        djs = Deferred()
    1.96 +        self.runtime._expect_data(jnx, DjKIND, djs)
    1.97  
    1.98 -        deferreds.append(di)
    1.99 -        deferreds.append(dj)
   1.100 +        deferreds.append(dis)
   1.101 +        deferreds.append(djs)
   1.102          r = gatherResults(deferreds)
   1.103 -        def wrap(ls, inx, jnx):
   1.104 -            value = reduce(lambda x, y: x + y, [self.Zp(0)] + ls[0:-2])
   1.105 +        def wrap((values, dis, djs), inx, jnx):
   1.106 +            dis = eval(dis)
   1.107 +            djs = eval(djs)
   1.108              n_square_i = self.paillier.get_modulus_square(inx)
   1.109              n_square_j = self.paillier.get_modulus_square(jnx)
   1.110 -            enc_shares = len(self.runtime.players) * [1]
   1.111 -            enc_shares[inx - 1] = (enc_shares[inx - 1] * long(ls[-2])) % n_square_i
   1.112 -            enc_shares[jnx - 1] = (enc_shares[jnx - 1] * long(ls[-1])) % n_square_j
   1.113              N_squared_list = [self.paillier.get_modulus_square(player_id) for player_id in self.runtime.players]
   1.114 -            return PartialShareContents(value, enc_shares, N_squared_list)
   1.115 +            ps = []
   1.116 +            for v, di, dj in itertools.izip_longest(values, dis, djs, fillvalue=self.Zp(0)):
   1.117 +                value = v 
   1.118 +                enc_shares = len(self.runtime.players) * [1]
   1.119 +                enc_shares[inx - 1] = (enc_shares[inx - 1] * di) % n_square_i
   1.120 +                enc_shares[jnx - 1] = (enc_shares[jnx - 1] * dj) % n_square_j
   1.121 +                ps.append(PartialShareContents(value, enc_shares, N_squared_list))
   1.122 +            return ps
   1.123          r.addCallback(wrap, inx, jnx)
   1.124          return r
   1.125  
   1.126      def _full_mul(self, a, b):
   1.127          self.runtime.increment_pc()
   1.128          
   1.129 -        def do_full_mul((contents_a, contents_b)):
   1.130 +        def do_full_mul(shares):
   1.131 +            """Share content belonging to ai, bi are at:
   1.132 +            shares[i], shares[len(shares) + i].
   1.133 +            """
   1.134              deferreds = []
   1.135 +            len_shares = len(shares)
   1.136 +            a_values = [s.value for s in shares[0:len_shares/2]]
   1.137 +            b_enc_shares = []
   1.138 +            for inx in self.runtime.players:              
   1.139 +                b_enc_shares.append([s.enc_shares[inx - 1] for s in shares[len_shares/2:]])
   1.140              for inx in xrange(0, len(self.runtime.players)):
   1.141                  for jnx in xrange(0, len(self.runtime.players)):
   1.142 -                    deferreds.append(self._mul(inx + 1, jnx + 1, contents_a.value, contents_b.enc_shares[jnx]))
   1.143 -            def compute_share(partialShares):
   1.144 -                partialShareContents = reduce(lambda x, y: x + y, partialShares)
   1.145 -                pid = self.runtime.id
   1.146 -                share = partialShareContents.enc_shares[pid - 1]
   1.147 -                share = self.paillier.decrypt(share)
   1.148 -                share = self.Zp(share)
   1.149 -                return PartialShare(self.runtime, partialShareContents.value, partialShareContents.enc_shares)
   1.150 +                    deferreds.append(self._mul(inx + 1,
   1.151 +                                               jnx + 1,
   1.152 +                                               a_values,
   1.153 +                                               b_enc_shares[jnx]))
   1.154 +                        
   1.155 +            def compute_shares(partialShareContents, len_shares):
   1.156 +                num_players = len(self.runtime.players)
   1.157 +                pcs = len(partialShareContents[0]) * [None]
   1.158 +                for ps in partialShareContents:
   1.159 +                    for inx in xrange(0, len(ps)):
   1.160 +                        if pcs[inx] == None:
   1.161 +                            pcs[inx] = ps[inx]
   1.162 +                        else:
   1.163 +                            pcs[inx] += ps[inx]
   1.164 +                partialShares = [PartialShare(self.runtime,
   1.165 +                                              p.value,
   1.166 +                                              p.enc_shares) for p in pcs]
   1.167 +                return partialShares
   1.168              d = gatherResults(deferreds)
   1.169 -            d.addCallback(compute_share)
   1.170 +            d.addCallback(compute_shares, len_shares)
   1.171              return d
   1.172 -        s = gatherResults([a, b])
   1.173 +        s = gatherResults(a + b)
   1.174          self.runtime.schedule_callback(s, do_full_mul)
   1.175          return s
   1.176  
     2.1 --- a/viff/test/test_bedoza_triple.py	Mon Jul 19 10:46:09 2010 +0200
     2.2 +++ b/viff/test/test_bedoza_triple.py	Mon Jul 19 15:38:13 2010 +0200
     2.3 @@ -332,39 +332,44 @@
     2.4  
     2.5          Zp = GF(p)
     2.6  
     2.7 -        a1 = Zp(6)
     2.8 +        ais = [Zp(6), Zp(6), Zp(6), Zp(6)]
     2.9          b2 = Zp(7)
    2.10 -        c2 = triple_generator.paillier.encrypt(b2.value, 2)
    2.11 +        cs = []
    2.12 +        for ai in ais:
    2.13 +            cs.append(triple_generator.paillier.encrypt(b2.value, 2))      
    2.14          
    2.15          if runtime.id == 1:
    2.16 -            r1 = triple_generator._mul(1, 2, a1, c2)
    2.17 -            def check1(partialShare):
    2.18 -                zi = triple_generator.paillier.decrypt(partialShare.enc_shares[0])
    2.19 -                self.assertEquals(partialShare.value.value, zi)
    2.20 -                pc = tuple(runtime.program_counter)
    2.21 -                runtime.protocols[2].sendData(pc, TEXT, str(zi))
    2.22 +            r1 = triple_generator._mul(1, 2, ais, cs)
    2.23 +            def check1(partialShares):
    2.24 +                for partialShare in partialShares:
    2.25 +                    zi = triple_generator.paillier.decrypt(partialShare.enc_shares[0])
    2.26 +                    self.assertEquals(partialShare.value.value, zi)
    2.27 +                    pc = tuple(runtime.program_counter)
    2.28 +                    runtime.protocols[2].sendData(pc, TEXT, str(zi))
    2.29                  return True
    2.30              r1.addCallback(check1)
    2.31              return r1
    2.32          else:
    2.33              r1 = triple_generator._mul(1, 2)
    2.34 -            def check(partialShare):
    2.35 -                if runtime.id == 2:
    2.36 -                    zj = triple_generator.paillier.decrypt(partialShare.enc_shares[1])
    2.37 -                    self.assertEquals(partialShare.value.value, zj)
    2.38 -                    def check_additivity(zi, zj):
    2.39 -                        self.assertEquals((Zp(long(zi)) + zj).value, 8)
    2.40 -                        return None
    2.41 -                    d = Deferred()
    2.42 -                    d.addCallback(check_additivity, partialShare.value)
    2.43 -                    runtime._expect_data(1, TEXT, d)
    2.44 -                    return d
    2.45 -                else:
    2.46 -                    self.assertEquals(partialShare.value, 0)
    2.47 -                    self.assertNotEquals(partialShare.enc_shares[0], 0)
    2.48 -                    self.assertNotEquals(partialShare.enc_shares[1], 0)
    2.49 -                    self.assertEquals(partialShare.enc_shares[2], 1)
    2.50 -                return True
    2.51 +            def check(partialShares):
    2.52 +                deferreds = []
    2.53 +                for partialShare in partialShares:
    2.54 +                    if runtime.id == 2:
    2.55 +                        zj = triple_generator.paillier.decrypt(partialShare.enc_shares[1])
    2.56 +                        self.assertEquals(partialShare.value.value, zj)
    2.57 +                        def check_additivity(zi, zj):
    2.58 +                            self.assertEquals((Zp(long(zi)) + zj).value, 8)
    2.59 +                            return None
    2.60 +                        d = Deferred()
    2.61 +                        d.addCallback(check_additivity, partialShare.value)
    2.62 +                        runtime._expect_data(1, TEXT, d)
    2.63 +                        deferreds.append(d)
    2.64 +                    else:
    2.65 +                        self.assertEquals(partialShare.value, 0)
    2.66 +                        self.assertNotEquals(partialShare.enc_shares[0], 0)
    2.67 +                        self.assertNotEquals(partialShare.enc_shares[1], 0)
    2.68 +                        self.assertEquals(partialShare.enc_shares[2], 1)
    2.69 +                return gatherResults(deferreds)
    2.70              r1.addCallback(check)
    2.71              return r1
    2.72  
    2.73 @@ -376,16 +381,19 @@
    2.74  
    2.75          Zp = GF(p)
    2.76  
    2.77 -        a1 = Zp(6)
    2.78 +        ais = [Zp(6), Zp(6), Zp(6), Zp(6)]
    2.79          b2 = Zp(7)
    2.80 -        c2 = triple_generator.paillier.encrypt(b2.value, 2)
    2.81 +        cs = []
    2.82 +        for ai in ais:
    2.83 +            cs.append(triple_generator.paillier.encrypt(b2.value, 2))
    2.84          
    2.85 -        r1 = triple_generator._mul(2, 2, a1, c2)
    2.86 +        r1 = triple_generator._mul(2, 2, ais, cs)
    2.87          def check(partialShareContents):
    2.88 -            if runtime.id == 2:
    2.89 -                zi_enc = Zp(triple_generator.paillier.decrypt(partialShareContents.enc_shares[1]))
    2.90 -                self.assertEquals(zi_enc, partialShareContents.value)
    2.91 -                self.assertEquals(partialShareContents.value, 8)
    2.92 +            for partialShareContent in partialShareContents:
    2.93 +                if runtime.id == 2:
    2.94 +                    zi_enc = Zp(triple_generator.paillier.decrypt(partialShareContent.enc_shares[1]))
    2.95 +                    self.assertEquals(zi_enc, partialShareContent.value)
    2.96 +                    self.assertEquals(partialShareContent.value, 8)
    2.97              return True
    2.98              
    2.99          r1.addCallback(check)
   2.100 @@ -406,20 +414,36 @@
   2.101  
   2.102          paillier = triple_generator.paillier
   2.103          
   2.104 -        share_a = partial_share(random, runtime, GF(p), 6, paillier=paillier)
   2.105 -        share_b = partial_share(random, runtime, GF(p), 7, paillier=paillier)
   2.106 +        share_as = []
   2.107 +        share_bs = []      
   2.108 +        share_as.append(partial_share(random, runtime, GF(p), 6, paillier=paillier))
   2.109 +        share_bs.append(partial_share(random, runtime, GF(p), 7, paillier=paillier))
   2.110 +        share_as.append(partial_share(random, runtime, GF(p), 5, paillier=paillier))
   2.111 +        share_bs.append(partial_share(random, runtime, GF(p), 4, paillier=paillier))
   2.112 +        share_as.append(partial_share(random, runtime, GF(p), 2, paillier=paillier))
   2.113 +        share_bs.append(partial_share(random, runtime, GF(p), 3, paillier=paillier))
   2.114  
   2.115 -        share_z = triple_generator._full_mul(share_a, share_b)
   2.116 -        def check(share):
   2.117 +
   2.118 +        share_zs = triple_generator._full_mul(share_as, share_bs)
   2.119 +        def check(shares):
   2.120              def test_sum(ls):
   2.121 -                vals = ls[0]
   2.122 -                self.assertEquals(8, Zp(sum(vals)))
   2.123 -            value = _convolute(runtime, share.value.value)
   2.124 -            runtime.schedule_callback(gatherResults([value]), test_sum)
   2.125 -            return True
   2.126 +                self.assertEquals(8, Zp(sum(ls[0])))
   2.127 +                self.assertEquals(3, Zp(sum(ls[1])))
   2.128 +                self.assertEquals(6, Zp(sum(ls[2])))
   2.129 +            values = []
   2.130 +            for share in shares:
   2.131 +                value = _convolute(runtime, share.value.value)
   2.132 +                values.append(value)
   2.133 +            d = gatherResults(values)
   2.134 +            runtime.schedule_callback(d, test_sum)
   2.135 +            return d
   2.136              
   2.137 -        share_z.addCallback(check)
   2.138 -        return share_z
   2.139 +        def indirection(shares):
   2.140 +            d = gatherResults(shares)
   2.141 +            d.addCallback(check)
   2.142 +            return d
   2.143 +        share_zs.addCallback(indirection)
   2.144 +        return share_zs
   2.145  
   2.146      @protocol
   2.147      def test_fullmul_encrypted_values_are_the_same_as_the_share(self, runtime):
   2.148 @@ -431,27 +455,38 @@
   2.149          triple_generator = TripleGenerator(runtime, p, random)
   2.150  
   2.151          paillier = triple_generator.paillier
   2.152 -        
   2.153 -        share_a = partial_share(random, runtime, GF(p), 6, paillier=paillier)
   2.154 -        share_b = partial_share(random, runtime, GF(p), 7, paillier=paillier)
   2.155  
   2.156 -        share_z = triple_generator._full_mul(share_a, share_b)
   2.157 -        def check(share):
   2.158 -            def test_enc(enc_shares, value):
   2.159 -                all_the_same, zi_enc = reduce(lambda x, y: (x[0] and x[1] == y, y), enc_shares, (True, enc_shares[0]))
   2.160 -                zi_enc = triple_generator.paillier.decrypt(zi_enc)
   2.161 -                self.assertEquals(value, Zp(zi_enc))
   2.162 -                return True
   2.163 +        share_as = []
   2.164 +        share_bs = []      
   2.165 +        share_as.append(partial_share(random, runtime, GF(p), 6, paillier=paillier))
   2.166 +        share_bs.append(partial_share(random, runtime, GF(p), 7, paillier=paillier))
   2.167 +        share_as.append(partial_share(random, runtime, GF(p), 5, paillier=paillier))
   2.168 +        share_bs.append(partial_share(random, runtime, GF(p), 4, paillier=paillier))
   2.169 +        share_as.append(partial_share(random, runtime, GF(p), 2, paillier=paillier))
   2.170 +        share_bs.append(partial_share(random, runtime, GF(p), 3, paillier=paillier))
   2.171 +
   2.172 +        share_zs = triple_generator._full_mul(share_as, share_bs)
   2.173 +        def check(shares):
   2.174              all_enc_shares = []
   2.175 -            for inx, enc_share in enumerate(share.enc_shares):
   2.176 -                d = _convolute(runtime, enc_share)
   2.177 -                if runtime.id == inx + 1:
   2.178 -                    d.addCallback(test_enc, share.value)
   2.179 +            for share in shares:
   2.180 +                def test_enc(enc_shares, value):
   2.181 +                    all_the_same, zi_enc = reduce(lambda x, y: (x[0] and x[1] == y, y), enc_shares, (True, enc_shares[0]))
   2.182 +                    zi_enc = triple_generator.paillier.decrypt(zi_enc)
   2.183 +                    self.assertEquals(value, Zp(zi_enc))
   2.184 +                    return True
   2.185 +                for inx, enc_share in enumerate(share.enc_shares):
   2.186 +                    d = _convolute(runtime, enc_share)
   2.187 +                    if runtime.id == inx + 1:
   2.188 +                        d.addCallback(test_enc, share.value)
   2.189                  all_enc_shares.append(d)
   2.190              return gatherResults(all_enc_shares)
   2.191 -            
   2.192 -        share_z.addCallback(check)
   2.193 -        return share_z
   2.194 +        
   2.195 +        def indirection(shares):
   2.196 +            d = gatherResults(shares)
   2.197 +            d.addCallback(check)
   2.198 +            return d
   2.199 +        share_zs.addCallback(indirection)
   2.200 +        return share_zs
   2.201  
   2.202  
   2.203  missing_package = None