changeset 1317:d41aca4d7b6e

Made preprocessing more efficient.
author Marcel Keller <mkeller@cs.au.dk>
date Fri, 25 Sep 2009 20:29:18 +0200
parents 0f35ae3f503b
children 70c995d31d1a
files viff/active.py viff/aes.py viff/passive.py viff/runtime.py viff/test/test_active_runtime.py
diffstat 5 files changed, 41 insertions(+), 51 deletions(-) [+]
line wrap: on
line diff
--- a/viff/active.py	Thu Sep 17 17:59:08 2009 +0200
+++ b/viff/active.py	Fri Sep 25 20:29:18 2009 +0200
@@ -381,9 +381,7 @@
     def get_triple(self, field):
         # This is a waste, but this function is only called if there
         # are no pre-processed triples left.
-        count, result = self.generate_triples(field)
-        result.addCallback(lambda triples: triples[0])
-        return result
+        return self.generate_triples(field)[0]
 
     @increment_pc
     def generate_triples(self, field):
@@ -399,7 +397,7 @@
         t = self.threshold
         T = n - 2*t
 
-        def make_triple(shares):
+        def make_triple(shares, results):
             a_t, b_t, (r_t, r_2t) = shares
 
             c_2t = []
@@ -412,15 +410,17 @@
             d_2t = [c_2t[i] - r_2t[i] for i in range(T)]
             d = [self.open(d_2t[i], threshold=2*t) for i in range(T)]
             c_t = [r_t[i] + d[i] for i in range(T)]
-            return zip(a_t, b_t, c_t)
+
+            for triple, result in zip(zip(a_t, b_t, c_t), results):
+                gatherResults(triple).chainDeferred(result)
 
         single_a = self.single_share_random(T, t, field)
         single_b = self.single_share_random(T, t, field)
         double_c = self.double_share_random(T, t, 2*t, field)
 
-        result = gatherResults([single_a, single_b, double_c])
-        self.schedule_callback(result, make_triple)
-        return T, result
+        results = [Deferred() for i in range(T)]
+        self.schedule_callback(gatherResults([single_a, single_b, double_c]), make_triple, results)
+        return results
 
 class TriplesPRSSMixin:
     """Mixin class for generating multiplication triples using PRSS."""
@@ -428,9 +428,8 @@
     @increment_pc
     @preprocess("generate_triples")
     def get_triple(self, field):
-        count, result = self.generate_triples(field, quantity=1)
-        result.addCallback(lambda triples: triples[0])
-        return result
+        result = self.generate_triples(field, quantity=1)
+        return result[0]
 
     @increment_pc
     def generate_triples(self, field, quantity=20):
@@ -456,7 +455,7 @@
             d = self.open(d_2t, threshold=2*self.threshold)
             c_t[i] = r_t[i] + d
 
-        return quantity, succeed(zip(a_t, b_t, c_t))
+        return [gatherResults(triple) for triple in zip(a_t, b_t, c_t)]
 
 
 class BasicActiveRuntime(PassiveRuntime):
@@ -516,7 +515,6 @@
         result = Share(self, share_x.field)
         # This is the Deferred we will do processing on.
         triple = self.get_triple(share_x.field)
-        triple.addCallback(gather_shares)
         triple = self.schedule_complex_callback(triple, finish_mul)
         # We add the result to the chains in triple.
         triple.chainDeferred(result)
--- a/viff/aes.py	Thu Sep 17 17:59:08 2009 +0200
+++ b/viff/aes.py	Fri Sep 25 20:29:18 2009 +0200
@@ -161,7 +161,8 @@
 
     def invert_by_masked_exponentiation(self, byte):
         def add_and_multiply(masked_powers, random_powers):
-            byte_powers = map(operator.add, masked_powers, random_powers)[1:]
+            byte_powers = [Share(self.runtime, GF256, value) for value in
+                           map(operator.add, masked_powers, random_powers)[1:]]
             while len(byte_powers) > 1:
                 byte_powers.append(byte_powers.pop(0) * byte_powers.pop(0))
             return byte_powers[0]
--- a/viff/passive.py	Thu Sep 17 17:59:08 2009 +0200
+++ b/viff/passive.py	Fri Sep 25 20:29:18 2009 +0200
@@ -28,7 +28,7 @@
 from viff.field import GF256, FieldElement
 from viff.util import rand, profile
 
-from twisted.internet.defer import succeed
+from twisted.internet.defer import succeed, gatherResults
 
 
 class PassiveRuntime(Runtime):
