changeset 1001:74bcf4955f99

Merged with preprocessing cleanup.
author Martin Geisler <mg@daimi.au.dk>
date Tue, 14 Oct 2008 16:12:05 +0200
parents 3fe2baebdb99 7546e47ce876
children 735d5e3dade2
files apps/benchmark.py viff/runtime.py
diffstat 6 files changed, 308 insertions(+), 65 deletions(-) [+]
line wrap: on
line diff
--- a/apps/benchmark.py	Fri Oct 10 16:49:49 2008 +0200
+++ b/apps/benchmark.py	Tue Oct 14 16:12:05 2008 +0200
@@ -57,7 +57,7 @@
 import time
 from optparse import OptionParser
 import operator
-from pprint import pprint
+from pprint import pformat
 
 from twisted.internet import reactor
 
@@ -147,35 +147,35 @@
         self.rt = rt
         self.operation = operation
 
+        program_desc = {}
+
         if isinstance(self.rt, BasicActiveRuntime):
             # TODO: Make this optional and maybe automatic. The
             # program descriptions below were found by carefully
             # studying the output reported when the benchmarks were
             # run with no preprocessing. So they are quite brittle.
+            if self.operation == operator.mul:
+                key = ("generate_triples", (Zp,))
+                desc = [(i, 1, 0) for i in range(3 + 2*count, 3 + 3*count)]
+                program_desc.setdefault(key, []).extend(desc)
+            elif isinstance(self.rt, ComparisonToft05Mixin):
+                key = ("generate_triples", (GF256,))
+                desc = sum([[(c, 64, i, 1, 1, 0) for i in range(2, 33)] +
+                            [(c, 64, i, 3, 1, 0) for i in range(17, 33)]
+                            for c in range(3 + 2*count, 3 + 3*count)],
+                           [])
+                program_desc.setdefault(key, []).extend(desc)
+            elif isinstance(self.rt, ComparisonToft07Mixin):
+                key = ("generate_triples", (Zp,))
+                desc = sum([[(c, 2, 4, i, 2, 1, 0) for i in range(1, 33)] +
+                            [(c, 2, 4, 99, 2, 1, 0)] +
+                            [(c, 2, 4, i, 1, 0) for i in range(65, 98)]
+                            for c in range(3 + 2*count, 3 + 3*count)],
+                           [])
+                program_desc.setdefault(key, []).extend(desc)
+
+        if program_desc:
             print "Starting preprocessing"
-            if self.operation == operator.mul:
-                program_desc = {
-                    ("generate_triples", (Zp,)):
-                        [(i, 1, 0) for i in range(3 + 2*count, 3 + 3*count)]
-                    }
-            elif isinstance(self.rt, ComparisonToft05Mixin):
-                program_desc = {
-                    ("generate_triples", (GF256,)):
-                    sum([[(c, 64, i, 1, 1, 0) for i in range(2, 33)] +
-                         [(c, 64, i, 3, 1, 0) for i in range(17, 33)]
-                         for c in range(3 + 2*count, 3 + 3*count)],
-                        [])
-                    }
-            elif isinstance(self.rt, ComparisonToft07Mixin):
-                program_desc = {
-                    ("generate_triples", (Zp,)):
-                    sum([[(c, 2, 4, i, 2, 1, 0) for i in range(1, 33)] +
-                         [(c, 2, 4, 99, 2, 1, 0)] +
-                         [(c, 2, 4, i, 1, 0) for i in range(65, 98)]
-                         for c in range(3 + 2*count, 3 + 3*count)],
-                        [])
-                    }
-
             record_start("preprocessing")
             preproc = rt.preprocess(program_desc)
             preproc.addCallback(record_stop, "preprocessing")
@@ -224,7 +224,9 @@
 
         if self.rt._needed_data:
             print "Missing pre-processed data:"
-            pprint(self.rt._needed_data)
+            for (func, args), pcs in self.rt._needed_data.iteritems():
+                print "* %s%s:" % (func, args)
+                print "  " + pformat(pcs).replace("\n", "\n  ")
 
         self.rt.shutdown()
 
@@ -290,9 +292,10 @@
         mixins.append(ProbabilisticEqualityMixin)
 
 print "Using the base runtime: %s." % base_runtime_class
-print "With the following mixins:"
-for mixin in mixins:
-    print "- %s" % mixin
+if mixins:
+    print "With the following mixins:"
+    for mixin in mixins:
+        print "- %s" % mixin
 
 runtime_class = make_runtime_class(base_runtime_class, mixins)
 
--- a/doc/util.txt	Fri Oct 10 16:49:49 2008 +0200
+++ b/doc/util.txt	Tue Oct 14 16:12:05 2008 +0200
@@ -25,3 +25,8 @@
 
       Setting this environment variable to any value will turn
       :func:`wrapper` into a no-op.
