changeset 1230:86e0c7d54f22

Implementation of the leak tolerant multiplication command.
author Janus Dam Nielsen <janus.nielsen@alexandra.dk>
date Tue, 06 Oct 2009 10:05:24 +0200
parents eb0443115206
children db2d970885f4
files viff/orlandi.py viff/test/test_orlandi_runtime.py
diffstat 2 files changed, 240 insertions(+), 3 deletions(-) [+]
line wrap: on
line diff
--- a/viff/orlandi.py	Tue Oct 06 10:05:24 2009 +0200
+++ b/viff/orlandi.py	Tue Oct 06 10:05:24 2009 +0200
@@ -81,6 +81,21 @@
         Runtime.__init__(self, player, threshold, options)
         self.threshold = self.num_players - 1
 
+    def compute_delta(self, d):
+        def product(j):
+            pt = 1
+            pn = 1
+            for k in xrange(1, 2 * d + 2):
+                if k != j:
+                    pt *= k
+                    pn *= k - j
+            return pt // pn
+
+        delta = []
+        for j in xrange(1, 2 * d + 2):
+            delta.append(product(j))
+        return delta
+
     def output(self, share, receivers=None, threshold=None):
         return self.open(share, receivers, threshold)
 
@@ -243,8 +258,9 @@
             if Cx1 == Cx:
                 return x
             else:
-                raise OrlandiException("Wrong commitment for value %s, found %s expected %s." % 
-                                       (x, Cx1, Cx))
+                #return x
+                raise OrlandiException("Wrong commitment for value %s, %s, %s, found %s expected %s." % 
+                                       (x, rho1, rho2, Cx1, Cx))
 
         def deserialize(ls):
             shares = [(field(long(x)), field(long(rho1)), field(long(rho2))) for x, rho1, rho2 in map(self.list_str, ls)]
@@ -559,6 +575,11 @@
         Cz = Cx**c
         return (zi, rhoz, Cz)
 
+    def _get_share(self, field, value):
+        Cc = commitment.commit(value * 3, 0, 0)
+        c = OrlandiShare(self, field, field(value), (field(0), field(0)), Cc)
+        return c
+
     def _get_triple(self, field):
         n = field(0)
         Ca = commitment.commit(6, 0, 0)
@@ -617,6 +638,139 @@
 
         return result
 
+    def sum_poly(self, j, ls):
+        exp  = j
+        fj, (rhoj1, rhoj2), Cfj = ls[0]
+        x    = fj*exp
+        rho1 = rhoj1 * exp
+        rho2 = rhoj2 * exp
+        Cx   = Cfj**exp
+        exp *= j
+
+        for (fj, (rhoj1, rhoj2), Cfj) in ls[1:]:
+            x += fj * exp
+            rho1 += rhoj1 * exp
+            rho2 += rhoj2 * exp
+            Cx = Cx * (Cfj**exp)
+            exp *= j
+        return x, (rho1, rho2), Cx
+
+    def leak_tolerant_mul(self, share_x, share_y, M):
+        """Leak tolerant multiplication of shares.
+
+        Communication cost: ???.
+
+        Assuming a set of multiplicative triples:
+        ``M = ([a_i], [b_i], [c_i]) for 1 <= i <= 2d + 1``.
+
+        1) ``for i = 1, ..., d do [f_i] = rand(), [g_i] = rand()``
+
+        2) ``for j = 1, ..., 2d+1 do
+             [F_j] = [x] + SUM_i=1^d [f_i]*j^i 
+             and
+             [G_j] = [y] + SUM_i=1^d [g_i]*j^i`` 
+
+        3) for j = 1, ..., 2d+1 do [H_j] = Mul([F_j], [G_j], [a_j], [b_j], [c_j])
+
+        4) compute [H_0] = SUM_j=1^2d+1 delta_j[H_j] 
+
+        5) output [z] = [H_0]
+
+        delta_j = PRODUCT_k=1, k!=j^2d+1 k/(k-j).
+        """
+        assert isinstance(share_x, Share) or isinstance(share_y, Share), \
+            "At least one of share_x and share_y must be a Share."
+
+        self.program_counter[-1] += 1
+
+        field = getattr(share_x, "field", getattr(share_y, "field", None))
+        n = field(0)
+
+        cmul_result = self._cmul(share_x, share_y, field)
+        if cmul_result is not None:
+            return cmul_result
+
+        # 1) for i = 1, ..., d do [f_i] = rand(), [g_i] = rand()
+        d = (len(M) - 1) // 2
+        deltas = self.compute_delta(d)
+        f = []
+        g = []
+        for x in xrange(d):
+            f.append(self.random_share(field))
+            g.append(self.random_share(field))
+
+        def compute_polynomials(t):
+            x, y = t[0]
+            f = []
+            g = []
+            if 1 in t:
+                f = t[1]
+            if 2 in t:
+                g = t[2]
+#             print "==> poly", self.id
+#             print "x:", x
+#             print "y:", y
+#             print "f:", f
+#             print "g:", g
+            # 2) for j = 1, ..., 2d+1 do
+            # [F_j] = [x] + SUM_i=1^d [f_i]*j^i 
+            # and
+            # [G_j] = [y] + SUM_i=1^d [g_i]*j^i 
+            h0i, rhoh0, Ch0 = self._additive_constant(field(0), n)
+            H0 = OrlandiShare(self, field, h0i, rhoh0, Ch0)
+            xi, (rhoxi1, rhoxi2), Cx = x
+            yi, (rhoyi1, rhoyi2), Cy = y
+
+            for j in xrange(1, 2*d + 2):
+                Fji = xi
+                rho1_Fji = rhoxi1
+                rho2_Fji = rhoxi2
+                C_Fji = Cx
+                if f != []:
+                    # SUM_i=1^d [f_i]*j^i 
+                    vi, (rhovi1, rhovi2), Cv = self.sum_poly(j, f)
+                    # [F_j] = [x] + SUM_i=1^d [f_i]*j^i 
+                    Fji += vi
+                    rho1_Fji += rhovi1
+                    rho2_Fji += rhovi2
+                    C_Fji *= Cv
+                Gji = yi
+                rho1_Gji = rhoyi1
+                rho2_Gji = rhoyi2
+                C_Gji = Cy
+                if g != []:
+                    # SUM_i=1^d [g_i]*j^i 
+                    wi, (rhowi1, rhowi2), Cw = self.sum_poly(j, g)
+                    # [G_j] = [y] + SUM_i=1^d [g_i]*j^i
+                    Gji += wi
+                    rho1_Gji += rhowi1
+                    rho2_Gji += rhowi2
+                    C_Gji *= Cw
+                Fj = OrlandiShare(self, field, Fji, (rho1_Fji, rho2_Fji), C_Fji)
+                Gj = OrlandiShare(self, field, Gji, (rho1_Gji, rho2_Gji), C_Gji)
+                a, b, c = M.pop(0)
+
+                # [H_j] = Mul([F_j], [G_j], [a_j], [b_j], [c_j])
+                Hj = self._basic_multiplication(Fj, Gj, a, b, c)
+                dj = self._cmul(field(deltas[j - 1]), Hj, field)
+                H0 = H0 + dj
+            # 5) output [z] = [H_0]
+            return H0
+
+        ls = [gather_shares([share_x, share_y])]
+        if g:
+            ls.append(gather_shares(g))
+        if f:
+            ls.append(gather_shares(f))
+        result = gather_shares(ls)
+        self.schedule_callback(result, compute_polynomials)
+        result.addErrback(self.error_handler)
+
+        # do actual communication
+        self.activate_reactor()
+
+        return result
+
     def error_handler(self, ex):
         print "Error: ", ex
         return ex
