viff

changeset 1001:74bcf4955f99

Merged with preprocessing cleanup.
author Martin Geisler <mg@daimi.au.dk>
date Tue, 14 Oct 2008 16:12:05 +0200
parents 3fe2baebdb99 7546e47ce876
children 735d5e3dade2
files apps/benchmark.py viff/runtime.py
diffstat 6 files changed, 308 insertions(+), 65 deletions(-) [+]
line diff
     1.1 --- a/apps/benchmark.py	Fri Oct 10 16:49:49 2008 +0200
     1.2 +++ b/apps/benchmark.py	Tue Oct 14 16:12:05 2008 +0200
     1.3 @@ -57,7 +57,7 @@
     1.4  import time
     1.5  from optparse import OptionParser
     1.6  import operator
     1.7 -from pprint import pprint
     1.8 +from pprint import pformat
     1.9  
    1.10  from twisted.internet import reactor
    1.11  
    1.12 @@ -147,35 +147,35 @@
    1.13          self.rt = rt
    1.14          self.operation = operation
    1.15  
    1.16 +        program_desc = {}
    1.17 +
    1.18          if isinstance(self.rt, BasicActiveRuntime):
    1.19              # TODO: Make this optional and maybe automatic. The
    1.20              # program descriptions below were found by carefully
    1.21              # studying the output reported when the benchmarks were
    1.22              # run with no preprocessing. So they are quite brittle.
    1.23 +            if self.operation == operator.mul:
    1.24 +                key = ("generate_triples", (Zp,))
    1.25 +                desc = [(i, 1, 0) for i in range(3 + 2*count, 3 + 3*count)]
    1.26 +                program_desc.setdefault(key, []).extend(desc)
    1.27 +            elif isinstance(self.rt, ComparisonToft05Mixin):
    1.28 +                key = ("generate_triples", (GF256,))
    1.29 +                desc = sum([[(c, 64, i, 1, 1, 0) for i in range(2, 33)] +
    1.30 +                            [(c, 64, i, 3, 1, 0) for i in range(17, 33)]
    1.31 +                            for c in range(3 + 2*count, 3 + 3*count)],
    1.32 +                           [])
    1.33 +                program_desc.setdefault(key, []).extend(desc)
    1.34 +            elif isinstance(self.rt, ComparisonToft07Mixin):
    1.35 +                key = ("generate_triples", (Zp,))
    1.36 +                desc = sum([[(c, 2, 4, i, 2, 1, 0) for i in range(1, 33)] +
    1.37 +                            [(c, 2, 4, 99, 2, 1, 0)] +
    1.38 +                            [(c, 2, 4, i, 1, 0) for i in range(65, 98)]
    1.39 +                            for c in range(3 + 2*count, 3 + 3*count)],
    1.40 +                           [])
    1.41 +                program_desc.setdefault(key, []).extend(desc)
    1.42 +
    1.43 +        if program_desc:
    1.44              print "Starting preprocessing"
    1.45 -            if self.operation == operator.mul:
    1.46 -                program_desc = {
    1.47 -                    ("generate_triples", (Zp,)):
    1.48 -                        [(i, 1, 0) for i in range(3 + 2*count, 3 + 3*count)]
    1.49 -                    }
    1.50 -            elif isinstance(self.rt, ComparisonToft05Mixin):
    1.51 -                program_desc = {
    1.52 -                    ("generate_triples", (GF256,)):
    1.53 -                    sum([[(c, 64, i, 1, 1, 0) for i in range(2, 33)] +
    1.54 -                         [(c, 64, i, 3, 1, 0) for i in range(17, 33)]
    1.55 -                         for c in range(3 + 2*count, 3 + 3*count)],
    1.56 -                        [])
    1.57 -                    }
    1.58 -            elif isinstance(self.rt, ComparisonToft07Mixin):
    1.59 -                program_desc = {
    1.60 -                    ("generate_triples", (Zp,)):
    1.61 -                    sum([[(c, 2, 4, i, 2, 1, 0) for i in range(1, 33)] +
    1.62 -                         [(c, 2, 4, 99, 2, 1, 0)] +
    1.63 -                         [(c, 2, 4, i, 1, 0) for i in range(65, 98)]
    1.64 -                         for c in range(3 + 2*count, 3 + 3*count)],
    1.65 -                        [])
    1.66 -                    }
    1.67 -
    1.68              record_start("preprocessing")
    1.69              preproc = rt.preprocess(program_desc)
    1.70              preproc.addCallback(record_stop, "preprocessing")
    1.71 @@ -224,7 +224,9 @@
    1.72  
    1.73          if self.rt._needed_data:
    1.74              print "Missing pre-processed data:"
    1.75 -            pprint(self.rt._needed_data)
    1.76 +            for (func, args), pcs in self.rt._needed_data.iteritems():
    1.77 +                print "* %s%s:" % (func, args)
    1.78 +                print "  " + pformat(pcs).replace("\n", "\n  ")
    1.79  
    1.80          self.rt.shutdown()
    1.81  
    1.82 @@ -290,9 +292,10 @@
    1.83          mixins.append(ProbabilisticEqualityMixin)
    1.84  
    1.85  print "Using the base runtime: %s." % base_runtime_class
    1.86 -print "With the following mixins:"
    1.87 -for mixin in mixins:
    1.88 -    print "- %s" % mixin
    1.89 +if mixins:
    1.90 +    print "With the following mixins:"
    1.91 +    for mixin in mixins:
    1.92 +        print "- %s" % mixin
    1.93  
    1.94  runtime_class = make_runtime_class(base_runtime_class, mixins)
    1.95  
     2.1 --- a/doc/util.txt	Fri Oct 10 16:49:49 2008 +0200
     2.2 +++ b/doc/util.txt	Tue Oct 14 16:12:05 2008 +0200
     2.3 @@ -25,3 +25,8 @@
     2.4  
     2.5        Setting this environment variable to any value will turn
     2.6        :func:`wrapper` into a no-op.
     2.7 +
     2.8 +   .. envvar:: VIFF_PROFILE
     2.9 +
    2.10 +      Defining this variable will change :func:`profile` from a no-op
    2.11 +      to real decorator.
     3.1 --- a/viff/comparison.py	Fri Oct 10 16:49:49 2008 +0200
     3.2 +++ b/viff/comparison.py	Tue Oct 14 16:12:05 2008 +0200
     3.3 @@ -24,7 +24,7 @@
     3.4  
     3.5  import math
     3.6  
     3.7 -from viff.util import rand
     3.8 +from viff.util import rand, profile
     3.9  from viff.runtime import Runtime, Share, gather_shares, increment_pc
    3.10  from viff.active import ActiveRuntime
    3.11  from viff.field import GF256, FieldElement
    3.12 @@ -52,6 +52,20 @@
    3.13          tmp.field = dst_field
    3.14          return reduce(self.xor, dst_shares, tmp)
    3.15  
    3.16 +    def decomposed_random_sharing(self, field, bits):
    3.17 +        bits = [self.prss_share_bit_double(field) for _ in range(bits)]
    3.18 +        int_bits, bit_bits = zip(*bits)
    3.19 +
    3.20 +        def bits_to_int(bits):
    3.21 +            """Converts a list of bits to an integer."""
    3.22 +            return sum([2**i * b for i, b in enumerate(bits)])
    3.23 +
    3.24 +        int_b = gather_shares(int_bits)
    3.25 +        int_b.addCallback(bits_to_int)
    3.26 +
    3.27 +        return int_b, bit_bits
    3.28 +
    3.29 +    @profile
    3.30      @increment_pc
    3.31      def greater_than_equal(self, share_a, share_b):
    3.32          """Compute ``share_a >= share_b``.
    3.33 @@ -74,25 +88,14 @@
    3.34          m = l + self.options.security_parameter
    3.35          t = m + 1
    3.36  
    3.37 -        # Preprocessing begin
    3.38          assert 2**(l+1) + 2**t < field.modulus, "2^(l+1) + 2^t < p must hold"
    3.39          assert self.num_players + 2 < 2**l
    3.40  
    3.41 -        bits = [self.prss_share_bit_double(field) for _ in range(m)]
    3.42 -        int_bits, bit_bits = zip(*bits)
    3.43 +        a = share_a - share_b + 2**l
    3.44 +        b, bits = self.decomposed_random_sharing(field, m)
    3.45 +        T = self.open(2**t - b + a)
    3.46  
    3.47 -        def bits_to_int(bits):
    3.48 -            """Converts a list of bits to an integer."""
    3.49 -            return sum([2**i * b for i, b in enumerate(bits)])
    3.50 -
    3.51 -        int_b = gather_shares(int_bits)
    3.52 -        int_b.addCallback(bits_to_int)
    3.53 -        # Preprocessing done
    3.54 -
    3.55 -        a = share_a - share_b + 2**l
    3.56 -        T = self.open(2**t - int_b + a)
    3.57 -
    3.58 -        result = gather_shares((T,) + bit_bits)
    3.59 +        result = gather_shares((T,) + bits)
    3.60          self.schedule_callback(result, self._finish_greater_than_equal, l)
    3.61          return result
    3.62  
    3.63 @@ -183,6 +186,7 @@
    3.64          full_mask = reduce(self.add, dst_shares)
    3.65          return tmp - full_mask
    3.66  
    3.67 +    @profile
    3.68      @increment_pc
    3.69      def greater_than_equal_preproc(self, field, smallField=None):
    3.70          """Preprocessing for :meth:`greater_than_equal`."""
    3.71 @@ -237,6 +241,7 @@
    3.72          # Preprocessing done
    3.73          ##################################################
    3.74  
    3.75 +    @profile
    3.76      @increment_pc
    3.77      def greater_than_equal_online(self, share_a, share_b, preproc, field):
    3.78          """Compute ``share_a >= share_b``. Result is secret shared."""
     4.1 --- a/viff/runtime.py	Fri Oct 10 16:49:49 2008 +0200
     4.2 +++ b/viff/runtime.py	Tue Oct 14 16:12:05 2008 +0200
     4.3 @@ -41,7 +41,7 @@
     4.4  from viff import shamir
     4.5  from viff.prss import prss, prss_lsb, prss_zero
     4.6  from viff.field import GF256, FieldElement
     4.7 -from viff.util import wrapper, rand
     4.8 +from viff.util import wrapper, rand, profile, deep_wait
     4.9  
    4.10  from twisted.internet import reactor
    4.11  from twisted.internet.error import ConnectionDone, CannotListenError
    4.12 @@ -643,17 +643,6 @@
    4.13              # We concatenate the sub-lists in results.
    4.14              results = sum(results, [])
    4.15  
    4.16 -            wait_list = []
    4.17 -            for result in results:
    4.18 -                # We allow pre-processing methods to return tuples of
    4.19 -                # shares or individual shares as their result. Here we
    4.20 -                # deconstruct result (if possible) and wait on its
    4.21 -                # individual parts.
    4.22 -                if isinstance(result, tuple):
    4.23 -                    wait_list.extend(result)
    4.24 -                else:
    4.25 -                    wait_list.append(result)
    4.26 -
    4.27              # The pool must map program counters to Deferreds to
    4.28              # present a uniform interface for the functions we
    4.29              # pre-process.
    4.30 @@ -661,10 +650,11 @@
    4.31  
    4.32              # Update the pool with pairs of program counter and data.
    4.33              self._pool.update(zip(program_counters, results))
    4.34 +
    4.35              # Return a Deferred that waits on the individual results.
    4.36              # This is important to make it possible for the players to
    4.37              # avoid starting before the pre-processing is complete.
    4.38 -            return gatherResults(wait_list)
    4.39 +            return deep_wait(results)
    4.40  
    4.41          wait_list = []
    4.42          for ((generator, args), program_counters) in program.iteritems():
    4.43 @@ -759,6 +749,7 @@
    4.44          if self.id in receivers:
    4.45              return result
    4.46  
    4.47 +    @profile
    4.48      def add(self, share_a, share_b):
    4.49          """Addition of shares.
    4.50  
    4.51 @@ -789,6 +780,7 @@
    4.52          result.addCallback(lambda (a, b): a - b)
    4.53          return result
    4.54  
    4.55 +    @profile
    4.56      @increment_pc
    4.57      def mul(self, share_a, share_b):
    4.58          """Multiplication of shares.
    4.59 @@ -1158,14 +1150,19 @@
    4.60                  self.id = id
    4.61                  ctx = SSL.Context(SSL.SSLv3_METHOD)
    4.62                  # TODO: Make the file names configurable.
    4.63 -                ctx.use_certificate_file('player-%d.cert' % id)
    4.64 -                ctx.use_privatekey_file('player-%d.key' % id)
    4.65 -                ctx.check_privatekey()
    4.66 -
    4.67 -                ctx.load_verify_locations('ca.cert')
    4.68 -                ctx.set_verify(SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
    4.69 -                               lambda conn, cert, errnum, depth, ok: ok)
    4.70 -                self.ctx = ctx
    4.71 +                try:
    4.72 +                    ctx.use_certificate_file('player-%d.cert' % id)
    4.73 +                    ctx.use_privatekey_file('player-%d.key' % id)
    4.74 +                    ctx.check_privatekey()
    4.75 +                    ctx.load_verify_locations('ca.cert')
    4.76 +                    ctx.set_verify(SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
    4.77 +                                   lambda conn, cert, errnum, depth, ok: ok)
    4.78 +                    self.ctx = ctx
    4.79 +                except SSL.Error, e:
    4.80 +                    print "SSL errors - did you forget to generate certificates?"
    4.81 +                    for (lib, func, reason) in e.args[0]:
    4.82 +                        print "* %s in %s: %s" % (func, lib, reason)
    4.83 +                    raise SystemExit("Stopping program")
    4.84  
    4.85              def getContext(self):
    4.86                  return self.ctx
     5.1 --- a/viff/test/test_util.py	Fri Oct 10 16:49:49 2008 +0200
     5.2 +++ b/viff/test/test_util.py	Tue Oct 14 16:12:05 2008 +0200
     5.3 @@ -15,5 +15,153 @@
     5.4  # You should have received a copy of the GNU Lesser General Public
     5.5  # License along with VIFF. If not, see <http://www.gnu.org/licenses/>.
     5.6  
     5.7 +"""Tests for viff.util."""
     5.8 +
     5.9 +from viff.util import deep_wait
    5.10 +
    5.11 +from twisted.trial.unittest import TestCase
    5.12 +from twisted.internet.defer import Deferred, gatherResults, succeed
    5.13 +
    5.14  #: Declare doctests for Trial.
    5.15  __doctests__ = ['viff.util']
    5.16 +
    5.17 +
    5.18 +class DeepWaitTest(TestCase):
    5.19 +    """Tests for :func:`viff.util.deep_wait`."""
    5.20 +
    5.21 +    def setUp(self):
    5.22 +        self.calls = []
    5.23 +
    5.24 +    def test_trivial_wait(self):
    5.25 +        w = deep_wait("not a Deferred")
    5.26 +        w.addCallback(lambda _: self.calls.append("w"))
    5.27 +        self.assertIn("w", self.calls)
    5.28 +
    5.29 +    def test_simple_wait(self):
    5.30 +        a = Deferred()
    5.31 +        a.addCallback(self.calls.append)
    5.32 +
    5.33 +        w = deep_wait(a)
    5.34 +        w.addCallback(lambda _: self.calls.append("w"))
    5.35 +
    5.36 +        self.assertNotIn("w", self.calls)
    5.37 +        a.callback("a")
    5.38 +        self.assertIn("w", self.calls)
    5.39 +
    5.40 +    def test_tuple_wait(self):
    5.41 +        a = Deferred()
    5.42 +        b = Deferred()
    5.43 +
    5.44 +        a.addCallback(self.calls.append)
    5.45 +        b.addCallback(self.calls.append)
    5.46 +
    5.47 +        w = deep_wait((a, 123, b))
    5.48 +        w.addCallback(lambda _: self.calls.append("w"))
    5.49 +
    5.50 +        self.assertNotIn("w", self.calls)
    5.51 +        a.callback("a")
    5.52 +        self.assertNotIn("w", self.calls)
    5.53 +        b.callback("b")
    5.54 +        self.assertIn("w", self.calls)
    5.55 +
    5.56 +    def test_list_wait(self):
    5.57 +        a = Deferred()
    5.58 +        b = Deferred()
    5.59 +
    5.60 +        a.addCallback(self.calls.append)
    5.61 +        b.addCallback(self.calls.append)
    5.62 +
    5.63 +        w = deep_wait([a, 123, b])
    5.64 +        w.addCallback(lambda _: self.calls.append("w"))
    5.65 +
    5.66 +        self.assertNotIn("w", self.calls)
    5.67 +        a.callback("a")
    5.68 +        self.assertNotIn("w", self.calls)
    5.69 +        b.callback("b")
    5.70 +        self.assertIn("w", self.calls)
    5.71 +
    5.72 +    def test_deep_wait(self):
    5.73 +        a = Deferred()
    5.74 +        b = Deferred()
    5.75 +
    5.76 +        def return_b(_):
    5.77 +            """Callbacks which return a Deferred."""
    5.78 +            self.calls.append("return_b")
    5.79 +            return b
    5.80 +        
    5.81 +        a.addCallback(self.calls.append)
    5.82 +        a.addCallback(return_b)
    5.83 +
    5.84 +        w = deep_wait(a)
    5.85 +        w.addCallback(lambda _: self.calls.append("w"))
    5.86 +
    5.87 +        self.assertNotIn("a", self.calls)
    5.88 +        a.callback("a")
    5.89 +        self.assertIn("a", self.calls)
    5.90 +        self.assertIn("return_b", self.calls)
    5.91 +        self.assertNotIn("w", self.calls)
    5.92 +        self.assertNotIn("b", self.calls)
    5.93 +
    5.94 +        b.callback("b")
    5.95 +        self.assertIn("w", self.calls)
    5.96 +
    5.97 +    def test_mixed_deep_wait(self):
    5.98 +        a = Deferred()
    5.99 +        b = Deferred()
   5.100 +
   5.101 +        def return_mix(_):
   5.102 +            """Callbacks which return a Deferred and an integer."""
   5.103 +            self.calls.append("return_mix")
   5.104 +            return (b, 42)
   5.105 +        
   5.106 +        a.addCallback(self.calls.append)
   5.107 +        a.addCallback(return_mix)
   5.108 +
   5.109 +        w = deep_wait(a)
   5.110 +        w.addCallback(lambda _: self.calls.append("w"))
   5.111 +
   5.112 +        self.assertNotIn("a", self.calls)
   5.113 +        a.callback("a")
   5.114 +        self.assertIn("a", self.calls)
   5.115 +        self.assertIn("return_mix", self.calls)
   5.116 +        self.assertNotIn("w", self.calls)
   5.117 +
   5.118 +        b.callback("b")
   5.119 +        self.assertIn("w", self.calls)
   5.120 +
   5.121 +    def test_complex_deep_wait(self):
   5.122 +        a = Deferred()
   5.123 +        b = Deferred()
   5.124 +        c = Deferred()
   5.125 +        d = Deferred()
   5.126 +
   5.127 +        a.addCallback(self.calls.append)
   5.128 +        b.addCallback(self.calls.append)
   5.129 +        c.addCallback(self.calls.append)
   5.130 +        d.addCallback(self.calls.append)
   5.131 +
   5.132 +        def return_b(_):
   5.133 +            self.calls.append("return_b")
   5.134 +            return (b, 42)
   5.135 +
   5.136 +        def return_c_d(_):
   5.137 +            self.calls.append("return_c")
   5.138 +            return [(1, 2), "testing", [c, True], (d, 10)]
   5.139 +
   5.140 +        a.addCallback(return_b)
   5.141 +        b.addCallback(return_c_d)
   5.142 +
   5.143 +        w = deep_wait(a)
   5.144 +        w.addCallback(lambda _: self.calls.append("w"))
   5.145 +
   5.146 +        a.callback("a")
   5.147 +        self.assertNotIn("w", self.calls)
   5.148 +
   5.149 +        c.callback("c")
   5.150 +        self.assertNotIn("w", self.calls)
   5.151 +
   5.152 +        b.callback("b")
   5.153 +        self.assertNotIn("w", self.calls)
   5.154 +
   5.155 +        d.callback("d")
   5.156 +        self.assertIn("w", self.calls)
     6.1 --- a/viff/util.py	Fri Oct 10 16:49:49 2008 +0200
     6.2 +++ b/viff/util.py	Tue Oct 14 16:12:05 2008 +0200
     6.3 @@ -25,6 +25,7 @@
     6.4  __docformat__ = "restructuredtext"
     6.5  
     6.6  import os
     6.7 +import time
     6.8  import random
     6.9  import warnings
    6.10  from twisted.internet.defer import Deferred, succeed, gatherResults
    6.11 @@ -194,6 +195,31 @@
    6.12      return clone
    6.13  
    6.14  
    6.15 +class deep_wait(Deferred):
    6.16 +
    6.17 +    def __init__(self, result):
    6.18 +        Deferred.__init__(self)
    6.19 +        self._wait(result)
    6.20 +
    6.21 +    def _wait(self, value):
    6.22 +        deferreds = []
    6.23 +
    6.24 +        def collect(value):
    6.25 +            if isinstance(value, Deferred):
    6.26 +                deferreds.append(value)
    6.27 +            if isinstance(value, (tuple, list)):
    6.28 +                map(collect, value)
    6.29 +
    6.30 +        collect(value)
    6.31 +
    6.32 +        if deferreds:
    6.33 +            # There are one or more Deferreds to wait on.
    6.34 +            gatherResults(deferreds).addCallback(self._wait)
    6.35 +        else:
    6.36 +            # Found no Deferreds -- there is nothing to wait on and so
    6.37 +            # we are done!
    6.38 +            self.callback(None)
    6.39 +
    6.40  def find_prime(lower_bound, blum=False):
    6.41      """Find a prime above a lower bound.
    6.42  
    6.43 @@ -249,6 +275,65 @@
    6.44      return long(p)
    6.45  
    6.46  
    6.47 +PHASES = {}
    6.48 +
    6.49 +def begin(result, phase):
    6.50 +    """Begin a phase.
    6.51 +
    6.52 +    You can define program phases for the purpose of profiling a
    6.53 +    program execution. Use :func:`end` with a matching *phase* to
    6.54 +    record the ending of a phase. The :func:`profile` decorator makes
    6.55 +    it easy to wrap a :class:`Runtime <viff.runtime.Runtime>` method
    6.56 +    in matching :func:`begin`/:func:`end` calls.
    6.57 +
    6.58 +    The *result* argument is passed through, which makes it possible
    6.59 +    to add this function as a callback for a :class:`Deferred`.
    6.60 +    """
    6.61 +    PHASES[phase] = time.time()
    6.62 +    return result
    6.63 +
    6.64 +def end(result, phase):
    6.65 +    """End a phase.
    6.66 +
    6.67 +    This is the counter-part for :func:`begin`. It prints the name and
    6.68 +    the duration of the phase.
    6.69 +
    6.70 +    The *result* argument is passed through, which makes it possible
    6.71 +    to add this function as a callback for a :class:`Deferred`.
    6.72 +    """
    6.73 +    stop = time.time()
    6.74 +    start = PHASES.pop(phase, stop)
    6.75 +    print "%s from %f to %f (%f sec)" % (phase, start, stop, stop - start)
    6.76 +    return result
    6.77 +
    6.78 +def profile(method):
    6.79 +    """Profiling decorator.
    6.80 +
    6.81 +    Add this decorator to a method in order to trace method entry and
    6.82 +    exit. If the method returns a :class:`Deferred`, the method exit
    6.83 +    is recorded when the :class:`Deferred` fires.
    6.84 +
    6.85 +    In addition to adding this decorator, you must run the programs in
    6.86 +    an environment with :envvar:`VIFF_PROFILE` defined. Otherwise the
    6.87 +    decorator is a no-op and has no runtime overhead.
    6.88 +    """
    6.89 +    if not os.environ.get('VIFF_PROFILE'):
    6.90 +        return method
    6.91 +
    6.92 +    @wrapper(method)
    6.93 +    def profile_wrapper(self, *args, **kwargs):
    6.94 +        label = "%s %s" % (method.__name__,
    6.95 +                           ".".join(map(str, self.program_counter)))
    6.96 +        begin(None, label)
    6.97 +        result = method(self, *args, **kwargs)
    6.98 +        if isinstance(result, Deferred):
    6.99 +            result.addCallback(end, label)
   6.100 +        else:
   6.101 +            end(None, label)
   6.102 +        return result
   6.103 +
   6.104 +    return profile_wrapper
   6.105 +
   6.106  if __name__ == "__main__":
   6.107      import doctest    #pragma NO COVER
   6.108      doctest.testmod() #pragma NO COVER