viff

view viff/runtime.py @ 1429:b2b8e4a74cd6

runtime: print statistics with --statistics option.
author Martin Geisler <mg@cs.au.dk>
date Sun Jan 24 17:29:26 2010 +0100 (2 years ago)
parents 2324d01c74e2
children 1772506977cc
line source
1 # -*- coding: utf-8 -*-
2 #
3 # Copyright 2007, 2008, 2009 VIFF Development Team.
4 #
5 # This file is part of VIFF, the Virtual Ideal Functionality Framework.
6 #
7 # VIFF is free software: you can redistribute it and/or modify it
8 # under the terms of the GNU Lesser General Public License (LGPL) as
9 # published by the Free Software Foundation, either version 3 of the
10 # License, or (at your option) any later version.
11 #
12 # VIFF is distributed in the hope that it will be useful, but WITHOUT
13 # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
14 # or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
15 # Public License for more details.
16 #
17 # You should have received a copy of the GNU Lesser General Public
18 # License along with VIFF. If not, see <http://www.gnu.org/licenses/>.
20 """VIFF runtime. This is where the virtual ideal functionality is
21 hiding! The runtime is responsible for sharing inputs, handling
22 communication, and running the calculations.
24 Each player participating in the protocol will instantiate a
25 :class:`Runtime` object and use it for the calculations.
27 The Runtime returns :class:`Share` objects for most operations, and
28 these can be added, subtracted, and multiplied as normal thanks to
29 overloaded arithmetic operators. The runtime will take care of
30 scheduling things correctly behind the scenes.
31 """
32 from __future__ import division
34 import time
35 import struct
36 from optparse import OptionParser, OptionGroup
37 from collections import deque
38 import os
39 import sys
41 from viff.field import GF256, FieldElement
42 from viff.util import wrapper, rand, track_memory_usage, begin, end
43 from viff.constants import SHARE
44 import viff.reactor
46 from twisted.internet import reactor
47 from twisted.internet.task import LoopingCall
48 from twisted.internet.error import ConnectionDone, CannotListenError
49 from twisted.internet.defer import Deferred, DeferredList, gatherResults
50 from twisted.internet.defer import maybeDeferred
51 from twisted.internet.protocol import ReconnectingClientFactory, ServerFactory
52 from twisted.protocols.basic import Int16StringReceiver
55 class Share(Deferred):
56 """A shared number.
58 The :class:`Runtime` operates on shares, represented by this class.
59 Shares are asynchronous in the sense that they promise to attain a
60 value at some point in the future.
62 Shares overload the arithmetic operations so that ``x = a + b``
63 will create a new share *x*, which will eventually contain the
64 sum of *a* and *b*. Each share is associated with a
65 :class:`Runtime` and the arithmetic operations simply call back to
66 that runtime.
67 """
69 def __init__(self, runtime, field, value=None):
70 """Initialize a share.
72 If an initial value is given, it will be passed to
73 :meth:`callback` right away.
74 """
75 assert field is not None, "Cannot construct share without a field."
76 assert callable(field), "The field is not callable, wrong argument?"
78 Deferred.__init__(self)
79 self.runtime = runtime
80 self.field = field
81 if value is not None:
82 self.callback(value)
84 if os.environ.get("VIFF_PROFILE"):
85 old_init = __init__
87 def __init__(self, *a, **kw):
88 self.old_init(*a, **kw)
89 self.pc = self.runtime.program_counter[:]
90 begin(None, self.label())
92 def __del__(self):
93 end(None, self.label())
95 def label(self):
96 return "share " + hex(id(self)) + " " + \
97 ".".join(map(str, self.pc))
99 def __add__(self, other):
100 """Addition."""
101 return self.runtime.add(self, other)
103 def __radd__(self, other):
104 """Addition (reflected argument version)."""
105 return self.runtime.add(self, other)
107 def __sub__(self, other):
108 """Subtraction."""
109 return self.runtime.sub(self, other)
111 def __rsub__(self, other):
112 """Subtraction (reflected argument version)."""
113 return self.runtime.sub(other, self)
115 def __mul__(self, other):
116 """Multiplication."""
117 return self.runtime.mul(self, other)
119 def __rmul__(self, other):
120 """Multiplication (reflected argument version)."""
121 return self.runtime.mul(self, other)
123 def __pow__(self, exponent):
124 """Exponentation to known integer exponents."""
125 return self.runtime.pow(self, exponent)
127 def __xor__(self, other):
128 """Exclusive-or."""
129 return self.runtime.xor(self, other)
131 def __rxor__(self, other):
132 """Exclusive-or (reflected argument version)."""
133 return self.runtime.xor(self, other)
135 def __lt__(self, other):
136 """Strictly less-than comparison."""
137 # self < other <=> not (self >= other)
138 return 1 - self.runtime.greater_than_equal(self, other)
140 def __le__(self, other):
141 """Less-than or equal comparison."""
142 # self <= other <=> other >= self
143 return self.runtime.greater_than_equal(other, self)
145 def __gt__(self, other):
146 """Strictly greater-than comparison."""
147 # self > other <=> not (other >= self)
148 return 1 - self.runtime.greater_than_equal(other, self)
150 def __ge__(self, other):
151 """Greater-than or equal comparison."""
152 # self >= other
153 return self.runtime.greater_than_equal(self, other)
155 def __eq__(self, other):
156 """Equality testing."""
157 return self.runtime.equal(self, other)
159 def __neq__(self, other):
160 """Negated equality testing."""
161 return 1 - self.runtime.equal(self, other)
164 def clone(self):
165 """Clone a share.
167 Works like :meth:`util.clone_deferred` except that it returns a new
168 :class:`Share` instead of a :class:`Deferred`.
169 """
171 def split_result(result):
172 clone.callback(result)
173 return result
174 clone = Share(self.runtime, self.field)
175 self.addCallback(split_result)
176 return clone
179 class ShareList(Share):
180 """Create a share that waits on a number of other shares.
182 Roughly modelled after the Twisted :class:`DeferredList`
183 class. The advantage of this class is that it is a :class:`Share`
184 (not just a :class:`Deferred`) and that it can be made to trigger
185 when a certain threshold of the shares are ready. This example
186 shows how the :meth:`pprint` callback is triggered when *a* and
187 *c* are ready:
189 >>> from pprint import pprint
190 >>> from viff.field import GF256
191 >>> a = Share(None, GF256)
192 >>> b = Share(None, GF256)
193 >>> c = Share(None, GF256)
194 >>> shares = ShareList([a, b, c], threshold=2)
195 >>> shares.addCallback(pprint) # doctest: +ELLIPSIS
196 <ShareList at 0x...>
197 >>> a.callback(10)
198 >>> c.callback(20)
199 [(True, 10), None, (True, 20)]
201 The :meth:`pprint` function is called with a list of pairs. The first
202 component of each pair is a boolean indicating if the callback or
203 errback method was called on the corresponding :class:`Share`, and
204 the second component is the value given to the callback/errback.
206 If a threshold less than the full number of shares is used, some
207 of the pairs may be missing and :const:`None` is used instead. In
208 the example above the *b* share arrived later than *a* and *c*,
209 and so the list contains a :const:`None` on its place.
210 """
211 def __init__(self, shares, threshold=None):
212 """Initialize a share list.
214 The list of shares must be non-empty and if a threshold is
215 given, it must hold that ``0 < threshold <= len(shares)``. The
216 default threshold is ``len(shares)``.
217 """
218 assert len(shares) > 0, "Cannot create empty ShareList"
219 assert threshold is None or 0 < threshold <= len(shares), \
220 "Threshold out of range"
222 Share.__init__(self, shares[0].runtime, shares[0].field)
224 self.results = [None] * len(shares)
225 if threshold is None:
226 self.missing_shares = len(shares)
227 else:
228 self.missing_shares = threshold
230 for index, share in enumerate(shares):
231 share.addCallbacks(self._callback_fired, self._callback_fired,
232 callbackArgs=(index, True),
233 errbackArgs=(index, False))
235 def _callback_fired(self, result, index, success):
236 self.results[index] = (success, result)
237 self.missing_shares -= 1
238 if not self.called and self.missing_shares == 0:
239 self.callback(self.results)
240 return result
243 def gather_shares(shares):
244 """Gather shares.
246 Roughly modelled after the Twisted :meth:`gatherResults`
247 function. It takes a list of shares and returns a new
248 :class:`Share` which will be triggered with a list of values,
249 namely the values from the initial shares:
251 >>> from pprint import pprint
252 >>> from viff.field import GF256
253 >>> a = Share(None, GF256)
254 >>> b = Share(None, GF256)
255 >>> shares = gather_shares([a, b])
256 >>> shares.addCallback(pprint) # doctest: +ELLIPSIS
257 <ShareList at 0x...>
258 >>> a.callback(10)
259 >>> b.callback(20)
260 [10, 20]
261 """
263 def filter_results(results):
264 return [share for (_, share) in results]
265 share_list = ShareList(shares)
266 share_list.addCallback(filter_results)
267 return share_list
270 class ShareExchanger(Int16StringReceiver):
271 """Send and receive shares.
273 All players are connected by pair-wise connections and this
274 Twisted protocol is one such connection. It is used to send and
275 receive shares from one other player.
276 """
278 def __init__(self):
279 self.peer_id = None
280 self.lost_connection = Deferred()
281 #: Data expected to be received in the future.
282 self.incoming_data = {}
283 self.waiting_deferreds = {}
284 #: Statistics
285 self.sent_packets = 0
286 self.sent_bytes = 0
288 def connectionMade(self):
289 self.sendString(str(self.factory.runtime.id))
291 def connectionLost(self, reason):
292 reason.trap(ConnectionDone)
293 self.lost_connection.callback(self)
295 def stringReceived(self, string):
296 """Called when a share is received.
298 The string received is unpacked into the program counter, and
299 a data part. The data is passed the appropriate Deferred in
300 :class:`self.incoming_data`.
301 """
302 if self.peer_id is None:
303 # TODO: Handle ValueError if the string cannot be decoded.
304 self.peer_id = int(string)
305 try:
306 cert = self.transport.getPeerCertificate()
307 except AttributeError:
308 cert = None
309 if cert:
310 # The player ID are stored in the serial number of the
311 # certificate -- this makes it easy to check that the
312 # player is who he claims to be.
313 if cert.get_serial_number() != self.peer_id:
314 print "Peer %s claims to be %d, aborting!" \
315 % (cert.get_subject(), self.peer_id)
316 self.transport.loseConnection()
317 self.factory.identify_peer(self)
318 else:
319 try:
320 pc_size, data_size, data_type = struct.unpack("!HHB", string[:5])
321 fmt = "!%dI%ds" % (pc_size, data_size)
322 unpacked = struct.unpack(fmt, string[5:])
324 program_counter = unpacked[:pc_size]
325 data = unpacked[-1]
327 key = (program_counter, data_type)
329 if key in self.waiting_deferreds:
330 deq = self.waiting_deferreds[key]
331 deferred = deq.popleft()
332 if not deq:
333 del self.waiting_deferreds[key]
334 self.factory.runtime.handle_deferred_data(deferred, data)
335 else:
336 deq = self.incoming_data.setdefault(key, deque())
337 deq.append(data)
338 except struct.error, e:
339 self.factory.runtime.abort(self, e)
341 def sendData(self, program_counter, data_type, data):
342 """Send data to the peer.
344 The *program_counter* is a tuple of unsigned integers, the
345 *data_type* is an unsigned byte and *data* is a string.
347 The data is encoded as follows::
349 +---------+-----------+-----------+--------+--------------+
350 | pc_size | data_size | data_type | pc | data |
351 +---------+-----------+-----------+--------+--------------+
352 2 bytes 2 bytes 1 byte varies varies
354 The program counter takes up ``4 * pc_size`` bytes, the data
355 takes up ``data_size`` bytes.
356 """
357 pc_size = len(program_counter)
358 data_size = len(data)
359 fmt = "!HHB%dI%ds" % (pc_size, data_size)
360 t = (pc_size, data_size, data_type) + program_counter + (data,)
361 packet = struct.pack(fmt, *t)
362 self.sendString(packet)
363 self.sent_packets += 1
364 self.sent_bytes += len(packet)
366 def sendShare(self, program_counter, share):
367 """Send a share.
369 The program counter and the share are converted to bytes and
370 sent to the peer.
371 """
372 self.sendData(program_counter, SHARE, hex(share.value))
374 def loseConnection(self):
375 """Disconnect this protocol instance."""
376 self.transport.loseConnection()
378 class SelfShareExchanger(ShareExchanger):
380 def __init__(self, id, factory):
381 ShareExchanger.__init__(self)
382 self.peer_id = id
383 self.factory = factory
385 def stringReceived(self, program_counter, data_type, data):
386 """Called when a share is received.
388 The string received is unpacked into the program counter, and
389 a data part. The data is passed the appropriate Deferred in
390 :class:`self.incoming_data`.
391 """
392 try:
393 key = (program_counter, data_type)
395 if key in self.waiting_deferreds:
396 deq = self.waiting_deferreds[key]
397 deferred = deq.popleft()
398 if not deq:
399 del self.waiting_deferreds[key]
400 self.factory.runtime.handle_deferred_data(deferred, data)
401 else:
402 deq = self.incoming_data.setdefault(key, deque())
403 deq.append(data)
404 except struct.error, e:
405 self.factory.runtime.abort(self, e)
407 def sendData(self, program_counter, data_type, data):
408 """Send data to the self.id."""
409 self.stringReceived(program_counter, data_type, data)
411 def loseConnection(self):
412 """Disconnect this protocol instance."""
413 self.lost_connection.callback(self)
414 return None
417 class SelfShareExchangerFactory(ReconnectingClientFactory, ServerFactory):
418 """Factory for creating SelfShareExchanger protocols."""
420 protocol = SelfShareExchanger
421 maxDelay = 3
422 factor = 1.234567 # About half of the Twisted default
424 def __init__(self, runtime):
425 """Initialize the factory."""
426 self.runtime = runtime
428 def identify_peer(self, protocol):
429 raise Exception("Is identify_peer necessary?")
431 def clientConnectionLost(self, connector, reason):
432 reason.trap(ConnectionDone)
434 class FakeTransport(object):
435 def close(self):
436 return True
438 class ShareExchangerFactory(ReconnectingClientFactory, ServerFactory):
439 """Factory for creating ShareExchanger protocols."""
441 protocol = ShareExchanger
442 maxDelay = 3
443 factor = 1.234567 # About half of the Twisted default
445 def __init__(self, runtime, players, protocols_ready):
446 """Initialize the factory."""
447 self.runtime = runtime
448 self.players = players
449 self.needed_protocols = len(players) - 1
450 self.protocols_ready = protocols_ready
452 def identify_peer(self, protocol):
453 self.runtime.add_player(self.players[protocol.peer_id], protocol)
454 self.needed_protocols -= 1
455 if self.needed_protocols == 0:
456 self.protocols_ready.callback(self.runtime)
458 def clientConnectionLost(self, connector, reason):
459 reason.trap(ConnectionDone)
462 def preprocess(generator):
463 """Track calls to this method.
465 The decorated method will be replaced with a proxy method which
466 first tries to get the data needed from
467 :attr:`Runtime._pool`, and if that fails it falls back to the
468 original method. It also returns a flag to indicate whether the
469 data is from the pool.
471 The *generator* method is only used to record where the data
472 should be generated from, the method is not actually called. This
473 must be the name of the method (a string) and not the method
474 itself.
475 """
477 def preprocess_decorator(method):
479 @wrapper(method)
480 def preprocess_wrapper(self, *args, **kwargs):
481 self.increment_pc()
482 pc = tuple(self.program_counter)
483 try:
484 return self._pool.pop(pc), True
485 except KeyError:
486 key = (generator, args)
487 pcs = self._needed_data.setdefault(key, [])
488 pcs.append(pc)
489 self.fork_pc()
490 try:
491 return method(self, *args, **kwargs), False
492 finally:
493 self.unfork_pc()
495 return preprocess_wrapper
496 return preprocess_decorator
499 class Runtime:
500 """Basic VIFF runtime with no crypto.
502 This runtime contains only the most basic operations needed such
503 as the program counter, the list of other players, etc.
504 """
506 @staticmethod
507 def add_options(parser):
508 group = OptionGroup(parser, "VIFF Runtime Options")
509 parser.add_option_group(group)
511 group.add_option("-l", "--bit-length", type="int", metavar="L",
512 help=("Maximum bit length of input numbers for "
513 "comparisons."))
514 group.add_option("-k", "--security-parameter", type="int", metavar="K",
515 help=("Security parameter. Comparisons will leak "
516 "information with probability 2**-K."))
517 group.add_option("--no-ssl", action="store_false", dest="ssl",
518 help="Disable the use of secure SSL connections.")
519 group.add_option("--ssl", action="store_true",
520 help=("Enable the use of secure SSL connections "
521 "(if the OpenSSL bindings are available)."))
522 group.add_option("--deferred-debug", action="store_true",
523 help="Enable extra debug output for deferreds.")
524 group.add_option("--profile", action="store_true",
525 help="Collect and print profiling information.")
526 group.add_option("--track-memory", action="store_true",
527 help="Track memory usage over time.")
528 group.add_option("--statistics", action="store_true",
529 help="Print statistics on shutdown.")
530 group.add_option("--no-socket-retry", action="store_true",
531 default=False, help="Fail rather than keep retrying "
532 "to connect if port is already in use.")
534 try:
535 # Using __import__ since we do not use the module, we are
536 # only interested in the side-effect.
537 __import__('OpenSSL')
538 have_openssl = True
539 except ImportError:
540 have_openssl = False
542 parser.set_defaults(bit_length=32,
543 security_parameter=30,
544 ssl=have_openssl,
545 deferred_debug=False,
546 profile=False,
547 track_memory=False,
548 statistics=False)
550 def __init__(self, player, threshold, options=None):
551 """Initialize runtime.
553 Initialized a runtime owned by the given, the threshold, and
554 optionally a set of options. The runtime has no network
555 connections and knows of no other players -- the
556 :func:`create_runtime` function should be used instead to
557 create a usable runtime.
558 """
559 assert threshold > 0, "Must use a positive threshold."
560 #: ID of this player.
561 self.id = player.id
562 #: Shamir secret sharing threshold.
563 self.threshold = threshold
565 if options is None:
566 parser = OptionParser()
567 self.add_options(parser)
568 self.options = parser.get_default_values()
569 else:
570 self.options = options
572 if self.options.deferred_debug:
573 from twisted.internet import defer
574 defer.setDebugging(True)
576 #: Pool of preprocessed data.
577 self._pool = {}
578 #: Description of needed preprocessed data.
579 self._needed_data = {}
581 #: Current program counter.
582 self.program_counter = [0]
584 #: Connections to the other players.
585 #:
586 #: Mapping from from Player ID to :class:`ShareExchanger`
587 #: objects.
588 self.protocols = {}
590 #: Number of known players.
591 #:
592 #: Equal to ``len(self.players)``, but storing it here is more
593 #: direct.
594 self.num_players = 0
596 #: Information on players.
597 #:
598 #: Mapping from Player ID to :class:`Player` objects.
599 self.players = {}
600 # Add ourselves, but with no protocol since we wont be
601 # communicating with ourselves.
602 protocol = SelfShareExchanger(self.id, SelfShareExchangerFactory(self))
603 protocol.transport = FakeTransport()
604 self.add_player(player, protocol)
606 #: Queue of deferreds and data.
607 self.deferred_queue = deque()
608 self.complex_deferred_queue = deque()
609 #: Counter for calls of activate_reactor().
610 self.activation_counter = 0
611 #: Record the recursion depth.
612 self.depth_counter = 0
613 self.max_depth = 0
614 #: Recursion depth limit by experiment, including security margin.
615 self.depth_limit = int(sys.getrecursionlimit() / 50)
616 #: Use deferred queues only if the ViffReactor is running.
617 self.using_viff_reactor = isinstance(reactor, viff.reactor.ViffReactor)
619 def add_player(self, player, protocol):
620 self.players[player.id] = player
621 self.num_players = len(self.players)
622 self.protocols[player.id] = protocol
624 def shutdown(self):
625 """Shutdown the runtime.
627 All connections are closed and the runtime cannot be used
628 again after this has been called.
629 """
630 print "Synchronizing shutdown...",
632 def close_connections(_):
633 print "done."
634 print "Closing connections...",
635 results = [maybeDeferred(self.port.stopListening)]
636 for protocol in self.protocols.itervalues():
637 results.append(protocol.lost_connection)
638 protocol.loseConnection()
639 return DeferredList(results)
641 def stop_reactor(_):
642 print "done."
643 print "Stopping reactor...",
644 reactor.stop()
645 print "done."
647 sync = self.synchronize()
648 sync.addCallback(close_connections)
649 sync.addCallback(stop_reactor)
650 return sync
652 def abort(self, protocol, exc):
653 """Abort the execution due to an exception.
655 The *protocol* received bad data which resulted in *exc* being
656 raised when unpacking.
657 """
658 print "*** bad data from Player %d: %s" % (protocol.peer_id, exc)
659 print "*** aborting!"
660 for p in self.protocols.itervalues():
661 p.loseConnection()
662 reactor.stop()
663 print "*** all protocols disconnected"
665 def wait_for(self, *vars):
666 """Make the runtime wait for the variables given.
668 The runtime is shut down when all variables are calculated.
669 """
670 dl = DeferredList(vars)
671 self.schedule_callback(dl, lambda _: self.shutdown())
673 def increment_pc(self):
674 """Increment the program counter."""
675 self.program_counter[-1] += 1
677 def fork_pc(self):
678 """Fork the program counter."""
679 self.program_counter.append(0)
681 def unfork_pc(self):
682 """Leave a fork of the program counter."""
683 self.program_counter.pop()
685 def schedule_callback(self, deferred, func, *args, **kwargs):
686 """Schedule a callback on a deferred with the correct program
687 counter.
689 If a callback depends on the current program counter, then use
690 this method to schedule it instead of simply calling
691 addCallback directly. Simple callbacks that are independent of
692 the program counter can still be added directly to the
693 Deferred as usual.
695 Any extra arguments are passed to the callback as with
696 :meth:`addCallback`.
697 """
698 self.increment_pc()
699 saved_pc = self.program_counter[:]
701 @wrapper(func)
702 def callback_wrapper(*args, **kwargs):
703 """Wrapper for a callback which ensures a correct PC."""
704 try:
705 current_pc = self.program_counter[:]
706 self.program_counter[:] = saved_pc
707 self.fork_pc()
708 return func(*args, **kwargs)
709 finally:
710 self.program_counter[:] = current_pc
712 return deferred.addCallback(callback_wrapper, *args, **kwargs)
714 def schedule_complex_callback(self, deferred, func, *args, **kwargs):
715 """Schedule a complex callback, i.e. a callback which blocks a
716 long time.
718 Consider that the deferred is forked, i.e. if the callback returns
719 something to be used afterwards, add further callbacks to the returned
720 deferred."""
722 if not self.using_viff_reactor:
723 return self.schedule_callback(deferred, func, *args, **kwargs)
725 if isinstance(deferred, Share):
726 fork = Share(deferred.runtime, deferred.field)
727 else:
728 fork = Deferred()
730 def queue_callback(result, runtime, fork):
731 runtime.complex_deferred_queue.append((fork, result))
733 deferred.addCallback(queue_callback, self, fork)
734 return self.schedule_callback(fork, func, *args, **kwargs)
736 def synchronize(self):
737 """Introduce a synchronization point.
739 Returns a :class:`Deferred` which will trigger if and when all
740 other players have made their calls to :meth:`synchronize`. By
741 adding callbacks to the returned :class:`Deferred`, one can
742 divide a protocol execution into disjoint phases.
743 """
744 self.increment_pc()
745 shares = [self._exchange_shares(player, GF256(0))
746 for player in self.players]
747 result = gather_shares(shares)
748 result.addCallback(lambda _: None)
749 return result
751 def _expect_data(self, peer_id, data_type, deferred):
752 # Convert self.program_counter to a hashable value in order to
753 # use it as a key in self.protocols[peer_id].incoming_data.
754 pc = tuple(self.program_counter)
755 return self._expect_data_with_pc(pc, peer_id, data_type, deferred)
757 def _expect_data_with_pc(self, pc, peer_id, data_type, deferred):
758 key = (pc, data_type)
760 if key in self.protocols[peer_id].incoming_data:
761 # We have already received some data from the other side.
762 deq = self.protocols[peer_id].incoming_data[key]
763 data = deq.popleft()
764 if not deq:
765 del self.protocols[peer_id].incoming_data[key]
766 deferred.callback(data)
767 else:
768 # We have not yet received anything from the other side.
769 deq = self.protocols[peer_id].waiting_deferreds.setdefault(key, deque())
770 deq.append(deferred)
772 def _exchange_shares(self, peer_id, field_element):
773 """Exchange shares with another player.
775 We send the player our share and record a Deferred which will
776 trigger when the share from the other side arrives.
777 """
778 assert isinstance(field_element, FieldElement)
780 if peer_id == self.id:
781 return Share(self, field_element.field, field_element)
782 else:
783 share = self._expect_share(peer_id, field_element.field)
784 pc = tuple(self.program_counter)
785 self.protocols[peer_id].sendShare(pc, field_element)
786 return share
788 def _expect_share(self, peer_id, field):
789 share = Share(self, field)
790 share.addCallback(lambda value: field(long(value, 16)))
791 self._expect_data(peer_id, SHARE, share)
792 return share
794 def preprocess(self, program):
795 """Generate preprocess material.
797 The *program* specifies which methods to call and with which
798 arguments. The generator methods called must adhere to the
799 following interface:
801 - They must return a list of :class:`Deferred` instances.
803 - Every Deferred must yield an item of pre-processed data.
804 This can be value, a list or tuple of values, or a Deferred
805 (which will be converted to a value by Twisted), but NOT a
806 list of Deferreds. Use :meth:`gatherResults` to avoid the
807 latter.
809 The :meth:`~viff.active.TriplesPRSSMixin.generate_triples` method
810 is an example of a method fulfilling this interface.
811 """
813 def update(results, program_counters):
814 # Update the pool with pairs of program counter and data.
815 self._pool.update(zip(program_counters, results))
817 wait_list = []
818 for ((generator, args), program_counters) in program.iteritems():
819 print "Preprocessing %s (%d items)" % (generator, len(program_counters))
820 self.increment_pc()
821 self.fork_pc()
822 func = getattr(self, generator)
823 count = 0
824 start_time = time.time()
826 while program_counters:
827 count += 1
828 self.increment_pc()
829 self.fork_pc()
830 results = func(quantity=len(program_counters), *args)
831 self.unfork_pc()
832 ready = gatherResults(results)
833 ready.addCallback(update, program_counters[:len(results)])
834 del program_counters[:len(results)]
835 wait_list.append(ready)
836 self.unfork_pc()
837 return gatherResults(wait_list)
839 def input(self, inputters, field, number=None):
840 """Input *number* to the computation.
842 The players listed in *inputters* must provide an input
843 number, everybody will receive a list with :class:`Share`
844 objects, one from each *inputter*. If only a single player is
845 listed in *inputters*, then a :class:`Share` is given back
846 directly.
847 """
848 raise NotImplementedError
850 def output(self, share, receivers=None):
851 """Open *share* to *receivers* (defaults to all players).
853 Returns a :class:`Share` to players with IDs in *receivers*
854 and :const:`None` to the remaining players.
855 """
856 raise NotImplementedError
858 def add(self, share_a, share_b):
859 """Secure addition.
861 At least one of the arguments must be a :class:`Share`, the
862 other can be a :class:`~viff.field.FieldElement` or a
863 (possible long) Python integer."""
864 raise NotImplementedError
866 def mul(self, share_a, share_b):
867 """Secure multiplication.
869 At least one of the arguments must be a :class:`Share`, the
870 other can be a :class:`~viff.field.FieldElement` or a
871 (possible long) Python integer."""
872 raise NotImplementedError
874 def handle_deferred_data(self, deferred, data):
875 """Put deferred and data into the queue if the ViffReactor is running.
876 Otherwise, just execute the callback."""
878 if self.using_viff_reactor:
879 self.deferred_queue.append((deferred, data))
880 else:
881 deferred.callback(data)
883 def process_deferred_queue(self):
884 """Execute the callbacks of the deferreds in the queue.
886 If this function is not called via activate_reactor(), also
887 complex callbacks are executed."""
889 self.process_queue(self.deferred_queue)
891 if self.depth_counter == 0:
892 self.process_queue(self.complex_deferred_queue)
894 def process_queue(self, queue):
895 """Execute the callbacks of the deferreds in *queue*."""
897 while queue:
898 deferred, data = queue.popleft()
899 deferred.callback(data)
901 def activate_reactor(self):
902 """Activate the reactor to do actual communcation.
904 This is where the recursion happens."""
906 if not self.using_viff_reactor:
907 return
909 self.activation_counter += 1
911 # setting the number to n makes the reactor called
912 # only every n-th time
913 if self.activation_counter >= 2:
914 self.depth_counter += 1
916 if self.depth_counter > self.max_depth:
917 # Record the maximal depth reached.
918 self.max_depth = self.depth_counter
919 if self.depth_counter >= self.depth_limit:
920 print "Recursion depth limit reached."
922 if self.depth_counter < self.depth_limit:
923 reactor.doIteration(0)
925 self.depth_counter -= 1
926 self.activation_counter = 0
928 def print_transferred_data(self):
929 """Print the amount of transferred data for all connections."""
931 for protocol in self.protocols.itervalues():
932 print "Transfer to peer %d: %d bytes in %d packets" % \
933 (protocol.peer_id, protocol.sent_bytes, protocol.sent_packets)
936 def make_runtime_class(runtime_class=None, mixins=None):
937 """Creates a new runtime class with *runtime_class* as a base
938 class mixing in the *mixins*. By default
939 :class:`viff.passive.PassiveRuntime` will be used.
940 """
941 if runtime_class is None:
942 # The import is put here because of circular depencencies
943 # between viff.runtime and viff.passive.
944 from viff.passive import PassiveRuntime
945 runtime_class = PassiveRuntime
946 if mixins is None:
947 return runtime_class
948 else:
949 # The order is important: we want the most specific classes to
950 # go first so that they can override methods from later
951 # classes. We must also include at least one new-style class
952 # in bases -- we include it last to avoid overriding __init__
953 # from the other base classes.
954 bases = tuple(mixins) + (runtime_class, object)
955 return type("ExtendedRuntime", bases, {})
957 def create_runtime(id, players, threshold, options=None, runtime_class=None):
958 """Create a :class:`Runtime` and connect to the other players.
960 This function should be used in normal programs instead of
961 instantiating the Runtime directly. This function makes sure that
962 the Runtime is correctly connected to the other players.
964 The return value is a Deferred which will trigger when the runtime
965 is ready. Add your protocol as a callback on this Deferred using
966 code like this::
968 def protocol(runtime):
969 a, b, c = runtime.shamir_share([1, 2, 3], Zp, input)
971 a = runtime.open(a)
972 b = runtime.open(b)
973 c = runtime.open(c)
975 dprint("Opened a: %s", a)
976 dprint("Opened b: %s", b)
977 dprint("Opened c: %s", c)
979 runtime.wait_for(a,b,c)
981 pre_runtime = create_runtime(id, players, 1)
982 pre_runtime.addCallback(protocol)
984 This is the general template which VIFF programs should follow.
985 Please see the example applications for more examples.
987 """
988 if options and options.track_memory:
989 lc = LoopingCall(track_memory_usage)
990 # Five times per second seems like a fair value. Besides, the
991 # kernel will track the peak memory usage for us anyway.
992 lc.start(0.2)
993 reactor.addSystemEventTrigger("after", "shutdown", track_memory_usage)
995 if runtime_class is None:
996 # The import is put here because of circular depencencies
997 # between viff.runtime and viff.passive.
998 from viff.passive import PassiveRuntime
999 runtime_class = PassiveRuntime
1001 if options and options.profile:
1002 # To collect profiling information we monkey patch reactor.run
1003 # to do the collecting. It would be nicer to simply start the
1004 # profiler here and stop it upon shutdown, but this triggers
1005 # http://bugs.python.org/issue1375 since the start and stop
1006 # calls are in different stack frames.
1007 import cProfile
1008 prof = cProfile.Profile()
1009 old_run = reactor.run
1010 def new_run(*args, **kwargs):
1011 print "Starting reactor with profiling"
1012 prof.runcall(old_run, *args, **kwargs)
1014 import pstats
1015 stats = pstats.Stats(prof)
1016 print
1017 stats.strip_dirs()
1018 stats.sort_stats("time", "calls")
1019 stats.print_stats(40)
1020 stats.dump_stats("player-%d.pstats" % id)
1021 reactor.run = new_run
1023 # This will yield a Runtime when all protocols are connected.
1024 result = Deferred()
1026 # Create a runtime that knows about no other players than itself.
1027 # It will eventually be returned in result when the factory has
1028 # determined that all needed protocols are ready.
1029 runtime = runtime_class(players[id], threshold, options)
1030 factory = ShareExchangerFactory(runtime, players, result)
1032 if options and options.statistics:
1033 reactor.addSystemEventTrigger("after", "shutdown",
1034 runtime.print_transferred_data)
1036 if options and options.ssl:
1037 print "Using SSL"
1038 from twisted.internet.ssl import ContextFactory
1039 from OpenSSL import SSL
1041 class SSLContextFactory(ContextFactory):
1042 def __init__(self, id):
1043 """Create new SSL context factory for *id*."""
1044 self.id = id
1045 ctx = SSL.Context(SSL.SSLv3_METHOD)
1046 # TODO: Make the file names configurable.
1047 try:
1048 ctx.use_certificate_file('player-%d.cert' % id)
1049 ctx.use_privatekey_file('player-%d.key' % id)
1050 ctx.check_privatekey()
1051 ctx.load_verify_locations('ca.cert')
1052 ctx.set_verify(SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
1053 lambda conn, cert, errnum, depth, ok: ok)
1054 self.ctx = ctx
1055 except SSL.Error, e:
1056 print "SSL errors - did you forget to generate certificates?"
1057 for (lib, func, reason) in e.args[0]:
1058 print "* %s in %s: %s" % (func, lib, reason)
1059 raise SystemExit("Stopping program")
1061 def getContext(self):
1062 return self.ctx
1064 ctx_factory = SSLContextFactory(id)
1065 listen = lambda port: reactor.listenSSL(port, factory, ctx_factory)
1066 connect = lambda host, port: reactor.connectSSL(host, port, factory, ctx_factory)
1067 else:
1068 print "Not using SSL"
1069 listen = lambda port: reactor.listenTCP(port, factory)
1070 connect = lambda host, port: reactor.connectTCP(host, port, factory)
1072 port = players[id].port
1073 runtime.port = None
1074 delay = 2
1075 while runtime.port is None:
1076 # We keep trying to listen on the port, but with an
1077 # exponentially increasing delay between each attempt.
1078 try:
1079 runtime.port = listen(port)
1080 except CannotListenError, e:
1081 if options and options.no_socket_retry:
1082 raise
1083 delay *= 1 + rand.random()
1084 print "Error listening on port %d: %s" % (port, e.socketError[1])
1085 print "Will try again in %d seconds" % delay
1086 time.sleep(delay)
1087 print "Listening on port %d" % port
1089 for peer_id, player in players.iteritems():
1090 if peer_id > id:
1091 print "Will connect to %s" % player
1092 connect(player.host, player.port)
1094 if runtime.using_viff_reactor:
1095 # Process the deferred queue after every reactor iteration.
1096 reactor.setLoopCall(runtime.process_deferred_queue)
1098 return result
1100 if __name__ == "__main__":
1101 import doctest #pragma NO COVER
1102 doctest.testmod() #pragma NO COVER