changeset 1212:2daaf0e7a1f6

Make increment_pc mostly obsolete. Instead of calling increment_pc on every method entry to allocate a sub-program counter, we now let schedule_callback take care of this.
author Martin Geisler <mg@cs.au.dk>
date Fri, 18 Sep 2009 14:41:44 +0200
parents b94ff92ee1d9
children 7610deb0ebab
files viff/active.py viff/aes.py viff/comparison.py viff/paillier.py viff/passive.py viff/runtime.py viff/test/test_basic_runtime.py
diffstat 7 files changed, 19 insertions(+), 94 deletions(-) [+]
line wrap: on
line diff
--- a/viff/active.py	Thu Sep 17 15:48:01 2009 +0200
+++ b/viff/active.py	Fri Sep 18 14:41:44 2009 +0200
@@ -141,7 +141,6 @@
 
         return result
 
-    @increment_pc
     def broadcast(self, senders, message=None):
         """Perform one or more Bracha broadcast(s).
 
@@ -186,7 +185,6 @@
     #: to :const:`None` here and update it as necessary.
     _hyper = None
 
-    @increment_pc
     def single_share_random(self, T, degree, field):
         """Share a random secret.
 
@@ -273,7 +271,6 @@
         self.schedule_callback(result, exchange)
         return result
 
-    @increment_pc
     def double_share_random(self, T, d1, d2, field):
         """Double-share a random secret using two polynomials.
 
@@ -376,7 +373,6 @@
         self.schedule_callback(result, exchange)
         return result
 
-    @increment_pc
     @preprocess("generate_triples")
     def get_triple(self, field):
         # This is a waste, but this function is only called if there
@@ -385,7 +381,6 @@
         result.addCallback(lambda triples: triples[0])
         return result
 
-    @increment_pc
     def generate_triples(self, field):
         """Generate multiplication triples.
 
@@ -425,14 +420,12 @@
 class TriplesPRSSMixin:
     """Mixin class for generating multiplication triples using PRSS."""
 
-    @increment_pc
     @preprocess("generate_triples")
     def get_triple(self, field):
         count, result = self.generate_triples(field)
         result.addCallback(lambda triples: triples[0])
         return result
 
-    @increment_pc
     def generate_triples(self, field):
         """Generate a multiplication triple using PRSS.
 
@@ -467,7 +460,6 @@
     :class:`ActiveRuntime` instead.
     """
 
-    @increment_pc
     def mul(self, share_x, share_y):
         """Multiplication of shares.
 
--- a/viff/aes.py	Thu Sep 17 15:48:01 2009 +0200
+++ b/viff/aes.py	Fri Sep 18 14:41:44 2009 +0200
@@ -24,7 +24,7 @@
 import operator
 
 from viff.field import GF256
-from viff.runtime import Share, gather_shares, increment_pc
+from viff.runtime import Share, gather_shares
 from viff.matrix import Matrix
 
 
@@ -319,7 +319,6 @@
                     "or of shares thereof."
             return input
 
-    @increment_pc
     def encrypt(self, cleartext, key, benchmark=False, prepare_at_once=False):
         """Rijndael encryption.
 
--- a/viff/comparison.py	Thu Sep 17 15:48:01 2009 +0200
+++ b/viff/comparison.py	Fri Sep 18 14:41:44 2009 +0200
@@ -25,7 +25,7 @@
 import math
 
 from viff.util import rand, profile
-from viff.runtime import Share, gather_shares, increment_pc
+from viff.runtime import Share, gather_shares
 from viff.passive import PassiveRuntime
 from viff.active import ActiveRuntime
 from viff.field import GF256, FieldElement
@@ -34,7 +34,6 @@
 class ComparisonToft05Mixin:
     """Comparison by Tomas Toft, 2005."""
 
-    @increment_pc
     def convert_bit_share(self, share, dst_field):
         """Convert a 0/1 share into dst_field."""
         bit = rand.randint(0, 1)
@@ -67,7 +66,6 @@
         return int_b, bit_bits
 
     @profile
-    @increment_pc
     def greater_than_equal(self, share_a, share_b):
         """Compute ``share_a >= share_b``.
 
@@ -100,7 +98,6 @@
         self.schedule_callback(result, self._finish_greater_than_equal, l)
         return result
 
-    @increment_pc
     def _finish_greater_than_equal(self, results, l):
         """Finish the calculation."""
         T = results[0]
@@ -128,7 +125,6 @@
 
         return GF256(T.bit(l)) ^ (bit_bits[l] ^ vec[0][1])
 
-    @increment_pc
     def _diamond(self, (top_a, bot_a), (top_b, bot_b)):
         """The "diamond-operator".
 
