changeset 1525:80841840990f

SimpleArithmetic, Orlandi, BeDOZa: Removed _get_triple.
author Janus Dam Nielsen <janus.nielsen@alexandra.dk>
date Mon, 26 Jul 2010 14:46:15 +0200
parents 8da98c5697b5
children a87fd09f8c38
files viff/bedoza/bedoza.py viff/orlandi.py viff/simplearithmetic.py viff/test/test_bedoza_runtime.py viff/test/test_orlandi_runtime.py
diffstat 5 files changed, 123 insertions(+), 172 deletions(-) [+]
line wrap: on
line diff
--- a/viff/bedoza/bedoza.py	Mon Jul 26 11:11:05 2010 +0200
+++ b/viff/bedoza/bedoza.py	Mon Jul 26 14:46:15 2010 +0200
@@ -34,6 +34,7 @@
 
 from viff.bedoza.share_generators import ShareGenerator
 
+
 class BeDOZaException(Exception):
     pass
 
@@ -304,9 +305,6 @@
         assert(isinstance(c, FieldElement))
         return x.cmul(c)
 
-    def _get_triple(self, field):
-        return self.triples.pop(), False
-
 
 class BeDOZaRuntime(BeDOZaMixin, SimpleArithmeticRuntime):
     """The BeDOZa runtime.
@@ -328,8 +326,12 @@
     it is ready.
     """
 
-    def __init__(self, player, threshold=None, options=None):
-        """Initialize runtime."""
+    def __init__(self, player, threshold=None, options=None, triples=[]):
+        """Initialize runtime.
+
+        *triples* is a list of multiplicative triples previously
+        generated using the bedoza_triple.TripleGenerator.
+        """
         SimpleArithmeticRuntime.__init__(self, player, threshold, options)
         self.threshold = self.num_players - 1
-
+        self.triples = triples
--- a/viff/orlandi.py	Mon Jul 26 11:11:05 2010 +0200
+++ b/viff/orlandi.py	Mon Jul 26 14:46:15 2010 +0200
@@ -35,6 +35,8 @@
 
 from viff.simplearithmetic import SimpleArithmeticRuntime
 
+from viff.triple import Triple
+
 from hash_broadcast import HashBroadcastMixin
 
 try:
@@ -631,16 +633,6 @@
     
     def _wrap_in_share(self, (zi, rhoz, Cz), field):
         return OrlandiShare(self, field, zi, rhoz, Cz)
-
-    @preprocess("random_triple")
-    def _get_triple(self, field):
-        results = [Share(self, field) for i in range(3)]
-        def chain(triple, results):
-            for i, result in zip(triple, results):
-                result.callback(i)
-        self.random_triple(field, 1)[0].addCallbacks(chain, self.error_handler,
-                                                     (results,))
-        return results
     
     def sum_poly(self, j, ls):
         exp  = j
@@ -665,7 +657,7 @@
         Communication cost: ???.
 
         Assuming a set of multiplicative triples:
-        ``M = ([a_i], [b_i], [c_i]) for 1 <= i <= 2d + 1``.
+        ``M = Triple([a_i], [b_i], [c_i]) for 1 <= i <= 2d + 1``.
 
         1. ``for i = 1, ..., d do [f_i] = rand(), [g_i] = rand()``
 
@@ -705,7 +697,7 @@
             f.append(self.random_share(field))
             g.append(self.random_share(field))
 
-        def compute_polynomials(t):
+        def compute_polynomials(t, M):
             x, y = t[0]
             f = []
             g = []
@@ -754,10 +746,13 @@
                     C_Gji *= Cw
                 Fj = OrlandiShare(self, field, Fji, (rho1_Fji, rho2_Fji), C_Fji)
                 Gj = OrlandiShare(self, field, Gji, (rho1_Gji, rho2_Gji), C_Gji)
-                a, b, c = M.pop(0)
+                triple = M.pop(0)
 
                 # [H_j] = Mul([F_j], [G_j], [a_j], [b_j], [c_j])
