changeset 1547:dae353266aa6

Merged.
author Janus Dam Nielsen <janus.nielsen@alexandra.dk>
date Tue, 21 Sep 2010 11:51:35 +0200
parents 7db2fafaab44 da966a4620f4
children 6191f86d814b
files viff/test/test_bedoza_runtime.py viff/test/test_bedoza_triple.py
diffstat 5 files changed, 136 insertions(+), 120 deletions(-) [+]
line wrap: on
line diff
--- a/viff/bedoza/bedoza_triple.py	Tue Sep 21 10:05:02 2010 +0200
+++ b/viff/bedoza/bedoza_triple.py	Tue Sep 21 11:51:35 2010 +0200
@@ -35,6 +35,8 @@
 from viff.bedoza.add_macs import add_macs
 from viff.bedoza.modified_paillier import ModifiedPaillier
 from viff.bedoza.util import fast_pow
+from viff.bedoza.util import _convolute
+from viff.bedoza.share import Share
 
 from viff.triple import Triple
 
@@ -52,7 +54,7 @@
 
 class TripleGenerator(object):
 
-    def __init__(self, runtime, p, random):
+    def __init__(self, runtime, security_parameter, p, random):
         assert p > 1
         self.random = random
         # TODO: Generate Paillier cipher with N_i sufficiently larger than p
@@ -60,7 +62,8 @@
         self.p = p
         self.Zp = GF(p)
         self.k = self._bit_length_of(p)
-        self.u_bound = 2**(4 * self.k)
+        self.security_parameter = security_parameter
+        self.u_bound = 2**(self.security_parameter + 4 * self.k)
 
         paillier_random = Random(self.random.getrandbits(128))
         alpha_random = Random(self.random.getrandbits(128))
@@ -85,6 +88,10 @@
             bit_length += 1
         return bit_length
 
+    def generate_triples(self):
+        """Generates and returns a set with *self.security_parameter* triples."""
+        return generate_triples(self, self.security_parameter)
+        
     def generate_triples(self, n):
         """Generates and returns a set of n triples.
         
@@ -172,12 +179,10 @@
         """Multiply each of the field elements in *ais* with the
         corresponding encrypted elements in *cjs*.
         
-        Returns a deferred which will yield a list of PartialShareContents.
+        Returns a deferred which will yield a list of field elements.
         """
         CKIND = 1
-        DiKIND = 2
-        DjKIND = 3
-
+ 
         """The transmission_restraint_constant is the number of
         encrypted shares we can safely transmit in one call to
         sendData. The sendData method can only transmit up to
@@ -197,33 +202,24 @@
 
         pc = tuple(self.runtime.program_counter)
 
-        deferreds = []
+        deferred = []
         zis = []
         if self.runtime.id == inx:
             Nj_square = self.paillier.get_modulus_square(jnx)
             all_cs = []
-            all_dis = []
             for iny, (ai, cj) in enumerate(zip(ais, cjs)):
                 if iny % transmission_restraint_constant == 0:
                     cs = []
                     all_cs.append(cs)
-                    dis = []
-                    all_dis.append(dis)
                 u = rand.randint(0, self.u_bound)
                 Ej_u = self.paillier.encrypt(u, jnx)
                 cs.append( (fast_pow(cj, ai.value, Nj_square) * Ej_u) % Nj_square )
                 zi = self.Zp(-u)
                 zis.append(zi)
-                dis.append(self.paillier.encrypt(zi.value, inx))
                 
             for cs in all_cs:
                 self.runtime.protocols[jnx].sendData(pc, CKIND, str(cs))
 
-            for dis in all_dis:
-                for player_id in self.runtime.players:
-                    self.runtime.protocols[player_id].sendData(pc, DiKIND,
-                                                               str(dis))
-
         if self.runtime.id == jnx:
             all_cs = []
             for _ in xrange(number_of_packets):
@@ -234,66 +230,24 @@
             def decrypt(all_cs, pc, zis):
                 zjs = []
                 cs = reduce(lambda x, y: x + eval(y), all_cs, [])
