changeset 1179:cbc39e56b402

Reduce number of isinstance() calls. This is done by exploiting that in operator overloading functions like Share.__add__ the first parameter always is a Share.
author Marcel Keller <mkeller@cs.au.dk>
date Wed, 13 May 2009 15:32:00 +0200
parents e1e0c107c40b
children b44882c6d4f6
files viff/aes.py viff/passive.py viff/runtime.py
diffstat 3 files changed, 21 insertions(+), 30 deletions(-) [+]
line wrap: on
line diff
--- a/viff/aes.py	Wed May 13 11:12:20 2009 +0200
+++ b/viff/aes.py	Wed May 13 15:32:00 2009 +0200
@@ -205,7 +205,7 @@
 
                 # include the translation in the matrix multiplication
                 # (see definition of AES.A)
-                bits.append(GF256(1))
+                bits.append(Share(self.runtime, GF256, GF256(1)))
 
                 if (use_lin_comb):
                     bits = [self.runtime.lin_comb(AES.A.rows[j], bits) 
--- a/viff/passive.py	Wed May 13 11:12:20 2009 +0200
+++ b/viff/passive.py	Wed May 13 15:32:00 2009 +0200
@@ -113,12 +113,14 @@
 
         Communication cost: none.
         """
-        field = getattr(share_a, "field", getattr(share_b, "field", None))
-        if not isinstance(share_a, Share):
-            share_a = Share(self, field, share_a)
         if not isinstance(share_b, Share):
-            share_b = Share(self, field, share_b)
-
+            # Addition with constant. share_a always is a Share by
+            # operator overloading in Share. Clone share_a to avoid
+            # changing it.
+            result = share_a.clone()
+            result.addCallback(lambda a, b: a + b, share_b)
+            return result
+        
         result = gather_shares([share_a, share_b])
         result.addCallback(lambda (a, b): a + b)
         return result
@@ -149,16 +151,13 @@
             assert not isinstance(coeff, Share), \
                 "Coefficients should not be shares."
 
+        for share in shares:
+            assert isinstance(share, Share), \
+                "Shares should be shares."
+
         assert len(coefficients) == len(shares), \
             "Number of coefficients and shares should be equal."
 
-        field = None
-        for share in shares:
-            field = getattr(share, "field", field)
-        for i, share in enumerate(shares):
-            if not isinstance(share, Share):
-                shares[i] = Share(self, field, share)
-
         def computation(shares, coefficients):
             summands = [shares[i] * coefficients[i] for i in range(len(shares))]
             return reduce(lambda x, y: x + y, summands)
@@ -174,17 +173,13 @@
 
         Communication cost: 1 Shamir sharing.
         """
-        assert isinstance(share_a, Share) or isinstance(share_b, Share), \
-            "At least one of share_a and share_b must be a Share."
+        assert isinstance(share_a, Share), \
+            "share_a must be a Share."
 
-        if not isinstance(share_a, Share):
-            # Then share_b must be a Share => local multiplication. We
-            # clone first to avoid changing share_b.
-            result = share_b.clone()
-            result.addCallback(lambda b: share_a * b)
-            return result
         if not isinstance(share_b, Share):
-            # Likewise when share_b is a constant.
+            # Local multiplication. share_a always is a Share by
+            # operator overloading in Share. We clone share_a first
+            # to avoid changing it.
             result = share_a.clone()
             result.addCallback(lambda a: a * share_b)
             return result
@@ -227,11 +222,7 @@
 
     @increment_pc
     def xor(self, share_a, share_b):
-        field = getattr(share_a, "field", getattr(share_b, "field", None))
-        if not isinstance(share_a, Share):
-            if not isinstance(share_a, FieldElement):
-                share_a = field(share_a)
-            share_a = Share(self, field, share_a)
+        field = share_a.field
         if not isinstance(share_b, Share):
             if not isinstance(share_b, FieldElement):
                 share_b = field(share_b)
--- a/viff/runtime.py	Wed May 13 11:12:20 2009 +0200
+++ b/viff/runtime.py	Wed May 13 15:32:00 2009 +0200
@@ -92,7 +92,7 @@
 
     def __radd__(self, other):
         """Addition (reflected argument version)."""
-        return self.runtime.add(other, self)
+        return self.runtime.add(self, other)
 
     def __sub__(self, other):
         """Subtraction."""
@@ -108,7 +108,7 @@
 
     def __rmul__(self, other):
         """Multiplication (reflected argument version)."""
-        return self.runtime.mul(other, self)
+        return self.runtime.mul(self, other)
 
     def __pow__(self, exponent):
         """Exponentation to known integer exponents."""
@@ -120,7 +120,7 @@
 
     def __rxor__(self, other):
         """Exclusive-or (reflected argument version)."""
-        return self.runtime.xor(other, self)
+        return self.runtime.xor(self, other)
 
     def __lt__(self, other):
         """Strictly less-than comparison."""