changeset 1277:26f7a133172a

Merged with Janus.
author Marcel Keller <mkeller@cs.au.dk>
date Fri, 09 Oct 2009 16:33:15 +0200
parents f9ab0f24979d bf991ca28f23
children 4b686925f035 79c351c812f3
files viff/runtime.py viff/test/test_basic_runtime.py
diffstat 5 files changed, 102 insertions(+), 46 deletions(-) [+]
line wrap: on
line diff
--- a/apps/benchmark.py	Fri Oct 09 16:32:12 2009 +0200
+++ b/apps/benchmark.py	Fri Oct 09 16:33:15 2009 +0200
@@ -74,9 +74,16 @@
 from viff.comparison import ComparisonToft05Mixin, ComparisonToft07Mixin
 from viff.equality import ProbabilisticEqualityMixin
 from viff.paillier import PaillierRuntime
+from viff.orlandi import OrlandiRuntime
 from viff.config import load_config
 from viff.util import find_prime, rand
 
+
+# Hack in order to avoid Maximum recursion depth exceeded
+# exception;
+sys.setrecursionlimit(5000)
+
+
 last_timestamp = time.time()
 start = 0
 
@@ -103,7 +110,8 @@
 
 runtimes = {"PassiveRuntime": PassiveRuntime,
             "PaillierRuntime": PaillierRuntime, 
-            "BasicActiveRuntime": BasicActiveRuntime}
+            "BasicActiveRuntime": BasicActiveRuntime,
+            "OrlandiRuntime": OrlandiRuntime}
 
 mixins = {"TriplesHyperinvertibleMatricesMixin" : TriplesHyperinvertibleMatricesMixin, 
           "TriplesPRSSMixin": TriplesPRSSMixin, 
@@ -138,10 +146,17 @@
                   help="skip local computations using fake field elements")
 parser.add_option("--args", type="string",
                   help="additional arguments to the runtime, the format is a comma separated list of id=value pairs e.g. --args s=1,d=0,lambda=1")
+parser.add_option("--needed_data", type="string",
+                  help="name of a file containing already computed dictionary of needed_data. Useful for skipping generating the needed data, which usually elliminates half the execution time. Format of file: \"{('random_triple', (Zp,)): [(3, 1), (3, 4)]}\"")
+parser.add_option("--pc", type="string",
+                  help="The program counter to start from when using explicitly provided needed_data. Format: [3,0]")
 
 parser.set_defaults(modulus=2**65, threshold=1, count=10,
-                    runtime=runtimes.keys()[0], mixins="", num_players=2, prss=True,
-                    operation=operations.keys()[0], parallel=True, fake=False, args="")
+                    runtime="PassiveRuntime", mixins="", num_players=2, prss=True,
+                    operation=operations.keys()[0], parallel=True, fake=False, 
+                    args="", needed_data="")
+
+print "*" * 60
 
 # Add standard VIFF options.
 Runtime.add_options(parser)
@@ -168,28 +183,44 @@
 count = options.count
 print "I am player %d, will %s %d numbers" % (id, options.operation, count)
 
+
+class BenchmarkStrategy:
+
+    def benchmark(self, *args):
+        raise NotImplemented("Override this abstract method in subclasses")
+
+
+class SelfcontainedBenchmarkStrategy(BenchmarkStrategy):
+
+    def benchmark(self, *args):
+        sys.stdout.flush()
+        sync = self.rt.synchronize()
+        self.doTest(sync, lambda x: x)
+        self.rt.schedule_callback(sync, self.preprocess)
+        self.doTest(sync, lambda x: self.rt.shutdown())
+
+
+class NeededDataBenchmarkStrategy(BenchmarkStrategy):
+
+    def benchmark(self, needed_data, pc, *args):
+        self.pc = pc
+        sys.stdout.flush()
+        sync = self.rt.synchronize()
+        self.rt.schedule_callback(sync, lambda x: needed_data)
+        self.rt.schedule_callback(sync, self.preprocess)
+        self.doTest(sync, lambda x: self.rt.shutdown())
+
+
 # Defining the protocol as a class makes it easier to write the
 # callbacks in the order they are called. This class is a base class
 # that executes the protocol by calling the run_test method.
 class Benchmark:
 
     def __init__(self, rt, operation):
-        print "init"
         self.rt = rt
         self.operation = operation
         self.pc = None
-        sys.stdout.flush()
-        sync = self.rt.synchronize()
-        self.doTest(sync, lambda x: x)
-        self.rt.schedule_callback(sync, self.preprocess)
-        self.doTest(sync, lambda x: self.rt.shutdown())
         
-#     def sync_preprocess(self):
-#         print "Synchronizing preprocessing"
-#         sys.stdout.flush()
-#         sync = self.rt.synchronize()
-#         self.rt.schedule_callback(sync, self.preprocess)
-
     def preprocess(self, needed_data):
         print "Preprocess", needed_data
         if needed_data:
