viff

changeset 1280:79c351c812f3

Moved Benchmark classes to their own file.
author Janus Dam Nielsen <janus.nielsen@alexandra.dk>
date Fri, 16 Oct 2009 13:59:19 +0200
parents 26f7a133172a
children bced13257ba4
files apps/benchmark.py apps/benchmark_classes.py
diffstat 2 files changed, 196 insertions(+), 166 deletions(-) [+]
line diff
     1.1 --- a/apps/benchmark.py	Fri Oct 09 16:33:15 2009 +0200
     1.2 +++ b/apps/benchmark.py	Fri Oct 16 13:59:19 2009 +0200
     1.3 @@ -58,16 +58,13 @@
     1.4  from math import log
     1.5  from optparse import OptionParser
     1.6  import operator
     1.7 -from pprint import pformat
     1.8  
     1.9  import viff.reactor
    1.10  viff.reactor.install()
    1.11  from twisted.internet import reactor
    1.12 -from twisted.internet.defer import Deferred
    1.13  
    1.14 -from viff.field import GF, GF256, FakeGF
    1.15 -from viff.runtime import Runtime, create_runtime, gather_shares, \
    1.16 -    make_runtime_class
    1.17 +from viff.field import GF, FakeGF
    1.18 +from viff.runtime import Runtime, create_runtime, make_runtime_class
    1.19  from viff.passive import PassiveRuntime
    1.20  from viff.active import BasicActiveRuntime, \
    1.21      TriplesHyperinvertibleMatricesMixin, TriplesPRSSMixin
    1.22 @@ -76,8 +73,10 @@
    1.23  from viff.paillier import PaillierRuntime
    1.24  from viff.orlandi import OrlandiRuntime
    1.25  from viff.config import load_config
    1.26 -from viff.util import find_prime, rand
    1.27 +from viff.util import find_prime
    1.28  
    1.29 +from benchmark_classes import SelfcontainedBenchmarkStrategy, \
    1.30 +    NeededDataBenchmarkStrategy, ParallelBenchmark, SequentialBenchmark
    1.31  
    1.32  # Hack in order to avoid Maximum recursion depth exceeded
    1.33  # exception;
    1.34 @@ -85,23 +84,6 @@
    1.35  
    1.36  
    1.37  last_timestamp = time.time()
    1.38 -start = 0
    1.39 -
    1.40 -
    1.41 -def record_start(what):
    1.42 -    global start
    1.43 -    start = time.time()
    1.44 -    print "*" * 64
    1.45 -    print "Started", what
    1.46 -
    1.47 -
    1.48 -def record_stop(x, what):
    1.49 -    stop = time.time()
    1.50 -    print
    1.51 -    print "Total time used: %.3f sec" % (stop-start)
    1.52 -    print "Time per %s operation: %.0f ms" % (what, 1000*(stop-start) / count)
    1.53 -    print "*" * 6
    1.54 -    return x
    1.55  
    1.56  operations = {"mul": (operator.mul,[]),
    1.57                "compToft05": (operator.ge, [ComparisonToft05Mixin]),
    1.58 @@ -177,153 +159,15 @@
    1.59  else:
    1.60      Field = GF
    1.61  
    1.62 +
    1.63  Zp = Field(find_prime(options.modulus))
    1.64  print "Using field elements (%d bit modulus)" % log(Zp.modulus, 2)
    1.65  
    1.66 +
    1.67  count = options.count
    1.68  print "I am player %d, will %s %d numbers" % (id, options.operation, count)
    1.69  
    1.70  
    1.71 -class BenchmarkStrategy:
    1.72 -
    1.73 -    def benchmark(self, *args):
    1.74 -        raise NotImplemented("Override this abstract method in subclasses")
    1.75 -
    1.76 -
    1.77 -class SelfcontainedBenchmarkStrategy(BenchmarkStrategy):
    1.78 -
    1.79 -    def benchmark(self, *args):
    1.80 -        sys.stdout.flush()
    1.81 -        sync = self.rt.synchronize()
    1.82 -        self.doTest(sync, lambda x: x)
    1.83 -        self.rt.schedule_callback(sync, self.preprocess)
    1.84 -        self.doTest(sync, lambda x: self.rt.shutdown())
    1.85 -
    1.86 -
    1.87 -class NeededDataBenchmarkStrategy(BenchmarkStrategy):
    1.88 -
    1.89 -    def benchmark(self, needed_data, pc, *args):
    1.90 -        self.pc = pc
    1.91 -        sys.stdout.flush()
    1.92 -        sync = self.rt.synchronize()
    1.93 -        self.rt.schedule_callback(sync, lambda x: needed_data)
    1.94 -        self.rt.schedule_callback(sync, self.preprocess)
    1.95 -        self.doTest(sync, lambda x: self.rt.shutdown())
    1.96 -
    1.97 -
    1.98 -# Defining the protocol as a class makes it easier to write the
    1.99 -# callbacks in the order they are called. This class is a base class
   1.100 -# that executes the protocol by calling the run_test method.
   1.101 -class Benchmark:
   1.102 -
   1.103 -    def __init__(self, rt, operation):
   1.104 -        self.rt = rt
   1.105 -        self.operation = operation
   1.106 -        self.pc = None
   1.107 -        
   1.108 -    def preprocess(self, needed_data):
   1.109 -        print "Preprocess", needed_data
   1.110 -        if needed_data:
   1.111 -            print "Starting preprocessing"
   1.112 -            record_start("preprocessing")
   1.113 -            preproc = self.rt.preprocess(needed_data)
   1.114 -            preproc.addCallback(record_stop, "preprocessing")
   1.115 -            return preproc
   1.116 -        else:
   1.117 -            print "Need no preprocessing"
   1.118 -            return None
   1.119 -
   1.120 -    def doTest(self, d, termination_function):
   1.121 -        self.rt.schedule_callback(d, self.begin)
   1.122 -        self.rt.schedule_callback(d, self.sync_test)
   1.123 -        self.rt.schedule_callback(d, self.run_test)
   1.124 -        self.rt.schedule_callback(d, self.sync_test)
   1.125 -        self.rt.schedule_callback(d, self.finished, termination_function)
   1.126 -        return d
   1.127 -
   1.128 -    def begin(self, _):
   1.129 -        print "begin", self.rt.program_counter
   1.130 -        print "Runtime ready, generating shares"
   1.131 -        self.a_shares = []
   1.132 -        self.b_shares = []
   1.133 -        for i in range(count):
   1.134 -            inputter = (i % len(self.rt.players)) + 1
   1.135 -            if inputter == self.rt.id:
   1.136 -                a = rand.randint(0, Zp.modulus)
   1.137 -                b = rand.randint(0, Zp.modulus)
   1.138 -            else:
   1.139 -                a, b = None, None
   1.140 -            self.a_shares.append(self.rt.input([inputter], Zp, a))
   1.141 -            self.b_shares.append(self.rt.input([inputter], Zp, b))
   1.142 -        shares_ready = gather_shares(self.a_shares + self.b_shares)
   1.143 -        return shares_ready
   1.144 -
   1.145 -    def sync_test(self, x):
   1.146 -        print "Synchronizing test start."
   1.147 -        sys.stdout.flush()
   1.148 -        sync = self.rt.synchronize()
   1.149 -        self.rt.schedule_callback(sync, lambda y: x)
   1.150 -        return sync
   1.151 -
   1.152 -    def run_test(self, _):
   1.153 -        raise NotImplemented("Override this abstract method in a sub class.")
   1.154 -
   1.155 -    def finished(self, needed_data, termination_function):
   1.156 -        sys.stdout.flush()
   1.157 -
   1.158 -        if self.rt._needed_data:
   1.159 -            print "Missing pre-processed data:"
   1.160 -            for (func, args), pcs in needed_data.iteritems():
   1.161 -                print "* %s%s:" % (func, args)
   1.162 -                print "  " + pformat(pcs).replace("\n", "\n  ")
   1.163 -
   1.164 -        return termination_function(needed_data)
   1.165 -
   1.166 -# This class implements a benchmark where run_test executes all
   1.167 -# operations in parallel.
   1.168 -class ParallelBenchmark(Benchmark):
   1.169 -
   1.170 -    def run_test(self, shares):
   1.171 -        print "rt", self.rt.program_counter, self.pc
   1.172 -        if self.pc != None:
   1.173 -            self.rt.program_counter = self.pc
   1.174 -        else:
   1.175 -            self.pc = list(self.rt.program_counter)
   1.176 -        c_shares = []
   1.177 -        record_start("parallel test")
   1.178 -        while self.a_shares and self.b_shares:
   1.179 -            a = self.a_shares.pop()
   1.180 -            b = self.b_shares.pop()
   1.181 -            c_shares.append(self.operation(a, b))
   1.182 -            print "."
   1.183 -
   1.184 -        done = gather_shares(c_shares)
   1.185 -        done.addCallback(record_stop, "parallel test")
   1.186 -        def f(x):
   1.187 -            needed_data = self.rt._needed_data
   1.188 -            self.rt._needed_data = {}
   1.189 -            return needed_data
   1.190 -        done.addCallback(f)
   1.191 -        return done
   1.192 -
   1.193 -
   1.194 -# A benchmark where the operations are executed one after each other.
   1.195 -class SequentialBenchmark(Benchmark):
   1.196 -
   1.197 -    def run_test(self, _, termination_function, d):
   1.198 -        record_start("sequential test")
   1.199 -        self.single_operation(None, termination_function)
   1.200 -
   1.201 -    def single_operation(self, _, termination_function):
   1.202 -        if self.a_shares and self.b_shares:
   1.203 -            a = self.a_shares.pop()
   1.204 -            b = self.b_shares.pop()
   1.205 -            c = self.operation(a, b)
   1.206 -            self.rt.schedule_callback(c, self.single_operation, termination_function)
   1.207 -        else:
   1.208 -            record_stop(None, "sequential test")
   1.209 -            self.finished(None, termination_function)
   1.210 -
   1.211  # Identify the base runtime class.
   1.212  base_runtime_class = runtimes[options.runtime]
   1.213  
   1.214 @@ -379,10 +223,10 @@
   1.215  
   1.216  pre_runtime.addCallback(update_args, options)
   1.217  
   1.218 -def do_benchmark(runtime, operation, benchmark, *args):
   1.219 -    benchmark(runtime, operation).benchmark(*args)
   1.220 +def do_benchmark(runtime, operation, benchmark, field, count, *args):
   1.221 +    benchmark(runtime, operation, field, count).benchmark(*args)
   1.222  
   1.223 -pre_runtime.addCallback(do_benchmark, operation, benchmark, needed_data, options.pc)
   1.224 +pre_runtime.addCallback(do_benchmark, operation, benchmark, Zp, count, needed_data, options.pc)
   1.225  
   1.226  print "#### Starting reactor ###"
   1.227  reactor.run()
     2.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     2.2 +++ b/apps/benchmark_classes.py	Fri Oct 16 13:59:19 2009 +0200
     2.3 @@ -0,0 +1,186 @@
     2.4 +# Copyright 2009 VIFF Development Team.
     2.5 +#
     2.6 +# This file is part of VIFF, the Virtual Ideal Functionality Framework.
     2.7 +#
     2.8 +# VIFF is free software: you can redistribute it and/or modify it
     2.9 +# under the terms of the GNU Lesser General Public License (LGPL) as
    2.10 +# published by the Free Software Foundation, either version 3 of the
    2.11 +# License, or (at your option) any later version.
    2.12 +#
    2.13 +# VIFF is distributed in the hope that it will be useful, but WITHOUT
    2.14 +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
    2.15 +# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
    2.16 +# Public License for more details.
    2.17 +#
    2.18 +# You should have received a copy of the GNU Lesser General Public
    2.19 +# License along with VIFF. If not, see <http://www.gnu.org/licenses/>.
    2.20 +
    2.21 +import sys
    2.22 +import time
    2.23 +
    2.24 +from pprint import pformat
    2.25 +
    2.26 +from viff.runtime import gather_shares
    2.27 +from viff.util import rand
    2.28 +
    2.29 +start = 0
    2.30 +
    2.31 +
    2.32 +def record_start(what):
    2.33 +    global start
    2.34 +    start = time.time()
    2.35 +    print "*" * 64
    2.36 +    print "Started", what
    2.37 +
    2.38 +
    2.39 +def record_stop(x, what, count):
    2.40 +    stop = time.time()
    2.41 +    print
    2.42 +    print "Total time used: %.3f sec" % (stop-start)
    2.43 +    print "Time per %s operation: %.0f ms" % (what, 1000*(stop-start) / count)
    2.44 +    print "*" * 6
    2.45 +    return x
    2.46 +
    2.47 +
    2.48 +# Defining the protocol as a class makes it easier to write the
    2.49 +# callbacks in the order they are called. This class is a base class
    2.50 +# that executes the protocol by calling the run_test method.
    2.51 +class Benchmark:
    2.52 +
    2.53 +    def __init__(self, rt, operation, field, count):
    2.54 +        self.rt = rt
    2.55 +        self.operation = operation
    2.56 +        self.pc = None
    2.57 +        self.field = field
    2.58 +        self.count = count
    2.59 +        
    2.60 +    def preprocess(self, needed_data):
    2.61 +        print "Preprocess", needed_data
    2.62 +        if needed_data:
    2.63 +            print "Starting preprocessing"
    2.64 +            record_start("preprocessing")
    2.65 +            preproc = self.rt.preprocess(needed_data)
    2.66 +            preproc.addCallback(record_stop, "preprocessing", self.count)
    2.67 +            return preproc
    2.68 +        else:
    2.69 +            print "Need no preprocessing"
    2.70 +            return None
    2.71 +
    2.72 +    def doTest(self, d, termination_function):
    2.73 +        self.rt.schedule_callback(d, self.begin)
    2.74 +        self.rt.schedule_callback(d, self.sync_test)
    2.75 +        self.rt.schedule_callback(d, self.run_test)
    2.76 +        self.rt.schedule_callback(d, self.sync_test)
    2.77 +        self.rt.schedule_callback(d, self.finished, termination_function)
    2.78 +        return d
    2.79 +
    2.80 +    def begin(self, _):
    2.81 +        print "begin", self.rt.program_counter
    2.82 +        print "Runtime ready, generating shares"
    2.83 +        self.a_shares = []
    2.84 +        self.b_shares = []
    2.85 +        for i in range(self.count):
    2.86 +            inputter = (i % len(self.rt.players)) + 1
    2.87 +            if inputter == self.rt.id:
    2.88 +                a = rand.randint(0, self.field.modulus)
    2.89 +                b = rand.randint(0, self.field.modulus)
    2.90 +            else:
    2.91 +                a, b = None, None
    2.92 +            self.a_shares.append(self.rt.input([inputter], self.field, a))
    2.93 +            self.b_shares.append(self.rt.input([inputter], self.field, b))
    2.94 +        shares_ready = gather_shares(self.a_shares + self.b_shares)
    2.95 +        return shares_ready
    2.96 +
    2.97 +    def sync_test(self, x):
    2.98 +        print "Synchronizing test start."
    2.99 +        sys.stdout.flush()
   2.100 +        sync = self.rt.synchronize()
   2.101 +        self.rt.schedule_callback(sync, lambda y: x)
   2.102 +        return sync
   2.103 +
   2.104 +    def run_test(self, _):
   2.105 +        raise NotImplemented("Override this abstract method in a sub class.")
   2.106 +
   2.107 +    def finished(self, needed_data, termination_function):
   2.108 +        sys.stdout.flush()
   2.109 +
   2.110 +        if self.rt._needed_data:
   2.111 +            print "Missing pre-processed data:"
   2.112 +            for (func, args), pcs in needed_data.iteritems():
   2.113 +                print "* %s%s:" % (func, args)
   2.114 +                print "  " + pformat(pcs).replace("\n", "\n  ")
   2.115 +
   2.116 +        return termination_function(needed_data)
   2.117 +
   2.118 +
   2.119 +# This class implements a benchmark where run_test executes all
   2.120 +# operations in parallel.
   2.121 +class ParallelBenchmark(Benchmark):
   2.122 +
   2.123 +    def run_test(self, shares):
   2.124 +        print "rt", self.rt.program_counter, self.pc
   2.125 +        if self.pc != None:
   2.126 +            self.rt.program_counter = self.pc
   2.127 +        else:
   2.128 +            self.pc = list(self.rt.program_counter)
   2.129 +        c_shares = []
   2.130 +        record_start("parallel test")
   2.131 +        while self.a_shares and self.b_shares:
   2.132 +            a = self.a_shares.pop()
   2.133 +            b = self.b_shares.pop()
   2.134 +            c_shares.append(self.operation(a, b))
   2.135 +            print "."
   2.136 +
   2.137 +        done = gather_shares(c_shares)
   2.138 +        done.addCallback(record_stop, "parallel test", self.count)
   2.139 +        def f(x):
   2.140 +            needed_data = self.rt._needed_data
   2.141 +            self.rt._needed_data = {}
   2.142 +            return needed_data
   2.143 +        done.addCallback(f)
   2.144 +        return done
   2.145 +
   2.146 +
   2.147 +# A benchmark where the operations are executed one after each other.
   2.148 +class SequentialBenchmark(Benchmark):
   2.149 +
   2.150 +    def run_test(self, _, termination_function, d):
   2.151 +        record_start("sequential test")
   2.152 +        self.single_operation(None, termination_function)
   2.153 +
   2.154 +    def single_operation(self, _, termination_function):
   2.155 +        if self.a_shares and self.b_shares:
   2.156 +            a = self.a_shares.pop()
   2.157 +            b = self.b_shares.pop()
   2.158 +            c = self.operation(a, b)
   2.159 +            self.rt.schedule_callback(c, self.single_operation, termination_function)
   2.160 +        else:
   2.161 +            record_stop(None, "sequential test", self.count)
   2.162 +            self.finished(None, termination_function)
   2.163 +
   2.164 +
   2.165 +class BenchmarkStrategy:
   2.166 +
   2.167 +    def benchmark(self, *args):
   2.168 +        raise NotImplemented("Override this abstract method in subclasses")
   2.169 +
   2.170 +
   2.171 +class SelfcontainedBenchmarkStrategy(BenchmarkStrategy):
   2.172 +
   2.173 +    def benchmark(self, *args):
   2.174 +        sys.stdout.flush()
   2.175 +        sync = self.rt.synchronize()
   2.176 +        self.doTest(sync, lambda x: x)
   2.177 +        self.rt.schedule_callback(sync, self.preprocess)
   2.178 +        self.doTest(sync, lambda x: self.rt.shutdown())
   2.179 +
   2.180 +
   2.181 +class NeededDataBenchmarkStrategy(BenchmarkStrategy):
   2.182 +
   2.183 +    def benchmark(self, needed_data, pc, *args):
   2.184 +        self.pc = pc
   2.185 +        sys.stdout.flush()
   2.186 +        sync = self.rt.synchronize()
   2.187 +        self.rt.schedule_callback(sync, lambda x: needed_data)
   2.188 +        self.rt.schedule_callback(sync, self.preprocess)
   2.189 +        self.doTest(sync, lambda x: self.rt.shutdown())