viff

changeset 1230:86e0c7d54f22

Implementation of the leak tolerant multiplication command.
author Janus Dam Nielsen <janus.nielsen@alexandra.dk>
date Tue, 06 Oct 2009 10:05:24 +0200
parents eb0443115206
children db2d970885f4
files viff/orlandi.py viff/test/test_orlandi_runtime.py
diffstat 2 files changed, 240 insertions(+), 3 deletions(-) [+]
line diff
     1.1 --- a/viff/orlandi.py	Tue Oct 06 10:05:24 2009 +0200
     1.2 +++ b/viff/orlandi.py	Tue Oct 06 10:05:24 2009 +0200
     1.3 @@ -81,6 +81,21 @@
     1.4          Runtime.__init__(self, player, threshold, options)
     1.5          self.threshold = self.num_players - 1
     1.6  
     1.7 +    def compute_delta(self, d):
     1.8 +        def product(j):
     1.9 +            pt = 1
    1.10 +            pn = 1
    1.11 +            for k in xrange(1, 2 * d + 2):
    1.12 +                if k != j:
    1.13 +                    pt *= k
    1.14 +                    pn *= k - j
    1.15 +            return pt // pn
    1.16 +
    1.17 +        delta = []
    1.18 +        for j in xrange(1, 2 * d + 2):
    1.19 +            delta.append(product(j))
    1.20 +        return delta
    1.21 +
    1.22      def output(self, share, receivers=None, threshold=None):
    1.23          return self.open(share, receivers, threshold)
    1.24  
    1.25 @@ -243,8 +258,9 @@
    1.26              if Cx1 == Cx:
    1.27                  return x
    1.28              else:
    1.29 -                raise OrlandiException("Wrong commitment for value %s, found %s expected %s." % 
    1.30 -                                       (x, Cx1, Cx))
    1.31 +                #return x
    1.32 +                raise OrlandiException("Wrong commitment for value %s, %s, %s, found %s expected %s." % 
    1.33 +                                       (x, rho1, rho2, Cx1, Cx))
    1.34  
    1.35          def deserialize(ls):
    1.36              shares = [(field(long(x)), field(long(rho1)), field(long(rho2))) for x, rho1, rho2 in map(self.list_str, ls)]
    1.37 @@ -559,6 +575,11 @@
    1.38          Cz = Cx**c
    1.39          return (zi, rhoz, Cz)
    1.40  
    1.41 +    def _get_share(self, field, value):
    1.42 +        Cc = commitment.commit(value * 3, 0, 0)
    1.43 +        c = OrlandiShare(self, field, field(value), (field(0), field(0)), Cc)
    1.44 +        return c
    1.45 +
    1.46      def _get_triple(self, field):
    1.47          n = field(0)
    1.48          Ca = commitment.commit(6, 0, 0)
    1.49 @@ -617,6 +638,139 @@
    1.50  
    1.51          return result
    1.52  
    1.53 +    def sum_poly(self, j, ls):
    1.54 +        exp  = j
    1.55 +        fj, (rhoj1, rhoj2), Cfj = ls[0]
    1.56 +        x    = fj*exp
    1.57 +        rho1 = rhoj1 * exp
    1.58 +        rho2 = rhoj2 * exp
    1.59 +        Cx   = Cfj**exp
    1.60 +        exp *= j
    1.61 +
    1.62 +        for (fj, (rhoj1, rhoj2), Cfj) in ls[1:]:
    1.63 +            x += fj * exp
    1.64 +            rho1 += rhoj1 * exp
    1.65 +            rho2 += rhoj2 * exp
    1.66 +            Cx = Cx * (Cfj**exp)
    1.67 +            exp *= j
    1.68 +        return x, (rho1, rho2), Cx
    1.69 +
    1.70 +    def leak_tolerant_mul(self, share_x, share_y, M):
    1.71 +        """Leak tolerant multiplication of shares.
    1.72 +
    1.73 +        Communication cost: ???.
    1.74 +
    1.75 +        Assuming a set of multiplicative triples:
    1.76 +        ``M = ([a_i], [b_i], [c_i]) for 1 <= i <= 2d + 1``.
    1.77 +
    1.78 +        1) ``for i = 1, ..., d do [f_i] = rand(), [g_i] = rand()``
    1.79 +
    1.80 +        2) ``for j = 1, ..., 2d+1 do
    1.81 +             [F_j] = [x] + SUM_i=1^d [f_i]*j^i 
    1.82 +             and
    1.83 +             [G_j] = [y] + SUM_i=1^d [g_i]*j^i`` 
    1.84 +
    1.85 +        3) for j = 1, ..., 2d+1 do [H_j] = Mul([F_j], [G_j], [a_j], [b_j], [c_j])
    1.86 +
    1.87 +        4) compute [H_0] = SUM_j=1^2d+1 delta_j[H_j] 
    1.88 +
    1.89 +        5) output [z] = [H_0]
    1.90 +
    1.91 +        delta_j = PRODUCT_k=1, k!=j^2d+1 k/(k-j).
    1.92 +        """
    1.93 +        assert isinstance(share_x, Share) or isinstance(share_y, Share), \
    1.94 +            "At least one of share_x and share_y must be a Share."
    1.95 +
    1.96 +        self.program_counter[-1] += 1
    1.97 +
    1.98 +        field = getattr(share_x, "field", getattr(share_y, "field", None))
    1.99 +        n = field(0)
   1.100 +
   1.101 +        cmul_result = self._cmul(share_x, share_y, field)
   1.102 +        if cmul_result is not None:
   1.103 +            return cmul_result
   1.104 +
   1.105 +        # 1) for i = 1, ..., d do [f_i] = rand(), [g_i] = rand()
   1.106 +        d = (len(M) - 1) // 2
   1.107 +        deltas = self.compute_delta(d)
   1.108 +        f = []
   1.109 +        g = []
   1.110 +        for x in xrange(d):
   1.111 +            f.append(self.random_share(field))
   1.112 +            g.append(self.random_share(field))
   1.113 +
   1.114 +        def compute_polynomials(t):
   1.115 +            x, y = t[0]
   1.116 +            f = []
   1.117 +            g = []
   1.118 +            if 1 in t:
   1.119 +                f = t[1]
   1.120 +            if 2 in t:
   1.121 +                g = t[2]
   1.122 +#             print "==> poly", self.id
   1.123 +#             print "x:", x
   1.124 +#             print "y:", y
   1.125 +#             print "f:", f
   1.126 +#             print "g:", g
   1.127 +            # 2) for j = 1, ..., 2d+1 do
   1.128 +            # [F_j] = [x] + SUM_i=1^d [f_i]*j^i 
   1.129 +            # and
   1.130 +            # [G_j] = [y] + SUM_i=1^d [g_i]*j^i 
   1.131 +            h0i, rhoh0, Ch0 = self._additive_constant(field(0), n)
   1.132 +            H0 = OrlandiShare(self, field, h0i, rhoh0, Ch0)
   1.133 +            xi, (rhoxi1, rhoxi2), Cx = x
   1.134 +            yi, (rhoyi1, rhoyi2), Cy = y
   1.135 +
   1.136 +            for j in xrange(1, 2*d + 2):
   1.137 +                Fji = xi
   1.138 +                rho1_Fji = rhoxi1
   1.139 +                rho2_Fji = rhoxi2
   1.140 +                C_Fji = Cx
   1.141 +                if f != []:
   1.142 +                    # SUM_i=1^d [f_i]*j^i 
   1.143 +                    vi, (rhovi1, rhovi2), Cv = self.sum_poly(j, f)
   1.144 +                    # [F_j] = [x] + SUM_i=1^d [f_i]*j^i 
   1.145 +                    Fji += vi
   1.146 +                    rho1_Fji += rhovi1
   1.147 +                    rho2_Fji += rhovi2
   1.148 +                    C_Fji *= Cv
   1.149 +                Gji = yi
   1.150 +                rho1_Gji = rhoyi1
   1.151 +                rho2_Gji = rhoyi2
   1.152 +                C_Gji = Cy
   1.153 +                if g != []:
   1.154 +                    # SUM_i=1^d [g_i]*j^i 
   1.155 +                    wi, (rhowi1, rhowi2), Cw = self.sum_poly(j, g)
   1.156 +                    # [G_j] = [y] + SUM_i=1^d [g_i]*j^i
   1.157 +                    Gji += wi
   1.158 +                    rho1_Gji += rhowi1
   1.159 +                    rho2_Gji += rhowi2
   1.160 +                    C_Gji *= Cw
   1.161 +                Fj = OrlandiShare(self, field, Fji, (rho1_Fji, rho2_Fji), C_Fji)
   1.162 +                Gj = OrlandiShare(self, field, Gji, (rho1_Gji, rho2_Gji), C_Gji)
   1.163 +                a, b, c = M.pop(0)
   1.164 +
   1.165 +                # [H_j] = Mul([F_j], [G_j], [a_j], [b_j], [c_j])
   1.166 +                Hj = self._basic_multiplication(Fj, Gj, a, b, c)
   1.167 +                dj = self._cmul(field(deltas[j - 1]), Hj, field)
   1.168 +                H0 = H0 + dj
   1.169 +            # 5) output [z] = [H_0]
   1.170 +            return H0
   1.171 +
   1.172 +        ls = [gather_shares([share_x, share_y])]
   1.173 +        if g:
   1.174 +            ls.append(gather_shares(g))
   1.175 +        if f:
   1.176 +            ls.append(gather_shares(f))
   1.177 +        result = gather_shares(ls)
   1.178 +        self.schedule_callback(result, compute_polynomials)
   1.179 +        result.addErrback(self.error_handler)
   1.180 +
   1.181 +        # do actual communication
   1.182 +        self.activate_reactor()
   1.183 +
   1.184 +        return result
   1.185 +
   1.186      def error_handler(self, ex):
   1.187          print "Error: ", ex
   1.188          return ex
     2.1 --- a/viff/test/test_orlandi_runtime.py	Tue Oct 06 10:05:24 2009 +0200
     2.2 +++ b/viff/test/test_orlandi_runtime.py	Tue Oct 06 10:05:24 2009 +0200
     2.3 @@ -334,7 +334,7 @@
     2.4  
     2.5          self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
     2.6  
     2.7 -        count = 20
     2.8 +        count = 9
     2.9  
    2.10          a_shares = []
    2.11          b_shares = []
    2.12 @@ -481,6 +481,89 @@
    2.13          x2 = runtime.shift([1], self.Zp, x1)
    2.14          y2 = runtime.shift([1], self.Zp, y1)
    2.15  
    2.16 +    @protocol
    2.17 +    def test_sum_poly(self, runtime):
    2.18 +        """Test implementation of sum_poly."""
    2.19 +
    2.20 +        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
    2.21 +
    2.22 +        f = []
    2.23 +        f.append((self.Zp(7), (self.Zp(7), self.Zp(7)), self.Zp(7)))
    2.24 +        f.append((self.Zp(9), (self.Zp(9), self.Zp(9)), self.Zp(9)))
    2.25 +        f.append((self.Zp(13), (self.Zp(13), self.Zp(13)), self.Zp(13)))
    2.26 +        
    2.27 +        x, (rho1, rho2), Cx = runtime.sum_poly(1, f)
    2.28 +        self.assertEquals(x, 29)
    2.29 +        self.assertEquals(rho1, 29)
    2.30 +        self.assertEquals(rho2, 29)
    2.31 +        self.assertEquals(Cx, 29)
    2.32 +        return x
    2.33 + 
    2.34 +    @protocol
    2.35 +    def test_sum_poly(self, runtime):
    2.36 +        """Test implementation of sum_poly."""
    2.37 +
    2.38 +        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
    2.39 +        
    2.40 +        Cf1 = commitment.commit(21, 21, 21)
    2.41 +        Cf2 = commitment.commit(27, 27, 27)
    2.42 +        Cf3 = commitment.commit(39, 39, 39)
    2.43 +
    2.44 +        f = []
    2.45 +        f.append((self.Zp(7), (self.Zp(7), self.Zp(7)), Cf1))
    2.46 +        f.append((self.Zp(9), (self.Zp(9), self.Zp(9)), Cf2))
    2.47 +        f.append((self.Zp(13), (self.Zp(13), self.Zp(13)), Cf3))
    2.48 +        
    2.49 +        x, (rho1, rho2), Cx = runtime.sum_poly(3, f)
    2.50 +        self.assertEquals(x, 453)
    2.51 +        self.assertEquals(rho1, 453)
    2.52 +        self.assertEquals(rho2, 453)
    2.53 +        self.assertEquals(Cx, Cf1**3 * Cf2**9 * Cf3**27)
    2.54 +        return x
    2.55 +
    2.56 +    @protocol
    2.57 +    def test_delta(self, runtime):
    2.58 +        """Test implementation of compute_delta."""
    2.59 +
    2.60 +        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
    2.61 +
    2.62 +        delta = runtime.compute_delta(3)
    2.63 +        self.assertEquals(delta[0], 7)
    2.64 +        self.assertEquals(delta[1], -21)
    2.65 +        self.assertEquals(delta[2], 35)
    2.66 +        self.assertEquals(delta[3], -35)
    2.67 +        self.assertEquals(delta[4], 21)
    2.68 +        self.assertEquals(delta[5], -7)
    2.69 +        self.assertEquals(delta[6], 1)
    2.70 + 
    2.71 +        return delta
    2.72 +
    2.73 +    @protocol
    2.74 +    def test_leak_mul(self, runtime):
    2.75 +        """Test leaktolerant multiplication of two numbers."""
    2.76 +        commitment.set_reference_string(long(2), long(6))
    2.77 +
    2.78 +        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
    2.79 +       
    2.80 +        x1 = 42
    2.81 +        y1 = 7
    2.82 +
    2.83 +        d = 4
    2.84 +        
    2.85 +        def check(v):
    2.86 +            self.assertEquals(v, x1 * y1)
    2.87 + 
    2.88 +        x2 = runtime.shift([1], self.Zp, x1)
    2.89 +        y2 = runtime.shift([2], self.Zp, y1)
    2.90 +
    2.91 +        M = []
    2.92 +        for j in xrange(1, 2*d + 2):
    2.93 +            M.append(_get_triple(self, self.Zp))
    2.94 +        z2 = runtime.leak_tolerant_mul(x2, y2, M)
    2.95 +        d = runtime.open(z2)
    2.96 +        d.addCallback(check)
    2.97 +        return d
    2.98 +
    2.99          z2 = runtime._cmul(y2, x2, self.Zp)
   2.100          self.assertEquals(z2, None)
   2.101          return z2