@@ -203,10 +234,8 @@
             return None
 
     def doTest(self, d, termination_function):
-        print "doTest", self.rt.program_counter
         self.rt.schedule_callback(d, self.begin)
         self.rt.schedule_callback(d, self.sync_test)
-#         self.rt.schedule_callback(d, self.countdown, 3)
         self.rt.schedule_callback(d, self.run_test)
         self.rt.schedule_callback(d, self.sync_test)
         self.rt.schedule_callback(d, self.finished, termination_function)
@@ -236,16 +265,6 @@
         self.rt.schedule_callback(sync, lambda y: x)
         return sync
 
-#     def countdown(self, _, seconds):
-#         if seconds > 0:
-#             print "Starting test in %d" % seconds
-#             sys.stdout.flush()
-#             reactor.callLater(1, self.countdown, None, seconds - 1)
-#         else:
-#             print "Starting test now"
-#             sys.stdout.flush()
-#             self.run_test(None)
-
     def run_test(self, _):
         raise NotImplemented("Override this abstract method in a sub class.")
 
@@ -276,6 +295,7 @@
             a = self.a_shares.pop()
             b = self.b_shares.pop()
             c_shares.append(self.operation(a, b))
+            print "."
 
         done = gather_shares(c_shares)
         done.addCallback(record_stop, "parallel test")
@@ -330,6 +350,20 @@
 else:
     benchmark = SequentialBenchmark
 
+needed_data = ""
+if options.needed_data != "":
+    file = open(options.needed_data, 'r')
+    for l in file:
+        needed_data += l
+    needed_data = eval(needed_data)
+
+if options.needed_data != "" and options.pc != "":
+    bases = (benchmark,) + (NeededDataBenchmarkStrategy,) + (object,)
+    options.pc = eval(options.pc)
+else:
+    bases = (benchmark,) + (SelfcontainedBenchmarkStrategy,) + (object,)
+benchmark = type("ExtendedBenchmark", bases, {})
+
 pre_runtime = create_runtime(id, players, options.threshold,
                              options, runtime_class)
 
@@ -339,13 +373,16 @@
         for arg in options.args.split(','):
             id, value = arg.split('=')
             args[id] = long(value)
-        runtime.setArgs(args)
+        runtime.set_args(args)
     return runtime
 
 
 pre_runtime.addCallback(update_args, options)
 
-pre_runtime.addCallback(benchmark, operation)
+def do_benchmark(runtime, operation, benchmark, *args):
+    benchmark(runtime, operation).benchmark(*args)
+
+pre_runtime.addCallback(do_benchmark, operation, benchmark, needed_data, options.pc)
 
 print "#### Starting reactor ###"
 reactor.run()
--- a/viff/hash_broadcast.py	Fri Oct 09 16:32:12 2009 +0200
+++ b/viff/hash_broadcast.py	Fri Oct 09 16:33:15 2009 +0200
@@ -54,7 +54,7 @@
             signals[peer_id] = long(signal)
             # If all signals are received then check if they are OK or INCONSISTENTHASH.
             if num_receivers == len(signals.keys()):
-                s = reduce(lambda x, y: OK if OK == y else INCONSISTENTHASH, signals.values())
+                s = reduce(lambda x, y: (OK == y and OK) or INCONSISTENTHASH, signals.values())
                 if OK == s:
                     # Make the result ready.
                     result.callback(message[0])
@@ -69,7 +69,10 @@
                 signal = OK
                 # First we check if the hashes we received are equal to the hash we computed ourselves.
                 for peer_id in receivers:
-                    signal = signal if a_hashes[peer_id] == a_hashes[self.id] else INCONSISTENTHASH
+                    if a_hashes[peer_id] == a_hashes[self.id]:
+                        signal = signal
+                    else:
+                        signal = INCONSISTENTHASH
                 # Then we send the SAME signal to everybody. 
                 for peer_id in receivers:
                     self.protocols[peer_id].sendData(unique_pc, SIGNAL, str(signal))           
--- a/viff/orlandi.py	Fri Oct 09 16:32:12 2009 +0200
+++ b/viff/orlandi.py	Fri Oct 09 16:33:15 2009 +0200
@@ -17,7 +17,7 @@
 
 from twisted.internet.defer import Deferred, DeferredList, gatherResults
 
-from viff.runtime import Runtime, Share, ShareList, gather_shares
+from viff.runtime import Runtime, Share, ShareList, gather_shares, preprocess
 from viff.util import rand
 from viff.constants import TEXT, PAILLIER
 from viff.field import FieldElement