-                all_djs = []
                 for iny, c in enumerate(cs):
-                    if iny % transmission_restraint_constant == 0:
-                        djs = []
-                        all_djs.append(djs)
                     t = self.paillier.decrypt(c)
                     zj = self.Zp(t)
                     zjs.append(zj)
-                    djs.append(self.paillier.encrypt(zj.value, jnx))
-                for djs in all_djs:
-                    for player_id in self.runtime.players:
-                        self.runtime.protocols[player_id].sendData(pc, DjKIND,
-                                                                   str(djs))
                 if not zis == []:
                     return [x + y for x, y in zip(zis, zjs)]
                 else:
                     return zjs 
             all_cs_d = gatherResults(all_cs)
             all_cs_d.addCallback(decrypt, pc, zis)
-            deferreds.append(all_cs_d)
+            deferred = all_cs_d
         else:
             zis_deferred = Deferred()
             zis_deferred.callback(zis)
-            deferreds.append(zis_deferred)
+            deferred = zis_deferred
 
-        all_dis = []
-        for _ in xrange(number_of_packets):
-            dis = Deferred()
-            self.runtime._expect_data(inx, DiKIND, dis)
-            all_dis.append(dis)
-        all_djs = []
-        for _ in xrange(number_of_packets):
-            djs = Deferred()
-            self.runtime._expect_data(jnx, DjKIND, djs)
-            all_djs.append(djs)
-
-        deferreds.append(gatherResults(all_dis))
-        deferreds.append(gatherResults(all_djs))
-        r = gatherResults(deferreds)
-        def wrap((values, dis, djs), inx, jnx):
-            dis = reduce(lambda x, y: x + eval(y), dis, [])
-            djs = reduce(lambda x, y: x + eval(y), djs, [])
-            n_square_i = self.paillier.get_modulus_square(inx)
-            n_square_j = self.paillier.get_modulus_square(jnx)
-            N_squared_list = [self.paillier.get_modulus_square(player_id)
-                              for player_id in self.runtime.players]
-            ps = []
-            
-            for v, di, dj in itertools.izip_longest(values, dis, djs,
-                                                    fillvalue=self.Zp(0)):
-                value = v
-                enc_shares = len(self.runtime.players) * [1]
-                enc_shares[inx - 1] = (enc_shares[inx - 1] * di) % n_square_i
-                enc_shares[jnx - 1] = (enc_shares[jnx - 1] * dj) % n_square_j
-                ps.append(PartialShareContents(value, enc_shares,
-                                               N_squared_list))
-            return ps
-        r.addCallback(wrap, inx, jnx)
-        return r
-
+        return deferred
+       
     def _full_mul(self, a, b):
         """Multiply each of the PartialShares in the list *a* with the
         corresponding PartialShare in the list *b*.
@@ -302,41 +256,50 @@
         """
         self.runtime.increment_pc()
         
-        def do_full_mul(shares, result_shares):
+        def do_full_mul(shareContents, result_shares):
             """Share content belonging to ai, bi are at:
-            shares[i], shares[len(shares) + i].
+            shareContents[i], shareContents[len(shareContents) + i].
             """
             deferreds = []
-            len_shares = len(shares)
-            a_values = [s.value for s in shares[0:len_shares/2]]
+            len_shares = len(shareContents)
+
+            ais = [shareContent.value for shareContent in shareContents[0:len_shares/2]]
+            bis = [shareContent.value for shareContent in shareContents[len_shares/2:]]
+            
             b_enc_shares = []
-            for inx in self.runtime.players:              
+            for inx in self.runtime.players:
                 b_enc_shares.append([s.enc_shares[inx - 1]
-                                     for s in shares[len_shares/2:]])
+                                     for s in shareContents[len_shares/2:]])
+
+            values = len(ais) * [0]
+
             for inx in xrange(0, len(self.runtime.players)):
                 for jnx in xrange(0, len(self.runtime.players)):
                     deferreds.append(self._mul(inx + 1,
                                                jnx + 1,
-                                               len(a_values),
-                                               a_values,
+                                               len(ais),
+                                               ais,
                                                b_enc_shares[jnx]))
