changeset 1137:b6d229859b5b

Added an inversion by exponentiation variant with minimal number of multiplications and less rounds. Changed the internal structure regarding the inversions, the choice is now made in __init__().
author Marcel Keller <mkeller@cs.au.dk>
date Tue, 17 Feb 2009 18:15:32 +0100
parents 72b7a0717627
children cecc7b3c6eb0
files viff/aes.py
diffstat 1 files changed, 89 insertions(+), 60 deletions(-) [+]
line wrap: on
line diff
--- a/viff/aes.py	Tue Feb 17 11:21:38 2009 +0100
+++ b/viff/aes.py	Tue Feb 17 18:15:32 2009 +0100
@@ -72,7 +72,7 @@
     """
 
     def __init__(self, runtime, key_size, block_size=128, 
-                 use_exponentiation=False, use_square_and_multiply=False):
+                 use_exponentiation=False):
         """Initialize Rijndael.
 
         AES(runtime, key_size, block_size), whereas key size and block
@@ -87,8 +87,93 @@
         self.n_b = block_size / 32
         self.rounds = max(self.n_k, self.n_b) + 6
         self.runtime = runtime
-        self.use_exponentiation = use_exponentiation
-        self.use_square_and_multiply = use_square_and_multiply
+
+        if (use_exponentiation is not False):
+            if (isinstance(use_exponentiation, int) and
+                use_exponentiation < len(AES.exponentiation_variants)):
+                use_exponentiation = \
+                    AES.exponentiation_variants[use_exponentiation]
+            elif (use_exponentiation not in AES.exponentation_variants):
+                use_exponentiation = "shortest_sequential_chain"
+
+            print "Use %s for inversion by exponentiation." % \
+                use_exponentiation
+
+            if (use_exponentiation == "standard_square_and_multiply"):
+                self.invert = lambda byte: byte ** 254
+            elif (use_exponentiation == "shortest_chain_with_least_rounds"):
+                self.invert = self.invert_by_exponentiation_with_less_rounds
+            else:
+                self.invert = self.invert_by_exponentiation
+        else:
+            self.invert = self.invert_by_masking
+            print "Use inversion by masking."
+
+    exponentiation_variants = ["standard_square_and_multiply",
+                               "shortest_sequential_chain",
+                               "shortest_chain_with_least_rounds"]
+
+    def invert_by_masking(self, byte):
+        bits = bit_decompose(byte)
+
+        for j in range(len(bits)):
+            bits[j].addCallback(lambda x: GF256(1) - x)
+#            bits[j] = 1 - bits[j]
+
+        while(len(bits) > 1):
+            bits.append(bits.pop(0) * bits.pop(0))
+
+        # b == 1 if byte is 0, b == 0 else
+        b = bits[0]
+
+        r = Share(self.runtime, GF256)
+        c = Share(self.runtime, GF256)
+
+        def get_masked_byte(c_opened, r_related, c, r, byte):
+            if (c_opened == 0):
+                r_trial = self.runtime.prss_share_random(GF256)
+                c_trial = self.runtime.open((byte + b) * r_trial)
+                c_trial.addCallback(get_masked_byte, r_trial,
+                                    c, r, byte)
+            else:
+                r_related.addCallback(r.callback)
+                c.callback(~c_opened)
+
+        get_masked_byte(0, None, c, r, byte)
+
+        # necessary to avoid communication in multiplication
+        # was: return c * r - b
+        result = gather_shares([c, r, b])
+        result.addCallback(lambda (c, r, b): c * r - b)
+        return result
+
+    def invert_by_exponentiation(self, byte):
+        byte_2 = byte * byte
+        byte_3 = byte_2 * byte
+        byte_6 = byte_3 * byte_3
+        byte_12 = byte_6 * byte_6
+        byte_15 = byte_12 * byte_3
+        byte_30 = byte_15 * byte_15
+        byte_60 = byte_30 * byte_30
+        byte_63 = byte_60 * byte_3
+        byte_126 = byte_63 * byte_63
+        byte_252 = byte_126 * byte_126
+        byte_254 = byte_252 * byte_2
+        return byte_254
+
+    def invert_by_exponentiation_with_less_rounds(self, byte):
+        byte_2 = byte * byte
+        byte_4 = byte_2 * byte_2
+        byte_8 = byte_4 * byte_4
+        byte_9 = byte_8 * byte
+        byte_16 = byte_8 * byte_8
+        byte_25 = byte_16 * byte_9
+        byte_50 = byte_25 * byte_25
+        byte_54 = byte_50 * byte_4
+        byte_100 = byte_50 * byte_50
+        byte_200 = byte_100 * byte_100
+        byte_254 = byte_200 * byte_54
+        return byte_254
 
     # matrix for byte_sub, the last column is the translation vector
     A = Matrix([[1,0,0,0,1,1,1,1, 1],
@@ -106,67 +191,11 @@
         The first argument should be a matrix consisting of elements
         of GF(2^8)."""
 
-        def invert_by_masking(byte):
-            bits = bit_decompose(byte)
-
-            for j in range(len(bits)):
-                bits[j].addCallback(lambda x: GF256(1) - x)
-#                bits[j] = 1 - bits[j]
-
-            while(len(bits) > 1):
-                bits.append(bits.pop(0) * bits.pop(0))
-
-            # b == 1 if byte is 0, b == 0 else
-            b = bits[0]
-
-            r = Share(self.runtime, GF256)
-            c = Share(self.runtime, GF256)
-
-            def get_masked_byte(c_opened, r_related, c, r, byte):
-                if (c_opened == 0):
-                    r_trial = self.runtime.prss_share_random(GF256)
-                    c_trial = self.runtime.open((byte + b) * r_trial)
-                    c_trial.addCallback(get_masked_byte, r_trial,
-                                        c, r, byte)
-                else:
-                    r_related.addCallback(r.callback)
-                    c.callback(~c_opened)
-
-            get_masked_byte(0, None, c, r, byte)
-
-            # necessary to avoid communication in multiplication
-            # was: return c * r - b
-            result = gather_shares([c, r, b])
-            result.addCallback(lambda (c, r, b): c * r - b)
-            return result
-
-        def invert_by_exponentiation(byte):
-            byte_2 = byte * byte
-            byte_3 = byte_2 * byte
-            byte_6 = byte_3 * byte_3
-            byte_12 = byte_6 * byte_6
-            byte_15 = byte_12 * byte_3
-            byte_30 = byte_15 * byte_15
-            byte_60 = byte_30 * byte_30
-            byte_63 = byte_60 * byte_3
-            byte_126 = byte_63 * byte_63
-            byte_252 = byte_126 * byte_126
-            byte_254 = byte_252 * byte_2
-            return byte_254
-
-        if (self.use_exponentiation):
-            if (self.use_square_and_multiply):
-                invert = lambda byte: byte ** 254
-            else:
-                invert = invert_by_exponentiation
-        else:
-            invert = invert_by_masking
-
         for h in range(len(state)):
             row = state[h]
             
             for i in range(len(row)):
-                bits = bit_decompose(invert(row[i]))
+                bits = bit_decompose(self.invert(row[i]))
 
                 # include the translation in the matrix multiplication
                 # (see definition of AES.A)