@@ -598,6 +598,7 @@
         c = OrlandiShare(self, field, field(value), (field(0), field(0)), Cc)
         return c
 
+    @preprocess("random_triple")
     def _get_triple(self, field):
         c, d = self.random_triple(field, 1)
         def f(ls):
@@ -1025,7 +1026,7 @@
 
         return result
 
-    def random_triple(self, field, number_of_requested_triples):
+    def random_triple(self, field, quantity=1):
         """Generate a list of triples ``(a, b, c)`` where ``c = a * b``.
 
         The triple ``(a, b, c)`` is secure in the Fcrs-hybrid model.
@@ -1035,14 +1036,14 @@
 
         M = []
 
-# print "Generating %i triples... relax, have a brak..." % ((1 + self.s_lambda) * (2 * self.d + 1) * number_of_requested_triples)
+# print "Generating %i triples... relax, have a break..." % ((1 + self.s_lambda) * (2 * self.d + 1) * quantity)
 
-        for x in xrange((1 + self.s_lambda) * (2 * self.d + 1) * number_of_requested_triples):
+        for x in xrange((1 + self.s_lambda) * (2 * self.d + 1) * quantity):
             M.append(self.triple_test(field))
 
         def step3(ls):
             """Coin-flip a subset test_set of M of size lambda(2d + 1)M."""
-            size = self.s_lambda * (2 * self.d + 1) * number_of_requested_triples
+            size = self.s_lambda * (2 * self.d + 1) * quantity
             inx = 0
             p_half = field.modulus // 2
             def coin_flip(v, ls, test_set):
@@ -1250,18 +1251,18 @@
             return dls_all
 
         def step6(M_without_test_set):
-            """Partition M without test_set in number_of_requested_triples
+            """Partition M without test_set in quantity
             random subsets M_i of size (2d + 1).
             """
             subsets = []
             size = 2 * self.d + 1
-            for x in xrange(number_of_requested_triples):
+            for x in xrange(quantity):
                 subsets.append([])
 
             def put_in_set(v, M_without_test_set, subsets):
                 if 0 == len(M_without_test_set):
                     return subsets
-                v = v.value % number_of_requested_triples
+                v = v.value % quantity
                 if size > len(subsets[v]):
                     subsets[v].append(M_without_test_set.pop(0))
                 r = self.random_share(field)
@@ -1311,12 +1312,17 @@
         self.activate_reactor()
 
         s = Share(self, field)
-        def f(ls, s):
-            s.callback(ls)
-        result.addCallbacks(f, self.error_handler, callbackArgs=(s,))
-        return number_of_requested_triples, s
+        # We add the result to the chains in result.
+        result.chainDeferred(s)
+
+        return quantity, s
 
     def error_handler(self, ex):
         print "Error: ", ex
         return ex
 
+    def set_args(self, args):
+        """args is a dictionary."""
+        self.s = args['s']
+        self.d = args['d']
+        self.s_lambda = args['lambda']
--- a/viff/runtime.py	Fri Oct 09 16:32:12 2009 +0200
+++ b/viff/runtime.py	Fri Oct 09 16:33:15 2009 +0200
@@ -806,7 +806,14 @@
         example of a method fulfilling this interface.
         """
 
-        def update(results, program_counters):
+        def update(results, program_counters, start_time, count, what):
+            stop = time.time()
+
+            print
+            print "Total time used: %.3f sec" % (stop - start_time)
+            print "Time per %s operation: %.0f ms" % (what, 1000*(stop - start_time) / count)
+            print "*" * 6
+
             # We concatenate the sub-lists in results.
             results = sum(results, [])
 
@@ -831,7 +838,10 @@
             func = getattr(self, generator)
             results = []
             items = 0
+            count = 0
+            start_time = time.time()
             while items < len(program_counters):
+                count += 1
                 self.increment_pc()
                 self.fork_pc()
                 item_count, result = func(quantity=len(program_counters) - items, *args)
@@ -839,7 +849,7 @@
                 results.append(result)
                 self.unfork_pc()
             ready = gatherResults(results)
-            ready.addCallback(update, program_counters)
+            ready.addCallback(update, program_counters, start_time, count, generator)
             wait_list.append(ready)
             self.unfork_pc()
         return DeferredList(wait_list)
--- a/viff/test/test_active_runtime.py	Fri Oct 09 16:32:12 2009 +0200
+++ b/viff/test/test_active_runtime.py	Fri Oct 09 16:33:15 2009 +0200
@@ -113,7 +113,7 @@
                 results.append(result)
             return gatherResults(results)
 
-        count, triples = runtime.generate_triples(self.Zp)
+        count, triples = runtime.generate_triples(self.Zp, 1)
         self.assertEquals(count, runtime.num_players - 2*runtime.threshold)
 
         runtime.schedule_callback(triples, check)