changeset 1471:06b1b7647643

BeDOZa: Moved multiplication functionality into SimpleArithmetic.
author Janus Dam Nielsen <janus.nielsen@alexandra.dk>
date Wed, 07 Jul 2010 15:40:45 +0200
parents f27609bc4831
children 55d2c1692771
files viff/bedoza.py viff/simplearithmetic.py viff/test/test_bedoza_runtime.py
diffstat 3 files changed, 138 insertions(+), 71 deletions(-) [+]
line wrap: on
line diff
--- a/viff/bedoza.py	Wed Jul 07 14:31:30 2010 +0200
+++ b/viff/bedoza.py	Wed Jul 07 15:40:45 2010 +0200
@@ -297,11 +297,16 @@
         return (zi, zks, zms)
 
     def _minus_public_right(self, x, c, field):
+        (zi, zks, zms) = self._minus_public_right_without_share(x, c, field)
+        return BeDOZaShare(self, field, zi, zks, zms)
+
+    def _minus_public_right_without_share(self, x, c, field):
         (xi, xks, xms) = x
         if self.id == 1:
             xi = xi - c
         xks.keys[0] = xks.keys[0] + xks.alpha * c
-        return BeDOZaShare(self, field, xi, xks, xms)
+        return xi, xks, xms
+
 
     def _minus_public_left(self, x, c, field):
         y = self._constant_multiply(x, field(-1))
@@ -372,69 +377,5 @@
                 c += share_c.value
         return [triple_a, triple_b, triple_c]
 
-    def mul(self, share_x, share_y):
-        """Multiplication of shares."""
-        assert isinstance(share_x, Share) or isinstance(share_y, Share), \
-            "At least one of share_x and share_y must be a Share."
-
-        self.increment_pc()
-
-        field = getattr(share_x, "field", getattr(share_y, "field", None))
-
-        triple = self._get_triple(field)
-        return self._basic_multiplication(share_x, share_y, *triple)
-
-    def _basic_multiplication(self, share_x, share_y, triple_a, triple_b, triple_c):
-        """Multiplication of shares give a triple.
-
-        Communication cost: ???.
-
-        ``d = Open([x] - [a])``
-        ``e = Open([y] - [b])``
-        ``[z] = e[x] + d[y] - [de] + [c]``
-        """
-        assert isinstance(share_x, Share) or isinstance(share_y, Share), \
-            "At least one of share_x and share_y must be a Share."
-
-        self.increment_pc()
-
-        field = getattr(share_x, "field", getattr(share_y, "field", None))
-        n = field(0)
-
-        cmul_result = self._cmul(share_x, share_y, field)
-        if cmul_result is  not None:
-            return cmul_result
-
-        def multiply((x, y, c, d, e)):
-            # [de]
-            de = d * e
-            # e[x]
-            t1 = self._constant_multiply(x, e)
-            # d[y]
-            t2 = self._constant_multiply(y, d)
-            # d[y] - [de]
-            t3 = self._minus_public_right1(t2, de, field)
-            # d[y] - [de] + [c]
-            t4 = self._plus((t3, c), field)
-            # [z] = e[x] + d[y] - [de] + [c]
-            zi, zks, zms = self._plus((t1, t4), field)
-            return BeDOZaShare(self, field, zi, zks, zms)
-
-        #d = Open([x] - [a])
-        d = self.output(share_x - triple_a)
-        # e = Open([y] - [b])
-        e = self.output(share_y - triple_b)
-        result = gather_shares([share_x, share_y, triple_c, d, e])
-        result.addCallbacks(multiply, self.error_handler)
-
-        # do actual communication
-        self.activate_reactor()
-
-        return result
-
-    def _minus_public_right1(self, x, c, field):
-        (xi, xks, xms) = x
-        if self.id == 1:
-            xi = xi - c
-        xks.keys[0] = xks.keys[0] + xks.alpha * c
-        return xi, xks, xms
+    def _wrap_in_share(self, (zi, zks, zms), field):
+        return BeDOZaShare(self, field, zi, zks, zms)
--- a/viff/simplearithmetic.py	Wed Jul 07 14:31:30 2010 +0200
+++ b/viff/simplearithmetic.py	Wed Jul 07 15:40:45 2010 +0200
@@ -21,12 +21,19 @@
 class SimpleArithmetic:
     """Provides methods for addition and subtraction.
 
-    Provides set: {add, sub}.
-    Requires set: {self._plus((x,y), field),
-                   self._minus((x,y), field),
+    Provides set: {add, sub, mul}.
+    Requires set: {self._plus((x, y), field),
+                   self._minus((x, y), field),
                    self._plus_public(x, c, field),
                    self._minus_public_right(x, c, field),
-                   self._minus_public_left(x, c, field)}.
+                   self._minus_public_left(x, c, field),
+                   self._wrap_in_share(x, field),
+                   self._get_triple(field),
+                   self._constant_multiply(x, c),
+                   self._cmul(x, y, field),
+                   self.open(x),
+                   self.increment_pc(),
+                   self.activate_reactor()}.
     """
 
     def add(self, share_a, share_b):
