viff

changeset 1137:b6d229859b5b

Added an inversion by exponentiation variant with minimal number of multiplications and less rounds. Changed the internal structure regarding the inversions, the choice is now made in __init__().
author Marcel Keller <mkeller@cs.au.dk>
date Tue, 17 Feb 2009 18:15:32 +0100
parents 72b7a0717627
children cecc7b3c6eb0
files viff/aes.py
diffstat 1 files changed, 89 insertions(+), 60 deletions(-) [+]
line diff
     1.1 --- a/viff/aes.py	Tue Feb 17 11:21:38 2009 +0100
     1.2 +++ b/viff/aes.py	Tue Feb 17 18:15:32 2009 +0100
     1.3 @@ -72,7 +72,7 @@
     1.4      """
     1.5  
     1.6      def __init__(self, runtime, key_size, block_size=128, 
     1.7 -                 use_exponentiation=False, use_square_and_multiply=False):
     1.8 +                 use_exponentiation=False):
     1.9          """Initialize Rijndael.
    1.10  
    1.11          AES(runtime, key_size, block_size), whereas key size and block
    1.12 @@ -87,8 +87,93 @@
    1.13          self.n_b = block_size / 32
    1.14          self.rounds = max(self.n_k, self.n_b) + 6
    1.15          self.runtime = runtime
    1.16 -        self.use_exponentiation = use_exponentiation
    1.17 -        self.use_square_and_multiply = use_square_and_multiply
    1.18 +
    1.19 +        if (use_exponentiation is not False):
    1.20 +            if (isinstance(use_exponentiation, int) and
    1.21 +                use_exponentiation < len(AES.exponentiation_variants)):
    1.22 +                use_exponentiation = \
    1.23 +                    AES.exponentiation_variants[use_exponentiation]
    1.24 +            elif (use_exponentiation not in AES.exponentation_variants):
    1.25 +                use_exponentiation = "shortest_sequential_chain"
    1.26 +
    1.27 +            print "Use %s for inversion by exponentiation." % \
    1.28 +                use_exponentiation
    1.29 +
    1.30 +            if (use_exponentiation == "standard_square_and_multiply"):
    1.31 +                self.invert = lambda byte: byte ** 254
    1.32 +            elif (use_exponentiation == "shortest_chain_with_least_rounds"):
    1.33 +                self.invert = self.invert_by_exponentiation_with_less_rounds
    1.34 +            else:
    1.35 +                self.invert = self.invert_by_exponentiation
    1.36 +        else:
    1.37 +            self.invert = self.invert_by_masking
    1.38 +            print "Use inversion by masking."
    1.39 +
    1.40 +    exponentiation_variants = ["standard_square_and_multiply",
    1.41 +                               "shortest_sequential_chain",
    1.42 +                               "shortest_chain_with_least_rounds"]
    1.43 +
    1.44 +    def invert_by_masking(self, byte):
    1.45 +        bits = bit_decompose(byte)
    1.46 +
    1.47 +        for j in range(len(bits)):
    1.48 +            bits[j].addCallback(lambda x: GF256(1) - x)
    1.49 +#            bits[j] = 1 - bits[j]
    1.50 +
    1.51 +        while(len(bits) > 1):
    1.52 +            bits.append(bits.pop(0) * bits.pop(0))
    1.53 +
    1.54 +        # b == 1 if byte is 0, b == 0 else
    1.55 +        b = bits[0]
    1.56 +
    1.57 +        r = Share(self.runtime, GF256)
    1.58 +        c = Share(self.runtime, GF256)
    1.59 +
    1.60 +        def get_masked_byte(c_opened, r_related, c, r, byte):
    1.61 +            if (c_opened == 0):
    1.62 +                r_trial = self.runtime.prss_share_random(GF256)
    1.63 +                c_trial = self.runtime.open((byte + b) * r_trial)
    1.64 +                c_trial.addCallback(get_masked_byte, r_trial,
    1.65 +                                    c, r, byte)
    1.66 +            else:
    1.67 +                r_related.addCallback(r.callback)
    1.68 +                c.callback(~c_opened)
    1.69 +
    1.70 +        get_masked_byte(0, None, c, r, byte)
    1.71 +
    1.72 +        # necessary to avoid communication in multiplication
    1.73 +        # was: return c * r - b
    1.74 +        result = gather_shares([c, r, b])
    1.75 +        result.addCallback(lambda (c, r, b): c * r - b)
    1.76 +        return result
    1.77 +
    1.78 +    def invert_by_exponentiation(self, byte):
    1.79 +        byte_2 = byte * byte
    1.80 +        byte_3 = byte_2 * byte
    1.81 +        byte_6 = byte_3 * byte_3
    1.82 +        byte_12 = byte_6 * byte_6
    1.83 +        byte_15 = byte_12 * byte_3
    1.84 +        byte_30 = byte_15 * byte_15
    1.85 +        byte_60 = byte_30 * byte_30
    1.86 +        byte_63 = byte_60 * byte_3
    1.87 +        byte_126 = byte_63 * byte_63
    1.88 +        byte_252 = byte_126 * byte_126
    1.89 +        byte_254 = byte_252 * byte_2
    1.90 +        return byte_254
    1.91 +
    1.92 +    def invert_by_exponentiation_with_less_rounds(self, byte):
    1.93 +        byte_2 = byte * byte
    1.94 +        byte_4 = byte_2 * byte_2
    1.95 +        byte_8 = byte_4 * byte_4
    1.96 +        byte_9 = byte_8 * byte
    1.97 +        byte_16 = byte_8 * byte_8
    1.98 +        byte_25 = byte_16 * byte_9
    1.99 +        byte_50 = byte_25 * byte_25
   1.100 +        byte_54 = byte_50 * byte_4
   1.101 +        byte_100 = byte_50 * byte_50
   1.102 +        byte_200 = byte_100 * byte_100
   1.103 +        byte_254 = byte_200 * byte_54
   1.104 +        return byte_254
   1.105  
   1.106      # matrix for byte_sub, the last column is the translation vector
   1.107      A = Matrix([[1,0,0,0,1,1,1,1, 1],
   1.108 @@ -106,67 +191,11 @@
   1.109          The first argument should be a matrix consisting of elements
   1.110          of GF(2^8)."""
   1.111  
   1.112 -        def invert_by_masking(byte):
   1.113 -            bits = bit_decompose(byte)
   1.114 -
   1.115 -            for j in range(len(bits)):
   1.116 -                bits[j].addCallback(lambda x: GF256(1) - x)
   1.117 -#                bits[j] = 1 - bits[j]
   1.118 -
   1.119 -            while(len(bits) > 1):
   1.120 -                bits.append(bits.pop(0) * bits.pop(0))
   1.121 -
   1.122 -            # b == 1 if byte is 0, b == 0 else
   1.123 -            b = bits[0]
   1.124 -
   1.125 -            r = Share(self.runtime, GF256)
   1.126 -            c = Share(self.runtime, GF256)
   1.127 -
   1.128 -            def get_masked_byte(c_opened, r_related, c, r, byte):
   1.129 -                if (c_opened == 0):
   1.130 -                    r_trial = self.runtime.prss_share_random(GF256)
   1.131 -                    c_trial = self.runtime.open((byte + b) * r_trial)
   1.132 -                    c_trial.addCallback(get_masked_byte, r_trial,
   1.133 -                                        c, r, byte)
   1.134 -                else:
   1.135 -                    r_related.addCallback(r.callback)
   1.136 -                    c.callback(~c_opened)
   1.137 -
   1.138 -            get_masked_byte(0, None, c, r, byte)
   1.139 -
   1.140 -            # necessary to avoid communication in multiplication
   1.141 -            # was: return c * r - b
   1.142 -            result = gather_shares([c, r, b])
   1.143 -            result.addCallback(lambda (c, r, b): c * r - b)
   1.144 -            return result
   1.145 -
   1.146 -        def invert_by_exponentiation(byte):
   1.147 -            byte_2 = byte * byte
   1.148 -            byte_3 = byte_2 * byte
   1.149 -            byte_6 = byte_3 * byte_3
   1.150 -            byte_12 = byte_6 * byte_6
   1.151 -            byte_15 = byte_12 * byte_3
   1.152 -            byte_30 = byte_15 * byte_15
   1.153 -            byte_60 = byte_30 * byte_30
   1.154 -            byte_63 = byte_60 * byte_3
   1.155 -            byte_126 = byte_63 * byte_63
   1.156 -            byte_252 = byte_126 * byte_126
   1.157 -            byte_254 = byte_252 * byte_2
   1.158 -            return byte_254
   1.159 -
   1.160 -        if (self.use_exponentiation):
   1.161 -            if (self.use_square_and_multiply):
   1.162 -                invert = lambda byte: byte ** 254
   1.163 -            else:
   1.164 -                invert = invert_by_exponentiation
   1.165 -        else:
   1.166 -            invert = invert_by_masking
   1.167 -
   1.168          for h in range(len(state)):
   1.169              row = state[h]
   1.170              
   1.171              for i in range(len(row)):
   1.172 -                bits = bit_decompose(invert(row[i]))
   1.173 +                bits = bit_decompose(self.invert(row[i]))
   1.174  
   1.175                  # include the translation in the matrix multiplication
   1.176                  # (see definition of AES.A)