+
+   .. envvar:: VIFF_PROFILE
+
+      Defining this variable will change :func:`profile` from a no-op
+      to real decorator.
--- a/viff/comparison.py	Fri Oct 10 16:49:49 2008 +0200
+++ b/viff/comparison.py	Tue Oct 14 16:12:05 2008 +0200
@@ -24,7 +24,7 @@
 
 import math
 
-from viff.util import rand
+from viff.util import rand, profile
 from viff.runtime import Runtime, Share, gather_shares, increment_pc
 from viff.active import ActiveRuntime
 from viff.field import GF256, FieldElement
@@ -52,6 +52,20 @@
         tmp.field = dst_field
         return reduce(self.xor, dst_shares, tmp)
 
+    def decomposed_random_sharing(self, field, bits):
+        bits = [self.prss_share_bit_double(field) for _ in range(bits)]
+        int_bits, bit_bits = zip(*bits)
+
+        def bits_to_int(bits):
+            """Converts a list of bits to an integer."""
+            return sum([2**i * b for i, b in enumerate(bits)])
+
+        int_b = gather_shares(int_bits)
+        int_b.addCallback(bits_to_int)
+
+        return int_b, bit_bits
+
+    @profile
     @increment_pc
     def greater_than_equal(self, share_a, share_b):
         """Compute ``share_a >= share_b``.
@@ -74,25 +88,14 @@
         m = l + self.options.security_parameter
         t = m + 1
 
-        # Preprocessing begin
         assert 2**(l+1) + 2**t < field.modulus, "2^(l+1) + 2^t < p must hold"
         assert self.num_players + 2 < 2**l
 
-        bits = [self.prss_share_bit_double(field) for _ in range(m)]
-        int_bits, bit_bits = zip(*bits)
+        a = share_a - share_b + 2**l
+        b, bits = self.decomposed_random_sharing(field, m)
+        T = self.open(2**t - b + a)
 
-        def bits_to_int(bits):
-            """Converts a list of bits to an integer."""
-            return sum([2**i * b for i, b in enumerate(bits)])
-
-        int_b = gather_shares(int_bits)
-        int_b.addCallback(bits_to_int)
-        # Preprocessing done
-
-        a = share_a - share_b + 2**l
-        T = self.open(2**t - int_b + a)
-
-        result = gather_shares((T,) + bit_bits)
+        result = gather_shares((T,) + bits)
         self.schedule_callback(result, self._finish_greater_than_equal, l)
         return result
 
@@ -183,6 +186,7 @@
         full_mask = reduce(self.add, dst_shares)
         return tmp - full_mask
 
+    @profile
     @increment_pc
     def greater_than_equal_preproc(self, field, smallField=None):
         """Preprocessing for :meth:`greater_than_equal`."""
@@ -237,6 +241,7 @@
         # Preprocessing done
         ##################################################
 
+    @profile
     @increment_pc
     def greater_than_equal_online(self, share_a, share_b, preproc, field):
         """Compute ``share_a >= share_b``. Result is secret shared."""
--- a/viff/runtime.py	Fri Oct 10 16:49:49 2008 +0200
+++ b/viff/runtime.py	Tue Oct 14 16:12:05 2008 +0200
@@ -41,7 +41,7 @@
 from viff import shamir
 from viff.prss import prss, prss_lsb, prss_zero
 from viff.field import GF256, FieldElement
-from viff.util import wrapper, rand
+from viff.util import wrapper, rand, profile, deep_wait
 
 from twisted.internet import reactor
 from twisted.internet.error import ConnectionDone, CannotListenError
@@ -643,17 +643,6 @@
             # We concatenate the sub-lists in results.
             results = sum(results, [])
 
-            wait_list = []
-            for result in results:
-                # We allow pre-processing methods to return tuples of
-                # shares or individual shares as their result. Here we
-                # deconstruct result (if possible) and wait on its
-                # individual parts.
-                if isinstance(result, tuple):
-                    wait_list.extend(result)
-                else:
-                    wait_list.append(result)
-
             # The pool must map program counters to Deferreds to
             # present a uniform interface for the functions we
             # pre-process.
@@ -661,10 +650,11 @@
 
             # Update the pool with pairs of program counter and data.
             self._pool.update(zip(program_counters, results))
+
             # Return a Deferred that waits on the individual results.
             # This is important to make it possible for the players to
             # avoid starting before the pre-processing is complete.
-            return gatherResults(wait_list)
+            return deep_wait(results)
 
         wait_list = []
         for ((generator, args), program_counters) in program.iteritems():
@@ -759,6 +749,7 @@
         if self.id in receivers:
             return result
 
+    @profile
     def add(self, share_a, share_b):
         """Addition of shares.
 
