viff

changeset 658:6e716c77190b

Make double_share_random return a single Deferred.
author Martin Geisler <mg@daimi.au.dk>
date Sat, 12 Apr 2008 23:51:21 +0200
parents a4f9a86a6f75
children dfdfd16a376a
files viff/runtime.py viff/test/test_active_runtime.py
diffstat 2 files changed, 31 insertions(+), 25 deletions(-) [+]
line diff
     1.1 --- a/viff/runtime.py	Sat Apr 12 23:42:25 2008 +0200
     1.2 +++ b/viff/runtime.py	Sat Apr 12 23:51:21 2008 +0200
     1.3 @@ -1052,7 +1052,7 @@
     1.4  
     1.5          # Return the first T shares (the ones that was not opened in
     1.6          # the verifying step.
     1.7 -        return rvec1.rows[0][:T], rvec2.rows[0][:T]
     1.8 +        return succeed((rvec1.rows[0][:T], rvec2.rows[0][:T]))
     1.9  
    1.10      @increment_pc
    1.11      def get_triple(self, field):
    1.12 @@ -1088,12 +1088,15 @@
    1.13              ci.addCallback(lambda (ai, bi): ai * bi)
    1.14              c_2t.append(ci)
    1.15  
    1.16 -        r_t, r_2t = self.double_share_random(T, t, 2*t, field)
    1.17 -        d_2t = [c_2t[i] - r_2t[i] for i in range(T)]
    1.18 -        d = [self.open(d_2t[i], threshold=2*t) for i in range(T)]
    1.19 -        c_t = [r_t[i] + d[i] for i in range(T)]
    1.20 +        def make_triple((r_t, r_2t)):
    1.21 +            d_2t = [c_2t[i] - r_2t[i] for i in range(T)]
    1.22 +            d = [self.open(d_2t[i], threshold=2*t) for i in range(T)]
    1.23 +            c_t = [r_t[i] + d[i] for i in range(T)]
    1.24 +            return zip(a_t, b_t, c_t)
    1.25  
    1.26 -        return succeed(zip(a_t, b_t, c_t))
    1.27 +        double = self.double_share_random(T, t, 2*t, field)
    1.28 +        double.addCallback(make_triple)
    1.29 +        return double
    1.30  
    1.31      @increment_pc
    1.32      def _broadcast(self, sender, message=None):
     2.1 --- a/viff/test/test_active_runtime.py	Sat Apr 12 23:42:25 2008 +0200
     2.2 +++ b/viff/test/test_active_runtime.py	Sat Apr 12 23:51:21 2008 +0200
     2.3 @@ -68,29 +68,32 @@
     2.4          from viff.field import GF
     2.5          self.Zp = GF(11)
     2.6  
     2.7 -        r_t, r_2t = runtime.double_share_random(T,
     2.8 -                                                runtime.threshold,
     2.9 -                                                2*runtime.threshold,
    2.10 -                                                self.Zp)
    2.11 -
    2.12 -        # Check that we got the expected number of shares.
    2.13 -        self.assertEquals(len(r_t), T)
    2.14 -        self.assertEquals(len(r_2t), T)
    2.15 -
    2.16          def verify(shares):
    2.17              """Verify that the list contains two equal shares."""
    2.18              self.assertEquals(shares[0], shares[1])
    2.19  
    2.20 -        results = []
    2.21 -        for a, b in zip(r_t, r_2t):
    2.22 -            self.assert_type(a, Share)
    2.23 -            self.assert_type(b, Share)
    2.24 -            open_a = runtime.open(a)
    2.25 -            open_b = runtime.open(b, threshold=2*runtime.threshold)
    2.26 -            result = gatherResults([open_a, open_b])
    2.27 -            result.addCallback(verify)
    2.28 -            results.append(result)
    2.29 -        return gatherResults(results)
    2.30 +        def check(double):
    2.31 +            r_t, r_2t = double
    2.32 +
    2.33 +            # Check that we got the expected number of shares.
    2.34 +            self.assertEquals(len(r_t), T)
    2.35 +            self.assertEquals(len(r_2t), T)
    2.36 +
    2.37 +            results = []
    2.38 +            for a, b in zip(r_t, r_2t):
    2.39 +                self.assert_type(a, Share)
    2.40 +                self.assert_type(b, Share)
    2.41 +                open_a = runtime.open(a)
    2.42 +                open_b = runtime.open(b, threshold=2*runtime.threshold)
    2.43 +                result = gatherResults([open_a, open_b])
    2.44 +                result.addCallback(verify)
    2.45 +                results.append(result)
    2.46 +            return gatherResults(results)
    2.47 +
    2.48 +        double = runtime.double_share_random(T, runtime.threshold,
    2.49 +                                             2*runtime.threshold, self.Zp)
    2.50 +        double.addCallback(check)
    2.51 +        return double
    2.52  
    2.53      @protocol
    2.54      def test_generate_triples(self, runtime):