changeset 1083:e530dfcb40cd

Implemented inversion by exponentiation in AES ByteSub.
author Marcel Keller <mkeller@cs.au.dk>
date Thu, 15 Jan 2009 18:30:46 +0100
parents c1b3ced5bf05
children b4d9b373bbab
files viff/aes.py viff/test/test_aes.py
diffstat 2 files changed, 63 insertions(+), 34 deletions(-) [+]
line wrap: on
line diff
--- a/viff/aes.py	Wed Jan 14 18:33:33 2009 +0100
+++ b/viff/aes.py	Thu Jan 15 18:30:46 2009 +0100
@@ -51,7 +51,8 @@
 
 
 class AES:
-    def __init__(self, runtime, key_size, block_size=128):
+    def __init__(self, runtime, key_size, block_size=128, 
+                 use_exponentiation=False):
         """Initialize Rijndael.
 
         AES(runtime, key_size, block_size), whereas key size and block
@@ -66,6 +67,7 @@
         self.n_b = block_size / 32
         self.rounds = max(self.n_k, self.n_b) + 6
         self.runtime = runtime
+        self.use_exponentiation = use_exponentiation
 
     # matrix for byte_sub
     A = Matrix([[1,0,0,0,1,1,1,1],
@@ -83,39 +85,58 @@
         The first argument should be a matrix consisting of elements
         of GF(2^8)."""
 
+        def invert_by_masking(byte):
+            bits = bit_decompose(byte)
+
+            for j in range(len(bits)):
+                bits[j] = 1 - bits[j]
+
+            while(len(bits) > 1):
+                bits.append(bits.pop() * bits.pop())
+
+            # b == 1 if byte is 0, b == 0 else
+            b = bits[0]
+
+            r = Share(self.runtime, GF256)
+            c = Share(self.runtime, GF256)
+
+            def get_masked_byte(c_opened, r_related, c, r, byte):
+                if (c_opened == 0):
+                    r_trial = self.runtime.prss_share_random(GF256)
+                    c_trial = self.runtime.open((byte + b) * r_trial)
+                    c_trial.addCallback(get_masked_byte, r_trial,
+                                        c, r, byte)
+                else:
+                    r_related.addCallback(r.callback)
+                    c.callback(~c_opened)
+
+            get_masked_byte(0, None, c, r, byte)
+            return c * r - b
+
+        def invert_by_exponentiation(byte):
+            byte_2 = byte * byte
+            byte_3 = byte_2 * byte
+            byte_6 = byte_3 * byte_3
+            byte_12 = byte_6 * byte_6
+            byte_15 = byte_12 * byte_3
+            byte_30 = byte_15 * byte_15
+            byte_60 = byte_30 * byte_30
+            byte_63 = byte_60 * byte_3
+            byte_126 = byte_63 * byte_63
+            byte_252 = byte_126 * byte_126
+            byte_254 = byte_252 * byte_2
+            return byte_254
+
+        if (self.use_exponentiation):
+            invert = invert_by_exponentiation
+        else:
+            invert = invert_by_masking
+
         for h in range(len(state)):
             row = state[h]
             
             for i in range(len(row)):
-                byte = row[i]
-                bits = bit_decompose(byte)
-
-                for j in range(len(bits)):
-                    bits[j] = 1 - bits[j]
-
-                while(len(bits) > 1):
-                    bits.append(bits.pop() * bits.pop())
-
-                # b == 1 if byte is 0, b == 0 else
-                b = bits[0]
-
-                r = Share(self.runtime, GF256)
-                c = Share(self.runtime, GF256)
-
-                def get_masked_byte(c_opened, r_related, c, r, byte):
-                    if (c_opened == 0):
-                        r_trial = self.runtime.prss_share_random(GF256)
-                        c_trial = self.runtime.open((byte + b) * r_trial)
-                        c_trial.addCallback(get_masked_byte, r_trial,
-                                            c, r, byte)
-                    else:
-                        r_related.addCallback(r.callback)
-                        c.callback(~c_opened)
-
-                get_masked_byte(0, None, c, r, byte)
-                inverted_byte = c * r - b
-
-                bits = bit_decompose(inverted_byte)
+                bits = bit_decompose(invert(row[i]))
 
                 # caution: order is lsb first
                 vector = AES.A * Matrix(zip(bits)) + Matrix(zip([1,1,0,0,0,1,1,0]))
--- a/viff/test/test_aes.py	Wed Jan 14 18:33:33 2009 +0100
+++ b/viff/test/test_aes.py	Thu Jan 15 18:30:46 2009 +0100
@@ -69,9 +69,7 @@
 
         return gather_shares(opened_results)
 
-    @protocol
-    def test_byte_sub(self, runtime):
-        aes = AES(runtime, 128)
+    def _test_byte_sub(self, runtime, aes):
         results = []
         expected_results = []
 
@@ -81,13 +79,23 @@
 
             for j in range(4):
                 b = 60 * i + j
-                results[i].append(Share(runtime, GF256, b))
+                results[i].append(Share(runtime, GF256, GF256(b)))
                 expected_results[i].append(S[b])
 
         aes.byte_sub(results)
         self.verify(runtime, results, expected_results)
 
     @protocol
+    def test_byte_sub_with_masking(self, runtime):
+        self._test_byte_sub(runtime, AES(runtime, 128, 
+                                         use_exponentiation=False))
+
+    @protocol
+    def test_byte_sub_with_exponentiation(self, runtime):
+        self._test_byte_sub(runtime, AES(runtime, 128, 
+                                         use_exponentiation=True))
+
+    @protocol
     def test_key_expansion(self, runtime):
         aes = AES(runtime, 256)
         key = []