@@ -789,6 +780,7 @@
         result.addCallback(lambda (a, b): a - b)
         return result
 
+    @profile
     @increment_pc
     def mul(self, share_a, share_b):
         """Multiplication of shares.
@@ -1158,14 +1150,19 @@
                 self.id = id
                 ctx = SSL.Context(SSL.SSLv3_METHOD)
                 # TODO: Make the file names configurable.
-                ctx.use_certificate_file('player-%d.cert' % id)
-                ctx.use_privatekey_file('player-%d.key' % id)
-                ctx.check_privatekey()
-
-                ctx.load_verify_locations('ca.cert')
-                ctx.set_verify(SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
-                               lambda conn, cert, errnum, depth, ok: ok)
-                self.ctx = ctx
+                try:
+                    ctx.use_certificate_file('player-%d.cert' % id)
+                    ctx.use_privatekey_file('player-%d.key' % id)
+                    ctx.check_privatekey()
+                    ctx.load_verify_locations('ca.cert')
+                    ctx.set_verify(SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
+                                   lambda conn, cert, errnum, depth, ok: ok)
+                    self.ctx = ctx
+                except SSL.Error, e:
+                    print "SSL errors - did you forget to generate certificates?"
+                    for (lib, func, reason) in e.args[0]:
+                        print "* %s in %s: %s" % (func, lib, reason)
+                    raise SystemExit("Stopping program")
 
             def getContext(self):
                 return self.ctx
--- a/viff/test/test_util.py	Fri Oct 10 16:49:49 2008 +0200
+++ b/viff/test/test_util.py	Tue Oct 14 16:12:05 2008 +0200
@@ -15,5 +15,153 @@
 # You should have received a copy of the GNU Lesser General Public
 # License along with VIFF. If not, see <http://www.gnu.org/licenses/>.
 
+"""Tests for viff.util."""
+
+from viff.util import deep_wait
+
+from twisted.trial.unittest import TestCase
+from twisted.internet.defer import Deferred, gatherResults, succeed
+
 #: Declare doctests for Trial.
 __doctests__ = ['viff.util']
+
+
+class DeepWaitTest(TestCase):
+    """Tests for :func:`viff.util.deep_wait`."""
+
+    def setUp(self):
+        self.calls = []
+
+    def test_trivial_wait(self):
+        w = deep_wait("not a Deferred")
+        w.addCallback(lambda _: self.calls.append("w"))
+        self.assertIn("w", self.calls)
+
+    def test_simple_wait(self):
+        a = Deferred()
+        a.addCallback(self.calls.append)
+
+        w = deep_wait(a)
+        w.addCallback(lambda _: self.calls.append("w"))
+
+        self.assertNotIn("w", self.calls)
+        a.callback("a")
+        self.assertIn("w", self.calls)
+
+    def test_tuple_wait(self):
+        a = Deferred()
+        b = Deferred()
+
+        a.addCallback(self.calls.append)
+        b.addCallback(self.calls.append)
+
+        w = deep_wait((a, 123, b))
+        w.addCallback(lambda _: self.calls.append("w"))
+
+        self.assertNotIn("w", self.calls)
+        a.callback("a")
+        self.assertNotIn("w", self.calls)
+        b.callback("b")
+        self.assertIn("w", self.calls)
+
+    def test_list_wait(self):
+        a = Deferred()
+        b = Deferred()
+
+        a.addCallback(self.calls.append)
+        b.addCallback(self.calls.append)
+
+        w = deep_wait([a, 123, b])
+        w.addCallback(lambda _: self.calls.append("w"))
+
+        self.assertNotIn("w", self.calls)
+        a.callback("a")
+        self.assertNotIn("w", self.calls)
+        b.callback("b")
+        self.assertIn("w", self.calls)
+
+    def test_deep_wait(self):
+        a = Deferred()
+        b = Deferred()
+
+        def return_b(_):
+            """Callbacks which return a Deferred."""
+            self.calls.append("return_b")
+            return b
+        
+        a.addCallback(self.calls.append)
+        a.addCallback(return_b)
+
+        w = deep_wait(a)
+        w.addCallback(lambda _: self.calls.append("w"))
+
+        self.assertNotIn("a", self.calls)
+        a.callback("a")
+        self.assertIn("a", self.calls)
+        self.assertIn("return_b", self.calls)
+        self.assertNotIn("w", self.calls)
+        self.assertNotIn("b", self.calls)
+
+        b.callback("b")
+        self.assertIn("w", self.calls)
+
+    def test_mixed_deep_wait(self):
+        a = Deferred()
+        b = Deferred()
+
+        def return_mix(_):
+            """Callbacks which return a Deferred and an integer."""
+            self.calls.append("return_mix")
+            return (b, 42)
+        
+        a.addCallback(self.calls.append)
+        a.addCallback(return_mix)
+
+        w = deep_wait(a)
+        w.addCallback(lambda _: self.calls.append("w"))
+
+        self.assertNotIn("a", self.calls)
+        a.callback("a")
+        self.assertIn("a", self.calls)
+        self.assertIn("return_mix", self.calls)
+        self.assertNotIn("w", self.calls)
+
+        b.callback("b")
+        self.assertIn("w", self.calls)
+
+    def test_complex_deep_wait(self):
+        a = Deferred()
+        b = Deferred()
+        c = Deferred()
+        d = Deferred()
+
+        a.addCallback(self.calls.append)
+        b.addCallback(self.calls.append)
+        c.addCallback(self.calls.append)
+        d.addCallback(self.calls.append)
+
+        def return_b(_):
+            self.calls.append("return_b")
+            return (b, 42)
+
+        def return_c_d(_):
+            self.calls.append("return_c")
+            return [(1, 2), "testing", [c, True], (d, 10)]
+
+        a.addCallback(return_b)
+        b.addCallback(return_c_d)
+
+        w = deep_wait(a)
+        w.addCallback(lambda _: self.calls.append("w"))
+
+        a.callback("a")
+        self.assertNotIn("w", self.calls)
+
+        c.callback("c")
+        self.assertNotIn("w", self.calls)
+
+        b.callback("b")
+        self.assertNotIn("w", self.calls)
+
+        d.callback("d")
+        self.assertIn("w", self.calls)
--- a/viff/util.py	Fri Oct 10 16:49:49 2008 +0200
+++ b/viff/util.py	Tue Oct 14 16:12:05 2008 +0200
@@ -25,6 +25,7 @@
 __docformat__ = "restructuredtext"
 
 import os
+import time
 import random
 import warnings
 from twisted.internet.defer import Deferred, succeed, gatherResults
@@ -194,6 +195,31 @@
     return clone
 
 
+class deep_wait(Deferred):
+
+    def __init__(self, result):
+        Deferred.__init__(self)
+        self._wait(result)
+
+    def _wait(self, value):
+        deferreds = []
+
+        def collect(value):
+            if isinstance(value, Deferred):
+                deferreds.append(value)
+            if isinstance(value, (tuple, list)):
+                map(collect, value)
+
+        collect(value)
+
+        if deferreds:
+            # There are one or more Deferreds to wait on.
+            gatherResults(deferreds).addCallback(self._wait)
+        else:
+            # Found no Deferreds -- there is nothing to wait on and so
+            # we are done!
+            self.callback(None)
+
 def find_prime(lower_bound, blum=False):
     """Find a prime above a lower bound.
 