-                Hj = self._basic_multiplication(Fj, Gj, a, b, c)
+                Hj = self._basic_multiplication(Fj, Gj,
+                                                triple.a,
+                                                triple.b,
+                                                triple.c)
                 dj = self._cmul(field(deltas[j - 1]), Hj, field)
                 H0 = H0 + dj
             # 5) output [z] = [H_0]
@@ -768,7 +763,7 @@
             ls.append(gather_shares(f))
             ls.append(gather_shares(g))
         result = gather_shares(ls)
-        self.schedule_callback(result, compute_polynomials)
+        self.schedule_callback(result, compute_polynomials, M)
         result.addErrback(self.error_handler)
 
         # do actual communication
@@ -777,7 +772,7 @@
         return result
 
     def triple_gen(self, field):
-        """Generate a triple ``a, b, c`` s.t. ``c = a * b``.
+        """Generate a triple ``Tripel(a, b, c)`` s.t. ``c = a * b``.
 
         1. Every party ``P_i`` chooses random values ``a_i, r_i in Z_p
            X (Z_p)^2``, compute ``alpha_i = Enc_eki(a_i)`` and ``Ai =
@@ -836,7 +831,7 @@
             a = OrlandiShare(self, field, ai, r, A)
             b = OrlandiShare(self, field, bi, s, B)
             c = OrlandiShare(self, field, ci, t, C)
-            return (a, b, c, (alphas, alpha_randomness, gammas, dijs))
+            return (Triple(a, b, c), (alphas, alpha_randomness, gammas, dijs))
 
         def decrypt_gammas(ls):
             """Decrypt all the elements of the list *ls*."""
@@ -969,18 +964,18 @@
         triple2 = self.triple_gen(field)
         r = self.open(self.random_share(field))
 
-        def check(v, a, b, c, ec):
+        def check(v, triple, ec):
             if v.value != 0:
                 raise OrlandiException("TripleTest failed - The two triples were inconsistent.")
-            return (a, b, c, ec)
+            return (triple, ec)
 
-        def compute_value(((a, b, c, ec), (x, y, z, _), r)):
-            l = self._cmul(r, x, field)
-            m = self._cmul(r, y, field)
-            n = self._cmul(r*r, z, field)
-            d = c - self._basic_multiplication(a, b, l, m, n)
+        def compute_value(((t1, ec), (t2, _), r)):
+            l = self._cmul(r, t2.a, field)
+            m = self._cmul(r, t2.b, field)
+            n = self._cmul(r*r, t2.c, field)
+            d = t1.c - self._basic_multiplication(t1.a, t1.b, l, m, n)
             r = self.open(d)
-            r.addCallbacks(check, self.error_handler, callbackArgs=(a, b, c, ec))
+            r.addCallbacks(check, self.error_handler, callbackArgs=(t1, ec))
             return r
 
         result = gatherResults([triple1, triple2, r])
@@ -993,9 +988,9 @@
         return result
 
     def random_triple(self, field, quantity=1):
-        """Generate a list of triples ``(a, b, c)`` where ``c = a * b``.
+        """Generate a list of triples ``Triple(a, b, c)`` where ``c = a * b``.
 
