viff

view viff/runtime.py @ 1538:9d4f9551644c

Added computation-id option.

This changeset adds a command line option to VIFF allowing users to
specify a computation id.

Prior to this changeset, any computation involving pseudo-random
secret sharing (which for the time being boils down to computations
done with the PassiveRuntime) could only be run one time using the
same set of VIFF player configuration files. If more than one
computation was executed with the same set of configuration files, the
security of the system would be broken.

With this changeset, multiple computations can be run securely with
the same set of configuration files as long as each computation is run
with a unique computation id.
author Tomas Toft <ttoft at cs.au.dk>
date Wed Aug 11 16:09:32 2010 +0200 (21 months ago)
parents 1772506977cc
children
line source
1 # -*- coding: utf-8 -*-
2 #
3 # Copyright 2007, 2008, 2009 VIFF Development Team.
4 #
5 # This file is part of VIFF, the Virtual Ideal Functionality Framework.
6 #
7 # VIFF is free software: you can redistribute it and/or modify it
8 # under the terms of the GNU Lesser General Public License (LGPL) as
9 # published by the Free Software Foundation, either version 3 of the
10 # License, or (at your option) any later version.
11 #
12 # VIFF is distributed in the hope that it will be useful, but WITHOUT
13 # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
14 # or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
15 # Public License for more details.
16 #
17 # You should have received a copy of the GNU Lesser General Public
18 # License along with VIFF. If not, see <http://www.gnu.org/licenses/>.
20 """VIFF runtime. This is where the virtual ideal functionality is
21 hiding! The runtime is responsible for sharing inputs, handling
22 communication, and running the calculations.
24 Each player participating in the protocol will instantiate a
25 :class:`Runtime` object and use it for the calculations.
27 The Runtime returns :class:`Share` objects for most operations, and
28 these can be added, subtracted, and multiplied as normal thanks to
29 overloaded arithmetic operators. The runtime will take care of
30 scheduling things correctly behind the scenes.
31 """
32 from __future__ import division
34 import time
35 import struct
36 from optparse import OptionParser, OptionGroup
37 from collections import deque
38 import os
39 import sys
41 from viff.field import GF256, FieldElement
42 from viff.util import wrapper, rand, track_memory_usage, begin, end
43 from viff.constants import SHARE
44 import viff.reactor
46 from twisted.internet import reactor
47 from twisted.internet.task import LoopingCall
48 from twisted.internet.error import ConnectionDone, CannotListenError
49 from twisted.internet.defer import Deferred, DeferredList, gatherResults
50 from twisted.internet.defer import maybeDeferred
51 from twisted.internet.protocol import ReconnectingClientFactory, ServerFactory
52 from twisted.protocols.basic import Int16StringReceiver
55 class Share(Deferred):
56 """A shared number.
58 The :class:`Runtime` operates on shares, represented by this class.
59 Shares are asynchronous in the sense that they promise to attain a
60 value at some point in the future.
62 Shares overload the arithmetic operations so that ``x = a + b``
63 will create a new share *x*, which will eventually contain the
64 sum of *a* and *b*. Each share is associated with a
65 :class:`Runtime` and the arithmetic operations simply call back to
66 that runtime.
67 """
69 def __init__(self, runtime, field, value=None):
70 """Initialize a share.
72 If an initial value is given, it will be passed to
73 :meth:`callback` right away.
74 """
75 assert field is not None, "Cannot construct share without a field."
76 assert callable(field), "The field is not callable, wrong argument?"
78 Deferred.__init__(self)
79 self.runtime = runtime
80 self.field = field
81 if value is not None:
82 self.callback(value)
84 if os.environ.get("VIFF_PROFILE"):
85 old_init = __init__
87 def __init__(self, *a, **kw):
88 self.old_init(*a, **kw)
89 self.pc = self.runtime.program_counter[:]
90 begin(None, self.label())
92 def __del__(self):
93 end(None, self.label())
95 def label(self):
96 return "share " + hex(id(self)) + " " + \
97 ".".join(map(str, self.pc))
99 def __add__(self, other):
100 """Addition."""
101 return self.runtime.add(self, other)
103 def __radd__(self, other):
104 """Addition (reflected argument version)."""
105 return self.runtime.add(self, other)
107 def __sub__(self, other):
108 """Subtraction."""
109 return self.runtime.sub(self, other)
111 def __rsub__(self, other):
112 """Subtraction (reflected argument version)."""
113 return self.runtime.sub(other, self)
115 def __mul__(self, other):
116 """Multiplication."""
117 return self.runtime.mul(self, other)
119 def __rmul__(self, other):
120 """Multiplication (reflected argument version)."""
121 return self.runtime.mul(self, other)
123 def __pow__(self, exponent):
124 """Exponentation to known integer exponents."""
125 return self.runtime.pow(self, exponent)
127 def __xor__(self, other):
128 """Exclusive-or."""
129 return self.runtime.xor(self, other)
131 def __rxor__(self, other):
132 """Exclusive-or (reflected argument version)."""
133 return self.runtime.xor(self, other)
135 def __lt__(self, other):
136 """Strictly less-than comparison."""
137 # self < other <=> not (self >= other)
138 return 1 - self.runtime.greater_than_equal(self, other)
140 def __le__(self, other):
141 """Less-than or equal comparison."""
142 # self <= other <=> other >= self
143 return self.runtime.greater_than_equal(other, self)
145 def __gt__(self, other):
146 """Strictly greater-than comparison."""
147 # self > other <=> not (other >= self)
148 return 1 - self.runtime.greater_than_equal(other, self)
150 def __ge__(self, other):
151 """Greater-than or equal comparison."""
152 # self >= other
153 return self.runtime.greater_than_equal(self, other)
155 def __eq__(self, other):
156 """Equality testing."""
157 return self.runtime.equal(self, other)
159 def __neq__(self, other):
160 """Negated equality testing."""
161 return 1 - self.runtime.equal(self, other)
164 def clone(self):
165 """Clone a share.
167 Works like :meth:`util.clone_deferred` except that it returns a new
168 :class:`Share` instead of a :class:`Deferred`.
169 """
171 def split_result(result):
172 clone.callback(result)
173 return result
174 clone = Share(self.runtime, self.field)
175 self.addCallback(split_result)
176 return clone
179 class ShareList(Share):
180 """Create a share that waits on a number of other shares.
182 Roughly modelled after the Twisted :class:`DeferredList`
183 class. The advantage of this class is that it is a :class:`Share`
184 (not just a :class:`Deferred`) and that it can be made to trigger
185 when a certain threshold of the shares are ready. This example
186 shows how the :meth:`pprint` callback is triggered when *a* and
187 *c* are ready:
189 >>> from pprint import pprint
190 >>> from viff.field import GF256
191 >>> a = Share(None, GF256)
192 >>> b = Share(None, GF256)
193 >>> c = Share(None, GF256)
194 >>> shares = ShareList([a, b, c], threshold=2)
195 >>> shares.addCallback(pprint) # doctest: +ELLIPSIS
196 <ShareList at 0x...>
197 >>> a.callback(10)
198 >>> c.callback(20)
199 [(True, 10), None, (True, 20)]
201 The :meth:`pprint` function is called with a list of pairs. The first
202 component of each pair is a boolean indicating if the callback or
203 errback method was called on the corresponding :class:`Share`, and
204 the second component is the value given to the callback/errback.
206 If a threshold less than the full number of shares is used, some
207 of the pairs may be missing and :const:`None` is used instead. In
208 the example above the *b* share arrived later than *a* and *c*,
209 and so the list contains a :const:`None` on its place.
210 """
211 def __init__(self, shares, threshold=None):
212 """Initialize a share list.
214 The list of shares must be non-empty and if a threshold is
215 given, it must hold that ``0 < threshold <= len(shares)``. The
216 default threshold is ``len(shares)``.
217 """
218 assert len(shares) > 0, "Cannot create empty ShareList"
219 assert threshold is None or 0 < threshold <= len(shares), \
220 "Threshold out of range"
222 Share.__init__(self, shares[0].runtime, shares[0].field)
224 self.results = [None] * len(shares)
225 if threshold is None:
226 self.missing_shares = len(shares)
227 else:
228 self.missing_shares = threshold
230 for index, share in enumerate(shares):
231 share.addCallbacks(self._callback_fired, self._callback_fired,
232 callbackArgs=(index, True),
233 errbackArgs=(index, False))
235 def _callback_fired(self, result, index, success):
236 self.results[index] = (success, result)
237 self.missing_shares -= 1
238 if not self.called and self.missing_shares == 0:
239 self.callback(self.results)
240 return result
243 def gather_shares(shares):
244 """Gather shares.
246 Roughly modelled after the Twisted :meth:`gatherResults`
247 function. It takes a list of shares and returns a new
248 :class:`Share` which will be triggered with a list of values,
249 namely the values from the initial shares:
251 >>> from pprint import pprint
252 >>> from viff.field import GF256
253 >>> a = Share(None, GF256)
254 >>> b = Share(None, GF256)
255 >>> shares = gather_shares([a, b])
256 >>> shares.addCallback(pprint) # doctest: +ELLIPSIS
257 <ShareList at 0x...>
258 >>> a.callback(10)
259 >>> b.callback(20)
260 [10, 20]
261 """
263 def filter_results(results):
264 return [share for (_, share) in results]
265 share_list = ShareList(shares)
266 share_list.addCallback(filter_results)
267 return share_list
270 class ShareExchanger(Int16StringReceiver):
271 """Send and receive shares.
273 All players are connected by pair-wise connections and this
274 Twisted protocol is one such connection. It is used to send and
275 receive shares from one other player.
276 """
278 def __init__(self):
279 self.peer_id = None
280 self.lost_connection = Deferred()
281 #: Data expected to be received in the future.
282 self.incoming_data = {}
283 self.waiting_deferreds = {}
284 #: Statistics
285 self.sent_packets = 0
286 self.sent_bytes = 0
288 def connectionMade(self):
289 self.sendString(str(self.factory.runtime.id))
291 def connectionLost(self, reason):
292 reason.trap(ConnectionDone)
293 self.lost_connection.callback(self)
295 def stringReceived(self, string):
296 """Called when a share is received.
298 The string received is unpacked into the program counter, and
299 a data part. The data is passed the appropriate Deferred in
300 :class:`self.incoming_data`.
301 """
302 if self.peer_id is None:
303 # TODO: Handle ValueError if the string cannot be decoded.
304 self.peer_id = int(string)
305 try:
306 cert = self.transport.getPeerCertificate()
307 except AttributeError:
308 cert = None
309 if cert:
310 # The player ID are stored in the serial number of the
311 # certificate -- this makes it easy to check that the
312 # player is who he claims to be.
313 if cert.get_serial_number() != self.peer_id:
314 print "Peer %s claims to be %d, aborting!" \
315 % (cert.get_subject(), self.peer_id)
316 self.transport.loseConnection()
317 self.factory.identify_peer(self)
318 else:
319 try:
320 pc_size, data_size, data_type = struct.unpack("!HHB", string[:5])
321 fmt = "!%dI%ds" % (pc_size, data_size)
322 unpacked = struct.unpack(fmt, string[5:])
324 program_counter = unpacked[:pc_size]
325 data = unpacked[-1]
327 key = (program_counter, data_type)
329 if key in self.waiting_deferreds:
330 deq = self.waiting_deferreds[key]
331 deferred = deq.popleft()
332 if not deq:
333 del self.waiting_deferreds[key]
334 self.factory.runtime.handle_deferred_data(deferred, data)
335 else:
336 deq = self.incoming_data.setdefault(key, deque())
337 deq.append(data)
338 except struct.error, e:
339 self.factory.runtime.abort(self, e)
341 def sendData(self, program_counter, data_type, data):
342 """Send data to the peer.
344 The *program_counter* is a tuple of unsigned integers, the
345 *data_type* is an unsigned byte and *data* is a string.
347 The data is encoded as follows::
349 +---------+-----------+-----------+--------+--------------+
350 | pc_size | data_size | data_type | pc | data |
351 +---------+-----------+-----------+--------+--------------+
352 2 bytes 2 bytes 1 byte varies varies
354 The program counter takes up ``4 * pc_size`` bytes, the data
355 takes up ``data_size`` bytes.
356 """
357 pc_size = len(program_counter)
358 data_size = len(data)
359 fmt = "!HHB%dI%ds" % (pc_size, data_size)
360 t = (pc_size, data_size, data_type) + program_counter + (data,)
361 packet = struct.pack(fmt, *t)
362 self.sendString(packet)
363 self.sent_packets += 1
364 self.sent_bytes += len(packet)
366 def sendShare(self, program_counter, share):
367 """Send a share.
369 The program counter and the share are converted to bytes and
370 sent to the peer.
371 """
372 self.sendData(program_counter, SHARE, hex(share.value))
374 def loseConnection(self):
375 """Disconnect this protocol instance."""
376 self.transport.loseConnection()
378 class SelfShareExchanger(ShareExchanger):
380 def __init__(self, id, factory):
381 ShareExchanger.__init__(self)
382 self.peer_id = id
383 self.factory = factory
385 def stringReceived(self, program_counter, data_type, data):
386 """Called when a share is received.
388 The string received is unpacked into the program counter, and
389 a data part. The data is passed the appropriate Deferred in
390 :class:`self.incoming_data`.
391 """
392 try:
393 key = (program_counter, data_type)
395 if key in self.waiting_deferreds:
396 deq = self.waiting_deferreds[key]
397 deferred = deq.popleft()
398 if not deq:
399 del self.waiting_deferreds[key]
400 self.factory.runtime.handle_deferred_data(deferred, data)
401 else:
402 deq = self.incoming_data.setdefault(key, deque())
403 deq.append(data)
404 except struct.error, e:
405 self.factory.runtime.abort(self, e)
407 def sendData(self, program_counter, data_type, data):
408 """Send data to the self.id."""
409 self.stringReceived(program_counter, data_type, data)
411 def loseConnection(self):
412 """Disconnect this protocol instance."""
413 self.lost_connection.callback(self)
414 return None
417 class SelfShareExchangerFactory(ReconnectingClientFactory, ServerFactory):
418 """Factory for creating SelfShareExchanger protocols."""
420 protocol = SelfShareExchanger
421 maxDelay = 3
422 factor = 1.234567 # About half of the Twisted default
424 def __init__(self, runtime):
425 """Initialize the factory."""
426 self.runtime = runtime
428 def identify_peer(self, protocol):
429 raise Exception("Is identify_peer necessary?")
431 def clientConnectionLost(self, connector, reason):
432 reason.trap(ConnectionDone)
434 class FakeTransport(object):
435 def close(self):
436 return True
438 class ShareExchangerFactory(ReconnectingClientFactory, ServerFactory):
439 """Factory for creating ShareExchanger protocols."""
441 protocol = ShareExchanger
442 maxDelay = 3
443 factor = 1.234567 # About half of the Twisted default
445 def __init__(self, runtime, players, protocols_ready):
446 """Initialize the factory."""
447 self.runtime = runtime
448 self.players = players
449 self.needed_protocols = len(players) - 1
450 self.protocols_ready = protocols_ready
452 def identify_peer(self, protocol):
453 self.runtime.add_player(self.players[protocol.peer_id], protocol)
454 self.needed_protocols -= 1
455 if self.needed_protocols == 0:
456 self.protocols_ready.callback(self.runtime)
458 def clientConnectionLost(self, connector, reason):
459 reason.trap(ConnectionDone)
462 def preprocess(generator):
463 """Track calls to this method.
465 The decorated method will be replaced with a proxy method which
466 first tries to get the data needed from
467 :attr:`Runtime._pool`, and if that fails it falls back to the
468 original method. It also returns a flag to indicate whether the
469 data is from the pool.
471 The *generator* method is only used to record where the data
472 should be generated from, the method is not actually called. This
473 must be the name of the method (a string) and not the method
474 itself.
475 """
477 def preprocess_decorator(method):
479 @wrapper(method)
480 def preprocess_wrapper(self, *args, **kwargs):
481 self.increment_pc()
482 pc = tuple(self.program_counter)
483 try:
484 return self._pool.pop(pc), True
485 except KeyError:
486 key = (generator, args)
487 pcs = self._needed_data.setdefault(key, [])
488 pcs.append(pc)
489 self.fork_pc()
490 try:
491 return method(self, *args, **kwargs), False
492 finally:
493 self.unfork_pc()
495 return preprocess_wrapper
496 return preprocess_decorator
499 class Runtime:
500 """Basic VIFF runtime with no crypto.
502 This runtime contains only the most basic operations needed such
503 as the program counter, the list of other players, etc.
504 """
506 @staticmethod
507 def add_options(parser):
508 group = OptionGroup(parser, "VIFF Runtime Options")
509 parser.add_option_group(group)
511 group.add_option("-l", "--bit-length", type="int", metavar="L",
512 help=("Maximum bit length of input numbers for "
513 "comparisons."))
514 group.add_option("-k", "--security-parameter", type="int", metavar="K",
515 help=("Security parameter. Comparisons will leak "
516 "information with probability 2**-K."))
517 group.add_option("--no-ssl", action="store_false", dest="ssl",
518 help="Disable the use of secure SSL connections.")
519 group.add_option("--ssl", action="store_true",
520 help=("Enable the use of secure SSL connections "
521 "(if the OpenSSL bindings are available)."))
522 group.add_option("--deferred-debug", action="store_true",
523 help="Enable extra debug output for deferreds.")
524 group.add_option("--profile", action="store_true",
525 help="Collect and print profiling information.")
526 group.add_option("--track-memory", action="store_true",
527 help="Track memory usage over time.")
528 group.add_option("--statistics", action="store_true",
529 help="Print statistics on shutdown.")
530 group.add_option("--no-socket-retry", action="store_true",
531 default=False, help="Fail rather than keep retrying "
532 "to connect if port is already in use.")
533 group.add_option("--host", metavar="HOST:PORT", action="append",
534 help="Override host and port of players as specified "
535 "in the configuration file. You can use this option "
536 "multiple times on the command line; the first will "
537 "override host and port of player 1, the second that "
538 "of player 2, and so forth.")
539 group.add_option("--computation-id", type="int", metavar="ID",
540 help="Set the (positive, integer) ID for this "
541 "computation. All IDs for runs using the same set "
542 "of player configuration files must be unique "
543 "to ensure security.")
545 try:
546 # Using __import__ since we do not use the module, we are
547 # only interested in the side-effect.
548 __import__('OpenSSL')
549 have_openssl = True
550 except ImportError:
551 have_openssl = False
553 parser.set_defaults(bit_length=32,
554 security_parameter=30,
555 ssl=have_openssl,
556 deferred_debug=False,
557 profile=False,
558 track_memory=False,
559 statistics=False,
560 computation_id=None)
562 def __init__(self, player, threshold, options=None):
563 """Initialize runtime.
565 Initialized a runtime owned by the given, the threshold, and
566 optionally a set of options. The runtime has no network
567 connections and knows of no other players -- the
568 :func:`create_runtime` function should be used instead to
569 create a usable runtime.
570 """
571 assert threshold > 0, "Must use a positive threshold."
572 #: ID of this player.
573 self.id = player.id
574 #: Shamir secret sharing threshold.
575 self.threshold = threshold
577 if options is None:
578 parser = OptionParser()
579 self.add_options(parser)
580 self.options = parser.get_default_values()
581 else:
582 self.options = options
584 if self.options.deferred_debug:
585 from twisted.internet import defer
586 defer.setDebugging(True)
588 #: Pool of preprocessed data.
589 self._pool = {}
590 #: Description of needed preprocessed data.
591 self._needed_data = {}
593 #: Current program counter.
594 __comp_id = self.options.computation_id
595 if __comp_id is None:
596 __comp_id = 0
597 else:
598 assert __comp_id > 0, "Non-positive ID: %d." % __comp_id
599 self.program_counter = [__comp_id, 0]
601 #: Connections to the other players.
602 #:
603 #: Mapping from from Player ID to :class:`ShareExchanger`
604 #: objects.
605 self.protocols = {}
607 #: Number of known players.
608 #:
609 #: Equal to ``len(self.players)``, but storing it here is more
610 #: direct.
611 self.num_players = 0
613 #: Information on players.
614 #:
615 #: Mapping from Player ID to :class:`Player` objects.
616 self.players = {}
617 # Add ourselves, but with no protocol since we wont be
618 # communicating with ourselves.
619 protocol = SelfShareExchanger(self.id, SelfShareExchangerFactory(self))
620 protocol.transport = FakeTransport()
621 self.add_player(player, protocol)
623 #: Queue of deferreds and data.
624 self.deferred_queue = deque()
625 self.complex_deferred_queue = deque()
626 #: Counter for calls of activate_reactor().
627 self.activation_counter = 0
628 #: Record the recursion depth.
629 self.depth_counter = 0
630 self.max_depth = 0
631 #: Recursion depth limit by experiment, including security margin.
632 self.depth_limit = int(sys.getrecursionlimit() / 50)
633 #: Use deferred queues only if the ViffReactor is running.
634 self.using_viff_reactor = isinstance(reactor, viff.reactor.ViffReactor)
636 def add_player(self, player, protocol):
637 self.players[player.id] = player
638 self.num_players = len(self.players)
639 self.protocols[player.id] = protocol
641 def shutdown(self):
642 """Shutdown the runtime.
644 All connections are closed and the runtime cannot be used
645 again after this has been called.
646 """
647 print "Synchronizing shutdown...",
649 def close_connections(_):
650 print "done."
651 print "Closing connections...",
652 results = [maybeDeferred(self.port.stopListening)]
653 for protocol in self.protocols.itervalues():
654 results.append(protocol.lost_connection)
655 protocol.loseConnection()
656 return DeferredList(results)
658 def stop_reactor(_):
659 print "done."
660 print "Stopping reactor...",
661 reactor.stop()
662 print "done."
664 sync = self.synchronize()
665 sync.addCallback(close_connections)
666 sync.addCallback(stop_reactor)
667 return sync
669 def abort(self, protocol, exc):
670 """Abort the execution due to an exception.
672 The *protocol* received bad data which resulted in *exc* being
673 raised when unpacking.
674 """
675 print "*** bad data from Player %d: %s" % (protocol.peer_id, exc)
676 print "*** aborting!"
677 for p in self.protocols.itervalues():
678 p.loseConnection()
679 reactor.stop()
680 print "*** all protocols disconnected"
682 def wait_for(self, *vars):
683 """Make the runtime wait for the variables given.
685 The runtime is shut down when all variables are calculated.
686 """
687 dl = DeferredList(vars)
688 self.schedule_callback(dl, lambda _: self.shutdown())
690 def increment_pc(self):
691 """Increment the program counter."""
692 self.program_counter[-1] += 1
694 def fork_pc(self):
695 """Fork the program counter."""
696 self.program_counter.append(0)
698 def unfork_pc(self):
699 """Leave a fork of the program counter."""
700 self.program_counter.pop()
702 def schedule_callback(self, deferred, func, *args, **kwargs):
703 """Schedule a callback on a deferred with the correct program
704 counter.
706 If a callback depends on the current program counter, then use
707 this method to schedule it instead of simply calling
708 addCallback directly. Simple callbacks that are independent of
709 the program counter can still be added directly to the
710 Deferred as usual.
712 Any extra arguments are passed to the callback as with
713 :meth:`addCallback`.
714 """
715 self.increment_pc()
716 saved_pc = self.program_counter[:]
718 @wrapper(func)
719 def callback_wrapper(*args, **kwargs):
720 """Wrapper for a callback which ensures a correct PC."""
721 try:
722 current_pc = self.program_counter[:]
723 self.program_counter[:] = saved_pc
724 self.fork_pc()
725 return func(*args, **kwargs)
726 finally:
727 self.program_counter[:] = current_pc
729 return deferred.addCallback(callback_wrapper, *args, **kwargs)
731 def schedule_complex_callback(self, deferred, func, *args, **kwargs):
732 """Schedule a complex callback, i.e. a callback which blocks a
733 long time.
735 Consider that the deferred is forked, i.e. if the callback returns
736 something to be used afterwards, add further callbacks to the returned
737 deferred."""
739 if not self.using_viff_reactor:
740 return self.schedule_callback(deferred, func, *args, **kwargs)
742 if isinstance(deferred, Share):
743 fork = Share(deferred.runtime, deferred.field)
744 else:
745 fork = Deferred()
747 def queue_callback(result, runtime, fork):
748 runtime.complex_deferred_queue.append((fork, result))
750 deferred.addCallback(queue_callback, self, fork)
751 return self.schedule_callback(fork, func, *args, **kwargs)
753 def synchronize(self):
754 """Introduce a synchronization point.
756 Returns a :class:`Deferred` which will trigger if and when all
757 other players have made their calls to :meth:`synchronize`. By
758 adding callbacks to the returned :class:`Deferred`, one can
759 divide a protocol execution into disjoint phases.
760 """
761 self.increment_pc()
762 shares = [self._exchange_shares(player, GF256(0))
763 for player in self.players]
764 result = gather_shares(shares)
765 result.addCallback(lambda _: None)
766 return result
768 def _expect_data(self, peer_id, data_type, deferred):
769 # Convert self.program_counter to a hashable value in order to
770 # use it as a key in self.protocols[peer_id].incoming_data.
771 pc = tuple(self.program_counter)
772 return self._expect_data_with_pc(pc, peer_id, data_type, deferred)
774 def _expect_data_with_pc(self, pc, peer_id, data_type, deferred):
775 key = (pc, data_type)
777 if key in self.protocols[peer_id].incoming_data:
778 # We have already received some data from the other side.
779 deq = self.protocols[peer_id].incoming_data[key]
780 data = deq.popleft()
781 if not deq:
782 del self.protocols[peer_id].incoming_data[key]
783 deferred.callback(data)
784 else:
785 # We have not yet received anything from the other side.
786 deq = self.protocols[peer_id].waiting_deferreds.setdefault(key, deque())
787 deq.append(deferred)
789 def _exchange_shares(self, peer_id, field_element):
790 """Exchange shares with another player.
792 We send the player our share and record a Deferred which will
793 trigger when the share from the other side arrives.
794 """
795 assert isinstance(field_element, FieldElement)
797 if peer_id == self.id:
798 return Share(self, field_element.field, field_element)
799 else:
800 share = self._expect_share(peer_id, field_element.field)
801 pc = tuple(self.program_counter)
802 self.protocols[peer_id].sendShare(pc, field_element)
803 return share
805 def _expect_share(self, peer_id, field):
806 share = Share(self, field)
807 share.addCallback(lambda value: field(long(value, 16)))
808 self._expect_data(peer_id, SHARE, share)
809 return share
811 def preprocess(self, program):
812 """Generate preprocess material.
814 The *program* specifies which methods to call and with which
815 arguments. The generator methods called must adhere to the
816 following interface:
818 - They must return a list of :class:`Deferred` instances.
820 - Every Deferred must yield an item of pre-processed data.
821 This can be value, a list or tuple of values, or a Deferred
822 (which will be converted to a value by Twisted), but NOT a
823 list of Deferreds. Use :meth:`gatherResults` to avoid the
824 latter.
826 The :meth:`~viff.active.TriplesPRSSMixin.generate_triples` method
827 is an example of a method fulfilling this interface.
828 """
830 def update(results, program_counters):
831 # Update the pool with pairs of program counter and data.
832 self._pool.update(zip(program_counters, results))
834 wait_list = []
835 for ((generator, args), program_counters) in program.iteritems():
836 print "Preprocessing %s (%d items)" % (generator, len(program_counters))
837 self.increment_pc()
838 self.fork_pc()
839 func = getattr(self, generator)
840 count = 0
841 start_time = time.time()
843 while program_counters:
844 count += 1
845 self.increment_pc()
846 self.fork_pc()
847 results = func(quantity=len(program_counters), *args)
848 self.unfork_pc()
849 ready = gatherResults(results)
850 ready.addCallback(update, program_counters[:len(results)])
851 del program_counters[:len(results)]
852 wait_list.append(ready)
853 self.unfork_pc()
854 return gatherResults(wait_list)
856 def input(self, inputters, field, number=None):
857 """Input *number* to the computation.
859 The players listed in *inputters* must provide an input
860 number, everybody will receive a list with :class:`Share`
861 objects, one from each *inputter*. If only a single player is
862 listed in *inputters*, then a :class:`Share` is given back
863 directly.
864 """
865 raise NotImplementedError
867 def output(self, share, receivers=None):
868 """Open *share* to *receivers* (defaults to all players).
870 Returns a :class:`Share` to players with IDs in *receivers*
871 and :const:`None` to the remaining players.
872 """
873 raise NotImplementedError
875 def add(self, share_a, share_b):
876 """Secure addition.
878 At least one of the arguments must be a :class:`Share`, the
879 other can be a :class:`~viff.field.FieldElement` or a
880 (possible long) Python integer."""
881 raise NotImplementedError
883 def mul(self, share_a, share_b):
884 """Secure multiplication.
886 At least one of the arguments must be a :class:`Share`, the
887 other can be a :class:`~viff.field.FieldElement` or a
888 (possible long) Python integer."""
889 raise NotImplementedError
891 def handle_deferred_data(self, deferred, data):
892 """Put deferred and data into the queue if the ViffReactor is running.
893 Otherwise, just execute the callback."""
895 if self.using_viff_reactor:
896 self.deferred_queue.append((deferred, data))
897 else:
898 deferred.callback(data)
900 def process_deferred_queue(self):
901 """Execute the callbacks of the deferreds in the queue.
903 If this function is not called via activate_reactor(), also
904 complex callbacks are executed."""
906 self.process_queue(self.deferred_queue)
908 if self.depth_counter == 0:
909 self.process_queue(self.complex_deferred_queue)
911 def process_queue(self, queue):
912 """Execute the callbacks of the deferreds in *queue*."""
914 while queue:
915 deferred, data = queue.popleft()
916 deferred.callback(data)
918 def activate_reactor(self):
919 """Activate the reactor to do actual communcation.
921 This is where the recursion happens."""
923 if not self.using_viff_reactor:
924 return
926 self.activation_counter += 1
928 # setting the number to n makes the reactor called
929 # only every n-th time
930 if self.activation_counter >= 2:
931 self.depth_counter += 1
933 if self.depth_counter > self.max_depth:
934 # Record the maximal depth reached.
935 self.max_depth = self.depth_counter
936 if self.depth_counter >= self.depth_limit:
937 print "Recursion depth limit reached."
939 if self.depth_counter < self.depth_limit:
940 reactor.doIteration(0)
942 self.depth_counter -= 1
943 self.activation_counter = 0
945 def print_transferred_data(self):
946 """Print the amount of transferred data for all connections."""
948 for protocol in self.protocols.itervalues():
949 print "Transfer to peer %d: %d bytes in %d packets" % \
950 (protocol.peer_id, protocol.sent_bytes, protocol.sent_packets)
953 def make_runtime_class(runtime_class=None, mixins=None):
954 """Creates a new runtime class with *runtime_class* as a base
955 class mixing in the *mixins*. By default
956 :class:`viff.passive.PassiveRuntime` will be used.
957 """
958 if runtime_class is None:
959 # The import is put here because of circular depencencies
960 # between viff.runtime and viff.passive.
961 from viff.passive import PassiveRuntime
962 runtime_class = PassiveRuntime
963 if mixins is None:
964 return runtime_class
965 else:
966 # The order is important: we want the most specific classes to
967 # go first so that they can override methods from later
968 # classes. We must also include at least one new-style class
969 # in bases -- we include it last to avoid overriding __init__
970 # from the other base classes.
971 bases = tuple(mixins) + (runtime_class, object)
972 return type("ExtendedRuntime", bases, {})
974 def create_runtime(id, players, threshold, options=None, runtime_class=None):
975 """Create a :class:`Runtime` and connect to the other players.
977 This function should be used in normal programs instead of
978 instantiating the Runtime directly. This function makes sure that
979 the Runtime is correctly connected to the other players.
981 The return value is a Deferred which will trigger when the runtime
982 is ready. Add your protocol as a callback on this Deferred using
983 code like this::
985 def protocol(runtime):
986 a, b, c = runtime.shamir_share([1, 2, 3], Zp, input)
988 a = runtime.open(a)
989 b = runtime.open(b)
990 c = runtime.open(c)
992 dprint("Opened a: %s", a)
993 dprint("Opened b: %s", b)
994 dprint("Opened c: %s", c)
996 runtime.wait_for(a,b,c)
998 pre_runtime = create_runtime(id, players, 1)
999 pre_runtime.addCallback(protocol)
1001 This is the general template which VIFF programs should follow.
1002 Please see the example applications for more examples.
1004 """
1005 if options and options.track_memory:
1006 lc = LoopingCall(track_memory_usage)
1007 # Five times per second seems like a fair value. Besides, the
1008 # kernel will track the peak memory usage for us anyway.
1009 lc.start(0.2)
1010 reactor.addSystemEventTrigger("after", "shutdown", track_memory_usage)
1012 if runtime_class is None:
1013 # The import is put here because of circular depencencies
1014 # between viff.runtime and viff.passive.
1015 from viff.passive import PassiveRuntime
1016 runtime_class = PassiveRuntime
1018 if options and options.host:
1019 for i in range(len(options.host)):
1020 players[i + 1].host, port_str = options.host[i].rsplit(":")
1021 players[i + 1].port = int(port_str)
1023 if options and options.profile:
1024 # To collect profiling information we monkey patch reactor.run
1025 # to do the collecting. It would be nicer to simply start the
1026 # profiler here and stop it upon shutdown, but this triggers
1027 # http://bugs.python.org/issue1375 since the start and stop
1028 # calls are in different stack frames.
1029 import cProfile
1030 prof = cProfile.Profile()
1031 old_run = reactor.run
1032 def new_run(*args, **kwargs):
1033 print "Starting reactor with profiling"
1034 prof.runcall(old_run, *args, **kwargs)
1036 import pstats
1037 stats = pstats.Stats(prof)
1038 print
1039 stats.strip_dirs()
1040 stats.sort_stats("time", "calls")
1041 stats.print_stats(40)
1042 stats.dump_stats("player-%d.pstats" % id)
1043 reactor.run = new_run
1045 # This will yield a Runtime when all protocols are connected.
1046 result = Deferred()
1048 # Create a runtime that knows about no other players than itself.
1049 # It will eventually be returned in result when the factory has
1050 # determined that all needed protocols are ready.
1051 runtime = runtime_class(players[id], threshold, options)
1052 factory = ShareExchangerFactory(runtime, players, result)
1054 if options and options.statistics:
1055 reactor.addSystemEventTrigger("after", "shutdown",
1056 runtime.print_transferred_data)
1058 if options and options.ssl:
1059 print "Using SSL"
1060 from twisted.internet.ssl import ContextFactory
1061 from OpenSSL import SSL
1063 class SSLContextFactory(ContextFactory):
1064 def __init__(self, id):
1065 """Create new SSL context factory for *id*."""
1066 self.id = id
1067 ctx = SSL.Context(SSL.SSLv3_METHOD)
1068 # TODO: Make the file names configurable.
1069 try:
1070 ctx.use_certificate_file('player-%d.cert' % id)
1071 ctx.use_privatekey_file('player-%d.key' % id)
1072 ctx.check_privatekey()
1073 ctx.load_verify_locations('ca.cert')
1074 ctx.set_verify(SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
1075 lambda conn, cert, errnum, depth, ok: ok)
1076 self.ctx = ctx
1077 except SSL.Error, e:
1078 print "SSL errors - did you forget to generate certificates?"
1079 for (lib, func, reason) in e.args[0]:
1080 print "* %s in %s: %s" % (func, lib, reason)
1081 raise SystemExit("Stopping program")
1083 def getContext(self):
1084 return self.ctx
1086 ctx_factory = SSLContextFactory(id)
1087 listen = lambda port: reactor.listenSSL(port, factory, ctx_factory)
1088 connect = lambda host, port: reactor.connectSSL(host, port, factory, ctx_factory)
1089 else:
1090 print "Not using SSL"
1091 listen = lambda port: reactor.listenTCP(port, factory)
1092 connect = lambda host, port: reactor.connectTCP(host, port, factory)
1094 port = players[id].port
1095 runtime.port = None
1096 delay = 2
1097 while runtime.port is None:
1098 # We keep trying to listen on the port, but with an
1099 # exponentially increasing delay between each attempt.
1100 try:
1101 runtime.port = listen(port)
1102 except CannotListenError, e:
1103 if options and options.no_socket_retry:
1104 raise
1105 delay *= 1 + rand.random()
1106 print "Error listening on port %d: %s" % (port, e.socketError[1])
1107 print "Will try again in %d seconds" % delay
1108 time.sleep(delay)
1109 print "Listening on port %d" % port
1111 for peer_id, player in players.iteritems():
1112 if peer_id > id:
1113 print "Will connect to %s" % player
1114 connect(player.host, player.port)
1116 if runtime.using_viff_reactor:
1117 # Process the deferred queue after every reactor iteration.
1118 reactor.setLoopCall(runtime.process_deferred_queue)
1120 return result
1122 if __name__ == "__main__":
1123 import doctest #pragma NO COVER
1124 doctest.testmod() #pragma NO COVER