changeset 1330:6f02ecfa60e1

orlandi: Adapted to new preprocessing interface.
author Marcel Keller <mkeller@cs.au.dk>
date Sat, 24 Oct 2009 18:39:19 +0200
parents c4803511dbf8
children 7856dceaf7b5
files viff/orlandi.py viff/test/test_orlandi_runtime.py
diffstat 2 files changed, 32 insertions(+), 45 deletions(-) [+]
line wrap: on
line diff
--- a/viff/orlandi.py	Fri Oct 23 15:03:30 2009 +0200
+++ b/viff/orlandi.py	Sat Oct 24 18:39:19 2009 +0200
@@ -520,17 +520,11 @@
 
         field = getattr(share_x, "field", getattr(share_y, "field", None))
 
-        def finish_mul((a, b, c)):
-            return self._basic_multiplication(share_x, share_y, a, b, c)
-
-        # This will be the result, a Share object.
-        result = Share(self, share_x.field)
-        # This is the Deferred we will do processing on.
-        triple = self._get_triple(field)
-        triple = self.schedule_complex_callback(triple, finish_mul)
-        # We add the result to the chains in triple.
-        triple.chainDeferred(result)
-        return result
+        triple, prep = self._get_triple(field)
+        if prep:
+            # The data from the pool must be wrapped in Shares.
+            triple = [Share(self, field, i) for i in triple]
+        return self._basic_multiplication(share_x, share_y, *triple)
 
     def _additive_constant(self, zero, field_element):
         """Greate an additive constant.
@@ -623,11 +617,13 @@
 
     @preprocess("random_triple")
     def _get_triple(self, field):
-        c, d = self.random_triple(field, 1)
-        def f(ls):
-            return ls[0]
-        d.addCallbacks(f, self.error_handler)
-        return d
+        results = [Share(self, field) for i in range(3)]
+        def chain(triple, results):
+            for i, result in zip(triple, results):
+                result.callback(i)
+        self.random_triple(field, 1)[0].addCallbacks(chain, self.error_handler,
+                                                     (results,))
+        return results
 
     def _basic_multiplication(self, share_x, share_y, triple_a, triple_b, triple_c):
         """Multiplication of shares give a triple.
@@ -1314,6 +1310,7 @@
             r.addErrback(self.error_handler)
             return r
 
+        results = [Deferred() for i in xrange(quantity)]
 
         def step7(Msets):
             """For i = 1,...,M do:
@@ -1323,21 +1320,15 @@
             d) Open([c] + [r])
             """
             ds = []
-            for Mi in Msets:
+            for Mi, result in zip(Msets, results):
                 a = self.random_share(field)
                 b = self.random_share(field)
                 r = self.random_share(field)
                 c = self.leak_tolerant_mul(a, b, Mi)
                 d = self.open(c + r)
-                def return_abc(x, a, b, c):
-                    return a, b, c
-                d.addCallbacks(return_abc, self.error_handler, callbackArgs=(a, b, c))
-                ds.append(d)
-            result = gather_shares(ds)
-            def triples(ls):
-                return ls
-            result.addCallbacks(triples, self.error_handler)
-            return result
+                def return_abc(x, a, b, c, result):
+                    gatherResults([a, b, c]).chainDeferred(result)
+                d.addCallbacks(return_abc, self.error_handler, callbackArgs=(a, b, c, result))
 
         result = gatherResults(M)
         self.schedule_callback(result, step3)
@@ -1348,12 +1339,7 @@
 
         # do actual communication
         self.activate_reactor()
-
-        s = Share(self, field)
-        # We add the result to the chains in result.
-        result.chainDeferred(s)
-
-        return quantity, s
+        return results
 
     def error_handler(self, ex):
         print "Error: ", ex
--- a/viff/test/test_orlandi_runtime.py	Fri Oct 23 15:03:30 2009 +0200
+++ b/viff/test/test_orlandi_runtime.py	Sat Oct 24 18:39:19 2009 +0200
@@ -18,7 +18,7 @@
 from twisted.internet.defer import gatherResults, DeferredList
 
 from viff.test.util import RuntimeTestCase, protocol
-from viff.runtime import gather_shares
+from viff.runtime import gather_shares, Share
 try:
     from viff.orlandi import OrlandiRuntime, OrlandiShare
     import commitment
@@ -560,9 +560,10 @@
         x2 = runtime.shift([1], self.Zp, x1)
         y2 = runtime.shift([2], self.Zp, y1)
 
-        c, sls = runtime.random_triple(self.Zp, 2*runtime.d + 1)
+        sls = gatherResults(runtime.random_triple(self.Zp, 2*runtime.d + 1))
 
         def cont(M):
+            M = [[Share(self, self.Zp, j) for j in i] for i in M]
             z2 = runtime.leak_tolerant_mul(x2, y2, M)
             d = runtime.open(z2)
             d.addCallback(check)
@@ -666,9 +667,9 @@
         def open(ls):
             ds = []
             for (a, b, c) in ls:
-                d1 = runtime.open(a)
-                d2 = runtime.open(b)
-                d3 = runtime.open(c)
+                d1 = runtime.open(Share(self, self.Zp, a))
+                d2 = runtime.open(Share(self, self.Zp, b))
+                d3 = runtime.open(Share(self, self.Zp, c))
                 ds.append(d1)
                 ds.append(d2)
                 ds.append(d3)
@@ -676,7 +677,7 @@
             d = gatherResults(ds)
             d.addCallback(check)
             return d
-        c, d = runtime.random_triple(self.Zp, 1)
+        d = gatherResults(runtime.random_triple(self.Zp, 1))
         d.addCallbacks(open, runtime.error_handler)
         return d
 
@@ -696,9 +697,9 @@
         def open(ls):
             ds = []
             for [(a, b, c)] in ls:
-                d1 = runtime.open(a)
-                d2 = runtime.open(b)
-                d3 = runtime.open(c)
+                d1 = runtime.open(Share(self, self.Zp, a))
+                d2 = runtime.open(Share(self, self.Zp, b))
+                d3 = runtime.open(Share(self, self.Zp, c))
                 ds.append(d1)
                 ds.append(d2)
                 ds.append(d3)
@@ -706,10 +707,10 @@
             d = gatherResults(ds)
             d.addCallback(check)
             return d
-        ac, a = runtime.random_triple(self.Zp, 1)
-        bc, b = runtime.random_triple(self.Zp, 1)
-        cc, c = runtime.random_triple(self.Zp, 1)
-        d = gather_shares([a, b, c])
+        a = gatherResults(runtime.random_triple(self.Zp, 1))
+        b = gatherResults(runtime.random_triple(self.Zp, 1))
+        c = gatherResults(runtime.random_triple(self.Zp, 1))
+        d = gatherResults([a, b, c])
         d.addCallbacks(open, runtime.error_handler)
         return d