-        The triple ``(a, b, c)`` is secure in the Fcrs-hybrid model.
+        The triple ``Triple(a, b, c)`` is secure in the Fcrs-hybrid model.
 
         """
         self.increment_pc()
@@ -1160,7 +1155,10 @@
                 return True
 
             dls_all = []
-            for (a, b, c, (alphas, alpha_randomness, gammas, dijs)) in T:
+            for (triple, (alphas, alpha_randomness, gammas, dijs)) in T:
+                a = triple.a
+                b = triple.b
+                c = triple.c
                 ds_a = [None] * len(self.players)
                 ds_b = [None] * len(self.players)
                 ds_c = [None] * len(self.players)
@@ -1201,8 +1199,8 @@
 
             def result(x):
                 ls = []
-                for a, b, c, _ in M_without_test_set:
-                    ls.append((a, b, c))
+                for triple, _ in M_without_test_set:
+                    ls.append(triple)
                 return ls
 
             dls_all = gatherResults(dls_all)
@@ -1252,7 +1250,7 @@
                 c = self.leak_tolerant_mul(a, b, Mi)
                 d = self.open(c + r)
                 def return_abc(x, a, b, c, result):
-                    gatherResults([a, b, c]).chainDeferred(result)
+                    result.callback(Triple(a, b, c))
                 d.addCallbacks(return_abc, self.error_handler, callbackArgs=(a, b, c, result))
 
         result = gatherResults(M)
@@ -1304,3 +1302,4 @@
         self.s = 1
         self.d = 0
         self.s_lambda = 1
+        self.triples = []
--- a/viff/simplearithmetic.py	Mon Jul 26 11:11:05 2010 +0200
+++ b/viff/simplearithmetic.py	Mon Jul 26 14:46:15 2010 +0200
@@ -92,11 +92,12 @@
 
         field = getattr(share_x, "field", getattr(share_y, "field", None))
 
-        triple, prep = self._get_triple(field)
-        if prep:
-            # The data from the pool must be wrapped in Shares.
-            triple = [Share(self, field, i) for i in triple]
-        return self._basic_multiplication(share_x, share_y, *triple)
+        triple = self.triples.pop()
+        return self._basic_multiplication(share_x,
+                                          share_y,
+                                          triple.a,
+                                          triple.b,
+                                          triple.c)
 
     def _cmul(self, share_x, share_y, field):
         """Multiplication of a share with a constant.
--- a/viff/test/test_bedoza_runtime.py	Mon Jul 26 11:11:05 2010 +0200
+++ b/viff/test/test_bedoza_runtime.py	Mon Jul 26 14:46:15 2010 +0200
@@ -382,34 +382,38 @@
         runtime.schedule_callback(triple, do_stuff, alpha)
         return triple
 
-    # @protocol
-    # def test_mul_mul(self, runtime):
-    #     """Test multiplication of two numbers."""
+    @protocol
+    def test_mul_mul(self, runtime):
+        """Test multiplication of two numbers."""
 
-    #     x1 = 6
-    #     y1 = 6
+        x1 = 6
+        y1 = 6
 
-    #     def check(v):
-    #         self.assertEquals(v, self.Zp(x1 * y1))
+        def check(v):
+            self.assertEquals(v, self.Zp(x1 * y1))
 
-    #     gen = TripleGenerator(runtime, self.Zp.modulus, Random(3423993))
-    #     alpha = gen.alpha
-    #     runtime.triples = gen.generate_triples(1)
+        gen = TripleGenerator(runtime, self.Zp.modulus, Random(3423993))
+        alpha = gen.alpha
+        triples = gen.generate_triples(1)
         
-
-    #     random = Random(3423993)
-    #     share_random = Random(random.getrandbits(128))
-    #     paillier = ModifiedPaillier(runtime, Random(random.getrandbits(128)))          
-    #     gen = ShareGenerator(self.Zp, runtime, share_random,
-    #                          paillier, self.u_bound, self.alpha)
+        def do_mult(triples, alpha):
+            runtime.triples = triples
+            random = Random(3423993)
+            share_random = Random(random.getrandbits(128))
+            paillier = ModifiedPaillier(runtime, Random(random.getrandbits(128)))          
+            gen = ShareGenerator(self.Zp, runtime, share_random,
+                                 paillier, self.u_bound, alpha)
         
-    #     x2 = gen.generate_share(x1)
-    #     y2 = gen.generate_share(y1)
+            x2 = gen.generate_share(x1)
+            y2 = gen.generate_share(y1)
         
-    #     z2 = x2 * y2
-    #     d = runtime.open(z2)
-    #     d.addCallback(check)
-    #     return d
+            z2 = x2 * y2
+            d = runtime.open(z2)
+            d.addCallback(check)
+            return d
+        r = gatherResults(triples)
+        runtime.schedule_callback(r, do_mult, alpha)
+        return r
     
     @protocol
     def test_basic_multiply_constant_right(self, runtime):
--- a/viff/test/test_orlandi_runtime.py	Mon Jul 26 11:11:05 2010 +0200
+++ b/viff/test/test_orlandi_runtime.py	Mon Jul 26 14:46:15 2010 +0200
@@ -492,13 +492,20 @@
         def check(v):
             self.assertEquals(v, x1 * y1)
 