@@ -467,14 +467,13 @@
         """Generate a random secret share in GF256 and returns
         [*share*, *share*^2, *share*^4, ..., *share*^(i^max)]."""
         share = self.prss_share_random(GF256)
-        return succeed(self.powerchain(share, max))
+        return gatherResults(self.powerchain(share, max))
 
     def prss_powerchains(self, max=7, quantity=20):
         """Does *quantity* times the same as :meth:`prss_powerchain`.
         Used for preprocessing."""
         shares = self.prss_share_random_multi(GF256, quantity)
-        return quantity, succeed([self.powerchain(share, max)
-                                  for share in shares])
+        return [gatherResults(self.powerchain(share, max)) for share in shares]
 
     def input(self, inputters, field, number=None, threshold=None):
         """Input *number* to the computation.
--- a/viff/runtime.py	Thu Sep 17 17:59:08 2009 +0200
+++ b/viff/runtime.py	Fri Sep 25 20:29:18 2009 +0200
@@ -737,14 +737,13 @@
         arguments. The generator methods called must adhere to the
         following interface:
 
-        - They must return a ``(int, Deferred)`` tuple where the
-          ``int`` tells us how many items of pre-processed data the
-          :class:`Deferred` will yield.
+        - They must return a list of :class:`Deferred` instances.
 
-        - The Deferred must yield a list of the promised length.
-
-        - The list contains the actual data. This data can be either a
-          Deferred or a tuple of Deferreds.
+        - Every Deferred must yield an item of pre-processed data.
+          This can be value, a list or tuple of values, or a Deferred
+          (which will be converted to a value by Twisted), but NOT a
+          list of Deferreds. Use :meth:`gatherResults` to avoid the
+          latter.
 
         The :meth:`ActiveRuntime.generate_triples` method is an
         example of a method fulfilling this interface.
@@ -767,20 +766,14 @@
             # avoid starting before the pre-processing is complete.
             return deep_wait(results)
 
-        wait_list = []
         for ((generator, args), program_counters) in program.iteritems():
             print "Preprocessing %s (%d items)" % (generator, len(program_counters))
             func = getattr(self, generator)
             results = []
-            items = 0
-            while items < len(program_counters):
-                item_count, result = func(*args)
-                items += item_count
-                results.append(result)
-            ready = gatherResults(results)
-            ready.addCallback(update, program_counters)
-            wait_list.append(ready)
-        return DeferredList(wait_list)
+            while len(results) < len(program_counters):
+                results += func(*args)
+            self._pool.update(zip(program_counters, results))
+        return DeferredList(results).addCallback(lambda _: None)
 
     def input(self, inputters, field, number=None):
         """Input *number* to the computation.
--- a/viff/test/test_active_runtime.py	Thu Sep 17 17:59:08 2009 +0200
+++ b/viff/test/test_active_runtime.py	Fri Sep 25 20:29:18 2009 +0200
@@ -99,24 +99,23 @@
             """Verify a multiplication triple."""
             self.assertEquals(triple[0] * triple[1], triple[2])
 
-        def check(triples):
-            results = []
-            for a, b, c in triples:
-                self.assert_type(a, Share)
-                self.assert_type(b, Share)
-                self.assert_type(c, Share)
-                open_a = runtime.open(a)
-                open_b = runtime.open(b)
-                open_c = runtime.open(c)
-                result = gatherResults([open_a, open_b, open_c])
-                result.addCallback(verify)
-                results.append(result)
-            return gatherResults(results)
+        def check(triple):
+            a, b, c = triple
+            self.assert_type(a, self.Zp)
+            self.assert_type(b, self.Zp)
+            self.assert_type(c, self.Zp)
+            open_a = runtime.open(Share(self, self.Zp, a))
+            open_b = runtime.open(Share(self, self.Zp, b))
+            open_c = runtime.open(Share(self, self.Zp, c))
+            result = gatherResults([open_a, open_b, open_c])
+            result.addCallback(verify)
+            return result
 
-        count, triples = runtime.generate_triples(self.Zp)
-        self.assertEquals(count, runtime.num_players - 2*runtime.threshold)
+        triples = runtime.generate_triples(self.Zp)
+        self.assertEquals(len(triples), runtime.num_players - 2*runtime.threshold)
 
-        runtime.schedule_callback(triples, check)
+        for triple in triples:
+            runtime.schedule_callback(triple, check)
         return triples