changeset 1094:bdab6f30511a

Merged with Marcel's AES code.
author Martin Geisler <mg@daimi.au.dk>
date Thu, 29 Jan 2009 16:19:49 +0100
parents aff12eb0e28c af545b802fd8
children e5bb773fb1fe
files NEWS apps/aes.py viff/passive.py
diffstat 6 files changed, 982 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/NEWS	Tue Jan 27 15:52:23 2009 +0100
+++ b/NEWS	Thu Jan 29 16:19:49 2009 +0100
@@ -44,6 +44,8 @@
 * Exponentiation of shares by square-and-multiply for public
   exponents. This means that if x is a Share, x**7 now works.
 
+* Added multi-party AES encryption.
+
 
 Version 0.7.1, released on 2008-10-09
 -------------------------------------
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/apps/aes.py	Thu Jan 29 16:19:49 2009 +0100
@@ -0,0 +1,88 @@
+#!/usr/bin/python
+
+# Copyright 2009 VIFF Development Team.
+#
+# This file is part of VIFF, the Virtual Ideal Functionality Framework.
+#
+# VIFF is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Lesser General Public License (LGPL) as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# VIFF is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
+# Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with VIFF. If not, see <http://www.gnu.org/licenses/>.
+
+# This example shows how to use multi-party AES encryption.
+
+
+import sys
+import time
+from optparse import OptionParser
+
+from twisted.internet import reactor
+
+from viff.field import GF256
+from viff.runtime import Runtime, create_runtime, gather_shares
+from viff.config import load_config
+
+from viff.aes import bit_decompose,AES
+
+
+parser = OptionParser(usage="Usage: %prog [options] config_file")
+parser.add_option("-e", "--exponentiation", action="store_true",
+                  help="Use exponentiation to invert bytes (default).")
+parser.add_option("-m", "--masking", action="store_false", 
+                  dest="exponentiation", 
+                  help="Use masking to invert bytes.")
+parser.set_defaults(exponentiation=True)
+
+# Add standard VIFF options.
+Runtime.add_options(parser)
+
+(options, args) = parser.parse_args()
+
+if len(args) == 0:
+    parser.error("You must specify a config file.")
+
+id, players = load_config(args[0])
+
+def encrypt(_, rt, key):
+    start = time.time()
+    print "Started at %f." % start
+
+    aes = AES(rt, 192, use_exponentiation=options.exponentiation)
+    ciphertext = aes.encrypt("a" * 16, key, True)
+
+    opened_ciphertext = [rt.open(c) for c in ciphertext]
+
+    def fin(ciphertext):
+        print "Finished after %f sec." % (time.time() - start)
+        print "Ciphertext:", [hex(c.value) for c in ciphertext]
+        rt.shutdown()
+
+    g = gather_shares(opened_ciphertext)
+    g.addCallback(fin)
+
+def share_key(rt):
+    key =  []
+
+    for i in range(24):
+        inputter = i % 3 + 1
+        
+        if (inputter == id):
+            key.append(rt.input([inputter], GF256, ord("b")))
+        else:
+            key.append(rt.input([inputter], GF256))
+
+    s = rt.synchronize()
+    s.addCallback(encrypt, rt, key)
+
+rt = create_runtime(id, players, 1, options)
+rt.addCallback(share_key)
+
+reactor.run()
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/viff/aes.py	Thu Jan 29 16:19:49 2009 +0100
@@ -0,0 +1,349 @@
+# Copyright 2009 VIFF Development Team.
+#
+# This file is part of VIFF, the Virtual Ideal Functionality Framework.
+#
+# VIFF is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Lesser General Public License (LGPL) as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# VIFF is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
+# Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with VIFF. If not, see <http://www.gnu.org/licenses/>.
+
+"""MPC implementation of AES (Rijndael)."""
+
+__docformat__ = "restructuredtext"
+
+
+import time
+
+from viff.field import GF256
+from viff.runtime import Share
+from viff.matrix import Matrix
+
+
+def bit_decompose(share, use_lin_comb=True):
+    """Bit decomposition for GF256 shares."""
+
+    assert isinstance(share, Share) and share.field == GF256, \
+        "Parameter must be GF256 share."
+
+    r_bits = [share.runtime.prss_share_random(GF256, binary=True) \
+                  for i in range(8)]
+    
+    if (use_lin_comb):
+        r = share.runtime.lin_comb([2 ** i for i in range(8)], r_bits)
+    else:
+        r = reduce(lambda x,y: x + y, 
+                   [r_bits[i] * 2 ** i for i in range(8)])
+
+    c = share.runtime.open(share + r)
+    c_bits = [Share(share.runtime, GF256) for i in range(8)]
+    
+    def decompose(byte, bits):
+        value = byte.value
+
+        for i in range(8):
+            c_bits[i].callback(GF256(value & 1))
+            value >>= 1
+
+    c.addCallback(decompose, c_bits)
+
+    return [c_bits[i] + r_bits[i] for i in range(8)]
+
+
+class AES:
+    def __init__(self, runtime, key_size, block_size=128, 
+                 use_exponentiation=False):
+        """Initialize Rijndael.
+
+        AES(runtime, key_size, block_size), whereas key size and block
+        size must be given in bits. Block size defaults to 128."""
+
+        assert key_size in [128, 192, 256], \
+            "Key size must be 128, 192 or 256"
+        assert block_size in [128, 192, 256], \
+            "Block size be 128, 192 or 256"
+
+        self.n_k = key_size / 32
+        self.n_b = block_size / 32
+        self.rounds = max(self.n_k, self.n_b) + 6
+        self.runtime = runtime
+        self.use_exponentiation = use_exponentiation
+
+    # matrix for byte_sub
+    A = Matrix([[1,0,0,0,1,1,1,1],
+                [1,1,0,0,0,1,1,1],
+                [1,1,1,0,0,0,1,1],
+                [1,1,1,1,0,0,0,1],
+                [1,1,1,1,1,0,0,0],
+                [0,1,1,1,1,1,0,0],
+                [0,0,1,1,1,1,1,0],
+                [0,0,0,1,1,1,1,1]])
+
+    def byte_sub(self, state, use_lin_comb=True):
+        """ByteSub operation of Rijndael.
+
+        The first argument should be a matrix consisting of elements
+        of GF(2^8)."""
+
+        def invert_by_masking(byte):
+            bits = bit_decompose(byte)
+
+            for j in range(len(bits)):
+                bits[j].addCallback(lambda x: GF256(1) - x)
+#                bits[j] = 1 - bits[j]
+
+            while(len(bits) > 1):
+                bits.append(bits.pop() * bits.pop())
+
+            # b == 1 if byte is 0, b == 0 else
+            b = bits[0]
+
+            r = Share(self.runtime, GF256)
+            c = Share(self.runtime, GF256)
+
+            def get_masked_byte(c_opened, r_related, c, r, byte):
+                if (c_opened == 0):
+                    r_trial = self.runtime.prss_share_random(GF256)
+                    c_trial = self.runtime.open((byte + b) * r_trial)
+                    c_trial.addCallback(get_masked_byte, r_trial,
+                                        c, r, byte)
+                else:
+                    r_related.addCallback(r.callback)
+                    c.callback(~c_opened)
+
+            get_masked_byte(0, None, c, r, byte)
+            return c * r - b
+
+        def invert_by_exponentiation(byte):
+            byte_2 = byte * byte
+            byte_3 = byte_2 * byte
+            byte_6 = byte_3 * byte_3
+            byte_12 = byte_6 * byte_6
+            byte_15 = byte_12 * byte_3
+            byte_30 = byte_15 * byte_15
+            byte_60 = byte_30 * byte_30
+            byte_63 = byte_60 * byte_3
+            byte_126 = byte_63 * byte_63
+            byte_252 = byte_126 * byte_126
+            byte_254 = byte_252 * byte_2
+            return byte_254
+
+        if (self.use_exponentiation):
+            invert = invert_by_exponentiation
+        else:
+            invert = invert_by_masking
+
+        for h in range(len(state)):
+            row = state[h]
+            
+            for i in range(len(row)):
+                bits = bit_decompose(invert(row[i]))
+
+                # caution: order is lsb first
+                vector = AES.A * Matrix(zip(bits)) + Matrix(zip([1,1,0,0,0,1,1,0]))
+                bits = zip(*vector.rows)[0]
+
+                if (use_lin_comb):
+                    row[i] = self.runtime.lin_comb(
+                        [2**j for j in range(len(bits))], bits)
+                else:
+                    row[i] = reduce(lambda x,y: x + y, 
+                                    [bits[j] * 2**j for j in range(len(bits))])
+
+    def shift_row(self, state):
+        """AES ShiftRow.
+
+        State should be a list of 4 rows."""
+
+        assert len(state) == 4, "Wrong state size."
+
+        if self.n_b in [4,6]:
+            offsets = [0, 1, 2, 3]
+        else:
+            offsets = [0, 1, 3, 4]
+
+        for i, row in enumerate(state):
+            for j in range(offsets[i]):
+                row.append(row.pop(0))
+
+    # matrix for mix_column
+    C = [[2, 3, 1, 1],
+         [1, 2, 3, 1],
+         [1, 1, 2, 3],
+         [3, 1, 1, 2]]
+
+    for row in C:
+        for i in xrange(len(row)):
+            row[i] = GF256(row[i])
+
+    C = Matrix(C)
+
+    def mix_column(self, state):
+        """Rijndael MixColumn.
+
+        Input should be a list of 4 rows."""
+
+        assert len(state) == 4, "Wrong state size."
+
+        state[:] = (AES.C * Matrix(state)).rows
+
+    def add_round_key(self, state, round_key):
+        """Rijndael AddRoundKey.
+
+        State should be a list of 4 rows and round_key a list of
+        4-byte columns (words)."""
+
+        assert len(round_key) == self.n_b, "Wrong key size."
+        assert len(round_key[0]) == 4, "Key must consist of 4-byte words."
+
+        state[:] = (Matrix(state) + Matrix(zip(*round_key))).rows
+
+    def key_expansion(self, key):
+        """Rijndael key expansion.
+
+        Input and output are lists of 4-byte columns (words)."""
+
+        assert len(key) == self.n_k, "Wrong key size."
+        assert len(key[0]) == 4, "Key must consist of 4-byte words."
+
+        expanded_key = list(key)
+
+        for i in xrange(self.n_k, self.n_b * (self.rounds + 1)):
+            temp = list(expanded_key[i - 1])
+
+            if (i % self.n_k == 0):
+                temp.append(temp.pop(0))
+                self.byte_sub([temp])
+                temp[0] += GF256(2) ** (i / self.n_k - 1)
+            elif (self.n_k > 6 and i % self.n_k == 4):
+                self.byte_sub([temp])
+
+            new_word = []
+
+            for j in xrange(4):
+                new_word.append(expanded_key[i - self.n_k][j] + temp[j])
+
+            expanded_key.append(new_word)
+
+        return expanded_key
+
+    def preprocess(self, input):
+        if (isinstance(input, str)):
+            return [Share(self.runtime, GF256, GF256(ord(c))) 
+                    for c in input]
+        else:
+            for byte in input:
+                assert byte.field == GF256, \
+                    "Input must be a list of GF256 elements " \
+                    "or of shares thereof."
+            return input
+
+    def encrypt(self, cleartext, key, benchmark=False):
+        """Rijndael encryption.
+
+        Cleartext and key should be either a string or a list of bytes 
+        (possibly shared as elements of GF256)."""
+
+        start = time.time()
+
+        assert len(cleartext) == 4 * self.n_b, "Wrong length of cleartext."
+        assert len(key) == 4 * self.n_k, "Wrong length of key."
+
+        cleartext = self.preprocess(cleartext)
+        key = self.preprocess(key)
+
+        state = [cleartext[i::4] for i in xrange(4)]
+        key = [key[4*i:4*i+4] for i in xrange(self.n_k)]
+
+        if (benchmark):
+            global preparation, communication
+            preparation = 0
+            communication = 0
+
+            def progress(x, i, start_round):
+                time_diff = time.time() - start_round
+                global communication
+                communication += time_diff
+                print "Round %2d: %f, %f" % \
+                    (i, time_diff, time.time() - start)
+                return x
+
+            def prep_progress(i, start_round):
+                time_diff = time.time() - start_round
+                global preparation
+                preparation += time_diff
+                print "Round %2d preparation: %f, %f" % \
+                    (i, time_diff, time.time() - start)
+        else:
+            progress = lambda x, i, start_round: x
+            prep_progress = lambda i, start_round: None
+
+        expanded_key = self.key_expansion(key)
+
+        self.add_round_key(state, expanded_key[0:self.n_b])
+
+        prep_progress(0, start)
+
+        def get_trigger(state):
+            return state[3][self.n_b-1]
+
+        def get_last(state):
+            return state[3][self.n_b-1]
+
+        def round(_, state, i):
+            start_round = time.time()
+            
+            self.byte_sub(state)
+            self.shift_row(state)
+            self.mix_column(state)
+            self.add_round_key(state, expanded_key[i*self.n_b:(i+1)*self.n_b])
+
+            get_last(state).addCallback(progress, i, time.time())
+
+            if (i < self.rounds - 1):
+                get_trigger(state).addCallback(round, state, i + 1)
+            else:
+                get_trigger(state).addCallback(final_round, state)
+
+            prep_progress(i, start_round)
+
+            return _
+
+        def final_round(_, state):
+            start_round = time.time()
+
+            self.byte_sub(state)
+            self.shift_row(state)
+            self.add_round_key(state, expanded_key[self.rounds*self.n_b:])
+
+            get_last(state).addCallback(progress, self.rounds, time.time())
+
+            get_trigger(state).addCallback(finish, state)
+
+            prep_progress(self.rounds, start_round)
+
+            return _
+
+        def finish(_, state):
+            actual_result = [byte for word in zip(*state) for byte in word]
+
+            for a, b in zip(actual_result, result):
+                a.addCallback(b.callback)
+
+            if (benchmark):
+                print "Total preparation time: %f" % preparation
+                print "Total communication time: %f" % communication
+
+            return _
+
+        round(None, state, 1)
+
+        result = [Share(self.runtime, GF256) for i in xrange(4 * self.n_b)]
+        return result
--- a/viff/passive.py	Tue Jan 27 15:52:23 2009 +0100
+++ b/viff/passive.py	Thu Jan 29 16:19:49 2009 +0100
@@ -138,6 +138,34 @@
         result.addCallback(lambda (a, b): a - b)
         return result
 