@@ -74,3 +81,63 @@
             result = gather_shares([share_a, share_b])
             result.addCallbacks(self._minus, self.error_handler, callbackArgs=(field,))
             return result
+
+    def mul(self, share_x, share_y):
+        """Multiplication of shares."""
+        assert isinstance(share_x, Share) or isinstance(share_y, Share), \
+            "At least one of share_x and share_y must be a Share."
+
+        self.increment_pc()
+
+        field = getattr(share_x, "field", getattr(share_y, "field", None))
+
+        triple = self._get_triple(field)
+        return self._basic_multiplication(share_x, share_y, *triple)
+
+    def _basic_multiplication(self, share_x, share_y, triple_a, triple_b, triple_c):
+        """Multiplication of shares give a triple.
+
+        Communication cost: ???.
+
+        ``d = Open([x] - [a])``
+        ``e = Open([y] - [b])``
+        ``[z] = e[x] + d[y] - [de] + [c]``
+        """
+        assert isinstance(share_x, Share) or isinstance(share_y, Share), \
+            "At least one of share_x and share_y must be a Share."
+
+        self.increment_pc()
+
+        field = getattr(share_x, "field", getattr(share_y, "field", None))
+        n = field(0)
+
+        cmul_result = self._cmul(share_x, share_y, field)
+        if cmul_result is  not None:
+            return cmul_result
+
+        def multiply((x, y, c, d, e)):
+            # [de]
+            de = d * e
+            # e[x]
+            t1 = self._constant_multiply(x, e)
+            # d[y]
+            t2 = self._constant_multiply(y, d)
+            # d[y] - [de]
+            t3 = self._minus_public_right_without_share(t2, de, field)
+            # d[y] - [de] + [c]
+            t4 = self._plus((t3, c), field)
+            # [z] = e[x] + d[y] - [de] + [c]
+            z = self._plus((t1, t4), field)
+            return self._wrap_in_share(z, field)
+
+        #d = Open([x] - [a])
+        d = self.output(share_x - triple_a)
+        # e = Open([y] - [b])
+        e = self.output(share_y - triple_b)
+        result = gather_shares([share_x, share_y, triple_c, d, e])
+        result.addCallbacks(multiply, self.error_handler)
+
+        # do actual communication
+        self.activate_reactor()
+
+        return result
--- a/viff/test/test_bedoza_runtime.py	Wed Jul 07 14:31:30 2010 +0200
+++ b/viff/test/test_bedoza_runtime.py	Wed Jul 07 15:40:45 2010 +0200
@@ -371,3 +371,62 @@
         d = runtime.open(z2)
         d.addCallback(check)
         return d
+
+    @protocol
+    def test_mul_mul(self, runtime):
+        """Test multiplication of two numbers."""
+
+        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
+
+        x1 = 6
+        y1 = 6
+
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+
+        x2 = runtime.random_share(self.Zp)
+        y2 = runtime.random_share(self.Zp)
+
+        z2 = x2 * y2
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d
+    @protocol
+    def test_basic_multiply_constant_right(self, runtime):
+        """Test multiplication of two numbers."""
+
+        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
+
+        x1 = 6
+        y1 = 6
+
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+
+        x2 = runtime.random_share(self.Zp)
+
+        a, b, c = runtime._get_triple(self.Zp)
+        z2 = runtime._basic_multiplication(x2, self.Zp(y1), a, b, c)
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d
+
+    @protocol
+    def test_basic_multiply_constant_left(self, runtime):
+        """Test multiplication of two numbers."""
+
+        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
+
+        x1 = 6
+        y1 = 6
+
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+
+        x2 = runtime.random_share(self.Zp)
+
+        a, b, c = runtime._get_triple(self.Zp)
+        z2 = runtime._basic_multiplication(self.Zp(y1), x2, a, b, c)
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d