-        x2 = runtime.shift([2], self.Zp, x1)
-        y2 = runtime.shift([3], self.Zp, y1)
+        triples = runtime.random_triple(self.Zp, 1)
+        
+        def do_mult(triples):
+            runtime.triples = triples
+            x2 = runtime.shift([2], self.Zp, x1)
+            y2 = runtime.shift([3], self.Zp, y1)
 
-        z2 = x2 * y2
-        d = runtime.open(z2)
-        d.addCallback(check)
-        return d
+            z2 = x2 * y2
+            d = runtime.open(z2)
+            d.addCallback(check)
+            return d
+        r = gatherResults(triples)
+        runtime.schedule_callback(r, do_mult)
+        return r
 
     @protocol
     def test_basic_multiply_constant_right(self, runtime):
@@ -667,20 +674,16 @@
         x2 = runtime.shift([1], self.Zp, x1)
         y2 = runtime.shift([2], self.Zp, y1)
 
-        sls = gatherResults(runtime.random_triple(self.Zp, 2*runtime.d + 1))
+        triples = runtime.random_triple(self.Zp, 2*runtime.d + 1)
 
-        def cont(M):
-            M = [[Share(self, self.Zp, j) for j in i] for i in M]
-            z2 = runtime.leak_tolerant_mul(x2, y2, M)
+        def cont(triples):
+            z2 = runtime.leak_tolerant_mul(x2, y2, triples)
             d = runtime.open(z2)
             d.addCallback(check)
             return d
-        sls.addCallbacks(cont, runtime.error_handler)
-        return sls
-
-        z2 = runtime._cmul(y2, x2, self.Zp)
-        self.assertEquals(z2, None)
-        return z2
+        r = gatherResults(triples)
+        runtime.schedule_callback(r, cont)
+        return r
 
     @protocol
     def test_leak_mul1(self, runtime):
@@ -705,20 +708,17 @@
         x2 = runtime.shift([1], self.Zp, x1)
         y2 = runtime.shift([2], self.Zp, y1)
 
-        sls = gatherResults(runtime.random_triple(self.Zp, 2*runtime.d + 1))
+        triples = runtime.random_triple(self.Zp, 2*runtime.d + 1)
 
-        def cont(M):
-            M = [[Share(self, self.Zp, j) for j in i] for i in M]
-            z2 = runtime.leak_tolerant_mul(x2, y2, M)
+        def cont(triples):
+            z2 = runtime.leak_tolerant_mul(x2, y2, triples)
             d = runtime.open(z2)
             d.addCallback(check)
             return d
-        sls.addCallbacks(cont, runtime.error_handler)
-        return sls
+        r = gatherResults(triples)
+        runtime.schedule_callback(r, cont)
+        return r
 
-        z2 = runtime._cmul(y2, x2, self.Zp)
-        self.assertEquals(z2, None)
-        return z2
 
 class TripleGenTest(RuntimeTestCase):
     """Test for generation of triples."""
@@ -728,7 +728,7 @@
 
     runtime_class = OrlandiRuntime
 
-    timeout = 1600
+    timeout = 10
 
     def generate_configs(self, *args):
         global keys
@@ -749,10 +749,10 @@
         def check((a, b, c)):
             self.assertEquals(c, a * b)
 
-        def open((a, b, c, _)):
-            d1 = runtime.open(a)
-            d2 = runtime.open(b)
-            d3 = runtime.open(c)
+        def open((triple, _)):
+            d1 = runtime.open(triple.a)
+            d2 = runtime.open(triple.b)
+            d3 = runtime.open(triple.c)
             d = gatherResults([d1, d2, d3])
             d.addCallback(check)
             return d
@@ -770,13 +770,13 @@
             self.assertEquals(c, a * b)
             self.assertEquals(dz, dx * dy)
 