-                        
-            def compute_shares(partialShareContents, len_shares, result_shares):
-                num_players = len(self.runtime.players)
-                pcs = len(partialShareContents[0]) * [None]
-                for ps in partialShareContents:
-                    for inx in xrange(0, len(ps)):
-                        if pcs[inx] == None:
-                            pcs[inx] = ps[inx]
-                        else:
-                            pcs[inx] += ps[inx]
-                for p, s in zip(pcs, result_shares):
-                    s.callback(p)
+            
+            def compute_shares(zils, values, result_shares):
+                for zil in zils:
+                    for inx, zi in enumerate(zil):
+                        values[inx] += zi
+
+                return values
+            
+            d = gatherResults(deferreds)
+            d.addCallback(compute_shares, values, result_shares)
+            
+            def callBackPartialShareContents(partialShareContents, result_shares):
+                for v, s in zip(partialShareContents, result_shares):
+                    s.callback(v)
                 return None
-            d = gatherResults(deferreds)
-            d.addCallback(compute_shares, len_shares, result_shares)
+            
+            d.addCallback(lambda values: Share(values, self.runtime, self.paillier))
+            d.addCallback(callBackPartialShareContents, result_shares)
             return d
-        result_shares = [Share(self.runtime, self.Zp) for x in a]
+        result_shares = [PartialShare(self.runtime, self.Zp) for _ in a]
         self.runtime.schedule_callback(gatherResults(a + b),
                                        do_full_mul,
                                        result_shares)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/viff/bedoza/share.py	Tue Sep 21 11:51:35 2010 +0200
@@ -0,0 +1,53 @@
+# Copyright 2010 VIFF Development Team.
+#
+# This file is part of VIFF, the Virtual Ideal Functionality Framework.
+#
+# VIFF is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Lesser General Public License (LGPL) as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# VIFF is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
+# Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with VIFF. If not, see <http://www.gnu.org/licenses/>.
+
+from viff.bedoza.shares import PartialShareContents
+from viff.bedoza.util import _convolute
+
+def Share(field_elements, runtime, paillier):
+    """Each party input a list of field elements *field_elements*.
+    The value of the field elements are encrypted and the encrypted
+    values are exchanged.
+
+    Returns a deferred, which yields a list of PartialShareContents.  
+    """
+    
+    runtime.increment_pc()
+
+    N_squared_list = [paillier.get_modulus_square(player_id)
+                      for player_id in runtime.players]
+
+    list_of_enc_shares = []
+    for field_element in field_elements:
+        list_of_enc_shares.append(paillier.encrypt(field_element.value))
+        
+    list_of_enc_shares = _convolute(runtime, list_of_enc_shares, deserialize=eval)
+    def create_partial_share(list_of_enc_shares, field_elements):
+
+        reordered_encrypted_shares = [[] for _ in list_of_enc_shares[0]]
+        for enc_shares in list_of_enc_shares:
+            for inx, enc_share in enumerate(enc_shares):
+                reordered_encrypted_shares[inx].append(enc_share)
+
+        partialShareContents = []
+        for enc_shares, field_element in zip(reordered_encrypted_shares, field_elements):
+            partialShareContents.append(PartialShareContents(field_element, enc_shares, N_squared_list))
+        return partialShareContents
+    
+    runtime.schedule_callback(list_of_enc_shares, create_partial_share, field_elements)
+    return list_of_enc_shares
+        
--- a/viff/bedoza/share_generators.py	Tue Sep 21 10:05:02 2010 +0200
+++ b/viff/bedoza/share_generators.py	Tue Sep 21 11:51:35 2010 +0200
@@ -31,7 +31,7 @@
     def generate_share(self, value):
         self.runtime.increment_pc()
         