@@ -249,6 +275,65 @@
     return long(p)
 
 
+PHASES = {}
+
+def begin(result, phase):
+    """Begin a phase.
+
+    You can define program phases for the purpose of profiling a
+    program execution. Use :func:`end` with a matching *phase* to
+    record the ending of a phase. The :func:`profile` decorator makes
+    it easy to wrap a :class:`Runtime <viff.runtime.Runtime>` method
+    in matching :func:`begin`/:func:`end` calls.
+
+    The *result* argument is passed through, which makes it possible
+    to add this function as a callback for a :class:`Deferred`.
+    """
+    PHASES[phase] = time.time()
+    return result
+
+def end(result, phase):
+    """End a phase.
+
+    This is the counter-part for :func:`begin`. It prints the name and
+    the duration of the phase.
+
+    The *result* argument is passed through, which makes it possible
+    to add this function as a callback for a :class:`Deferred`.
+    """
+    stop = time.time()
+    start = PHASES.pop(phase, stop)
+    print "%s from %f to %f (%f sec)" % (phase, start, stop, stop - start)
+    return result
+
+def profile(method):
+    """Profiling decorator.
+
+    Add this decorator to a method in order to trace method entry and
+    exit. If the method returns a :class:`Deferred`, the method exit
+    is recorded when the :class:`Deferred` fires.
+
+    In addition to adding this decorator, you must run the programs in
+    an environment with :envvar:`VIFF_PROFILE` defined. Otherwise the
+    decorator is a no-op and has no runtime overhead.
+    """
+    if not os.environ.get('VIFF_PROFILE'):
+        return method
+
+    @wrapper(method)
+    def profile_wrapper(self, *args, **kwargs):
+        label = "%s %s" % (method.__name__,
+                           ".".join(map(str, self.program_counter)))
+        begin(None, label)
+        result = method(self, *args, **kwargs)
+        if isinstance(result, Deferred):
+            result.addCallback(end, label)
+        else:
+            end(None, label)
+        return result
+
+    return profile_wrapper
+
 if __name__ == "__main__":
     import doctest    #pragma NO COVER
     doctest.testmod() #pragma NO COVER