viff

changeset 1072:c503c9b40df0

AES KeyExpansion and AddRoundKey implemented.
author Marcel Keller <mkeller@cs.au.dk>
date Fri, 09 Jan 2009 14:22:58 +0100
parents f9fb8c387f8f
children 936ce049980f
files viff/aes.py viff/test/test_aes.py
diffstat 2 files changed, 71 insertions(+), 6 deletions(-) [+]
line diff
     1.1 --- a/viff/aes.py	Tue Dec 23 16:28:37 2008 +0100
     1.2 +++ b/viff/aes.py	Fri Jan 09 14:22:58 2009 +0100
     1.3 @@ -71,11 +71,7 @@
     1.4          """ByteSub operation of Rijndael.
     1.5  
     1.6          The first argument should be a matrix consisting of elements
     1.7 -        of GF(2^8) or shares thereof with 4 rows and block_size / 32
     1.8 -        elements."""
     1.9 -
    1.10 -        assert len(state) == 4, "State must have 4 rows."
    1.11 -        assert len(state[0]) == self.n_b, "State must have block_size / 32 columns"
    1.12 +        of GF(2^8)."""
    1.13  
    1.14          for h in range(len(state)):
    1.15              row = state[h]
    1.16 @@ -142,3 +138,42 @@
    1.17      def mix_column(self, state):
    1.18          state[:] = (AES.C * Matrix(state)).rows
    1.19  
    1.20 +    def add_round_key(self, state, round_key):
    1.21 +        """Rijndael AddRoundKey.
    1.22 +
    1.23 +        State should be a list of 4 rows and round_key a list of
    1.24 +        4-byte columns (words)."""
    1.25 +
    1.26 +        assert len(round_key) == self.n_k, "Wrong key size."
    1.27 +        assert len(round_key[0]) == 4, "Key must consist of 4-byte words."
    1.28 +
    1.29 +        state[:] = (Matrix(state) + Matrix(zip(*round_key))).rows
    1.30 +
    1.31 +    def key_expansion(self, key):
    1.32 +        """Rijndael key expansion.
    1.33 +
    1.34 +        Input and output are lists of 4-byte columns (words)."""
    1.35 +
    1.36 +        assert len(key) == self.n_k, "Wrong key size."
    1.37 +        assert len(key[0]) == 4, "Key must consist of 4-byte words."
    1.38 +
    1.39 +        expanded_key = list(key)
    1.40 +
    1.41 +        for i in xrange(self.n_k, self.n_b * (self.rounds + 1)):
    1.42 +            temp = list(expanded_key[i - 1])
    1.43 +
    1.44 +            if (i % self.n_k == 0):
    1.45 +                temp.append(temp.pop(0))
    1.46 +                self.byte_sub([temp])
    1.47 +                temp[0] += GF256(2) ** (i / self.n_k - 1)
    1.48 +            elif (self.n_k > 6 and i % self.n_k == 4):
    1.49 +                self.byte_sub([temp])
    1.50 +
    1.51 +            new_word = []
    1.52 +
    1.53 +            for j in xrange(4):
    1.54 +                new_word.append(expanded_key[i - self.n_k][j] + temp[j])
    1.55 +
    1.56 +            expanded_key.append(new_word)
    1.57 +
    1.58 +        return expanded_key
     2.1 --- a/viff/test/test_aes.py	Tue Dec 23 16:28:37 2008 +0100
     2.2 +++ b/viff/test/test_aes.py	Fri Jan 09 14:22:58 2009 +0100
     2.3 @@ -24,7 +24,7 @@
     2.4  from viff.runtime import gather_shares, Share
     2.5  from viff.aes import bit_decompose, AES
     2.6  
     2.7 -from viff.test.rijndael import S
     2.8 +from viff.test.rijndael import S, rijndael
     2.9  
    2.10  
    2.11  __doctest__ = ["viff.aes"]
    2.12 @@ -86,3 +86,33 @@
    2.13  
    2.14          aes.byte_sub(results)
    2.15          self.verify(runtime, results, expected_results)
    2.16 +
    2.17 +    @protocol
    2.18 +    def test_key_expansion(self, runtime):
    2.19 +        aes = AES(runtime, 256)
    2.20 +        key = []
    2.21 +        ascii_key = []
    2.22 +
    2.23 +        for i in xrange(8):
    2.24 +            key.append([])
    2.25 +
    2.26 +            for j in xrange(4):
    2.27 +                b = 15 * i + j
    2.28 +                key[i].append(Share(runtime, GF256, GF256(b)))
    2.29 +                ascii_key.append(chr(b))
    2.30 +
    2.31 +        result = aes.key_expansion(key)
    2.32 +
    2.33 +        r = rijndael(ascii_key)
    2.34 +        expected_result = []
    2.35 +
    2.36 +        for round_key in r.Ke:
    2.37 +            for word in round_key:
    2.38 +                split_word = []
    2.39 +                expected_result.append(split_word)
    2.40 +
    2.41 +                for j in xrange(4):
    2.42 +                    split_word.insert(0, word % 256)
    2.43 +                    word /= 256
    2.44 +
    2.45 +        self.verify(runtime, result, expected_result)