-        # TODO: Exclusve?
+        # TODO: Exclusive?
         r = [self.Zp(self.random.randint(0, self.Zp.modulus - 1))
              for _ in range(self.runtime.num_players - 1)]
         if self.runtime.id == 1:
--- a/viff/test/test_bedoza_runtime.py	Tue Sep 21 10:05:02 2010 +0200
+++ b/viff/test/test_bedoza_runtime.py	Tue Sep 21 11:51:35 2010 +0200
@@ -79,7 +79,8 @@
         RuntimeTestCase.setUp(self)
         self.Zp = GF(17)
         bits_in_p = 5
-        self.u_bound = 2**(4 * bits_in_p)
+        self.security_parameter = 32
+        self.u_bound = 2**(self.security_parameter + 4 * bits_in_p)
         self.alpha = 15
 
     @protocol
@@ -344,7 +345,7 @@
             return d
 
         random = Random(3423993)
-        gen = TripleGenerator(runtime, self.Zp.modulus, random)
+        gen = TripleGenerator(runtime, self.security_parameter, self.Zp.modulus, random)
         [triple] = gen.generate_triples(1)
         triple.addCallback(open)
         return triple
@@ -377,7 +378,7 @@
             d.addCallback(check)
             return d
 
-        gen = TripleGenerator(runtime, self.Zp.modulus, Random(3423993))
+        gen = TripleGenerator(runtime, self.security_parameter, self.Zp.modulus, Random(3423993))
         alpha = gen.alpha
         [triple] = gen.generate_triples(1)
         runtime.schedule_callback(triple, do_stuff, alpha)
@@ -393,7 +394,7 @@
         def check(v):
             self.assertEquals(v, self.Zp(x1 * y1))
 
-        gen = TripleGenerator(runtime, self.Zp.modulus, Random(3423993))
+        gen = TripleGenerator(runtime, self.security_parameter, self.Zp.modulus, Random(3423993))
         alpha = gen.alpha
         triples = gen.generate_triples(1)
         
@@ -444,7 +445,7 @@
             d.addCallback(check)
             return d
 
-        gen = TripleGenerator(runtime, self.Zp.modulus, Random(3423993))
+        gen = TripleGenerator(runtime, self.security_parameter, self.Zp.modulus, Random(3423993))
         alpha = gen.alpha
         [triple] = gen.generate_triples(1)
         runtime.schedule_callback(triple, do_stuff, alpha)
@@ -478,7 +479,7 @@
             d.addCallback(check)
             return d
 
-        gen = TripleGenerator(runtime, self.Zp.modulus, Random(3423993))
+        gen = TripleGenerator(runtime, self.security_parameter, self.Zp.modulus, Random(3423993))
         alpha = gen.alpha
         [triple] = gen.generate_triples(1)
         runtime.schedule_callback(triple, do_stuff, alpha)
--- a/viff/test/test_bedoza_triple.py	Tue Sep 21 10:05:02 2010 +0200
+++ b/viff/test/test_bedoza_triple.py	Tue Sep 21 11:51:35 2010 +0200
@@ -76,6 +76,10 @@
 
     runtime_class = BeDOZaRuntime
 
+    def setUp(self):
+        RuntimeTestCase.setUp(self)
+        self.security_parameter = 32
+
     # TODO: During test, we would like generation of Paillier keys to
     # be deterministic. How do we obtain that?
     def generate_configs(self, *args):
@@ -407,9 +411,9 @@
         p = 17
 
         Zp = GF(p)
-        
+      
         random = Random(283883)        
-        triple_generator = TripleGenerator(runtime, p, random)
+        triple_generator = TripleGenerator(runtime, self.security_parameter, p, random)
 
         triples = triple_generator.generate_triples(10)
 
@@ -433,9 +437,9 @@
         p = 17
 
         Zp = GF(p)
-        
+       
         random = Random(283883)        
-        triple_generator = TripleGenerator(runtime, p, random)
+        triple_generator = TripleGenerator(runtime, self.security_parameter, p, random)
 
         triples = triple_generator._generate_passive_triples(5)
         def verify(triples):