--- a/viff/test/test_orlandi_runtime.py	Tue Oct 06 10:05:24 2009 +0200
+++ b/viff/test/test_orlandi_runtime.py	Tue Oct 06 10:05:24 2009 +0200
@@ -334,7 +334,7 @@
 
         self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
 
-        count = 20
+        count = 9
 
         a_shares = []
         b_shares = []
@@ -481,6 +481,89 @@
         x2 = runtime.shift([1], self.Zp, x1)
         y2 = runtime.shift([1], self.Zp, y1)
 
+    @protocol
+    def test_sum_poly(self, runtime):
+        """Test implementation of sum_poly."""
+
+        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
+
+        f = []
+        f.append((self.Zp(7), (self.Zp(7), self.Zp(7)), self.Zp(7)))
+        f.append((self.Zp(9), (self.Zp(9), self.Zp(9)), self.Zp(9)))
+        f.append((self.Zp(13), (self.Zp(13), self.Zp(13)), self.Zp(13)))
+        
+        x, (rho1, rho2), Cx = runtime.sum_poly(1, f)
+        self.assertEquals(x, 29)
+        self.assertEquals(rho1, 29)
+        self.assertEquals(rho2, 29)
+        self.assertEquals(Cx, 29)
+        return x
+ 
+    @protocol
+    def test_sum_poly(self, runtime):
+        """Test implementation of sum_poly."""
+
+        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
+        
+        Cf1 = commitment.commit(21, 21, 21)
+        Cf2 = commitment.commit(27, 27, 27)
+        Cf3 = commitment.commit(39, 39, 39)
+
+        f = []
+        f.append((self.Zp(7), (self.Zp(7), self.Zp(7)), Cf1))
+        f.append((self.Zp(9), (self.Zp(9), self.Zp(9)), Cf2))
+        f.append((self.Zp(13), (self.Zp(13), self.Zp(13)), Cf3))
+        
+        x, (rho1, rho2), Cx = runtime.sum_poly(3, f)
+        self.assertEquals(x, 453)
+        self.assertEquals(rho1, 453)
+        self.assertEquals(rho2, 453)
+        self.assertEquals(Cx, Cf1**3 * Cf2**9 * Cf3**27)
+        return x
+
+    @protocol
+    def test_delta(self, runtime):
+        """Test implementation of compute_delta."""
+
+        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
+
+        delta = runtime.compute_delta(3)
+        self.assertEquals(delta[0], 7)
+        self.assertEquals(delta[1], -21)
+        self.assertEquals(delta[2], 35)
+        self.assertEquals(delta[3], -35)
+        self.assertEquals(delta[4], 21)
+        self.assertEquals(delta[5], -7)
+        self.assertEquals(delta[6], 1)
+ 
+        return delta
+
+    @protocol
+    def test_leak_mul(self, runtime):
+        """Test leaktolerant multiplication of two numbers."""
+        commitment.set_reference_string(long(2), long(6))
+
+        self.Zp = GF(6277101735386680763835789423176059013767194773182842284081)
+       
+        x1 = 42
+        y1 = 7
+
+        d = 4
+        
+        def check(v):
+            self.assertEquals(v, x1 * y1)
+ 
+        x2 = runtime.shift([1], self.Zp, x1)
+        y2 = runtime.shift([2], self.Zp, y1)
+
+        M = []
+        for j in xrange(1, 2*d + 2):
+            M.append(_get_triple(self, self.Zp))
+        z2 = runtime.leak_tolerant_mul(x2, y2, M)
+        d = runtime.open(z2)
+        d.addCallback(check)
+        return d
+
         z2 = runtime._cmul(y2, x2, self.Zp)
         self.assertEquals(z2, None)
         return z2