changeset 1182:e89fb02c5e3d

Merged with Marcel.
author Martin Geisler <mg@cs.au.dk>
date Thu, 14 May 2009 12:01:54 +0200
parents a1304b0072d4 3171ea9886cb
children b5eea8738968 6cd5ceb87542 e918af983f75
files
diffstat 4 files changed, 46 insertions(+), 46 deletions(-) [+]
line wrap: on
line diff
--- a/viff/aes.py	Thu May 14 10:47:12 2009 +0200
+++ b/viff/aes.py	Thu May 14 12:01:54 2009 +0200
@@ -205,7 +205,7 @@
 
                 # include the translation in the matrix multiplication
                 # (see definition of AES.A)
-                bits.append(GF256(1))
+                bits.append(Share(self.runtime, GF256, GF256(1)))
 
                 if (use_lin_comb):
                     bits = [self.runtime.lin_comb(AES.A.rows[j], bits) 
--- a/viff/field.py	Thu May 14 10:47:12 2009 +0200
+++ b/viff/field.py	Thu May 14 12:01:54 2009 +0200
@@ -168,16 +168,21 @@
             other = other.value
         return GF256(self.value ^ other)
 
-    #: Add this and another GF256 element (reflected argument version).
-    __radd__ = __add__
+    def __radd__(self, other):
+        """Add this and another number (reflected argument version).
 
+        other is not Share, otherwise Share.__add__() would have been
+        called, and other is not a GF256, otherwise GF256.__add__()
+        would have been called."""
+        return GF256(self.value ^ other)
+    
     #: Subtract this and another GF256 element.
     #:
     #: Addition is its own inverse in GF(2^8) and so this is the same
     #: as `__add__`.
     __sub__ = __add__
     #: Subtract this and another GF256 element (reflected argument version).
-    __rsub__ = __sub__
+    __rsub__ = __radd__
 
     #: Exclusive-or.
     #:
@@ -185,7 +190,7 @@
     __xor__ = __add__
 
     #: Exclusive-or (reflected argument version).
-    __rxor__ = __xor__
+    __rxor__ = __radd__
 
     def __mul__(self, other):
         """Multiply this and another GF256.
@@ -204,8 +209,14 @@
         return _mul_table[(self.value, other)]
 
 
-    #: Multiply this and another GF256 element (reflected argument version).
-    __rmul__ = __mul__
+    def __rmul__(self, other):
+        """Multiply this and another number (reflected argument
+        version).
+
+        other is not Share, otherwise Share.__mul__() would have been
+        called, and other is not a GF256, otherwise GF256.__mul__()
+        would have been called."""
+        return _mul_table[(self.value, other)]
 
     def __pow__(self, exponent):
         """Exponentiation."""
--- a/viff/passive.py	Thu May 14 10:47:12 2009 +0200
+++ b/viff/passive.py	Thu May 14 12:01:54 2009 +0200
@@ -113,12 +113,14 @@
 
         Communication cost: none.
         """
-        field = getattr(share_a, "field", getattr(share_b, "field", None))
-        if not isinstance(share_a, Share):
-            share_a = Share(self, field, share_a)
         if not isinstance(share_b, Share):
-            share_b = Share(self, field, share_b)
-
+            # Addition with constant. share_a always is a Share by
+            # operator overloading in Share. Clone share_a to avoid
+            # changing it.
+            result = share_a.clone()
+            result.addCallback(lambda a, b: a + b, share_b)
+            return result
+        
         result = gather_shares([share_a, share_b])
         result.addCallback(lambda (a, b): a + b)
         return result
@@ -149,16 +151,13 @@
             assert not isinstance(coeff, Share), \
                 "Coefficients should not be shares."
 
+        for share in shares:
+            assert isinstance(share, Share), \
+                "Shares should be shares."
+
         assert len(coefficients) == len(shares), \
             "Number of coefficients and shares should be equal."
 
-        field = None
-        for share in shares:
-            field = getattr(share, "field", field)
-        for i, share in enumerate(shares):
-            if not isinstance(share, Share):
-                shares[i] = Share(self, field, share)
-
         def computation(shares, coefficients):
             summands = [shares[i] * coefficients[i] for i in range(len(shares))]
             return reduce(lambda x, y: x + y, summands)
@@ -174,17 +173,13 @@
 
         Communication cost: 1 Shamir sharing.
         """
-        assert isinstance(share_a, Share) or isinstance(share_b, Share), \
-            "At least one of share_a and share_b must be a Share."
+        assert isinstance(share_a, Share), \
+            "share_a must be a Share."
 
-        if not isinstance(share_a, Share):
-            # Then share_b must be a Share => local multiplication. We
-            # clone first to avoid changing share_b.
-            result = share_b.clone()
-            result.addCallback(lambda b: share_a * b)
-            return result
         if not isinstance(share_b, Share):
-            # Likewise when share_b is a constant.
+            # Local multiplication. share_a always is a Share by
+            # operator overloading in Share. We clone share_a first
+            # to avoid changing it.
             result = share_a.clone()
             result.addCallback(lambda a: a * share_b)
             return result
@@ -227,11 +222,7 @@
 
     @increment_pc
     def xor(self, share_a, share_b):
-        field = getattr(share_a, "field", getattr(share_b, "field", None))
-        if not isinstance(share_a, Share):
-            if not isinstance(share_a, FieldElement):
-                share_a = field(share_a)
-            share_a = Share(self, field, share_a)
+        field = share_a.field
         if not isinstance(share_b, Share):
             if not isinstance(share_b, FieldElement):
                 share_b = field(share_b)
--- a/viff/runtime.py	Thu May 14 10:47:12 2009 +0200
+++ b/viff/runtime.py	Thu May 14 12:01:54 2009 +0200
@@ -92,7 +92,7 @@
 
     def __radd__(self, other):
         """Addition (reflected argument version)."""
-        return self.runtime.add(other, self)
+        return self.runtime.add(self, other)
 
     def __sub__(self, other):
         """Subtraction."""
@@ -108,7 +108,7 @@
 
     def __rmul__(self, other):
         """Multiplication (reflected argument version)."""
-        return self.runtime.mul(other, self)
+        return self.runtime.mul(self, other)
 
     def __pow__(self, exponent):
         """Exponentation to known integer exponents."""
@@ -120,7 +120,7 @@
 
     def __rxor__(self, other):
         """Exclusive-or (reflected argument version)."""
-        return self.runtime.xor(other, self)
+        return self.runtime.xor(self, other)
 
     def __lt__(self, other):
         """Strictly less-than comparison."""
@@ -270,6 +270,7 @@
         self.lost_connection = Deferred()
         #: Data expected to be received in the future.
         self.incoming_data = {}
+        self.waiting_deferreds = {}
 
     def connectionMade(self):
         self.sendString(str(self.factory.runtime.id))
@@ -312,13 +313,14 @@
 
                 key = (program_counter, data_type)
 
-                deq = self.incoming_data.setdefault(key, deque())
-                if deq and isinstance(deq[0], Deferred):
+                if key in self.waiting_deferreds:
+                    deq = self.waiting_deferreds[key]
                     deferred = deq.popleft()
                     if not deq:
-                        del self.incoming_data[key]
+                        del self.waiting_deferreds[key]
                     deferred.callback(data)
                 else:
+                    deq = self.incoming_data.setdefault(key, deque())
                     deq.append(data)
             except struct.error, e:
                 self.factory.runtime.abort(self, e)
@@ -601,11 +603,6 @@
         Any extra arguments are passed to the callback as with
         :meth:`addCallback`.
         """
-        # TODO, http://tracker.viff.dk/issue22: When several callbacks
-        # are scheduled from the same method, they all save the same
-        # program counter. Simply decorating callback with increase_pc
-        # does not seem to work (the multiplication benchmark hangs).
-        # This should be fixed.
         saved_pc = self.program_counter[:]
 
         @wrapper(func)
@@ -642,15 +639,16 @@
         pc = tuple(self.program_counter)
         key = (pc, data_type)
 
-        deq = self.protocols[peer_id].incoming_data.setdefault(key, deque())
-        if deq and not isinstance(deq[0], Deferred):
+        if key in self.protocols[peer_id].incoming_data:
             # We have already received some data from the other side.
+            deq = self.protocols[peer_id].incoming_data[key]
             data = deq.popleft()
             if not deq:
                 del self.protocols[peer_id].incoming_data[key]
             deferred.callback(data)
         else:
             # We have not yet received anything from the other side.
+            deq = self.protocols[peer_id].waiting_deferreds.setdefault(key, deque())
             deq.append(deferred)
 
     def _exchange_shares(self, peer_id, field_element):