@@ -160,7 +156,6 @@
     elements and gives a secret result shared over Zp.
     """
 
-    @increment_pc
     def convert_bit_share(self, share, dst_field):
         """Convert a 0/1 share into *dst_field*."""
         l = self.options.security_parameter + math.log(dst_field.modulus, 2)
@@ -188,7 +183,6 @@
         return tmp - full_mask
 
     @profile
-    @increment_pc
     def greater_than_equal_preproc(self, field, smallField=None):
         """Preprocessing for :meth:`greater_than_equal`."""
         if smallField is None:
@@ -243,7 +237,6 @@
         ##################################################
 
     @profile
-    @increment_pc
     def greater_than_equal_online(self, share_a, share_b, preproc, field):
         """Compute ``share_a >= share_b``. Result is secret shared."""
         # increment l as a, b are increased
@@ -272,7 +265,6 @@
                                r_modl, r_bits, z)
         return c
 
-    @increment_pc
     def _finish_greater_than_equal(self, c, field, smallField, s_bit, s_sign,
                                mask, r_modl, r_bits, z):
         """Finish the calculation."""
@@ -316,7 +308,6 @@
         return (z - result) * ~field(2**l)
     # END _finish_greater_than
 
-    @increment_pc
     def greater_than_equal(self, share_a, share_b):
         """Compute ``share_a >= share_b``.
 
--- a/viff/paillier.py	Thu Sep 17 15:48:01 2009 +0200
+++ b/viff/paillier.py	Fri Sep 18 14:41:44 2009 +0200
@@ -27,7 +27,7 @@
 from twisted.internet.defer import Deferred, gatherResults
 import gmpy
 
-from viff.runtime import Runtime, increment_pc, Share, gather_shares
+from viff.runtime import Runtime, Share, gather_shares
 from viff.runtime import PAILLIER
 from viff.util import rand, find_random_prime
 
@@ -78,7 +78,6 @@
         else:
             self.peer = player
 
-    @increment_pc
     def prss_share_random(self, field):
         """Generate a share of a uniformly random element."""
         prfs = self.players[self.id].prfs(field.modulus)
@@ -94,7 +93,6 @@
         """
         return self.share(inputters, field, number)
 
-    @increment_pc
     def share(self, inputters, field, number=None):
         """Share *number* additively."""
         assert number is None or self.id in inputters
@@ -121,7 +119,6 @@
     def output(self, share, receivers=None):
         return self.open(share, receivers)
 
-    @increment_pc
     def open(self, share, receivers=None):
         """Open *share* to *receivers* (defaults to both players)."""
 
--- a/viff/passive.py	Thu Sep 17 15:48:01 2009 +0200
+++ b/viff/passive.py	Fri Sep 18 14:41:44 2009 +0200
@@ -54,7 +54,6 @@
     def output(self, share, receivers=None, threshold=None):
         return self.open(share, receivers, threshold)
 
-    @increment_pc
     def open(self, share, receivers=None, threshold=None):
         """Open a secret sharing.
 
@@ -172,7 +171,6 @@
         return result
 
     @profile
-    @increment_pc
     def mul(self, share_a, share_b):
         """Multiplication of shares.
 
@@ -229,7 +227,6 @@
         else:
             return share * (share ** (exponent-1))
 
-    @increment_pc
     def xor(self, share_a, share_b):
         field = share_a.field
         if not isinstance(share_b, Share):
@@ -242,7 +239,6 @@
         else:
             return share_a + share_b - 2 * share_a * share_b
 
-    @increment_pc
     def prss_share(self, inputters, field, element=None):
         """Creates pseudo-random secret sharings.
 
@@ -351,7 +347,6 @@
         self.schedule_callback(result, finish, share, binary)
         return result
 
-    @increment_pc
     def prss_share_random_multi(self, field, quantity, binary=False):
         """Does the same as calling *quantity* times :meth:`prss_share_random`,
         but with less calls to the PRF. Sampling of a binary element is only
@@ -374,7 +369,6 @@
                             modulus, quantity)
         return [Share(self, field, share) for share in shares]
 
-    @increment_pc
     def prss_share_zero(self, field):
         """Generate shares of the zero element from the field given.
 
@@ -387,7 +381,6 @@
                                field, prfs, prss_key)
         return Share(self, field, zero_share)
 
-    @increment_pc
     def prss_double_share(self, field):
         """Make a double-sharing using PRSS.
 
@@ -398,7 +391,6 @@
         z_2t = self.prss_share_zero(field)
         return (r_t, r_t + z_2t)
 
-    @increment_pc
     def prss_share_bit_double(self, field):
         """Share a random bit over *field* and GF256.
 
@@ -423,7 +415,6 @@
         # Use r_lsb to flip b as needed.
         return (b_p, b ^ r_lsb)
 
-    @increment_pc
     def prss_shamir_share_bit_double(self, field):
         """Shamir share a random bit over *field* and GF256."""
         n = self.num_players
@@ -455,7 +446,6 @@
         """
         return self.shamir_share(inputters, field, number, threshold)
 
-    @increment_pc
     def shamir_share(self, inputters, field, number=None, threshold=None):
         """Secret share *number* over *field* using Shamir's method.
 
--- a/viff/runtime.py	Thu Sep 17 15:48:01 2009 +0200
+++ b/viff/runtime.py	Fri Sep 18 14:41:44 2009 +0200
@@ -617,7 +617,6 @@
         dl = DeferredList(vars)
         self.schedule_callback(dl, lambda _: self.shutdown())
 
