viff

changeset 1317:d41aca4d7b6e

Made preprocessing more efficient.
author Marcel Keller <mkeller@cs.au.dk>
date Fri, 25 Sep 2009 20:29:18 +0200
parents 0f35ae3f503b
children 70c995d31d1a
files viff/active.py viff/aes.py viff/passive.py viff/runtime.py viff/test/test_active_runtime.py
diffstat 5 files changed, 41 insertions(+), 51 deletions(-) [+]
line diff
     1.1 --- a/viff/active.py	Thu Sep 17 17:59:08 2009 +0200
     1.2 +++ b/viff/active.py	Fri Sep 25 20:29:18 2009 +0200
     1.3 @@ -381,9 +381,7 @@
     1.4      def get_triple(self, field):
     1.5          # This is a waste, but this function is only called if there
     1.6          # are no pre-processed triples left.
     1.7 -        count, result = self.generate_triples(field)
     1.8 -        result.addCallback(lambda triples: triples[0])
     1.9 -        return result
    1.10 +        return self.generate_triples(field)[0]
    1.11  
    1.12      @increment_pc
    1.13      def generate_triples(self, field):
    1.14 @@ -399,7 +397,7 @@
    1.15          t = self.threshold
    1.16          T = n - 2*t
    1.17  
    1.18 -        def make_triple(shares):
    1.19 +        def make_triple(shares, results):
    1.20              a_t, b_t, (r_t, r_2t) = shares
    1.21  
    1.22              c_2t = []
    1.23 @@ -412,15 +410,17 @@
    1.24              d_2t = [c_2t[i] - r_2t[i] for i in range(T)]
    1.25              d = [self.open(d_2t[i], threshold=2*t) for i in range(T)]
    1.26              c_t = [r_t[i] + d[i] for i in range(T)]
    1.27 -            return zip(a_t, b_t, c_t)
    1.28 +
    1.29 +            for triple, result in zip(zip(a_t, b_t, c_t), results):
    1.30 +                gatherResults(triple).chainDeferred(result)
    1.31  
    1.32          single_a = self.single_share_random(T, t, field)
    1.33          single_b = self.single_share_random(T, t, field)
    1.34          double_c = self.double_share_random(T, t, 2*t, field)
    1.35  
    1.36 -        result = gatherResults([single_a, single_b, double_c])
    1.37 -        self.schedule_callback(result, make_triple)
    1.38 -        return T, result
    1.39 +        results = [Deferred() for i in range(T)]
    1.40 +        self.schedule_callback(gatherResults([single_a, single_b, double_c]), make_triple, results)
    1.41 +        return results
    1.42  
    1.43  class TriplesPRSSMixin:
    1.44      """Mixin class for generating multiplication triples using PRSS."""
    1.45 @@ -428,9 +428,8 @@
    1.46      @increment_pc
    1.47      @preprocess("generate_triples")
    1.48      def get_triple(self, field):
    1.49 -        count, result = self.generate_triples(field, quantity=1)
    1.50 -        result.addCallback(lambda triples: triples[0])
    1.51 -        return result
    1.52 +        result = self.generate_triples(field, quantity=1)
    1.53 +        return result[0]
    1.54  
    1.55      @increment_pc
    1.56      def generate_triples(self, field, quantity=20):
    1.57 @@ -456,7 +455,7 @@
    1.58              d = self.open(d_2t, threshold=2*self.threshold)
    1.59              c_t[i] = r_t[i] + d
    1.60  
    1.61 -        return quantity, succeed(zip(a_t, b_t, c_t))
    1.62 +        return [gatherResults(triple) for triple in zip(a_t, b_t, c_t)]
    1.63  
    1.64  
    1.65  class BasicActiveRuntime(PassiveRuntime):
    1.66 @@ -516,7 +515,6 @@
    1.67          result = Share(self, share_x.field)
    1.68          # This is the Deferred we will do processing on.
    1.69          triple = self.get_triple(share_x.field)
    1.70 -        triple.addCallback(gather_shares)
    1.71          triple = self.schedule_complex_callback(triple, finish_mul)
    1.72          # We add the result to the chains in triple.
    1.73          triple.chainDeferred(result)
     2.1 --- a/viff/aes.py	Thu Sep 17 17:59:08 2009 +0200
     2.2 +++ b/viff/aes.py	Fri Sep 25 20:29:18 2009 +0200
     2.3 @@ -161,7 +161,8 @@
     2.4  
     2.5      def invert_by_masked_exponentiation(self, byte):
     2.6          def add_and_multiply(masked_powers, random_powers):
     2.7 -            byte_powers = map(operator.add, masked_powers, random_powers)[1:]
     2.8 +            byte_powers = [Share(self.runtime, GF256, value) for value in
     2.9 +                           map(operator.add, masked_powers, random_powers)[1:]]
    2.10              while len(byte_powers) > 1:
    2.11                  byte_powers.append(byte_powers.pop(0) * byte_powers.pop(0))
    2.12              return byte_powers[0]
     3.1 --- a/viff/passive.py	Thu Sep 17 17:59:08 2009 +0200
     3.2 +++ b/viff/passive.py	Fri Sep 25 20:29:18 2009 +0200
     3.3 @@ -28,7 +28,7 @@
     3.4  from viff.field import GF256, FieldElement
     3.5  from viff.util import rand, profile
     3.6  
     3.7 -from twisted.internet.defer import succeed
     3.8 +from twisted.internet.defer import succeed, gatherResults
     3.9  
    3.10  
    3.11  class PassiveRuntime(Runtime):
    3.12 @@ -467,14 +467,13 @@
    3.13          """Generate a random secret share in GF256 and returns
    3.14          [*share*, *share*^2, *share*^4, ..., *share*^(i^max)]."""
    3.15          share = self.prss_share_random(GF256)
    3.16 -        return succeed(self.powerchain(share, max))
    3.17 +        return gatherResults(self.powerchain(share, max))
    3.18  
    3.19      def prss_powerchains(self, max=7, quantity=20):
    3.20          """Does *quantity* times the same as :meth:`prss_powerchain`.
    3.21          Used for preprocessing."""
    3.22          shares = self.prss_share_random_multi(GF256, quantity)
    3.23 -        return quantity, succeed([self.powerchain(share, max)
    3.24 -                                  for share in shares])
    3.25 +        return [gatherResults(self.powerchain(share, max)) for share in shares]
    3.26  
    3.27      def input(self, inputters, field, number=None, threshold=None):
    3.28          """Input *number* to the computation.
     4.1 --- a/viff/runtime.py	Thu Sep 17 17:59:08 2009 +0200
     4.2 +++ b/viff/runtime.py	Fri Sep 25 20:29:18 2009 +0200
     4.3 @@ -737,14 +737,13 @@
     4.4          arguments. The generator methods called must adhere to the
     4.5          following interface:
     4.6  
     4.7 -        - They must return a ``(int, Deferred)`` tuple where the
     4.8 -          ``int`` tells us how many items of pre-processed data the
     4.9 -          :class:`Deferred` will yield.
    4.10 +        - They must return a list of :class:`Deferred` instances.
    4.11  
    4.12 -        - The Deferred must yield a list of the promised length.
    4.13 -
    4.14 -        - The list contains the actual data. This data can be either a
    4.15 -          Deferred or a tuple of Deferreds.
    4.16 +        - Every Deferred must yield an item of pre-processed data.
    4.17 +          This can be value, a list or tuple of values, or a Deferred
    4.18 +          (which will be converted to a value by Twisted), but NOT a
    4.19 +          list of Deferreds. Use :meth:`gatherResults` to avoid the
    4.20 +          latter.
    4.21  
    4.22          The :meth:`ActiveRuntime.generate_triples` method is an
    4.23          example of a method fulfilling this interface.
    4.24 @@ -767,20 +766,14 @@
    4.25              # avoid starting before the pre-processing is complete.
    4.26              return deep_wait(results)
    4.27  
    4.28 -        wait_list = []
    4.29          for ((generator, args), program_counters) in program.iteritems():
    4.30              print "Preprocessing %s (%d items)" % (generator, len(program_counters))
    4.31              func = getattr(self, generator)
    4.32              results = []
    4.33 -            items = 0
    4.34 -            while items < len(program_counters):
    4.35 -                item_count, result = func(*args)
    4.36 -                items += item_count
    4.37 -                results.append(result)
    4.38 -            ready = gatherResults(results)
    4.39 -            ready.addCallback(update, program_counters)
    4.40 -            wait_list.append(ready)
    4.41 -        return DeferredList(wait_list)
    4.42 +            while len(results) < len(program_counters):
    4.43 +                results += func(*args)
    4.44 +            self._pool.update(zip(program_counters, results))
    4.45 +        return DeferredList(results).addCallback(lambda _: None)
    4.46  
    4.47      def input(self, inputters, field, number=None):
    4.48          """Input *number* to the computation.
     5.1 --- a/viff/test/test_active_runtime.py	Thu Sep 17 17:59:08 2009 +0200
     5.2 +++ b/viff/test/test_active_runtime.py	Fri Sep 25 20:29:18 2009 +0200
     5.3 @@ -99,24 +99,23 @@
     5.4              """Verify a multiplication triple."""
     5.5              self.assertEquals(triple[0] * triple[1], triple[2])
     5.6  
     5.7 -        def check(triples):
     5.8 -            results = []
     5.9 -            for a, b, c in triples:
    5.10 -                self.assert_type(a, Share)
    5.11 -                self.assert_type(b, Share)
    5.12 -                self.assert_type(c, Share)
    5.13 -                open_a = runtime.open(a)
    5.14 -                open_b = runtime.open(b)
    5.15 -                open_c = runtime.open(c)
    5.16 -                result = gatherResults([open_a, open_b, open_c])
    5.17 -                result.addCallback(verify)
    5.18 -                results.append(result)
    5.19 -            return gatherResults(results)
    5.20 +        def check(triple):
    5.21 +            a, b, c = triple
    5.22 +            self.assert_type(a, self.Zp)
    5.23 +            self.assert_type(b, self.Zp)
    5.24 +            self.assert_type(c, self.Zp)
    5.25 +            open_a = runtime.open(Share(self, self.Zp, a))
    5.26 +            open_b = runtime.open(Share(self, self.Zp, b))
    5.27 +            open_c = runtime.open(Share(self, self.Zp, c))
    5.28 +            result = gatherResults([open_a, open_b, open_c])
    5.29 +            result.addCallback(verify)
    5.30 +            return result
    5.31  
    5.32 -        count, triples = runtime.generate_triples(self.Zp)
    5.33 -        self.assertEquals(count, runtime.num_players - 2*runtime.threshold)
    5.34 +        triples = runtime.generate_triples(self.Zp)
    5.35 +        self.assertEquals(len(triples), runtime.num_players - 2*runtime.threshold)
    5.36  
    5.37 -        runtime.schedule_callback(triples, check)
    5.38 +        for triple in triples:
    5.39 +            runtime.schedule_callback(triple, check)
    5.40          return triples
    5.41  
    5.42