-        def open(((a, b, c, control), (x, y, z, _))):
-            d1 = runtime.open(a)
-            d2 = runtime.open(b)
-            d3 = runtime.open(c)
-            dx = runtime.open(x)
-            dy = runtime.open(y)
-            dz = runtime.open(z)
+        def open(((t1, control), (t2, _))):
+            d1 = runtime.open(t1.a)
+            d2 = runtime.open(t1.b)
+            d3 = runtime.open(t1.c)
+            dx = runtime.open(t2.a)
+            dy = runtime.open(t2.b)
+            dz = runtime.open(t2.c)
             d = gatherResults([d1, d2, d3, dx, dy, dz])
             d.addCallback(check)
             return d
@@ -795,10 +795,10 @@
         def check((a, b, c)):
             self.assertEquals(c, a * b)
 
-        def open((a, b, c, _)):
-            d1 = runtime.open(a)
-            d2 = runtime.open(b)
-            d3 = runtime.open(c)
+        def open((triple, _)):
+            d1 = runtime.open(triple.a)
+            d2 = runtime.open(triple.b)
+            d3 = runtime.open(triple.c)
             d = gatherResults([d1, d2, d3])
             d.addCallback(check)
             return d
@@ -819,15 +819,15 @@
                 c = ls[x * 3 + 2]
                 self.assertEquals(c, a * b)
 
-        def open(ls):
+        def open(triples):
+            triple = triples[0]
             ds = []
-            for (a, b, c) in ls:
-                d1 = runtime.open(Share(self, self.Zp, a))
-                d2 = runtime.open(Share(self, self.Zp, b))
-                d3 = runtime.open(Share(self, self.Zp, c))
-                ds.append(d1)
-                ds.append(d2)
-                ds.append(d3)
+            d1 = runtime.open(triple.a)
+            d2 = runtime.open(triple.b)
+            d3 = runtime.open(triple.c)
+            ds.append(d1)
+            ds.append(d2)
+            ds.append(d3)
 
             d = gatherResults(ds)
             d.addCallback(check)
@@ -851,10 +851,10 @@
 
         def open(ls):
             ds = []
-            for [(a, b, c)] in ls:
-                d1 = runtime.open(Share(self, self.Zp, a))
-                d2 = runtime.open(Share(self, self.Zp, b))
-                d3 = runtime.open(Share(self, self.Zp, c))
+            for [triple] in ls:
+                d1 = runtime.open(triple.a)
+                d2 = runtime.open(triple.b)
+                d3 = runtime.open(triple.c)
                 ds.append(d1)
                 ds.append(d2)
                 ds.append(d3)
@@ -869,61 +869,6 @@
         d.addCallbacks(open, runtime.error_handler)
         return d
 
-    @protocol
-    def test_random_triple_parallel(self, runtime):
-        """Test the triple_combiner command."""
-
-        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
-
-        def check(ls):
-            for x in xrange(len(ls) // 3):
-                a = ls[x * 3]
-                b = ls[x * 3 + 1]
-                c = ls[x * 3 + 2]
-                self.assertEquals(c, a * b)
-
-        def open(ls):
-            ds = []
-            for [(a, b, c)] in ls:
-                d1 = runtime.open(a)
-                d2 = runtime.open(b)
-                d3 = runtime.open(c)
-                ds.append(d1)
-                ds.append(d2)
-                ds.append(d3)
-
-            d = gatherResults(ds)
-            d.addCallback(check)
-            return d
-
-        a_shares = []
-        b_shares = []
-        c_shares = []
-
-        def cont(x):
-            while a_shares and b_shares:
-                a = a_shares.pop()
-                b = b_shares.pop()
-                c_shares.append(runtime.mul(a, b))
-            done = gather_shares(c_shares)
-            return done
-
-        count = 5
-
-        for i in range(count):
-            inputter = (i % len(runtime.players)) + 1
-            if inputter == runtime.id:
-                a = rand.randint(0, self.Zp.modulus)
-                b = rand.randint(0, self.Zp.modulus)
-            else:
-                a, b = None, None
-            a_shares.append(runtime.input([inputter], self.Zp, a))
-            b_shares.append(runtime.input([inputter], self.Zp, b))
-        shares_ready = gather_shares(a_shares + b_shares)
-
-        runtime.schedule_callback(shares_ready, cont)
-        return shares_ready
-
 
 def skip_tests(module_name):
     OrlandiAdvancedCommandsTest.skip = "Skipped due to missing " + module_name + " module."