viff

changeset 657:a4f9a86a6f75

Make generate_triples, get_triple return a single Deferred. This is needed to propagate errors when we begin verifying shares shares in double_share_random.
author Martin Geisler <mg@daimi.au.dk>
date Sat, 12 Apr 2008 23:42:25 +0200
parents bc15dbc60073
children 6e716c77190b
files viff/runtime.py viff/test/test_active_runtime.py
diffstat 2 files changed, 46 insertions(+), 30 deletions(-) [+]
line diff
     1.1 --- a/viff/runtime.py	Sat Apr 12 23:13:05 2008 +0200
     1.2 +++ b/viff/runtime.py	Sat Apr 12 23:42:25 2008 +0200
     1.3 @@ -44,7 +44,7 @@
     1.4  from viff.util import wrapper, rand
     1.5  
     1.6  from twisted.internet import reactor
     1.7 -from twisted.internet.defer import Deferred, DeferredList
     1.8 +from twisted.internet.defer import Deferred, DeferredList, succeed
     1.9  from twisted.internet.protocol import ClientFactory, ServerFactory
    1.10  from twisted.protocols.basic import Int16StringReceiver
    1.11  
    1.12 @@ -988,21 +988,31 @@
    1.13  
    1.14          # At this point both share_x and share_y must be Share
    1.15          # objects. We multiply them via a multiplication triple.
    1.16 -        a, b, c = self.get_triple(share_x.field)
    1.17 -        d = self.open(share_x - a)
    1.18 -        e = self.open(share_y - b)
    1.19 +        def finish_mul(triple):
    1.20 +            a, b, c = triple
    1.21 +            d = self.open(share_x - a)
    1.22 +            e = self.open(share_y - b)
    1.23  
    1.24 -        # TODO: We ought to be able to simple do
    1.25 -        #
    1.26 -        #   return d*e + d*y + e*x + c
    1.27 -        #
    1.28 -        # but that leads to infinite recursion since d and e are
    1.29 -        # Shares, not FieldElements. So we have to do a bit more
    1.30 -        # work... The following callback also leads to recursion, but
    1.31 -        # only one level since d and e are FieldElements now, which
    1.32 -        # means that we return in the above if statements.
    1.33 -        result = gather_shares([d, e])
    1.34 -        result.addCallback(lambda (d,e): d*e + d*b + e*a + c)
    1.35 +            # TODO: We ought to be able to simple do
    1.36 +            #
    1.37 +            #   return d*e + d*y + e*x + c
    1.38 +            #
    1.39 +            # but that leads to infinite recursion since d and e are
    1.40 +            # Shares, not FieldElements. So we have to do a bit more
    1.41 +            # work... The following callback also leads to recursion, but
    1.42 +            # only one level since d and e are FieldElements now, which
    1.43 +            # means that we return in the above if statements.
    1.44 +            result = gather_shares([d, e])
    1.45 +            result.addCallback(lambda (d,e): d*e + d*b + e*a + c)
    1.46 +            return result
    1.47 +
    1.48 +        # This will be the result, a Share object.
    1.49 +        result = Share(self, share_x.field)
    1.50 +        # This is the Deferred we will do processing on.
    1.51 +        triple = self.get_triple(share_x.field)
    1.52 +        triple.addCallback(finish_mul)
    1.53 +        # We add the result to the chains in triple.
    1.54 +        triple.chainDeferred(result)
    1.55          return result
    1.56  
    1.57      @increment_pc
    1.58 @@ -1050,7 +1060,9 @@
    1.59          # generate_triples to a preprocessing step and draw the
    1.60          # triples from a pool instead. Also, using only the first
    1.61          # triple is quite wasteful...
    1.62 -        return self.generate_triples(field)[0]
    1.63 +        result = self.generate_triples(field)
    1.64 +        result.addCallback(lambda triples: triples[0])
    1.65 +        return result
    1.66  
    1.67      @increment_pc
    1.68      def generate_triples(self, field):
    1.69 @@ -1081,7 +1093,7 @@
    1.70          d = [self.open(d_2t[i], threshold=2*t) for i in range(T)]
    1.71          c_t = [r_t[i] + d[i] for i in range(T)]
    1.72  
    1.73 -        return zip(a_t, b_t, c_t)
    1.74 +        return succeed(zip(a_t, b_t, c_t))
    1.75  
    1.76      @increment_pc
    1.77      def _broadcast(self, sender, message=None):
     2.1 --- a/viff/test/test_active_runtime.py	Sat Apr 12 23:13:05 2008 +0200
     2.2 +++ b/viff/test/test_active_runtime.py	Sat Apr 12 23:42:25 2008 +0200
     2.3 @@ -95,21 +95,25 @@
     2.4      @protocol
     2.5      def test_generate_triples(self, runtime):
     2.6          """Test generation of multiplication triples."""
     2.7 -        triples = runtime.generate_triples(self.Zp)
     2.8  
     2.9          def verify(triple):
    2.10              """Verify a multiplication triple."""
    2.11              self.assertEquals(triple[0] * triple[1], triple[2])
    2.12  
    2.13 -        results = []
    2.14 -        for a, b, c in triples:
    2.15 -            self.assert_type(a, Share)
    2.16 -            self.assert_type(b, Share)
    2.17 -            self.assert_type(c, Share)
    2.18 -            open_a = runtime.open(a)
    2.19 -            open_b = runtime.open(b)
    2.20 -            open_c = runtime.open(c)
    2.21 -            result = gatherResults([open_a, open_b, open_c])
    2.22 -            result.addCallback(verify)
    2.23 -            results.append(result)
    2.24 -        return gatherResults(results)
    2.25 +        def check(triples):
    2.26 +            results = []
    2.27 +            for a, b, c in triples:
    2.28 +                self.assert_type(a, Share)
    2.29 +                self.assert_type(b, Share)
    2.30 +                self.assert_type(c, Share)
    2.31 +                open_a = runtime.open(a)
    2.32 +                open_b = runtime.open(b)
    2.33 +                open_c = runtime.open(c)
    2.34 +                result = gatherResults([open_a, open_b, open_c])
    2.35 +                result.addCallback(verify)
    2.36 +                results.append(result)
    2.37 +            return gatherResults(results)
    2.38 +
    2.39 +        triples = runtime.generate_triples(self.Zp)
    2.40 +        triples.addCallback(check)
    2.41 +        return triples