+    def lin_comb(self, coefficients, shares):
+        """Linear combination of shares.
+
+        Communication cost: none. Saves the construction of unnecessary shares
+        compared to using add() and mul()."""
+
+        for coeff in coefficients:
+            assert not isinstance(coeff, Share), \
+                "Coefficients should not be shares."
+
+        assert len(coefficients) == len(shares), \
+            "Number of coefficients and shares should be equal."
+
+        field = None
+        for share in shares:
+            field = getattr(share, "field", field)
+        for i, share in enumerate(shares):
+            if not isinstance(share, Share):
+                shares[i] = Share(self, field, share)
+
+        def computation(shares, coefficients):
+            summands = [shares[i] * coefficients[i] for i in range(len(shares))]
+            return reduce(lambda x, y: x + y, summands)
+
+        result = gather_shares(shares)
+        result.addCallback(computation, coefficients)
+        return result
+
     @profile
     @increment_pc
     def mul(self, share_a, share_b):
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/viff/test/rijndael.py	Thu Jan 29 16:19:49 2009 +0100
@@ -0,0 +1,376 @@
+"""
+A pure python (slow) implementation of rijndael with a decent interface
+
+To include -
+
+from rijndael import rijndael
+
+To do a key setup -
+
+r = rijndael(key, block_size = 16)
+
+key must be a string of length 16, 24, or 32
+blocksize must be 16, 24, or 32. Default is 16
+
+To use -
+
+ciphertext = r.encrypt(plaintext)
+plaintext = r.decrypt(ciphertext)
+
+If any strings are of the wrong length a ValueError is thrown
+"""
+
+# ported from the Java reference code by Bram Cohen, April 2001
+# this code is public domain, unless someone makes 
+# an intellectual property claim against the reference 
+# code, in which case it can be made public domain by 
+# deleting all the comments and renaming all the variables
+
+import copy
+import string
+
+shifts = [[[0, 0], [1, 3], [2, 2], [3, 1]],
+          [[0, 0], [1, 5], [2, 4], [3, 3]],
+          [[0, 0], [1, 7], [3, 5], [4, 4]]]
+
+# [keysize][block_size]
+num_rounds = {16: {16: 10, 24: 12, 32: 14}, 24: {16: 12, 24: 12, 32: 14}, 32: {16: 14, 24: 14, 32: 14}}
+
+A = [[1, 1, 1, 1, 1, 0, 0, 0],
+     [0, 1, 1, 1, 1, 1, 0, 0],
+     [0, 0, 1, 1, 1, 1, 1, 0],
+     [0, 0, 0, 1, 1, 1, 1, 1],
+     [1, 0, 0, 0, 1, 1, 1, 1],
+     [1, 1, 0, 0, 0, 1, 1, 1],
+     [1, 1, 1, 0, 0, 0, 1, 1],
+     [1, 1, 1, 1, 0, 0, 0, 1]]
+
+# produce log and alog tables, needed for multiplying in the
+# field GF(2^m) (generator = 3)
+alog = [1]
+for i in xrange(255):
+    j = (alog[-1] << 1) ^ alog[-1]
+    if j & 0x100 != 0:
+        j ^= 0x11B
+    alog.append(j)
+
+log = [0] * 256
+for i in xrange(1, 255):
+    log[alog[i]] = i
+
+# multiply two elements of GF(2^m)
+def mul(a, b):
+    if a == 0 or b == 0:
+        return 0
+    return alog[(log[a & 0xFF] + log[b & 0xFF]) % 255]
+
+# substitution box based on F^{-1}(x)
+box = [[0] * 8 for i in xrange(256)]
+box[1][7] = 1
+for i in xrange(2, 256):
+    j = alog[255 - log[i]]
+    for t in xrange(8):
+        box[i][t] = (j >> (7 - t)) & 0x01
+
+B = [0, 1, 1, 0, 0, 0, 1, 1]
+
+# affine transform:  box[i] <- B + A*box[i]
+cox = [[0] * 8 for i in xrange(256)]
+for i in xrange(256):
+    for t in xrange(8):
+        cox[i][t] = B[t]
+        for j in xrange(8):
+            cox[i][t] ^= A[t][j] * box[i][j]
+
+# S-boxes and inverse S-boxes
+S =  [0] * 256
+Si = [0] * 256
+for i in xrange(256):
+    S[i] = cox[i][0] << 7
+    for t in xrange(1, 8):
+        S[i] ^= cox[i][t] << (7-t)
+    Si[S[i] & 0xFF] = i
+
+# T-boxes
+G = [[2, 1, 1, 3],
+    [3, 2, 1, 1],
+    [1, 3, 2, 1],
+    [1, 1, 3, 2]]
+
+AA = [[0] * 8 for i in xrange(4)]
+
+for i in xrange(4):
+    for j in xrange(4):
+        AA[i][j] = G[i][j]
+        AA[i][i+4] = 1
+
+for i in xrange(4):
+    pivot = AA[i][i]
+    if pivot == 0:
+        t = i + 1
+        while AA[t][i] == 0 and t < 4:
+            t += 1
+            assert t != 4, 'G matrix must be invertible'
+            for j in xrange(8):
+                AA[i][j], AA[t][j] = AA[t][j], AA[i][j]
+            pivot = AA[i][i]
+    for j in xrange(8):
+        if AA[i][j] != 0:
+            AA[i][j] = alog[(255 + log[AA[i][j] & 0xFF] - log[pivot & 0xFF]) % 255]
+    for t in xrange(4):
+        if i != t:
+            for j in xrange(i+1, 8):
+                AA[t][j] ^= mul(AA[i][j], AA[t][i])
+            AA[t][i] = 0
+
+iG = [[0] * 4 for i in xrange(4)]
+
+for i in xrange(4):
+    for j in xrange(4):
+        iG[i][j] = AA[i][j + 4]
+
+def mul4(a, bs):
+    if a == 0:
+        return 0
+    r = 0
+    for b in bs:
+        r <<= 8
+        if b != 0:
+            r = r | mul(a, b)
+    return r
+
+T1 = []
+T2 = []
+T3 = []
+T4 = []
+T5 = []
+T6 = []
+T7 = []
+T8 = []
+U1 = []
+U2 = []
+U3 = []
+U4 = []
+
+for t in xrange(256):
+    s = S[t]
+    T1.append(mul4(s, G[0]))
+    T2.append(mul4(s, G[1]))
+    T3.append(mul4(s, G[2]))
+    T4.append(mul4(s, G[3]))
+
+    s = Si[t]
+    T5.append(mul4(s, iG[0]))
+    T6.append(mul4(s, iG[1]))
+    T7.append(mul4(s, iG[2]))
+    T8.append(mul4(s, iG[3]))
+
+    U1.append(mul4(t, iG[0]))
+    U2.append(mul4(t, iG[1]))
+    U3.append(mul4(t, iG[2]))
+    U4.append(mul4(t, iG[3]))
+
+# round constants
+rcon = [1]
+r = 1
+for t in xrange(1, 30):
+    r = mul(2, r)
+    rcon.append(r)
+
+del A
+del AA
+del pivot
+del B
+del G
+del box
+del log
+del alog
+del i
+del j
+del r
+del s
+del t
+del mul
+del mul4
+del cox
+del iG
+
+class rijndael:
+    def __init__(self, key, block_size = 16):
+        if block_size != 16 and block_size != 24 and block_size != 32:
+            raise ValueError('Invalid block size: ' + str(block_size))
+        if len(key) != 16 and len(key) != 24 and len(key) != 32:
+            raise ValueError('Invalid key size: ' + str(len(key)))
+        self.block_size = block_size
+
+        ROUNDS = num_rounds[len(key)][block_size]
+        BC = block_size / 4
+        # encryption round keys
+        Ke = [[0] * BC for i in xrange(ROUNDS + 1)]
+        # decryption round keys
+        Kd = [[0] * BC for i in xrange(ROUNDS + 1)]
+        ROUND_KEY_COUNT = (ROUNDS + 1) * BC
+        KC = len(key) / 4
+
+        # copy user material bytes into temporary ints
+        tk = []
+        for i in xrange(0, KC):
+            tk.append((ord(key[i * 4]) << 24) | (ord(key[i * 4 + 1]) << 16) |
+                (ord(key[i * 4 + 2]) << 8) | ord(key[i * 4 + 3]))
+
+        # copy values into round key arrays
+        t = 0
+        j = 0
+        while j < KC and t < ROUND_KEY_COUNT:
+            Ke[t / BC][t % BC] = tk[j]
+            Kd[ROUNDS - (t / BC)][t % BC] = tk[j]
+            j += 1
+            t += 1
+        tt = 0
+        rconpointer = 0
+        while t < ROUND_KEY_COUNT:
+            # extrapolate using phi (the round key evolution function)
+            tt = tk[KC - 1]
+            tk[0] ^= (S[(tt >> 16) & 0xFF] & 0xFF) << 24 ^  \
+                     (S[(tt >>  8) & 0xFF] & 0xFF) << 16 ^  \
+                     (S[ tt        & 0xFF] & 0xFF) <<  8 ^  \
+                     (S[(tt >> 24) & 0xFF] & 0xFF)       ^  \
+                     (rcon[rconpointer]    & 0xFF) << 24
+            rconpointer += 1
+            if KC != 8:
+                for i in xrange(1, KC):
+                    tk[i] ^= tk[i-1]
+            else:
+                for i in xrange(1, KC / 2):
+                    tk[i] ^= tk[i-1]
+                tt = tk[KC / 2 - 1]
+                tk[KC / 2] ^= (S[ tt        & 0xFF] & 0xFF)       ^ \
+                              (S[(tt >>  8) & 0xFF] & 0xFF) <<  8 ^ \
+                              (S[(tt >> 16) & 0xFF] & 0xFF) << 16 ^ \
+                              (S[(tt >> 24) & 0xFF] & 0xFF) << 24
+                for i in xrange(KC / 2 + 1, KC):
+                    tk[i] ^= tk[i-1]
+            # copy values into round key arrays
+            j = 0
+            while j < KC and t < ROUND_KEY_COUNT:
+                Ke[t / BC][t % BC] = tk[j]
+                Kd[ROUNDS - (t / BC)][t % BC] = tk[j]
+                j += 1
+                t += 1
+        # inverse MixColumn where needed
+        for r in xrange(1, ROUNDS):
+            for j in xrange(BC):
+                tt = Kd[r][j]
+                Kd[r][j] = U1[(tt >> 24) & 0xFF] ^ \
+                           U2[(tt >> 16) & 0xFF] ^ \
+                           U3[(tt >>  8) & 0xFF] ^ \
+                           U4[ tt        & 0xFF]
+        self.Ke = Ke
+        self.Kd = Kd
+
+    def encrypt(self, plaintext):
+        if len(plaintext) != self.block_size:
+            raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(plaintext)))
+        Ke = self.Ke
+
+        BC = self.block_size / 4
+        ROUNDS = len(Ke) - 1
+        if BC == 4:
+            SC = 0
+        elif BC == 6:
+            SC = 1
+        else:
+            SC = 2
+        s1 = shifts[SC][1][0]
+        s2 = shifts[SC][2][0]
+        s3 = shifts[SC][3][0]
+        a = [0] * BC
+        # temporary work array
+        t = []
+        # plaintext to ints + key
+        for i in xrange(BC):
+            t.append((ord(plaintext[i * 4    ]) << 24 |
+                      ord(plaintext[i * 4 + 1]) << 16 |
+                      ord(plaintext[i * 4 + 2]) <<  8 |
+                      ord(plaintext[i * 4 + 3])        ) ^ Ke[0][i])
+        # apply round transforms
+        for r in xrange(1, ROUNDS):
+            for i in xrange(BC):
+                a[i] = (T1[(t[ i           ] >> 24) & 0xFF] ^
+                        T2[(t[(i + s1) % BC] >> 16) & 0xFF] ^
+                        T3[(t[(i + s2) % BC] >>  8) & 0xFF] ^
+                        T4[ t[(i + s3) % BC]        & 0xFF]  ) ^ Ke[r][i]
+            t = copy.copy(a)
+        # last round is special
+        result = []
+        for i in xrange(BC):
+            tt = Ke[ROUNDS][i]
+            result.append((S[(t[ i           ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
+            result.append((S[(t[(i + s1) % BC] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
+            result.append((S[(t[(i + s2) % BC] >>  8) & 0xFF] ^ (tt >>  8)) & 0xFF)
+            result.append((S[ t[(i + s3) % BC]        & 0xFF] ^  tt       ) & 0xFF)
+        return string.join(map(chr, result), '')
+
+    def decrypt(self, ciphertext):
+        if len(ciphertext) != self.block_size:
+            raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(ciphertext)))
+        Kd = self.Kd
+
+        BC = self.block_size / 4
+        ROUNDS = len(Kd) - 1
+        if BC == 4:
+            SC = 0
+        elif BC == 6:
+            SC = 1
+        else:
+            SC = 2
+        s1 = shifts[SC][1][1]
+        s2 = shifts[SC][2][1]
+        s3 = shifts[SC][3][1]
+        a = [0] * BC
+        # temporary work array
+        t = [0] * BC
+        # ciphertext to ints + key
+        for i in xrange(BC):
+            t[i] = (ord(ciphertext[i * 4    ]) << 24 |
+                    ord(ciphertext[i * 4 + 1]) << 16 |
+                    ord(ciphertext[i * 4 + 2]) <<  8 |
+                    ord(ciphertext[i * 4 + 3])        ) ^ Kd[0][i]
+        # apply round transforms
+        for r in xrange(1, ROUNDS):
+            for i in xrange(BC):
+                a[i] = (T5[(t[ i           ] >> 24) & 0xFF] ^
+                        T6[(t[(i + s1) % BC] >> 16) & 0xFF] ^
+                        T7[(t[(i + s2) % BC] >>  8) & 0xFF] ^
+                        T8[ t[(i + s3) % BC]        & 0xFF]  ) ^ Kd[r][i]
+            t = copy.copy(a)
+        # last round is special
+        result = []
+        for i in xrange(BC):
+            tt = Kd[ROUNDS][i]
+            result.append((Si[(t[ i           ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
+            result.append((Si[(t[(i + s1) % BC] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
+            result.append((Si[(t[(i + s2) % BC] >>  8) & 0xFF] ^ (tt >>  8)) & 0xFF)
+            result.append((Si[ t[(i + s3) % BC]        & 0xFF] ^  tt       ) & 0xFF)
+        return string.join(map(chr, result), '')
+
+def encrypt(key, block):
+    return rijndael(key, len(block)).encrypt(block)
+
+def decrypt(key, block):
+    return rijndael(key, len(block)).decrypt(block)
+
+def test():
+    def t(kl, bl):
+        b = 'b' * bl
+        r = rijndael('a' * kl, bl)
+        assert r.decrypt(r.encrypt(b)) == b
+    t(16, 16)
+    t(16, 24)
+    t(16, 32)
+    t(24, 16)
+    t(24, 24)
+    t(24, 32)
+    t(32, 16)
+    t(32, 24)
+    t(32, 32)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/viff/test/test_aes.py	Thu Jan 29 16:19:49 2009 +0100
@@ -0,0 +1,139 @@
+# Copyright 2009 VIFF Development Team.
+#
+# This file is part of VIFF, the Virtual Ideal Functionality Framework.
+#
+# VIFF is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Lesser General Public License (LGPL) as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# VIFF is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
+# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
+# Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with VIFF. If not, see <http://www.gnu.org/licenses/>.
+
+"""Tests for viff.aes."""
+
+
+from viff.test.util import RuntimeTestCase, protocol
+
+from viff.field import GF256
+from viff.runtime import gather_shares, Share
+from viff.aes import bit_decompose, AES
+
+from viff.test.rijndael import S, rijndael
+
+
+__doctest__ = ["viff.aes"]
+
+
+class BitDecompositionTestCase(RuntimeTestCase):
+    """Test GF256 bit decomposition."""
+
+    def verify(self, runtime, results, expected_results):
+        self.assert_type(results, list)
+        opened_results = []
+
+        for result, expected in zip(results, expected_results):
+            self.assert_type(result, Share)
+            opened = runtime.open(result)
+            opened.addCallback(self.assertEquals, expected)
+            opened_results.append(opened)
+        
+        return gather_shares(opened_results)
+
+    @protocol
+    def test_bit_decomposition(self, runtime):
+        share = Share(runtime, GF256, GF256(99))
+        return self.verify(runtime, bit_decompose(share),
+                           [1,1,0,0,0,1,1,0])
+
+
+class AESTestCase(RuntimeTestCase):
+    def verify(self, runtime, results, expected_results):
+        self.assert_type(results, list)
+        opened_results = []
+        
+        for result_row, expected_row in zip(results, expected_results):
+            self.assert_type(result_row, list)
+            self.assertEquals(len(result_row), len(expected_row))
+
+            for result, expected in zip(result_row, expected_row):
+                self.assert_type(result, Share)
+                opened = runtime.open(result)
+                opened.addCallback(self.assertEquals, expected)
+                opened_results.append(opened)
+
+        return gather_shares(opened_results)
+
+    def _test_byte_sub(self, runtime, aes):
+        results = []
+        expected_results = []
+
+        for i in range(4):
+            results.append([])
+            expected_results.append([])
+
+            for j in range(4):
+                b = 60 * i + j
+                results[i].append(Share(runtime, GF256, GF256(b)))
+                expected_results[i].append(S[b])
+
+        aes.byte_sub(results)
+        self.verify(runtime, results, expected_results)
+
+    @protocol
+    def test_byte_sub_with_masking(self, runtime):
+        self._test_byte_sub(runtime, AES(runtime, 128, 
+                                         use_exponentiation=False))
+
+    @protocol
+    def test_byte_sub_with_exponentiation(self, runtime):
+        self._test_byte_sub(runtime, AES(runtime, 128, 
+                                         use_exponentiation=True))
+
+    @protocol
+    def test_key_expansion(self, runtime):
+        aes = AES(runtime, 256)
+        key = []
+        ascii_key = []
+
+        for i in xrange(8):
+            key.append([])
+
+            for j in xrange(4):
+                b = 15 * i + j
+                key[i].append(Share(runtime, GF256, GF256(b)))
+                ascii_key.append(chr(b))
+
+        result = aes.key_expansion(key)
+
+        r = rijndael(ascii_key)
+        expected_result = []
+
+        for round_key in r.Ke:
+            for word in round_key:
+                split_word = []
+                expected_result.append(split_word)
+
+                for j in xrange(4):
+                    split_word.insert(0, word % 256)
+                    word /= 256
+
+        self.verify(runtime, result, expected_result)
+
+    @protocol
+    def test_encrypt(self, runtime):
+        cleartext = "Encrypt this!!!!"
+        key = "Supposed to be secret!?!"
+
+        aes = AES(runtime, 192)
+        r = rijndael(key)
+
+        result = aes.encrypt(cleartext, key)
+        expected = [ord(c) for c in r.encrypt(cleartext)]
+
+        return self.verify(runtime, [result], [expected])