changeset 1229:eb0443115206

Implementation of the basic multiplication command.
author Janus Dam Nielsen <janus.nielsen@alexandra.dk>
date Tue, 06 Oct 2009 10:05:24 +0200
parents a62e12c9947a
children 86e0c7d54f22
files viff/orlandi.py viff/test/test_orlandi_runtime.py
diffstat 2 files changed, 259 insertions(+), 2 deletions(-) [+]
line wrap: on
line diff
--- a/viff/orlandi.py	Tue Oct 06 10:05:24 2009 +0200
+++ b/viff/orlandi.py	Tue Oct 06 10:05:24 2009 +0200
@@ -459,6 +459,22 @@
              return results[0]
         return results       
 
+    def mul(self, share_x, share_y):
+        """Multiplication of shares.
+
+        Communication cost: ???.
+        """
+        # TODO: Communication cost?
+        assert isinstance(share_x, Share) or isinstance(share_y, Share), \
+            "At least one of share_x and share_y must be a Share."
+
+        self.program_counter[-1] += 1
+
+        field = getattr(share_x, "field", getattr(share_y, "field", None))
+
+        a, b, c = self._get_triple(field)
+        return self._basic_multiplication(share_x, share_y, a, b, c)
+
     def _additive_constant(self, zero, field_element):
         """Greate an additive constant.
 
@@ -506,6 +522,102 @@
         Cz = Cx / Cy
         return (zi, (rhozi1, rhozi2), Cz)
 
+    def _cmul(self, share_x, share_y, field):
+        """Multiplication of a share with a constant.
+
+        Either share_x or share_y must be an OrlandiShare but not both.
+        Returns None if both share_x and share_y are OrlandiShares.
+
+        """
+        def constant_multiply(x, c):
+            assert(isinstance(c, FieldElement))
+            zi, rhoz, Cx = self._const_mul(c.value, x)
+            return OrlandiShare(self, field, zi, rhoz, Cx)
+        if not isinstance(share_x, Share):
+            # Then share_y must be a Share => local multiplication. We
+            # clone first to avoid changing share_y.
+            assert isinstance(share_y, Share), \
+                "At least one of the arguments must be a share."
+            result = share_y.clone()
+            result.addCallback(constant_multiply, share_x)
+            return result
+        if not isinstance(share_y, Share):
+            # Likewise when share_y is a constant.
+            assert isinstance(share_x, Share), \
+                "At least one of the arguments must be a share."
+            result = share_x.clone()
+            result.addCallback(constant_multiply, share_y)
+            return result
+        return None
+
+    def _const_mul(self, c, x):
+        """Multiplication of a share-tuple with a constant c."""
+        assert(isinstance(c, long) or isinstance(c, int))
+        xi, (rhoi1, rhoi2), Cx = x
+        zi = xi * c
+        rhoz = (rhoi1 * c, rhoi2 * c)
+        Cz = Cx**c
+        return (zi, rhoz, Cz)
+
+    def _get_triple(self, field):
+        n = field(0)
+        Ca = commitment.commit(6, 0, 0)
+        a = OrlandiShare(self, field, field(2), (n, n), Ca)
+        Cb = commitment.commit(12, 0, 0)
+        b = OrlandiShare(self, field, field(4), (n, n), Cb)
+        Cc = commitment.commit(72, 0, 0)
+        c = OrlandiShare(self, field, field(24), (n, n), Cc)
+        return (a, b, c)
+
+    def _basic_multiplication(self, share_x, share_y, triple_a, triple_b, triple_c):
+        """Multiplication of shares give a triple.
+
+        Communication cost: ???.
+        
+        ``d = Open([x] - [a])``
+        ``e = Open([y] - [b])``
+        ``[z] = e[x] + d[y] - [de] + [c]``
+        """
+        assert isinstance(share_x, Share) or isinstance(share_y, Share), \
+            "At least one of share_x and share_y must be a Share."
+
+        self.program_counter[-1] += 1
+
+        field = getattr(share_x, "field", getattr(share_y, "field", None))
+        n = field(0)
+
+        cmul_result = self._cmul(share_x, share_y, field)
+        if cmul_result is  not None:
+            return cmul_result
+
+        def multiply((x, y, d, e, c)):
+            # [de]
+            de = self._additive_constant(field(0), d * e)
+            # e[x]
+            t1 = self._const_mul(e.value, x)
+            # d[y]
+            t2 = self._const_mul(d.value, y)
+            # d[y] - [de]
+            t3 = self._minus(t2, de)
+            # d[y] - [de] + [c]
+            t4 = self._plus(t3, c)
+            # [z] = e[x] + d[y] - [de] + [c]
+            zi, rhoz, Cz = self._plus(t1, t4)
+            return OrlandiShare(self, field, zi, rhoz, Cz)
+
+        # d = Open([x] - [a])
+        d = self.open(share_x - triple_a)
+        # e = Open([y] - [b])
+        e = self.open(share_y - triple_b)
+        result = gather_shares([share_x, share_y, d, e, triple_c])
+        result.addCallbacks(multiply, self.error_handler)
+
+        # do actual communication
+        self.activate_reactor()
+
+        return result
+
     def error_handler(self, ex):
         print "Error: ", ex
         return ex
+
--- a/viff/test/test_orlandi_runtime.py	Tue Oct 06 10:05:24 2009 +0200
+++ b/viff/test/test_orlandi_runtime.py	Tue Oct 06 10:05:24 2009 +0200
@@ -19,7 +19,7 @@
 
 from viff.test.util import RuntimeTestCase, protocol, BinaryOperatorTestCase
 from viff.runtime import Share, gather_shares
-from viff.orlandi import OrlandiRuntime
+from viff.orlandi import OrlandiRuntime, OrlandiShare
 
 from viff.field import FieldElement, GF
 from viff.passive import PassiveRuntime
@@ -28,6 +28,18 @@
 
 import commitment
 
+
+def _get_triple(runtime, field):
+    n = field(0)
+    Ca = commitment.commit(6, 0, 0)
+    a = OrlandiShare(runtime, field, field(2), (n, n), Ca)
+    Cb = commitment.commit(12, 0, 0)
+    b = OrlandiShare(runtime, field, field(4), (n, n), Cb)
+    Cc = commitment.commit(72, 0, 0)
+    c = OrlandiShare(runtime, field, field(24), (n, n), Cc)
+    return (a, b, c)
+
+
 class OrlandiBasicCommandsTest(RuntimeTestCase):
     """Test for basic commands."""
 
@@ -281,7 +293,6 @@
         d2.addCallback(check)
         return DeferredList([d1, d2])
 
-
     @protocol
     def test_shift_two_consequtive_inputters(self, runtime):
         """Test addition of the shift command."""
@@ -339,3 +350,137 @@
         shares_ready = gather_shares(a_shares + b_shares)
         return shares_ready
 
+    @protocol
+    def test_basic_multiply(self, runtime):
+        """Test multiplication of two numbers."""
+
+        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
+
+        x1 = 42
+        y1 = 7
+ 
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+ 
+        x2 = runtime.shift([2], self.Zp, x1)
+        y2 = runtime.shift([3], self.Zp, y1)
+
+        a, b, c = _get_triple(self, self.Zp)
+        z2 = runtime._basic_multiplication(x2, y2, a, b, c)
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d
+
+    @protocol
+    def test_mul_mul(self, runtime):
+        """Test multiplication of two numbers."""
+
+        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
+
+        x1 = 42
+        y1 = 7
+ 
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+ 
+        x2 = runtime.shift([2], self.Zp, x1)
+        y2 = runtime.shift([3], self.Zp, y1)
+
+        z2 = x2 * y2
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d
+
+    @protocol
+    def test_basic_multiply_constant_right(self, runtime):
+        """Test multiplication of two numbers."""
+
+        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
+
+        x1 = 42
+        y1 = 7
+
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+
+        x2 = runtime.shift([1], self.Zp, x1)
+
+        a, b, c = _get_triple(self, self.Zp)
+        z2 = runtime._basic_multiplication(x2, self.Zp(y1), a, b, c)
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d
+
+    @protocol
+    def test_basic_multiply_constant_left(self, runtime):
+        """Test multiplication of two numbers."""
+
+        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
+
+        x1 = 42
+        y1 = 7
+
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+
+        x2 = runtime.shift([1], self.Zp, x1)
+
+        a, b, c = _get_triple(self, self.Zp)
+        z2 = runtime._basic_multiplication(self.Zp(y1), x2, a, b, c)
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d
+
+    @protocol
+    def test_constant_multiplication_constant_left(self, runtime):
+        """Test multiplication of two numbers."""
+
+        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
+
+        x1 = 42
+        y1 = 7
+
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+
+        x2 = runtime.shift([1], self.Zp, x1)
+
+        z2 = runtime._cmul(self.Zp(y1), x2, self.Zp)
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d
+
+    @protocol
+    def test_constant_multiplication_constant_right(self, runtime):
+        """Test multiplication of two numbers."""
+
+        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
+
+        x1 = 42
+        y1 = 7
+
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+
+        x2 = runtime.shift([1], self.Zp, x1)
+
+        z2 = runtime._cmul(x2, self.Zp(y1), self.Zp)
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d
+
+    @protocol
+    def test_constant_multiplication_constant_None(self, runtime):
+        """Test multiplication of two numbers."""
+
+        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
+
+        x1 = 42
+        y1 = 7
+
+        x2 = runtime.shift([1], self.Zp, x1)
+        y2 = runtime.shift([1], self.Zp, y1)
+
+        z2 = runtime._cmul(y2, x2, self.Zp)
+        self.assertEquals(z2, None)
+        return z2