changeset 1159:a8206d017f43

Replace the unsafe marshal module with the safe struct module. This changes the ShareExchanger.sendData method so that data must now always be a string, integers and other types will no longer work. Part of this patch: http://lists.viff.dk/pipermail/viff-patches-viff.dk/2008-October/000053.html Committed by Martin Geisler.
author Mikkel Krøigård <mk@daimi.au.dk>
date Tue, 14 Apr 2009 23:05:41 +0200
parents ed3302fd93f1
children d5741e841ccc
files viff/runtime.py viff/test/test_basic_runtime.py
diffstat 2 files changed, 31 insertions(+), 13 deletions(-) [+]
line wrap: on
line diff
--- a/viff/runtime.py	Tue Apr 14 15:07:30 2009 +0200
+++ b/viff/runtime.py	Tue Apr 14 23:05:41 2009 +0200
@@ -34,7 +34,7 @@
 __docformat__ = "restructuredtext"
 
 import time
-import marshal
+import struct
 from optparse import OptionParser, OptionGroup
 from collections import deque
 
@@ -305,7 +305,18 @@
                     self.transport.loseConnection()
             self.factory.identify_peer(self)
         else:
-            program_counter, data_type, data = marshal.loads(string)
+            # TODO: we cannot handle the empty string
+            # also note that we cannot handle pcs longer than 256
+            pc_size = ord(string[0])
+            fmt = (pc_size + 1)*'i'
+            predata_size = struct.calcsize(fmt) + 1
+            fmt = "%s%is" % (fmt, len(string)-predata_size)
+
+            unpacked = struct.unpack(fmt, string[1:])
+
+            program_counter = unpacked[:pc_size]
+            data_type, data = unpacked[-2:]
+
             key = (program_counter, data_type)
 
             deq = self.incoming_data.setdefault(key, deque())
@@ -321,8 +332,11 @@
             # TypeError. They should be handled somehow.
 
     def sendData(self, program_counter, data_type, data):
-        send_data = (program_counter, data_type, data)
-        self.sendString(marshal.dumps(send_data))
+        pc_size = len(program_counter)
+        fmt = "%s%is" % ((pc_size + 1)*'i', len(data))
+        data_tuple = program_counter + (data_type, data)
+
+        self.sendString(chr(pc_size) + struct.pack(fmt, *data_tuple))
 
     def sendShare(self, program_counter, share):
         """Send a share.
@@ -330,7 +344,7 @@
         The program counter and the share are marshalled and sent to
         the peer.
         """
-        self.sendData(program_counter, SHARE, share.value)
+        self.sendData(program_counter, SHARE, hex(share.value))
 
     def loseConnection(self):
         """Disconnect this protocol instance."""
@@ -635,8 +649,12 @@
             return share
 
     def _expect_share(self, peer_id, field):
+
+        def unpack_share(value_string):
+            return field(long(value_string, 16))
+
         share = Share(self, field)
-        share.addCallback(lambda value: field(value))
+        share.addCallback(unpack_share)
         self._expect_data(peer_id, SHARE, share)
         return share
 
--- a/viff/test/test_basic_runtime.py	Tue Apr 14 15:07:30 2009 +0200
+++ b/viff/test/test_basic_runtime.py	Tue Apr 14 23:05:41 2009 +0200
@@ -1,4 +1,4 @@
-# Copyright 2008 VIFF Development Team.
+# Copyright 2008, 2009 VIFF Development Team.
 #
 # This file is part of VIFF, the Virtual Ideal Functionality Framework.
 #
@@ -140,17 +140,17 @@
         for peer_id in range(1, self.num_players+1):
             if peer_id != runtime.id:
                 pc = tuple(runtime.program_counter)
-                runtime.protocols[peer_id].sendData(pc, 42, 100)
-                runtime.protocols[peer_id].sendData(pc, 42, 200)
-                runtime.protocols[peer_id].sendData(pc, 42, 300)
+                runtime.protocols[peer_id].sendData(pc, 42, "100")
+                runtime.protocols[peer_id].sendData(pc, 42, "200")
+                runtime.protocols[peer_id].sendData(pc, 42, "300")
 
         # Then receive the data.
         deferreds = []
         for peer_id in range(1, self.num_players+1):
             if peer_id != runtime.id:
-                d100 = Deferred().addCallback(self.assertEquals, 100)
-                d200 = Deferred().addCallback(self.assertEquals, 200)
-                d300 = Deferred().addCallback(self.assertEquals, 300)
+                d100 = Deferred().addCallback(self.assertEquals, "100")
+                d200 = Deferred().addCallback(self.assertEquals, "200")
+                d300 = Deferred().addCallback(self.assertEquals, "300")
                 runtime._expect_data(peer_id, 42, d100)
                 runtime._expect_data(peer_id, 42, d200)
                 runtime._expect_data(peer_id, 42, d300)