changeset 657:a4f9a86a6f75

Make generate_triples, get_triple return a single Deferred. This is needed to propagate errors when we begin verifying shares shares in double_share_random.
author Martin Geisler <mg@daimi.au.dk>
date Sat, 12 Apr 2008 23:42:25 +0200
parents bc15dbc60073
children 6e716c77190b
files viff/runtime.py viff/test/test_active_runtime.py
diffstat 2 files changed, 46 insertions(+), 30 deletions(-) [+]
line wrap: on
line diff
--- a/viff/runtime.py	Sat Apr 12 23:13:05 2008 +0200
+++ b/viff/runtime.py	Sat Apr 12 23:42:25 2008 +0200
@@ -44,7 +44,7 @@
 from viff.util import wrapper, rand
 
 from twisted.internet import reactor
-from twisted.internet.defer import Deferred, DeferredList
+from twisted.internet.defer import Deferred, DeferredList, succeed
 from twisted.internet.protocol import ClientFactory, ServerFactory
 from twisted.protocols.basic import Int16StringReceiver
 
@@ -988,21 +988,31 @@
 
         # At this point both share_x and share_y must be Share
         # objects. We multiply them via a multiplication triple.
-        a, b, c = self.get_triple(share_x.field)
-        d = self.open(share_x - a)
-        e = self.open(share_y - b)
+        def finish_mul(triple):
+            a, b, c = triple
+            d = self.open(share_x - a)
+            e = self.open(share_y - b)
 
-        # TODO: We ought to be able to simple do
-        #
-        #   return d*e + d*y + e*x + c
-        #
-        # but that leads to infinite recursion since d and e are
-        # Shares, not FieldElements. So we have to do a bit more
-        # work... The following callback also leads to recursion, but
-        # only one level since d and e are FieldElements now, which
-        # means that we return in the above if statements.
-        result = gather_shares([d, e])
-        result.addCallback(lambda (d,e): d*e + d*b + e*a + c)
+            # TODO: We ought to be able to simple do
+            #
+            #   return d*e + d*y + e*x + c
+            #
+            # but that leads to infinite recursion since d and e are
+            # Shares, not FieldElements. So we have to do a bit more
+            # work... The following callback also leads to recursion, but
+            # only one level since d and e are FieldElements now, which
+            # means that we return in the above if statements.
+            result = gather_shares([d, e])
+            result.addCallback(lambda (d,e): d*e + d*b + e*a + c)
+            return result
+
+        # This will be the result, a Share object.
+        result = Share(self, share_x.field)
+        # This is the Deferred we will do processing on.
+        triple = self.get_triple(share_x.field)
+        triple.addCallback(finish_mul)
+        # We add the result to the chains in triple.
+        triple.chainDeferred(result)
         return result
 
     @increment_pc
@@ -1050,7 +1060,9 @@
         # generate_triples to a preprocessing step and draw the
         # triples from a pool instead. Also, using only the first
         # triple is quite wasteful...
-        return self.generate_triples(field)[0]
+        result = self.generate_triples(field)
+        result.addCallback(lambda triples: triples[0])
+        return result
 
     @increment_pc
     def generate_triples(self, field):
@@ -1081,7 +1093,7 @@
         d = [self.open(d_2t[i], threshold=2*t) for i in range(T)]
         c_t = [r_t[i] + d[i] for i in range(T)]
 
-        return zip(a_t, b_t, c_t)
+        return succeed(zip(a_t, b_t, c_t))
 
     @increment_pc
     def _broadcast(self, sender, message=None):
--- a/viff/test/test_active_runtime.py	Sat Apr 12 23:13:05 2008 +0200
+++ b/viff/test/test_active_runtime.py	Sat Apr 12 23:42:25 2008 +0200
@@ -95,21 +95,25 @@
     @protocol
     def test_generate_triples(self, runtime):
         """Test generation of multiplication triples."""
-        triples = runtime.generate_triples(self.Zp)
 
         def verify(triple):
             """Verify a multiplication triple."""
             self.assertEquals(triple[0] * triple[1], triple[2])
 
-        results = []
-        for a, b, c in triples:
-            self.assert_type(a, Share)
-            self.assert_type(b, Share)
-            self.assert_type(c, Share)
-            open_a = runtime.open(a)
-            open_b = runtime.open(b)
-            open_c = runtime.open(c)
-            result = gatherResults([open_a, open_b, open_c])
-            result.addCallback(verify)
-            results.append(result)
-        return gatherResults(results)
+        def check(triples):
+            results = []
+            for a, b, c in triples:
+                self.assert_type(a, Share)
+                self.assert_type(b, Share)
+                self.assert_type(c, Share)
+                open_a = runtime.open(a)
+                open_b = runtime.open(b)
+                open_c = runtime.open(c)
+                result = gatherResults([open_a, open_b, open_c])
+                result.addCallback(verify)
+                results.append(result)
+            return gatherResults(results)
+
+        triples = runtime.generate_triples(self.Zp)
+        triples.addCallback(check)
+        return triples