viff

changeset 1277:26f7a133172a

Merged with Janus.
author Marcel Keller <mkeller@cs.au.dk>
date Fri, 09 Oct 2009 16:33:15 +0200
parents f9ab0f24979d bf991ca28f23
children 4b686925f035 79c351c812f3
files viff/runtime.py viff/test/test_basic_runtime.py
diffstat 5 files changed, 102 insertions(+), 46 deletions(-) [+]
line diff
     1.1 --- a/apps/benchmark.py	Fri Oct 09 16:32:12 2009 +0200
     1.2 +++ b/apps/benchmark.py	Fri Oct 09 16:33:15 2009 +0200
     1.3 @@ -74,9 +74,16 @@
     1.4  from viff.comparison import ComparisonToft05Mixin, ComparisonToft07Mixin
     1.5  from viff.equality import ProbabilisticEqualityMixin
     1.6  from viff.paillier import PaillierRuntime
     1.7 +from viff.orlandi import OrlandiRuntime
     1.8  from viff.config import load_config
     1.9  from viff.util import find_prime, rand
    1.10  
    1.11 +
    1.12 +# Hack in order to avoid Maximum recursion depth exceeded
    1.13 +# exception;
    1.14 +sys.setrecursionlimit(5000)
    1.15 +
    1.16 +
    1.17  last_timestamp = time.time()
    1.18  start = 0
    1.19  
    1.20 @@ -103,7 +110,8 @@
    1.21  
    1.22  runtimes = {"PassiveRuntime": PassiveRuntime,
    1.23              "PaillierRuntime": PaillierRuntime, 
    1.24 -            "BasicActiveRuntime": BasicActiveRuntime}
    1.25 +            "BasicActiveRuntime": BasicActiveRuntime,
    1.26 +            "OrlandiRuntime": OrlandiRuntime}
    1.27  
    1.28  mixins = {"TriplesHyperinvertibleMatricesMixin" : TriplesHyperinvertibleMatricesMixin, 
    1.29            "TriplesPRSSMixin": TriplesPRSSMixin, 
    1.30 @@ -138,10 +146,17 @@
    1.31                    help="skip local computations using fake field elements")
    1.32  parser.add_option("--args", type="string",
    1.33                    help="additional arguments to the runtime, the format is a comma separated list of id=value pairs e.g. --args s=1,d=0,lambda=1")
    1.34 +parser.add_option("--needed_data", type="string",
    1.35 +                  help="name of a file containing already computed dictionary of needed_data. Useful for skipping generating the needed data, which usually elliminates half the execution time. Format of file: \"{('random_triple', (Zp,)): [(3, 1), (3, 4)]}\"")
    1.36 +parser.add_option("--pc", type="string",
    1.37 +                  help="The program counter to start from when using explicitly provided needed_data. Format: [3,0]")
    1.38  
    1.39  parser.set_defaults(modulus=2**65, threshold=1, count=10,
    1.40 -                    runtime=runtimes.keys()[0], mixins="", num_players=2, prss=True,
    1.41 -                    operation=operations.keys()[0], parallel=True, fake=False, args="")
    1.42 +                    runtime="PassiveRuntime", mixins="", num_players=2, prss=True,
    1.43 +                    operation=operations.keys()[0], parallel=True, fake=False, 
    1.44 +                    args="", needed_data="")
    1.45 +
    1.46 +print "*" * 60
    1.47  
    1.48  # Add standard VIFF options.
    1.49  Runtime.add_options(parser)
    1.50 @@ -168,28 +183,44 @@
    1.51  count = options.count
    1.52  print "I am player %d, will %s %d numbers" % (id, options.operation, count)
    1.53  
    1.54 +
    1.55 +class BenchmarkStrategy:
    1.56 +
    1.57 +    def benchmark(self, *args):
    1.58 +        raise NotImplemented("Override this abstract method in subclasses")
    1.59 +
    1.60 +
    1.61 +class SelfcontainedBenchmarkStrategy(BenchmarkStrategy):
    1.62 +
    1.63 +    def benchmark(self, *args):
    1.64 +        sys.stdout.flush()
    1.65 +        sync = self.rt.synchronize()
    1.66 +        self.doTest(sync, lambda x: x)
    1.67 +        self.rt.schedule_callback(sync, self.preprocess)
    1.68 +        self.doTest(sync, lambda x: self.rt.shutdown())
    1.69 +
    1.70 +
    1.71 +class NeededDataBenchmarkStrategy(BenchmarkStrategy):
    1.72 +
    1.73 +    def benchmark(self, needed_data, pc, *args):
    1.74 +        self.pc = pc
    1.75 +        sys.stdout.flush()
    1.76 +        sync = self.rt.synchronize()
    1.77 +        self.rt.schedule_callback(sync, lambda x: needed_data)
    1.78 +        self.rt.schedule_callback(sync, self.preprocess)
    1.79 +        self.doTest(sync, lambda x: self.rt.shutdown())
    1.80 +
    1.81 +
    1.82  # Defining the protocol as a class makes it easier to write the
    1.83  # callbacks in the order they are called. This class is a base class
    1.84  # that executes the protocol by calling the run_test method.
    1.85  class Benchmark:
    1.86  
    1.87      def __init__(self, rt, operation):
    1.88 -        print "init"
    1.89          self.rt = rt
    1.90          self.operation = operation
    1.91          self.pc = None
    1.92 -        sys.stdout.flush()
    1.93 -        sync = self.rt.synchronize()
    1.94 -        self.doTest(sync, lambda x: x)
    1.95 -        self.rt.schedule_callback(sync, self.preprocess)
    1.96 -        self.doTest(sync, lambda x: self.rt.shutdown())
    1.97          
    1.98 -#     def sync_preprocess(self):
    1.99 -#         print "Synchronizing preprocessing"
   1.100 -#         sys.stdout.flush()
   1.101 -#         sync = self.rt.synchronize()
   1.102 -#         self.rt.schedule_callback(sync, self.preprocess)
   1.103 -
   1.104      def preprocess(self, needed_data):
   1.105          print "Preprocess", needed_data
   1.106          if needed_data:
   1.107 @@ -203,10 +234,8 @@
   1.108              return None
   1.109  
   1.110      def doTest(self, d, termination_function):
   1.111 -        print "doTest", self.rt.program_counter
   1.112          self.rt.schedule_callback(d, self.begin)
   1.113          self.rt.schedule_callback(d, self.sync_test)
   1.114 -#         self.rt.schedule_callback(d, self.countdown, 3)
   1.115          self.rt.schedule_callback(d, self.run_test)
   1.116          self.rt.schedule_callback(d, self.sync_test)
   1.117          self.rt.schedule_callback(d, self.finished, termination_function)
   1.118 @@ -236,16 +265,6 @@
   1.119          self.rt.schedule_callback(sync, lambda y: x)
   1.120          return sync
   1.121  
   1.122 -#     def countdown(self, _, seconds):
   1.123 -#         if seconds > 0:
   1.124 -#             print "Starting test in %d" % seconds
   1.125 -#             sys.stdout.flush()
   1.126 -#             reactor.callLater(1, self.countdown, None, seconds - 1)
   1.127 -#         else:
   1.128 -#             print "Starting test now"
   1.129 -#             sys.stdout.flush()
   1.130 -#             self.run_test(None)
   1.131 -
   1.132      def run_test(self, _):
   1.133          raise NotImplemented("Override this abstract method in a sub class.")
   1.134  
   1.135 @@ -276,6 +295,7 @@
   1.136              a = self.a_shares.pop()
   1.137              b = self.b_shares.pop()
   1.138              c_shares.append(self.operation(a, b))
   1.139 +            print "."
   1.140  
   1.141          done = gather_shares(c_shares)
   1.142          done.addCallback(record_stop, "parallel test")
   1.143 @@ -330,6 +350,20 @@
   1.144  else:
   1.145      benchmark = SequentialBenchmark
   1.146  
   1.147 +needed_data = ""
   1.148 +if options.needed_data != "":
   1.149 +    file = open(options.needed_data, 'r')
   1.150 +    for l in file:
   1.151 +        needed_data += l
   1.152 +    needed_data = eval(needed_data)
   1.153 +
   1.154 +if options.needed_data != "" and options.pc != "":
   1.155 +    bases = (benchmark,) + (NeededDataBenchmarkStrategy,) + (object,)
   1.156 +    options.pc = eval(options.pc)
   1.157 +else:
   1.158 +    bases = (benchmark,) + (SelfcontainedBenchmarkStrategy,) + (object,)
   1.159 +benchmark = type("ExtendedBenchmark", bases, {})
   1.160 +
   1.161  pre_runtime = create_runtime(id, players, options.threshold,
   1.162                               options, runtime_class)
   1.163  
   1.164 @@ -339,13 +373,16 @@
   1.165          for arg in options.args.split(','):
   1.166              id, value = arg.split('=')
   1.167              args[id] = long(value)
   1.168 -        runtime.setArgs(args)
   1.169 +        runtime.set_args(args)
   1.170      return runtime
   1.171  
   1.172  
   1.173  pre_runtime.addCallback(update_args, options)
   1.174  
   1.175 -pre_runtime.addCallback(benchmark, operation)
   1.176 +def do_benchmark(runtime, operation, benchmark, *args):
   1.177 +    benchmark(runtime, operation).benchmark(*args)
   1.178 +
   1.179 +pre_runtime.addCallback(do_benchmark, operation, benchmark, needed_data, options.pc)
   1.180  
   1.181  print "#### Starting reactor ###"
   1.182  reactor.run()
     2.1 --- a/viff/hash_broadcast.py	Fri Oct 09 16:32:12 2009 +0200
     2.2 +++ b/viff/hash_broadcast.py	Fri Oct 09 16:33:15 2009 +0200
     2.3 @@ -54,7 +54,7 @@
     2.4              signals[peer_id] = long(signal)
     2.5              # If all signals are received then check if they are OK or INCONSISTENTHASH.
     2.6              if num_receivers == len(signals.keys()):
     2.7 -                s = reduce(lambda x, y: OK if OK == y else INCONSISTENTHASH, signals.values())
     2.8 +                s = reduce(lambda x, y: (OK == y and OK) or INCONSISTENTHASH, signals.values())
     2.9                  if OK == s:
    2.10                      # Make the result ready.
    2.11                      result.callback(message[0])
    2.12 @@ -69,7 +69,10 @@
    2.13                  signal = OK
    2.14                  # First we check if the hashes we received are equal to the hash we computed ourselves.
    2.15                  for peer_id in receivers:
    2.16 -                    signal = signal if a_hashes[peer_id] == a_hashes[self.id] else INCONSISTENTHASH
    2.17 +                    if a_hashes[peer_id] == a_hashes[self.id]:
    2.18 +                        signal = signal
    2.19 +                    else:
    2.20 +                        signal = INCONSISTENTHASH
    2.21                  # Then we send the SAME signal to everybody. 
    2.22                  for peer_id in receivers:
    2.23                      self.protocols[peer_id].sendData(unique_pc, SIGNAL, str(signal))           
     3.1 --- a/viff/orlandi.py	Fri Oct 09 16:32:12 2009 +0200
     3.2 +++ b/viff/orlandi.py	Fri Oct 09 16:33:15 2009 +0200
     3.3 @@ -17,7 +17,7 @@
     3.4  
     3.5  from twisted.internet.defer import Deferred, DeferredList, gatherResults
     3.6  
     3.7 -from viff.runtime import Runtime, Share, ShareList, gather_shares
     3.8 +from viff.runtime import Runtime, Share, ShareList, gather_shares, preprocess
     3.9  from viff.util import rand
    3.10  from viff.constants import TEXT, PAILLIER
    3.11  from viff.field import FieldElement
    3.12 @@ -598,6 +598,7 @@
    3.13          c = OrlandiShare(self, field, field(value), (field(0), field(0)), Cc)
    3.14          return c
    3.15  
    3.16 +    @preprocess("random_triple")
    3.17      def _get_triple(self, field):
    3.18          c, d = self.random_triple(field, 1)
    3.19          def f(ls):
    3.20 @@ -1025,7 +1026,7 @@
    3.21  
    3.22          return result
    3.23  
    3.24 -    def random_triple(self, field, number_of_requested_triples):
    3.25 +    def random_triple(self, field, quantity=1):
    3.26          """Generate a list of triples ``(a, b, c)`` where ``c = a * b``.
    3.27  
    3.28          The triple ``(a, b, c)`` is secure in the Fcrs-hybrid model.
    3.29 @@ -1035,14 +1036,14 @@
    3.30  
    3.31          M = []
    3.32  
    3.33 -# print "Generating %i triples... relax, have a brak..." % ((1 + self.s_lambda) * (2 * self.d + 1) * number_of_requested_triples)
    3.34 +# print "Generating %i triples... relax, have a break..." % ((1 + self.s_lambda) * (2 * self.d + 1) * quantity)
    3.35  
    3.36 -        for x in xrange((1 + self.s_lambda) * (2 * self.d + 1) * number_of_requested_triples):
    3.37 +        for x in xrange((1 + self.s_lambda) * (2 * self.d + 1) * quantity):
    3.38              M.append(self.triple_test(field))
    3.39  
    3.40          def step3(ls):
    3.41              """Coin-flip a subset test_set of M of size lambda(2d + 1)M."""
    3.42 -            size = self.s_lambda * (2 * self.d + 1) * number_of_requested_triples
    3.43 +            size = self.s_lambda * (2 * self.d + 1) * quantity
    3.44              inx = 0
    3.45              p_half = field.modulus // 2
    3.46              def coin_flip(v, ls, test_set):
    3.47 @@ -1250,18 +1251,18 @@
    3.48              return dls_all
    3.49  
    3.50          def step6(M_without_test_set):
    3.51 -            """Partition M without test_set in number_of_requested_triples
    3.52 +            """Partition M without test_set in quantity
    3.53              random subsets M_i of size (2d + 1).
    3.54              """
    3.55              subsets = []
    3.56              size = 2 * self.d + 1
    3.57 -            for x in xrange(number_of_requested_triples):
    3.58 +            for x in xrange(quantity):
    3.59                  subsets.append([])
    3.60  
    3.61              def put_in_set(v, M_without_test_set, subsets):
    3.62                  if 0 == len(M_without_test_set):
    3.63                      return subsets
    3.64 -                v = v.value % number_of_requested_triples
    3.65 +                v = v.value % quantity
    3.66                  if size > len(subsets[v]):
    3.67                      subsets[v].append(M_without_test_set.pop(0))
    3.68                  r = self.random_share(field)
    3.69 @@ -1311,12 +1312,17 @@
    3.70          self.activate_reactor()
    3.71  
    3.72          s = Share(self, field)
    3.73 -        def f(ls, s):
    3.74 -            s.callback(ls)
    3.75 -        result.addCallbacks(f, self.error_handler, callbackArgs=(s,))
    3.76 -        return number_of_requested_triples, s
    3.77 +        # We add the result to the chains in result.
    3.78 +        result.chainDeferred(s)
    3.79 +
    3.80 +        return quantity, s
    3.81  
    3.82      def error_handler(self, ex):
    3.83          print "Error: ", ex
    3.84          return ex
    3.85  
    3.86 +    def set_args(self, args):
    3.87 +        """args is a dictionary."""
    3.88 +        self.s = args['s']
    3.89 +        self.d = args['d']
    3.90 +        self.s_lambda = args['lambda']
     4.1 --- a/viff/runtime.py	Fri Oct 09 16:32:12 2009 +0200
     4.2 +++ b/viff/runtime.py	Fri Oct 09 16:33:15 2009 +0200
     4.3 @@ -806,7 +806,14 @@
     4.4          example of a method fulfilling this interface.
     4.5          """
     4.6  
     4.7 -        def update(results, program_counters):
     4.8 +        def update(results, program_counters, start_time, count, what):
     4.9 +            stop = time.time()
    4.10 +
    4.11 +            print
    4.12 +            print "Total time used: %.3f sec" % (stop - start_time)
    4.13 +            print "Time per %s operation: %.0f ms" % (what, 1000*(stop - start_time) / count)
    4.14 +            print "*" * 6
    4.15 +
    4.16              # We concatenate the sub-lists in results.
    4.17              results = sum(results, [])
    4.18  
    4.19 @@ -831,7 +838,10 @@
    4.20              func = getattr(self, generator)
    4.21              results = []
    4.22              items = 0
    4.23 +            count = 0
    4.24 +            start_time = time.time()
    4.25              while items < len(program_counters):
    4.26 +                count += 1
    4.27                  self.increment_pc()
    4.28                  self.fork_pc()
    4.29                  item_count, result = func(quantity=len(program_counters) - items, *args)
    4.30 @@ -839,7 +849,7 @@
    4.31                  results.append(result)
    4.32                  self.unfork_pc()
    4.33              ready = gatherResults(results)
    4.34 -            ready.addCallback(update, program_counters)
    4.35 +            ready.addCallback(update, program_counters, start_time, count, generator)
    4.36              wait_list.append(ready)
    4.37              self.unfork_pc()
    4.38          return DeferredList(wait_list)
     5.1 --- a/viff/test/test_active_runtime.py	Fri Oct 09 16:32:12 2009 +0200
     5.2 +++ b/viff/test/test_active_runtime.py	Fri Oct 09 16:33:15 2009 +0200
     5.3 @@ -113,7 +113,7 @@
     5.4                  results.append(result)
     5.5              return gatherResults(results)
     5.6  
     5.7 -        count, triples = runtime.generate_triples(self.Zp)
     5.8 +        count, triples = runtime.generate_triples(self.Zp, 1)
     5.9          self.assertEquals(count, runtime.num_players - 2*runtime.threshold)
    5.10  
    5.11          runtime.schedule_callback(triples, check)