@@ -456,8 +460,9 @@
     @protocol
     def test_mul_computes_correct_result(self, runtime):
         p = 17
+       
         random = Random(283883)        
-        triple_generator = TripleGenerator(runtime, p, random)
+        triple_generator = TripleGenerator(runtime, 32, p, random)
 
         Zp = GF(p)
 
@@ -471,35 +476,28 @@
         
         if runtime.id == 1:
             r1 = triple_generator._mul(1, 2, n, ais, cs)
-            def check1(partialShares):
-                for partialShare in partialShares:
-                    zi = triple_generator.paillier.decrypt(partialShare.enc_shares[0])
-                    self.assertEquals(partialShare.value.value, zi)
+            def check1(shares):
+                for share in shares:
                     pc = tuple(runtime.program_counter)
-                    runtime.protocols[2].sendData(pc, TEXT, str(zi))
+                    runtime.protocols[2].sendData(pc, TEXT, str(share.value))
                 return True
             r1.addCallback(check1)
             return r1
         else:
             r1 = triple_generator._mul(1, 2, n)
-            def check(partialShares):
+            def check(shares):
                 deferreds = []
-                for partialShare in partialShares:
+                for share in shares:
                     if runtime.id == 2:
-                        zj = triple_generator.paillier.decrypt(partialShare.enc_shares[1])
-                        self.assertEquals(partialShare.value.value, zj)
                         def check_additivity(zi, zj):
                             self.assertEquals((Zp(long(zi)) + zj).value, 8)
                             return None
                         d = Deferred()
-                        d.addCallback(check_additivity, partialShare.value)
+                        d.addCallback(check_additivity, share.value)
                         runtime._expect_data(1, TEXT, d)
                         deferreds.append(d)
                     else:
-                        self.assertEquals(partialShare.value, 0)
-                        self.assertNotEquals(partialShare.enc_shares[0], 0)
-                        self.assertNotEquals(partialShare.enc_shares[1], 0)
-                        self.assertEquals(partialShare.enc_shares[2], 1)
+                        self.assertEquals(share.value, 0)
                 return gatherResults(deferreds)
             r1.addCallback(check)
             return r1
@@ -507,8 +505,9 @@
     @protocol
     def test_mul_same_player_inputs_and_receives(self, runtime):
         p = 17
+      
         random = Random(283883)        
-        triple_generator = TripleGenerator(runtime, p, random)
+        triple_generator = TripleGenerator(runtime, self.security_parameter, p, random)
 
         Zp = GF(p)
 
@@ -521,12 +520,10 @@
         n = len(ais)
         
         r1 = triple_generator._mul(2, 2, n, ais, cs)
-        def check(partialShareContents):
-            for partialShareContent in partialShareContents:
+        def check(shares):
+            for share in shares:
                 if runtime.id == 2:
-                    zi_enc = Zp(triple_generator.paillier.decrypt(partialShareContent.enc_shares[1]))
-                    self.assertEquals(zi_enc, partialShareContent.value)
-                    self.assertEquals(partialShareContent.value, 8)
+                    self.assertEquals(share.value, 8)
             return True
             
         r1.addCallback(check)
@@ -535,7 +532,9 @@
 
 class FullMulTest(BeDOZaTestCase): 
     num_players = 3
-    
+
+    timeout = 10
+        
     @protocol
     def test_fullmul_computes_the_correct_result(self, runtime):
         p = 17
@@ -543,7 +542,7 @@
         Zp = GF(p)
         
         random = Random(283883)        
-        triple_generator = TripleGenerator(runtime, p, random)
+        triple_generator = TripleGenerator(runtime, self.security_parameter, p, random)
 
         paillier = triple_generator.paillier
         
@@ -582,7 +581,7 @@
         Zp = GF(p)
         
         random = Random(283883)        
-        triple_generator = TripleGenerator(runtime, p, random)
+        triple_generator = TripleGenerator(runtime, self.security_parameter, p, random)
 
         paillier = triple_generator.paillier