-    @increment_pc
     def schedule_callback(self, deferred, func, *args, **kwargs):
         """Schedule a callback on a deferred with the correct program
         counter.
@@ -631,7 +630,9 @@
         Any extra arguments are passed to the callback as with
         :meth:`addCallback`.
         """
+        self.program_counter[-1] += 1
         saved_pc = self.program_counter[:]
+        saved_pc.append(0)
 
         @wrapper(func)
         def callback_wrapper(*args, **kwargs):
@@ -667,7 +668,6 @@
         deferred.addCallback(queue_callback, self, fork)
         return self.schedule_callback(fork, func, *args, **kwargs)
 
-    @increment_pc
     def synchronize(self):
         """Introduce a synchronization point.
 
@@ -723,7 +723,6 @@
         self._expect_data(peer_id, SHARE, share)
         return share
 
-    @increment_pc
     def preprocess(self, program):
         """Generate preprocess material.
 
--- a/viff/test/test_basic_runtime.py	Thu Sep 17 15:48:01 2009 +0200
+++ b/viff/test/test_basic_runtime.py	Fri Sep 18 14:41:44 2009 +0200
@@ -18,7 +18,6 @@
 from twisted.internet.defer import Deferred, gatherResults
 
 from viff.test.util import RuntimeTestCase, protocol
-from viff.runtime import increment_pc
 
 
 class ProgramCounterTest(RuntimeTestCase):
@@ -32,26 +31,14 @@
     def test_simple_operation(self, runtime):
         """Test an operation which makes no further calls.
 
-        Each call should increment the program counter by one.
+        No callbacks are scheduled, and so the program counter is not
+        increased.
         """
+        self.assertEquals(runtime.program_counter, [0])
         runtime.synchronize()
-        self.assertEquals(runtime.program_counter, [1])
+        self.assertEquals(runtime.program_counter, [0])
         runtime.synchronize()
-        self.assertEquals(runtime.program_counter, [2])
-
-    @protocol
-    def test_complex_operation(self, runtime):
-        """Test an operation which makes nested calls.
-
-        This verifies that the program counter is only incremented by
-        one, even for a complex operation.
-        """
-        # Exclusive-or is calculated as x + y - 2 * x * y, so add,
-        # sub, and mul are called.
-        runtime.xor(self.Zp(0), self.Zp(1))
-        self.assertEquals(runtime.program_counter, [1])
-        runtime.xor(self.Zp(0), self.Zp(1))
-        self.assertEquals(runtime.program_counter, [2])
+        self.assertEquals(runtime.program_counter, [0])
 
     @protocol
     def test_callback(self, runtime):
@@ -62,62 +49,32 @@
         """
 
         def verify_program_counter(_):
+            # The callback is run with its own sub-program counter.
             self.assertEquals(runtime.program_counter, [1, 0])
 
         d = Deferred()
+
+        self.assertEquals(runtime.program_counter, [0])
+
+        # Scheduling a callback increases the program counter.
         runtime.schedule_callback(d, verify_program_counter)
-
-        runtime.synchronize()
-        self.assertEquals(runtime.program_counter, [2])
+        self.assertEquals(runtime.program_counter, [1])
 
         # Now trigger verify_program_counter.
         d.callback(None)
 
     @protocol
-    def test_nested_calls(self, runtime):
-        """Test Runtime methods that call other methods.
-
-        We create a couple of functions that are used as fake methods.
-        """
-
-        @increment_pc
-        def method_a(runtime):
-            # First top-level call, so first entry is 1. No calls to
-            # other methods decorated with increment_pc has been made,
-            # so the second entry is 0.
-            self.assertEquals(runtime.program_counter, [1, 0])
-            method_b(runtime, 1)
-
-            self.assertEquals(runtime.program_counter, [1, 1])
-            method_b(runtime, 2)
-
-            # At this point two sub-calls has been made:
-            self.assertEquals(runtime.program_counter, [1, 2])
-
-        @increment_pc
-        def method_b(runtime, count):
-            # This method is called twice from method_a:
-            self.assertEquals(runtime.program_counter, [1, count, 0])
-
-        # Zero top-level calls:
-        self.assertEquals(runtime.program_counter, [0])
-        method_a(runtime)
-
-        # One top-level call:
-        self.assertEquals(runtime.program_counter, [1])
-
-    @protocol
     def test_multiple_callbacks(self, runtime):
 
         d1 = Deferred()
         d2 = Deferred()
 
         def verify_program_counter(_, count):
-            self.assertEquals(runtime.program_counter, [1, count, 0])
+            self.assertEquals(runtime.program_counter, [count, 0])
 
-        @increment_pc
         def method_a(runtime):
-            self.assertEquals(runtime.program_counter, [1, 0])
+            # No calls to schedule_callback yet.
+            self.assertEquals(runtime.program_counter, [0])
 
             runtime.schedule_callback(d1, verify_program_counter, 1)
             runtime.schedule_callback(d2, verify_program_counter, 2)