viff

changeset 1330:6f02ecfa60e1

orlandi: Adapted to new preprocessing interface.
author Marcel Keller <mkeller@cs.au.dk>
date Sat, 24 Oct 2009 18:39:19 +0200
parents c4803511dbf8
children 7856dceaf7b5
files viff/orlandi.py viff/test/test_orlandi_runtime.py
diffstat 2 files changed, 32 insertions(+), 45 deletions(-) [+]
line diff
     1.1 --- a/viff/orlandi.py	Fri Oct 23 15:03:30 2009 +0200
     1.2 +++ b/viff/orlandi.py	Sat Oct 24 18:39:19 2009 +0200
     1.3 @@ -520,17 +520,11 @@
     1.4  
     1.5          field = getattr(share_x, "field", getattr(share_y, "field", None))
     1.6  
     1.7 -        def finish_mul((a, b, c)):
     1.8 -            return self._basic_multiplication(share_x, share_y, a, b, c)
     1.9 -
    1.10 -        # This will be the result, a Share object.
    1.11 -        result = Share(self, share_x.field)
    1.12 -        # This is the Deferred we will do processing on.
    1.13 -        triple = self._get_triple(field)
    1.14 -        triple = self.schedule_complex_callback(triple, finish_mul)
    1.15 -        # We add the result to the chains in triple.
    1.16 -        triple.chainDeferred(result)
    1.17 -        return result
    1.18 +        triple, prep = self._get_triple(field)
    1.19 +        if prep:
    1.20 +            # The data from the pool must be wrapped in Shares.
    1.21 +            triple = [Share(self, field, i) for i in triple]
    1.22 +        return self._basic_multiplication(share_x, share_y, *triple)
    1.23  
    1.24      def _additive_constant(self, zero, field_element):
    1.25          """Greate an additive constant.
    1.26 @@ -623,11 +617,13 @@
    1.27  
    1.28      @preprocess("random_triple")
    1.29      def _get_triple(self, field):
    1.30 -        c, d = self.random_triple(field, 1)
    1.31 -        def f(ls):
    1.32 -            return ls[0]
    1.33 -        d.addCallbacks(f, self.error_handler)
    1.34 -        return d
    1.35 +        results = [Share(self, field) for i in range(3)]
    1.36 +        def chain(triple, results):
    1.37 +            for i, result in zip(triple, results):
    1.38 +                result.callback(i)
    1.39 +        self.random_triple(field, 1)[0].addCallbacks(chain, self.error_handler,
    1.40 +                                                     (results,))
    1.41 +        return results
    1.42  
    1.43      def _basic_multiplication(self, share_x, share_y, triple_a, triple_b, triple_c):
    1.44          """Multiplication of shares give a triple.
    1.45 @@ -1314,6 +1310,7 @@
    1.46              r.addErrback(self.error_handler)
    1.47              return r
    1.48  
    1.49 +        results = [Deferred() for i in xrange(quantity)]
    1.50  
    1.51          def step7(Msets):
    1.52              """For i = 1,...,M do:
    1.53 @@ -1323,21 +1320,15 @@
    1.54              d) Open([c] + [r])
    1.55              """
    1.56              ds = []
    1.57 -            for Mi in Msets:
    1.58 +            for Mi, result in zip(Msets, results):
    1.59                  a = self.random_share(field)
    1.60                  b = self.random_share(field)
    1.61                  r = self.random_share(field)
    1.62                  c = self.leak_tolerant_mul(a, b, Mi)
    1.63                  d = self.open(c + r)
    1.64 -                def return_abc(x, a, b, c):
    1.65 -                    return a, b, c
    1.66 -                d.addCallbacks(return_abc, self.error_handler, callbackArgs=(a, b, c))
    1.67 -                ds.append(d)
    1.68 -            result = gather_shares(ds)
    1.69 -            def triples(ls):
    1.70 -                return ls
    1.71 -            result.addCallbacks(triples, self.error_handler)
    1.72 -            return result
    1.73 +                def return_abc(x, a, b, c, result):
    1.74 +                    gatherResults([a, b, c]).chainDeferred(result)
    1.75 +                d.addCallbacks(return_abc, self.error_handler, callbackArgs=(a, b, c, result))
    1.76  
    1.77          result = gatherResults(M)
    1.78          self.schedule_callback(result, step3)
    1.79 @@ -1348,12 +1339,7 @@
    1.80  
    1.81          # do actual communication
    1.82          self.activate_reactor()
    1.83 -
    1.84 -        s = Share(self, field)
    1.85 -        # We add the result to the chains in result.
    1.86 -        result.chainDeferred(s)
    1.87 -
    1.88 -        return quantity, s
    1.89 +        return results
    1.90  
    1.91      def error_handler(self, ex):
    1.92          print "Error: ", ex
     2.1 --- a/viff/test/test_orlandi_runtime.py	Fri Oct 23 15:03:30 2009 +0200
     2.2 +++ b/viff/test/test_orlandi_runtime.py	Sat Oct 24 18:39:19 2009 +0200
     2.3 @@ -18,7 +18,7 @@
     2.4  from twisted.internet.defer import gatherResults, DeferredList
     2.5  
     2.6  from viff.test.util import RuntimeTestCase, protocol
     2.7 -from viff.runtime import gather_shares
     2.8 +from viff.runtime import gather_shares, Share
     2.9  try:
    2.10      from viff.orlandi import OrlandiRuntime, OrlandiShare
    2.11      import commitment
    2.12 @@ -560,9 +560,10 @@
    2.13          x2 = runtime.shift([1], self.Zp, x1)
    2.14          y2 = runtime.shift([2], self.Zp, y1)
    2.15  
    2.16 -        c, sls = runtime.random_triple(self.Zp, 2*runtime.d + 1)
    2.17 +        sls = gatherResults(runtime.random_triple(self.Zp, 2*runtime.d + 1))
    2.18  
    2.19          def cont(M):
    2.20 +            M = [[Share(self, self.Zp, j) for j in i] for i in M]
    2.21              z2 = runtime.leak_tolerant_mul(x2, y2, M)
    2.22              d = runtime.open(z2)
    2.23              d.addCallback(check)
    2.24 @@ -666,9 +667,9 @@
    2.25          def open(ls):
    2.26              ds = []
    2.27              for (a, b, c) in ls:
    2.28 -                d1 = runtime.open(a)
    2.29 -                d2 = runtime.open(b)
    2.30 -                d3 = runtime.open(c)
    2.31 +                d1 = runtime.open(Share(self, self.Zp, a))
    2.32 +                d2 = runtime.open(Share(self, self.Zp, b))
    2.33 +                d3 = runtime.open(Share(self, self.Zp, c))
    2.34                  ds.append(d1)
    2.35                  ds.append(d2)
    2.36                  ds.append(d3)
    2.37 @@ -676,7 +677,7 @@
    2.38              d = gatherResults(ds)
    2.39              d.addCallback(check)
    2.40              return d
    2.41 -        c, d = runtime.random_triple(self.Zp, 1)
    2.42 +        d = gatherResults(runtime.random_triple(self.Zp, 1))
    2.43          d.addCallbacks(open, runtime.error_handler)
    2.44          return d
    2.45  
    2.46 @@ -696,9 +697,9 @@
    2.47          def open(ls):
    2.48              ds = []
    2.49              for [(a, b, c)] in ls:
    2.50 -                d1 = runtime.open(a)
    2.51 -                d2 = runtime.open(b)
    2.52 -                d3 = runtime.open(c)
    2.53 +                d1 = runtime.open(Share(self, self.Zp, a))
    2.54 +                d2 = runtime.open(Share(self, self.Zp, b))
    2.55 +                d3 = runtime.open(Share(self, self.Zp, c))
    2.56                  ds.append(d1)
    2.57                  ds.append(d2)
    2.58                  ds.append(d3)
    2.59 @@ -706,10 +707,10 @@
    2.60              d = gatherResults(ds)
    2.61              d.addCallback(check)
    2.62              return d
    2.63 -        ac, a = runtime.random_triple(self.Zp, 1)
    2.64 -        bc, b = runtime.random_triple(self.Zp, 1)
    2.65 -        cc, c = runtime.random_triple(self.Zp, 1)
    2.66 -        d = gather_shares([a, b, c])
    2.67 +        a = gatherResults(runtime.random_triple(self.Zp, 1))
    2.68 +        b = gatherResults(runtime.random_triple(self.Zp, 1))
    2.69 +        c = gatherResults(runtime.random_triple(self.Zp, 1))
    2.70 +        d = gatherResults([a, b, c])
    2.71          d.addCallbacks(open, runtime.error_handler)
    2.72          return d
    2.73