changeset 1200:5da110d6e5b2

Expand the key successively instead of at once.
author Marcel Keller <mkeller@cs.au.dk>
date Thu, 16 Jul 2009 12:25:22 +0200
parents f640ea5f0920
children 0fb5d4da2f1e
files viff/aes.py
diffstat 1 files changed, 16 insertions(+), 8 deletions(-) [+]
line wrap: on
line diff
--- a/viff/aes.py	Fri Jul 10 13:27:42 2009 +0200
+++ b/viff/aes.py	Thu Jul 16 12:25:22 2009 +0200
@@ -274,17 +274,22 @@
 
         state[:] = (Matrix(state) + Matrix(zip(*round_key))).rows
 
-    def key_expansion(self, key):
+    def key_expansion(self, key, new_length=None):
         """Rijndael key expansion.
 
-        Input and output are lists of 4-byte columns (words)."""
+        Input and output are lists of 4-byte columns (words).
+        *new_length* is the round for which the key should be expanded.
+        If ommitted, the key is expanded for all rounds."""
 
-        assert len(key) == self.n_k, "Wrong key size."
+        assert len(key) >= self.n_k, "Wrong key size."
         assert len(key[0]) == 4, "Key must consist of 4-byte words."
 
-        expanded_key = list(key)
+        expanded_key = key
 
-        for i in xrange(self.n_k, self.n_b * (self.rounds + 1)):
+        if new_length == None:
+            new_length = self.rounds
+
+        for i in xrange(len(key), self.n_b * (new_length + 1)):
             temp = list(expanded_key[i - 1])
 
             if (i % self.n_k == 0):
@@ -355,8 +360,7 @@
             progress = lambda x, i, start_round: x
             prep_progress = lambda i, start_round: None
 
-        expanded_key = self.key_expansion(key)
-
+        expanded_key = self.key_expansion(key[:], 0)
         self.add_round_key(state, expanded_key[0:self.n_b])
 
         prep_progress(0, start)
@@ -366,7 +370,9 @@
 
         def round(_, state, i):
             start_round = time.time()
-            
+
+            self.key_expansion(expanded_key, i)
+
             self.byte_sub(state)
             self.shift_row(state)
             self.mix_column(state)
@@ -388,6 +394,8 @@
         def final_round(_, state):
             start_round = time.time()
 
+            self.key_expansion(expanded_key, self.rounds)
+
             self.byte_sub(state)
             self.shift_row(state)
             self.add_round_key(state, expanded_key[self.rounds*self.n_b:])