viff

changeset 1229:eb0443115206

Implementation of the basic multiplication command.
author Janus Dam Nielsen <janus.nielsen@alexandra.dk>
date Tue, 06 Oct 2009 10:05:24 +0200
parents a62e12c9947a
children 86e0c7d54f22
files viff/orlandi.py viff/test/test_orlandi_runtime.py
diffstat 2 files changed, 259 insertions(+), 2 deletions(-) [+]
line diff
     1.1 --- a/viff/orlandi.py	Tue Oct 06 10:05:24 2009 +0200
     1.2 +++ b/viff/orlandi.py	Tue Oct 06 10:05:24 2009 +0200
     1.3 @@ -459,6 +459,22 @@
     1.4               return results[0]
     1.5          return results       
     1.6  
     1.7 +    def mul(self, share_x, share_y):
     1.8 +        """Multiplication of shares.
     1.9 +
    1.10 +        Communication cost: ???.
    1.11 +        """
    1.12 +        # TODO: Communication cost?
    1.13 +        assert isinstance(share_x, Share) or isinstance(share_y, Share), \
    1.14 +            "At least one of share_x and share_y must be a Share."
    1.15 +
    1.16 +        self.program_counter[-1] += 1
    1.17 +
    1.18 +        field = getattr(share_x, "field", getattr(share_y, "field", None))
    1.19 +
    1.20 +        a, b, c = self._get_triple(field)
    1.21 +        return self._basic_multiplication(share_x, share_y, a, b, c)
    1.22 +
    1.23      def _additive_constant(self, zero, field_element):
    1.24          """Greate an additive constant.
    1.25  
    1.26 @@ -506,6 +522,102 @@
    1.27          Cz = Cx / Cy
    1.28          return (zi, (rhozi1, rhozi2), Cz)
    1.29  
    1.30 +    def _cmul(self, share_x, share_y, field):
    1.31 +        """Multiplication of a share with a constant.
    1.32 +
    1.33 +        Either share_x or share_y must be an OrlandiShare but not both.
    1.34 +        Returns None if both share_x and share_y are OrlandiShares.
    1.35 +
    1.36 +        """
    1.37 +        def constant_multiply(x, c):
    1.38 +            assert(isinstance(c, FieldElement))
    1.39 +            zi, rhoz, Cx = self._const_mul(c.value, x)
    1.40 +            return OrlandiShare(self, field, zi, rhoz, Cx)
    1.41 +        if not isinstance(share_x, Share):
    1.42 +            # Then share_y must be a Share => local multiplication. We
    1.43 +            # clone first to avoid changing share_y.
    1.44 +            assert isinstance(share_y, Share), \
    1.45 +                "At least one of the arguments must be a share."
    1.46 +            result = share_y.clone()
    1.47 +            result.addCallback(constant_multiply, share_x)
    1.48 +            return result
    1.49 +        if not isinstance(share_y, Share):
    1.50 +            # Likewise when share_y is a constant.
    1.51 +            assert isinstance(share_x, Share), \
    1.52 +                "At least one of the arguments must be a share."
    1.53 +            result = share_x.clone()
    1.54 +            result.addCallback(constant_multiply, share_y)
    1.55 +            return result
    1.56 +        return None
    1.57 +
    1.58 +    def _const_mul(self, c, x):
    1.59 +        """Multiplication of a share-tuple with a constant c."""
    1.60 +        assert(isinstance(c, long) or isinstance(c, int))
    1.61 +        xi, (rhoi1, rhoi2), Cx = x
    1.62 +        zi = xi * c
    1.63 +        rhoz = (rhoi1 * c, rhoi2 * c)
    1.64 +        Cz = Cx**c
    1.65 +        return (zi, rhoz, Cz)
    1.66 +
    1.67 +    def _get_triple(self, field):
    1.68 +        n = field(0)
    1.69 +        Ca = commitment.commit(6, 0, 0)
    1.70 +        a = OrlandiShare(self, field, field(2), (n, n), Ca)
    1.71 +        Cb = commitment.commit(12, 0, 0)
    1.72 +        b = OrlandiShare(self, field, field(4), (n, n), Cb)
    1.73 +        Cc = commitment.commit(72, 0, 0)
    1.74 +        c = OrlandiShare(self, field, field(24), (n, n), Cc)
    1.75 +        return (a, b, c)
    1.76 +
    1.77 +    def _basic_multiplication(self, share_x, share_y, triple_a, triple_b, triple_c):
    1.78 +        """Multiplication of shares give a triple.
    1.79 +
    1.80 +        Communication cost: ???.
    1.81 +        
    1.82 +        ``d = Open([x] - [a])``
    1.83 +        ``e = Open([y] - [b])``
    1.84 +        ``[z] = e[x] + d[y] - [de] + [c]``
    1.85 +        """
    1.86 +        assert isinstance(share_x, Share) or isinstance(share_y, Share), \
    1.87 +            "At least one of share_x and share_y must be a Share."
    1.88 +
    1.89 +        self.program_counter[-1] += 1
    1.90 +
    1.91 +        field = getattr(share_x, "field", getattr(share_y, "field", None))
    1.92 +        n = field(0)
    1.93 +
    1.94 +        cmul_result = self._cmul(share_x, share_y, field)
    1.95 +        if cmul_result is  not None:
    1.96 +            return cmul_result
    1.97 +
    1.98 +        def multiply((x, y, d, e, c)):
    1.99 +            # [de]
   1.100 +            de = self._additive_constant(field(0), d * e)
   1.101 +            # e[x]
   1.102 +            t1 = self._const_mul(e.value, x)
   1.103 +            # d[y]
   1.104 +            t2 = self._const_mul(d.value, y)
   1.105 +            # d[y] - [de]
   1.106 +            t3 = self._minus(t2, de)
   1.107 +            # d[y] - [de] + [c]
   1.108 +            t4 = self._plus(t3, c)
   1.109 +            # [z] = e[x] + d[y] - [de] + [c]
   1.110 +            zi, rhoz, Cz = self._plus(t1, t4)
   1.111 +            return OrlandiShare(self, field, zi, rhoz, Cz)
   1.112 +
   1.113 +        # d = Open([x] - [a])
   1.114 +        d = self.open(share_x - triple_a)
   1.115 +        # e = Open([y] - [b])
   1.116 +        e = self.open(share_y - triple_b)
   1.117 +        result = gather_shares([share_x, share_y, d, e, triple_c])
   1.118 +        result.addCallbacks(multiply, self.error_handler)
   1.119 +
   1.120 +        # do actual communication
   1.121 +        self.activate_reactor()
   1.122 +
   1.123 +        return result
   1.124 +
   1.125      def error_handler(self, ex):
   1.126          print "Error: ", ex
   1.127          return ex
   1.128 +
     2.1 --- a/viff/test/test_orlandi_runtime.py	Tue Oct 06 10:05:24 2009 +0200
     2.2 +++ b/viff/test/test_orlandi_runtime.py	Tue Oct 06 10:05:24 2009 +0200
     2.3 @@ -19,7 +19,7 @@
     2.4  
     2.5  from viff.test.util import RuntimeTestCase, protocol, BinaryOperatorTestCase
     2.6  from viff.runtime import Share, gather_shares
     2.7 -from viff.orlandi import OrlandiRuntime
     2.8 +from viff.orlandi import OrlandiRuntime, OrlandiShare
     2.9  
    2.10  from viff.field import FieldElement, GF
    2.11  from viff.passive import PassiveRuntime
    2.12 @@ -28,6 +28,18 @@
    2.13  
    2.14  import commitment
    2.15  
    2.16 +
    2.17 +def _get_triple(runtime, field):
    2.18 +    n = field(0)
    2.19 +    Ca = commitment.commit(6, 0, 0)
    2.20 +    a = OrlandiShare(runtime, field, field(2), (n, n), Ca)
    2.21 +    Cb = commitment.commit(12, 0, 0)
    2.22 +    b = OrlandiShare(runtime, field, field(4), (n, n), Cb)
    2.23 +    Cc = commitment.commit(72, 0, 0)
    2.24 +    c = OrlandiShare(runtime, field, field(24), (n, n), Cc)
    2.25 +    return (a, b, c)
    2.26 +
    2.27 +
    2.28  class OrlandiBasicCommandsTest(RuntimeTestCase):
    2.29      """Test for basic commands."""
    2.30  
    2.31 @@ -281,7 +293,6 @@
    2.32          d2.addCallback(check)
    2.33          return DeferredList([d1, d2])
    2.34  
    2.35 -
    2.36      @protocol
    2.37      def test_shift_two_consequtive_inputters(self, runtime):
    2.38          """Test addition of the shift command."""
    2.39 @@ -339,3 +350,137 @@
    2.40          shares_ready = gather_shares(a_shares + b_shares)
    2.41          return shares_ready
    2.42  
    2.43 +    @protocol
    2.44 +    def test_basic_multiply(self, runtime):
    2.45 +        """Test multiplication of two numbers."""
    2.46 +
    2.47 +        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
    2.48 +
    2.49 +        x1 = 42
    2.50 +        y1 = 7
    2.51 + 
    2.52 +        def check(v):
    2.53 +            self.assertEquals(v, x1 * y1)
    2.54 + 
    2.55 +        x2 = runtime.shift([2], self.Zp, x1)
    2.56 +        y2 = runtime.shift([3], self.Zp, y1)
    2.57 +
    2.58 +        a, b, c = _get_triple(self, self.Zp)
    2.59 +        z2 = runtime._basic_multiplication(x2, y2, a, b, c)
    2.60 +        d = runtime.open(z2)
    2.61 +        d.addCallback(check)
    2.62 +        return d
    2.63 +
    2.64 +    @protocol
    2.65 +    def test_mul_mul(self, runtime):
    2.66 +        """Test multiplication of two numbers."""
    2.67 +
    2.68 +        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
    2.69 +
    2.70 +        x1 = 42
    2.71 +        y1 = 7
    2.72 + 
    2.73 +        def check(v):
    2.74 +            self.assertEquals(v, x1 * y1)
    2.75 + 
    2.76 +        x2 = runtime.shift([2], self.Zp, x1)
    2.77 +        y2 = runtime.shift([3], self.Zp, y1)
    2.78 +
    2.79 +        z2 = x2 * y2
    2.80 +        d = runtime.open(z2)
    2.81 +        d.addCallback(check)
    2.82 +        return d
    2.83 +
    2.84 +    @protocol
    2.85 +    def test_basic_multiply_constant_right(self, runtime):
    2.86 +        """Test multiplication of two numbers."""
    2.87 +
    2.88 +        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
    2.89 +
    2.90 +        x1 = 42
    2.91 +        y1 = 7
    2.92 +
    2.93 +        def check(v):
    2.94 +            self.assertEquals(v, x1 * y1)
    2.95 +
    2.96 +        x2 = runtime.shift([1], self.Zp, x1)
    2.97 +
    2.98 +        a, b, c = _get_triple(self, self.Zp)
    2.99 +        z2 = runtime._basic_multiplication(x2, self.Zp(y1), a, b, c)
   2.100 +        d = runtime.open(z2)
   2.101 +        d.addCallback(check)
   2.102 +        return d
   2.103 +
   2.104 +    @protocol
   2.105 +    def test_basic_multiply_constant_left(self, runtime):
   2.106 +        """Test multiplication of two numbers."""
   2.107 +
   2.108 +        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
   2.109 +
   2.110 +        x1 = 42
   2.111 +        y1 = 7
   2.112 +
   2.113 +        def check(v):
   2.114 +            self.assertEquals(v, x1 * y1)
   2.115 +
   2.116 +        x2 = runtime.shift([1], self.Zp, x1)
   2.117 +
   2.118 +        a, b, c = _get_triple(self, self.Zp)
   2.119 +        z2 = runtime._basic_multiplication(self.Zp(y1), x2, a, b, c)
   2.120 +        d = runtime.open(z2)
   2.121 +        d.addCallback(check)
   2.122 +        return d
   2.123 +
   2.124 +    @protocol
   2.125 +    def test_constant_multiplication_constant_left(self, runtime):
   2.126 +        """Test multiplication of two numbers."""
   2.127 +
   2.128 +        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
   2.129 +
   2.130 +        x1 = 42
   2.131 +        y1 = 7
   2.132 +
   2.133 +        def check(v):
   2.134 +            self.assertEquals(v, x1 * y1)
   2.135 +
   2.136 +        x2 = runtime.shift([1], self.Zp, x1)
   2.137 +
   2.138 +        z2 = runtime._cmul(self.Zp(y1), x2, self.Zp)
   2.139 +        d = runtime.open(z2)
   2.140 +        d.addCallback(check)
   2.141 +        return d
   2.142 +
   2.143 +    @protocol
   2.144 +    def test_constant_multiplication_constant_right(self, runtime):
   2.145 +        """Test multiplication of two numbers."""
   2.146 +
   2.147 +        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
   2.148 +
   2.149 +        x1 = 42
   2.150 +        y1 = 7
   2.151 +
   2.152 +        def check(v):
   2.153 +            self.assertEquals(v, x1 * y1)
   2.154 +
   2.155 +        x2 = runtime.shift([1], self.Zp, x1)
   2.156 +
   2.157 +        z2 = runtime._cmul(x2, self.Zp(y1), self.Zp)
   2.158 +        d = runtime.open(z2)
   2.159 +        d.addCallback(check)
   2.160 +        return d
   2.161 +
   2.162 +    @protocol
   2.163 +    def test_constant_multiplication_constant_None(self, runtime):
   2.164 +        """Test multiplication of two numbers."""
   2.165 +
   2.166 +        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
   2.167 +
   2.168 +        x1 = 42
   2.169 +        y1 = 7
   2.170 +
   2.171 +        x2 = runtime.shift([1], self.Zp, x1)
   2.172 +        y2 = runtime.shift([1], self.Zp, y1)
   2.173 +
   2.174 +        z2 = runtime._cmul(y2, x2, self.Zp)
   2.175 +        self.assertEquals(z2, None)
   2.176 +        return z2