changeset 1509:a6be6cce7046

BeDOZa: fullmul now returns a list of PartialShares.
author Janus Dam Nielsen <janus.nielsen@alexandra.dk>
date Tue, 20 Jul 2010 14:15:11 +0200
parents 891df84eb779
children e1aacaf19a54
files viff/bedoza_triple.py viff/test/test_bedoza_triple.py
diffstat 2 files changed, 21 insertions(+), 25 deletions(-) [+]
line wrap: on
line diff
--- a/viff/bedoza_triple.py	Tue Jul 20 14:01:16 2010 +0200
+++ b/viff/bedoza_triple.py	Tue Jul 20 14:15:11 2010 +0200
@@ -386,7 +386,7 @@
         r.addCallback(wrap, inx, jnx)
         return r
 
-    def _full_mul(self, a, b):
+    def _full_mul(self, a, b, field):
         """Multiply each of the PartialShares in the list *a* with the
         corresponding PartialShare in the list *b*.
         
@@ -394,7 +394,7 @@
         """
         self.runtime.increment_pc()
         
-        def do_full_mul(shares):
+        def do_full_mul(shares, result_shares):
             """Share content belonging to ai, bi are at:
             shares[i], shares[len(shares) + i].
             """
@@ -411,7 +411,7 @@
                                                a_values,
                                                b_enc_shares[jnx]))
                         
-            def compute_shares(partialShareContents, len_shares):
+            def compute_shares(partialShareContents, len_shares, result_shares):
                 num_players = len(self.runtime.players)
                 pcs = len(partialShareContents[0]) * [None]
                 for ps in partialShareContents:
@@ -420,16 +420,17 @@
                             pcs[inx] = ps[inx]
                         else:
                             pcs[inx] += ps[inx]
-                partialShares = [PartialShare(self.runtime,
-                                              p.value,
-                                              p.enc_shares) for p in pcs]
-                return partialShares
+                for p, s in zip(pcs, result_shares):
+                    s.callback(p)
+                return None
             d = gatherResults(deferreds)
-            d.addCallback(compute_shares, len_shares)
+            d.addCallback(compute_shares, len_shares, result_shares)
             return d
-        s = gatherResults(a + b)
-        self.runtime.schedule_callback(s, do_full_mul)
-        return s
+        result_shares = [Share(self.runtime, field) for x in a]
+        self.runtime.schedule_callback(gatherResults(a + b),
+                                       do_full_mul,
+                                       result_shares)
+        return result_shares
 
 
 # TODO: Represent all numbers by GF objects, Zp, Zn, etc.
--- a/viff/test/test_bedoza_triple.py	Tue Jul 20 14:01:16 2010 +0200
+++ b/viff/test/test_bedoza_triple.py	Tue Jul 20 14:15:11 2010 +0200
@@ -430,7 +430,7 @@
         share_bs.append(partial_share(random, runtime, GF(p), 3, paillier=paillier))
 
 
-        share_zs = triple_generator._full_mul(share_as, share_bs)
+        share_zs = triple_generator._full_mul(share_as, share_bs, Zp)
         def check(shares):
             def test_sum(ls):
                 self.assertEquals(8, Zp(sum(ls[0])))
@@ -444,12 +444,9 @@
             runtime.schedule_callback(d, test_sum)
             return d
             
-        def indirection(shares):
-            d = gatherResults(shares)
-            d.addCallback(check)
-            return d
-        share_zs.addCallback(indirection)
-        return share_zs
+        d = gatherResults(share_zs)
+        d.addCallback(check)
+        return d
 
     @protocol
     def test_fullmul_encrypted_values_are_the_same_as_the_share(self, runtime):
@@ -471,7 +468,7 @@
         share_as.append(partial_share(random, runtime, GF(p), 2, paillier=paillier))
         share_bs.append(partial_share(random, runtime, GF(p), 3, paillier=paillier))
 
-        share_zs = triple_generator._full_mul(share_as, share_bs)
+        share_zs = triple_generator._full_mul(share_as, share_bs, Zp)
         def check(shares):
             all_enc_shares = []
             for share in shares:
@@ -487,12 +484,10 @@
                 all_enc_shares.append(d)
             return gatherResults(all_enc_shares)
         
-        def indirection(shares):
-            d = gatherResults(shares)
-            d.addCallback(check)
-            return d
-        share_zs.addCallback(indirection)
-        return share_zs
+        d = gatherResults(share_zs)
+        d.addCallback(check)
+        return d
+        
 
 
 missing_package = None