viff

changeset 1094:bdab6f30511a

Merged with Marcel's AES code.
author Martin Geisler <mg@daimi.au.dk>
date Thu, 29 Jan 2009 16:19:49 +0100
parents aff12eb0e28c af545b802fd8
children e5bb773fb1fe
files NEWS apps/aes.py viff/passive.py
diffstat 6 files changed, 982 insertions(+), 0 deletions(-) [+]
line diff
     1.1 --- a/NEWS	Tue Jan 27 15:52:23 2009 +0100
     1.2 +++ b/NEWS	Thu Jan 29 16:19:49 2009 +0100
     1.3 @@ -44,6 +44,8 @@
     1.4  * Exponentiation of shares by square-and-multiply for public
     1.5    exponents. This means that if x is a Share, x**7 now works.
     1.6  
     1.7 +* Added multi-party AES encryption.
     1.8 +
     1.9  
    1.10  Version 0.7.1, released on 2008-10-09
    1.11  -------------------------------------
     2.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     2.2 +++ b/apps/aes.py	Thu Jan 29 16:19:49 2009 +0100
     2.3 @@ -0,0 +1,88 @@
     2.4 +#!/usr/bin/python
     2.5 +
     2.6 +# Copyright 2009 VIFF Development Team.
     2.7 +#
     2.8 +# This file is part of VIFF, the Virtual Ideal Functionality Framework.
     2.9 +#
    2.10 +# VIFF is free software: you can redistribute it and/or modify it
    2.11 +# under the terms of the GNU Lesser General Public License (LGPL) as
    2.12 +# published by the Free Software Foundation, either version 3 of the
    2.13 +# License, or (at your option) any later version.
    2.14 +#
    2.15 +# VIFF is distributed in the hope that it will be useful, but WITHOUT
    2.16 +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
    2.17 +# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
    2.18 +# Public License for more details.
    2.19 +#
    2.20 +# You should have received a copy of the GNU Lesser General Public
    2.21 +# License along with VIFF. If not, see <http://www.gnu.org/licenses/>.
    2.22 +
    2.23 +# This example shows how to use multi-party AES encryption.
    2.24 +
    2.25 +
    2.26 +import sys
    2.27 +import time
    2.28 +from optparse import OptionParser
    2.29 +
    2.30 +from twisted.internet import reactor
    2.31 +
    2.32 +from viff.field import GF256
    2.33 +from viff.runtime import Runtime, create_runtime, gather_shares
    2.34 +from viff.config import load_config
    2.35 +
    2.36 +from viff.aes import bit_decompose,AES
    2.37 +
    2.38 +
    2.39 +parser = OptionParser(usage="Usage: %prog [options] config_file")
    2.40 +parser.add_option("-e", "--exponentiation", action="store_true",
    2.41 +                  help="Use exponentiation to invert bytes (default).")
    2.42 +parser.add_option("-m", "--masking", action="store_false", 
    2.43 +                  dest="exponentiation", 
    2.44 +                  help="Use masking to invert bytes.")
    2.45 +parser.set_defaults(exponentiation=True)
    2.46 +
    2.47 +# Add standard VIFF options.
    2.48 +Runtime.add_options(parser)
    2.49 +
    2.50 +(options, args) = parser.parse_args()
    2.51 +
    2.52 +if len(args) == 0:
    2.53 +    parser.error("You must specify a config file.")
    2.54 +
    2.55 +id, players = load_config(args[0])
    2.56 +
    2.57 +def encrypt(_, rt, key):
    2.58 +    start = time.time()
    2.59 +    print "Started at %f." % start
    2.60 +
    2.61 +    aes = AES(rt, 192, use_exponentiation=options.exponentiation)
    2.62 +    ciphertext = aes.encrypt("a" * 16, key, True)
    2.63 +
    2.64 +    opened_ciphertext = [rt.open(c) for c in ciphertext]
    2.65 +
    2.66 +    def fin(ciphertext):
    2.67 +        print "Finished after %f sec." % (time.time() - start)
    2.68 +        print "Ciphertext:", [hex(c.value) for c in ciphertext]
    2.69 +        rt.shutdown()
    2.70 +
    2.71 +    g = gather_shares(opened_ciphertext)
    2.72 +    g.addCallback(fin)
    2.73 +
    2.74 +def share_key(rt):
    2.75 +    key =  []
    2.76 +
    2.77 +    for i in range(24):
    2.78 +        inputter = i % 3 + 1
    2.79 +        
    2.80 +        if (inputter == id):
    2.81 +            key.append(rt.input([inputter], GF256, ord("b")))
    2.82 +        else:
    2.83 +            key.append(rt.input([inputter], GF256))
    2.84 +
    2.85 +    s = rt.synchronize()
    2.86 +    s.addCallback(encrypt, rt, key)
    2.87 +
    2.88 +rt = create_runtime(id, players, 1, options)
    2.89 +rt.addCallback(share_key)
    2.90 +
    2.91 +reactor.run()
     3.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     3.2 +++ b/viff/aes.py	Thu Jan 29 16:19:49 2009 +0100
     3.3 @@ -0,0 +1,349 @@
     3.4 +# Copyright 2009 VIFF Development Team.
     3.5 +#
     3.6 +# This file is part of VIFF, the Virtual Ideal Functionality Framework.
     3.7 +#
     3.8 +# VIFF is free software: you can redistribute it and/or modify it
     3.9 +# under the terms of the GNU Lesser General Public License (LGPL) as
    3.10 +# published by the Free Software Foundation, either version 3 of the
    3.11 +# License, or (at your option) any later version.
    3.12 +#
    3.13 +# VIFF is distributed in the hope that it will be useful, but WITHOUT
    3.14 +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
    3.15 +# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
    3.16 +# Public License for more details.
    3.17 +#
    3.18 +# You should have received a copy of the GNU Lesser General Public
    3.19 +# License along with VIFF. If not, see <http://www.gnu.org/licenses/>.
    3.20 +
    3.21 +"""MPC implementation of AES (Rijndael)."""
    3.22 +
    3.23 +__docformat__ = "restructuredtext"
    3.24 +
    3.25 +
    3.26 +import time
    3.27 +
    3.28 +from viff.field import GF256
    3.29 +from viff.runtime import Share
    3.30 +from viff.matrix import Matrix
    3.31 +
    3.32 +
    3.33 +def bit_decompose(share, use_lin_comb=True):
    3.34 +    """Bit decomposition for GF256 shares."""
    3.35 +
    3.36 +    assert isinstance(share, Share) and share.field == GF256, \
    3.37 +        "Parameter must be GF256 share."
    3.38 +
    3.39 +    r_bits = [share.runtime.prss_share_random(GF256, binary=True) \
    3.40 +                  for i in range(8)]
    3.41 +    
    3.42 +    if (use_lin_comb):
    3.43 +        r = share.runtime.lin_comb([2 ** i for i in range(8)], r_bits)
    3.44 +    else:
    3.45 +        r = reduce(lambda x,y: x + y, 
    3.46 +                   [r_bits[i] * 2 ** i for i in range(8)])
    3.47 +
    3.48 +    c = share.runtime.open(share + r)
    3.49 +    c_bits = [Share(share.runtime, GF256) for i in range(8)]
    3.50 +    
    3.51 +    def decompose(byte, bits):
    3.52 +        value = byte.value
    3.53 +
    3.54 +        for i in range(8):
    3.55 +            c_bits[i].callback(GF256(value & 1))
    3.56 +            value >>= 1
    3.57 +
    3.58 +    c.addCallback(decompose, c_bits)
    3.59 +
    3.60 +    return [c_bits[i] + r_bits[i] for i in range(8)]
    3.61 +
    3.62 +
    3.63 +class AES:
    3.64 +    def __init__(self, runtime, key_size, block_size=128, 
    3.65 +                 use_exponentiation=False):
    3.66 +        """Initialize Rijndael.
    3.67 +
    3.68 +        AES(runtime, key_size, block_size), whereas key size and block
    3.69 +        size must be given in bits. Block size defaults to 128."""
    3.70 +
    3.71 +        assert key_size in [128, 192, 256], \
    3.72 +            "Key size must be 128, 192 or 256"
    3.73 +        assert block_size in [128, 192, 256], \
    3.74 +            "Block size be 128, 192 or 256"
    3.75 +
    3.76 +        self.n_k = key_size / 32
    3.77 +        self.n_b = block_size / 32
    3.78 +        self.rounds = max(self.n_k, self.n_b) + 6
    3.79 +        self.runtime = runtime
    3.80 +        self.use_exponentiation = use_exponentiation
    3.81 +
    3.82 +    # matrix for byte_sub
    3.83 +    A = Matrix([[1,0,0,0,1,1,1,1],
    3.84 +                [1,1,0,0,0,1,1,1],
    3.85 +                [1,1,1,0,0,0,1,1],
    3.86 +                [1,1,1,1,0,0,0,1],
    3.87 +                [1,1,1,1,1,0,0,0],
    3.88 +                [0,1,1,1,1,1,0,0],
    3.89 +                [0,0,1,1,1,1,1,0],
    3.90 +                [0,0,0,1,1,1,1,1]])
    3.91 +
    3.92 +    def byte_sub(self, state, use_lin_comb=True):
    3.93 +        """ByteSub operation of Rijndael.
    3.94 +
    3.95 +        The first argument should be a matrix consisting of elements
    3.96 +        of GF(2^8)."""
    3.97 +
    3.98 +        def invert_by_masking(byte):
    3.99 +            bits = bit_decompose(byte)
   3.100 +
   3.101 +            for j in range(len(bits)):
   3.102 +                bits[j].addCallback(lambda x: GF256(1) - x)
   3.103 +#                bits[j] = 1 - bits[j]
   3.104 +
   3.105 +            while(len(bits) > 1):
   3.106 +                bits.append(bits.pop() * bits.pop())
   3.107 +
   3.108 +            # b == 1 if byte is 0, b == 0 else
   3.109 +            b = bits[0]
   3.110 +
   3.111 +            r = Share(self.runtime, GF256)
   3.112 +            c = Share(self.runtime, GF256)
   3.113 +
   3.114 +            def get_masked_byte(c_opened, r_related, c, r, byte):
   3.115 +                if (c_opened == 0):
   3.116 +                    r_trial = self.runtime.prss_share_random(GF256)
   3.117 +                    c_trial = self.runtime.open((byte + b) * r_trial)
   3.118 +                    c_trial.addCallback(get_masked_byte, r_trial,
   3.119 +                                        c, r, byte)
   3.120 +                else:
   3.121 +                    r_related.addCallback(r.callback)
   3.122 +                    c.callback(~c_opened)
   3.123 +
   3.124 +            get_masked_byte(0, None, c, r, byte)
   3.125 +            return c * r - b
   3.126 +
   3.127 +        def invert_by_exponentiation(byte):
   3.128 +            byte_2 = byte * byte
   3.129 +            byte_3 = byte_2 * byte
   3.130 +            byte_6 = byte_3 * byte_3
   3.131 +            byte_12 = byte_6 * byte_6
   3.132 +            byte_15 = byte_12 * byte_3
   3.133 +            byte_30 = byte_15 * byte_15
   3.134 +            byte_60 = byte_30 * byte_30
   3.135 +            byte_63 = byte_60 * byte_3
   3.136 +            byte_126 = byte_63 * byte_63
   3.137 +            byte_252 = byte_126 * byte_126
   3.138 +            byte_254 = byte_252 * byte_2
   3.139 +            return byte_254
   3.140 +
   3.141 +        if (self.use_exponentiation):
   3.142 +            invert = invert_by_exponentiation
   3.143 +        else:
   3.144 +            invert = invert_by_masking
   3.145 +
   3.146 +        for h in range(len(state)):
   3.147 +            row = state[h]
   3.148 +            
   3.149 +            for i in range(len(row)):
   3.150 +                bits = bit_decompose(invert(row[i]))
   3.151 +
   3.152 +                # caution: order is lsb first
   3.153 +                vector = AES.A * Matrix(zip(bits)) + Matrix(zip([1,1,0,0,0,1,1,0]))
   3.154 +                bits = zip(*vector.rows)[0]
   3.155 +
   3.156 +                if (use_lin_comb):
   3.157 +                    row[i] = self.runtime.lin_comb(
   3.158 +                        [2**j for j in range(len(bits))], bits)
   3.159 +                else:
   3.160 +                    row[i] = reduce(lambda x,y: x + y, 
   3.161 +                                    [bits[j] * 2**j for j in range(len(bits))])
   3.162 +
   3.163 +    def shift_row(self, state):
   3.164 +        """AES ShiftRow.
   3.165 +
   3.166 +        State should be a list of 4 rows."""
   3.167 +
   3.168 +        assert len(state) == 4, "Wrong state size."
   3.169 +
   3.170 +        if self.n_b in [4,6]:
   3.171 +            offsets = [0, 1, 2, 3]
   3.172 +        else:
   3.173 +            offsets = [0, 1, 3, 4]
   3.174 +
   3.175 +        for i, row in enumerate(state):
   3.176 +            for j in range(offsets[i]):
   3.177 +                row.append(row.pop(0))
   3.178 +
   3.179 +    # matrix for mix_column
   3.180 +    C = [[2, 3, 1, 1],
   3.181 +         [1, 2, 3, 1],
   3.182 +         [1, 1, 2, 3],
   3.183 +         [3, 1, 1, 2]]
   3.184 +
   3.185 +    for row in C:
   3.186 +        for i in xrange(len(row)):
   3.187 +            row[i] = GF256(row[i])
   3.188 +
   3.189 +    C = Matrix(C)
   3.190 +
   3.191 +    def mix_column(self, state):
   3.192 +        """Rijndael MixColumn.
   3.193 +
   3.194 +        Input should be a list of 4 rows."""
   3.195 +
   3.196 +        assert len(state) == 4, "Wrong state size."
   3.197 +
   3.198 +        state[:] = (AES.C * Matrix(state)).rows
   3.199 +
   3.200 +    def add_round_key(self, state, round_key):
   3.201 +        """Rijndael AddRoundKey.
   3.202 +
   3.203 +        State should be a list of 4 rows and round_key a list of
   3.204 +        4-byte columns (words)."""
   3.205 +
   3.206 +        assert len(round_key) == self.n_b, "Wrong key size."
   3.207 +        assert len(round_key[0]) == 4, "Key must consist of 4-byte words."
   3.208 +
   3.209 +        state[:] = (Matrix(state) + Matrix(zip(*round_key))).rows
   3.210 +
   3.211 +    def key_expansion(self, key):
   3.212 +        """Rijndael key expansion.
   3.213 +
   3.214 +        Input and output are lists of 4-byte columns (words)."""
   3.215 +
   3.216 +        assert len(key) == self.n_k, "Wrong key size."
   3.217 +        assert len(key[0]) == 4, "Key must consist of 4-byte words."
   3.218 +
   3.219 +        expanded_key = list(key)
   3.220 +
   3.221 +        for i in xrange(self.n_k, self.n_b * (self.rounds + 1)):
   3.222 +            temp = list(expanded_key[i - 1])
   3.223 +
   3.224 +            if (i % self.n_k == 0):
   3.225 +                temp.append(temp.pop(0))
   3.226 +                self.byte_sub([temp])
   3.227 +                temp[0] += GF256(2) ** (i / self.n_k - 1)
   3.228 +            elif (self.n_k > 6 and i % self.n_k == 4):
   3.229 +                self.byte_sub([temp])
   3.230 +
   3.231 +            new_word = []
   3.232 +
   3.233 +            for j in xrange(4):
   3.234 +                new_word.append(expanded_key[i - self.n_k][j] + temp[j])
   3.235 +
   3.236 +            expanded_key.append(new_word)
   3.237 +
   3.238 +        return expanded_key
   3.239 +
   3.240 +    def preprocess(self, input):
   3.241 +        if (isinstance(input, str)):
   3.242 +            return [Share(self.runtime, GF256, GF256(ord(c))) 
   3.243 +                    for c in input]
   3.244 +        else:
   3.245 +            for byte in input:
   3.246 +                assert byte.field == GF256, \
   3.247 +                    "Input must be a list of GF256 elements " \
   3.248 +                    "or of shares thereof."
   3.249 +            return input
   3.250 +
   3.251 +    def encrypt(self, cleartext, key, benchmark=False):
   3.252 +        """Rijndael encryption.
   3.253 +
   3.254 +        Cleartext and key should be either a string or a list of bytes 
   3.255 +        (possibly shared as elements of GF256)."""
   3.256 +
   3.257 +        start = time.time()
   3.258 +
   3.259 +        assert len(cleartext) == 4 * self.n_b, "Wrong length of cleartext."
   3.260 +        assert len(key) == 4 * self.n_k, "Wrong length of key."
   3.261 +
   3.262 +        cleartext = self.preprocess(cleartext)
   3.263 +        key = self.preprocess(key)
   3.264 +
   3.265 +        state = [cleartext[i::4] for i in xrange(4)]
   3.266 +        key = [key[4*i:4*i+4] for i in xrange(self.n_k)]
   3.267 +
   3.268 +        if (benchmark):
   3.269 +            global preparation, communication
   3.270 +            preparation = 0
   3.271 +            communication = 0
   3.272 +
   3.273 +            def progress(x, i, start_round):
   3.274 +                time_diff = time.time() - start_round
   3.275 +                global communication
   3.276 +                communication += time_diff
   3.277 +                print "Round %2d: %f, %f" % \
   3.278 +                    (i, time_diff, time.time() - start)
   3.279 +                return x
   3.280 +
   3.281 +            def prep_progress(i, start_round):
   3.282 +                time_diff = time.time() - start_round
   3.283 +                global preparation
   3.284 +                preparation += time_diff
   3.285 +                print "Round %2d preparation: %f, %f" % \
   3.286 +                    (i, time_diff, time.time() - start)
   3.287 +        else:
   3.288 +            progress = lambda x, i, start_round: x
   3.289 +            prep_progress = lambda i, start_round: None
   3.290 +
   3.291 +        expanded_key = self.key_expansion(key)
   3.292 +
   3.293 +        self.add_round_key(state, expanded_key[0:self.n_b])
   3.294 +
   3.295 +        prep_progress(0, start)
   3.296 +
   3.297 +        def get_trigger(state):
   3.298 +            return state[3][self.n_b-1]
   3.299 +
   3.300 +        def get_last(state):
   3.301 +            return state[3][self.n_b-1]
   3.302 +
   3.303 +        def round(_, state, i):
   3.304 +            start_round = time.time()
   3.305 +            
   3.306 +            self.byte_sub(state)
   3.307 +            self.shift_row(state)
   3.308 +            self.mix_column(state)
   3.309 +            self.add_round_key(state, expanded_key[i*self.n_b:(i+1)*self.n_b])
   3.310 +
   3.311 +            get_last(state).addCallback(progress, i, time.time())
   3.312 +
   3.313 +            if (i < self.rounds - 1):
   3.314 +                get_trigger(state).addCallback(round, state, i + 1)
   3.315 +            else:
   3.316 +                get_trigger(state).addCallback(final_round, state)
   3.317 +
   3.318 +            prep_progress(i, start_round)
   3.319 +
   3.320 +            return _
   3.321 +
   3.322 +        def final_round(_, state):
   3.323 +            start_round = time.time()
   3.324 +
   3.325 +            self.byte_sub(state)
   3.326 +            self.shift_row(state)
   3.327 +            self.add_round_key(state, expanded_key[self.rounds*self.n_b:])
   3.328 +
   3.329 +            get_last(state).addCallback(progress, self.rounds, time.time())
   3.330 +
   3.331 +            get_trigger(state).addCallback(finish, state)
   3.332 +
   3.333 +            prep_progress(self.rounds, start_round)
   3.334 +
   3.335 +            return _
   3.336 +
   3.337 +        def finish(_, state):
   3.338 +            actual_result = [byte for word in zip(*state) for byte in word]
   3.339 +
   3.340 +            for a, b in zip(actual_result, result):
   3.341 +                a.addCallback(b.callback)
   3.342 +
   3.343 +            if (benchmark):
   3.344 +                print "Total preparation time: %f" % preparation
   3.345 +                print "Total communication time: %f" % communication
   3.346 +
   3.347 +            return _
   3.348 +
   3.349 +        round(None, state, 1)
   3.350 +
   3.351 +        result = [Share(self.runtime, GF256) for i in xrange(4 * self.n_b)]
   3.352 +        return result
     4.1 --- a/viff/passive.py	Tue Jan 27 15:52:23 2009 +0100
     4.2 +++ b/viff/passive.py	Thu Jan 29 16:19:49 2009 +0100
     4.3 @@ -138,6 +138,34 @@
     4.4          result.addCallback(lambda (a, b): a - b)
     4.5          return result
     4.6  
     4.7 +    def lin_comb(self, coefficients, shares):
     4.8 +        """Linear combination of shares.
     4.9 +
    4.10 +        Communication cost: none. Saves the construction of unnecessary shares
    4.11 +        compared to using add() and mul()."""
    4.12 +
    4.13 +        for coeff in coefficients:
    4.14 +            assert not isinstance(coeff, Share), \
    4.15 +                "Coefficients should not be shares."
    4.16 +
    4.17 +        assert len(coefficients) == len(shares), \
    4.18 +            "Number of coefficients and shares should be equal."
    4.19 +
    4.20 +        field = None
    4.21 +        for share in shares:
    4.22 +            field = getattr(share, "field", field)
    4.23 +        for i, share in enumerate(shares):
    4.24 +            if not isinstance(share, Share):
    4.25 +                shares[i] = Share(self, field, share)
    4.26 +
    4.27 +        def computation(shares, coefficients):
    4.28 +            summands = [shares[i] * coefficients[i] for i in range(len(shares))]
    4.29 +            return reduce(lambda x, y: x + y, summands)
    4.30 +
    4.31 +        result = gather_shares(shares)
    4.32 +        result.addCallback(computation, coefficients)
    4.33 +        return result
    4.34 +
    4.35      @profile
    4.36      @increment_pc
    4.37      def mul(self, share_a, share_b):
     5.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     5.2 +++ b/viff/test/rijndael.py	Thu Jan 29 16:19:49 2009 +0100
     5.3 @@ -0,0 +1,376 @@
     5.4 +"""
     5.5 +A pure python (slow) implementation of rijndael with a decent interface
     5.6 +
     5.7 +To include -
     5.8 +
     5.9 +from rijndael import rijndael
    5.10 +
    5.11 +To do a key setup -
    5.12 +
    5.13 +r = rijndael(key, block_size = 16)
    5.14 +
    5.15 +key must be a string of length 16, 24, or 32
    5.16 +blocksize must be 16, 24, or 32. Default is 16
    5.17 +
    5.18 +To use -
    5.19 +
    5.20 +ciphertext = r.encrypt(plaintext)
    5.21 +plaintext = r.decrypt(ciphertext)
    5.22 +
    5.23 +If any strings are of the wrong length a ValueError is thrown
    5.24 +"""
    5.25 +
    5.26 +# ported from the Java reference code by Bram Cohen, April 2001
    5.27 +# this code is public domain, unless someone makes 
    5.28 +# an intellectual property claim against the reference 
    5.29 +# code, in which case it can be made public domain by 
    5.30 +# deleting all the comments and renaming all the variables
    5.31 +
    5.32 +import copy
    5.33 +import string
    5.34 +
    5.35 +shifts = [[[0, 0], [1, 3], [2, 2], [3, 1]],
    5.36 +          [[0, 0], [1, 5], [2, 4], [3, 3]],
    5.37 +          [[0, 0], [1, 7], [3, 5], [4, 4]]]
    5.38 +
    5.39 +# [keysize][block_size]
    5.40 +num_rounds = {16: {16: 10, 24: 12, 32: 14}, 24: {16: 12, 24: 12, 32: 14}, 32: {16: 14, 24: 14, 32: 14}}
    5.41 +
    5.42 +A = [[1, 1, 1, 1, 1, 0, 0, 0],
    5.43 +     [0, 1, 1, 1, 1, 1, 0, 0],
    5.44 +     [0, 0, 1, 1, 1, 1, 1, 0],
    5.45 +     [0, 0, 0, 1, 1, 1, 1, 1],
    5.46 +     [1, 0, 0, 0, 1, 1, 1, 1],
    5.47 +     [1, 1, 0, 0, 0, 1, 1, 1],
    5.48 +     [1, 1, 1, 0, 0, 0, 1, 1],
    5.49 +     [1, 1, 1, 1, 0, 0, 0, 1]]
    5.50 +
    5.51 +# produce log and alog tables, needed for multiplying in the
    5.52 +# field GF(2^m) (generator = 3)
    5.53 +alog = [1]
    5.54 +for i in xrange(255):
    5.55 +    j = (alog[-1] << 1) ^ alog[-1]
    5.56 +    if j & 0x100 != 0:
    5.57 +        j ^= 0x11B
    5.58 +    alog.append(j)
    5.59 +
    5.60 +log = [0] * 256
    5.61 +for i in xrange(1, 255):
    5.62 +    log[alog[i]] = i
    5.63 +
    5.64 +# multiply two elements of GF(2^m)
    5.65 +def mul(a, b):
    5.66 +    if a == 0 or b == 0:
    5.67 +        return 0
    5.68 +    return alog[(log[a & 0xFF] + log[b & 0xFF]) % 255]
    5.69 +
    5.70 +# substitution box based on F^{-1}(x)
    5.71 +box = [[0] * 8 for i in xrange(256)]
    5.72 +box[1][7] = 1
    5.73 +for i in xrange(2, 256):
    5.74 +    j = alog[255 - log[i]]
    5.75 +    for t in xrange(8):
    5.76 +        box[i][t] = (j >> (7 - t)) & 0x01
    5.77 +
    5.78 +B = [0, 1, 1, 0, 0, 0, 1, 1]
    5.79 +
    5.80 +# affine transform:  box[i] <- B + A*box[i]
    5.81 +cox = [[0] * 8 for i in xrange(256)]
    5.82 +for i in xrange(256):
    5.83 +    for t in xrange(8):
    5.84 +        cox[i][t] = B[t]
    5.85 +        for j in xrange(8):
    5.86 +            cox[i][t] ^= A[t][j] * box[i][j]
    5.87 +
    5.88 +# S-boxes and inverse S-boxes
    5.89 +S =  [0] * 256
    5.90 +Si = [0] * 256
    5.91 +for i in xrange(256):
    5.92 +    S[i] = cox[i][0] << 7
    5.93 +    for t in xrange(1, 8):
    5.94 +        S[i] ^= cox[i][t] << (7-t)
    5.95 +    Si[S[i] & 0xFF] = i
    5.96 +
    5.97 +# T-boxes
    5.98 +G = [[2, 1, 1, 3],
    5.99 +    [3, 2, 1, 1],
   5.100 +    [1, 3, 2, 1],
   5.101 +    [1, 1, 3, 2]]
   5.102 +
   5.103 +AA = [[0] * 8 for i in xrange(4)]
   5.104 +
   5.105 +for i in xrange(4):
   5.106 +    for j in xrange(4):
   5.107 +        AA[i][j] = G[i][j]
   5.108 +        AA[i][i+4] = 1
   5.109 +
   5.110 +for i in xrange(4):
   5.111 +    pivot = AA[i][i]
   5.112 +    if pivot == 0:
   5.113 +        t = i + 1
   5.114 +        while AA[t][i] == 0 and t < 4:
   5.115 +            t += 1
   5.116 +            assert t != 4, 'G matrix must be invertible'
   5.117 +            for j in xrange(8):
   5.118 +                AA[i][j], AA[t][j] = AA[t][j], AA[i][j]
   5.119 +            pivot = AA[i][i]
   5.120 +    for j in xrange(8):
   5.121 +        if AA[i][j] != 0:
   5.122 +            AA[i][j] = alog[(255 + log[AA[i][j] & 0xFF] - log[pivot & 0xFF]) % 255]
   5.123 +    for t in xrange(4):
   5.124 +        if i != t:
   5.125 +            for j in xrange(i+1, 8):
   5.126 +                AA[t][j] ^= mul(AA[i][j], AA[t][i])
   5.127 +            AA[t][i] = 0
   5.128 +
   5.129 +iG = [[0] * 4 for i in xrange(4)]
   5.130 +
   5.131 +for i in xrange(4):
   5.132 +    for j in xrange(4):
   5.133 +        iG[i][j] = AA[i][j + 4]
   5.134 +
   5.135 +def mul4(a, bs):
   5.136 +    if a == 0:
   5.137 +        return 0
   5.138 +    r = 0
   5.139 +    for b in bs:
   5.140 +        r <<= 8
   5.141 +        if b != 0:
   5.142 +            r = r | mul(a, b)
   5.143 +    return r
   5.144 +
   5.145 +T1 = []
   5.146 +T2 = []
   5.147 +T3 = []
   5.148 +T4 = []
   5.149 +T5 = []
   5.150 +T6 = []
   5.151 +T7 = []
   5.152 +T8 = []
   5.153 +U1 = []
   5.154 +U2 = []
   5.155 +U3 = []
   5.156 +U4 = []
   5.157 +
   5.158 +for t in xrange(256):
   5.159 +    s = S[t]
   5.160 +    T1.append(mul4(s, G[0]))
   5.161 +    T2.append(mul4(s, G[1]))
   5.162 +    T3.append(mul4(s, G[2]))
   5.163 +    T4.append(mul4(s, G[3]))
   5.164 +
   5.165 +    s = Si[t]
   5.166 +    T5.append(mul4(s, iG[0]))
   5.167 +    T6.append(mul4(s, iG[1]))
   5.168 +    T7.append(mul4(s, iG[2]))
   5.169 +    T8.append(mul4(s, iG[3]))
   5.170 +
   5.171 +    U1.append(mul4(t, iG[0]))
   5.172 +    U2.append(mul4(t, iG[1]))
   5.173 +    U3.append(mul4(t, iG[2]))
   5.174 +    U4.append(mul4(t, iG[3]))
   5.175 +
   5.176 +# round constants
   5.177 +rcon = [1]
   5.178 +r = 1
   5.179 +for t in xrange(1, 30):
   5.180 +    r = mul(2, r)
   5.181 +    rcon.append(r)
   5.182 +
   5.183 +del A
   5.184 +del AA
   5.185 +del pivot
   5.186 +del B
   5.187 +del G
   5.188 +del box
   5.189 +del log
   5.190 +del alog
   5.191 +del i
   5.192 +del j
   5.193 +del r
   5.194 +del s
   5.195 +del t
   5.196 +del mul
   5.197 +del mul4
   5.198 +del cox
   5.199 +del iG
   5.200 +
   5.201 +class rijndael:
   5.202 +    def __init__(self, key, block_size = 16):
   5.203 +        if block_size != 16 and block_size != 24 and block_size != 32:
   5.204 +            raise ValueError('Invalid block size: ' + str(block_size))
   5.205 +        if len(key) != 16 and len(key) != 24 and len(key) != 32:
   5.206 +            raise ValueError('Invalid key size: ' + str(len(key)))
   5.207 +        self.block_size = block_size
   5.208 +
   5.209 +        ROUNDS = num_rounds[len(key)][block_size]
   5.210 +        BC = block_size / 4
   5.211 +        # encryption round keys
   5.212 +        Ke = [[0] * BC for i in xrange(ROUNDS + 1)]
   5.213 +        # decryption round keys
   5.214 +        Kd = [[0] * BC for i in xrange(ROUNDS + 1)]
   5.215 +        ROUND_KEY_COUNT = (ROUNDS + 1) * BC
   5.216 +        KC = len(key) / 4
   5.217 +
   5.218 +        # copy user material bytes into temporary ints
   5.219 +        tk = []
   5.220 +        for i in xrange(0, KC):
   5.221 +            tk.append((ord(key[i * 4]) << 24) | (ord(key[i * 4 + 1]) << 16) |
   5.222 +                (ord(key[i * 4 + 2]) << 8) | ord(key[i * 4 + 3]))
   5.223 +
   5.224 +        # copy values into round key arrays
   5.225 +        t = 0
   5.226 +        j = 0
   5.227 +        while j < KC and t < ROUND_KEY_COUNT:
   5.228 +            Ke[t / BC][t % BC] = tk[j]
   5.229 +            Kd[ROUNDS - (t / BC)][t % BC] = tk[j]
   5.230 +            j += 1
   5.231 +            t += 1
   5.232 +        tt = 0
   5.233 +        rconpointer = 0
   5.234 +        while t < ROUND_KEY_COUNT:
   5.235 +            # extrapolate using phi (the round key evolution function)
   5.236 +            tt = tk[KC - 1]
   5.237 +            tk[0] ^= (S[(tt >> 16) & 0xFF] & 0xFF) << 24 ^  \
   5.238 +                     (S[(tt >>  8) & 0xFF] & 0xFF) << 16 ^  \
   5.239 +                     (S[ tt        & 0xFF] & 0xFF) <<  8 ^  \
   5.240 +                     (S[(tt >> 24) & 0xFF] & 0xFF)       ^  \
   5.241 +                     (rcon[rconpointer]    & 0xFF) << 24
   5.242 +            rconpointer += 1
   5.243 +            if KC != 8:
   5.244 +                for i in xrange(1, KC):
   5.245 +                    tk[i] ^= tk[i-1]
   5.246 +            else:
   5.247 +                for i in xrange(1, KC / 2):
   5.248 +                    tk[i] ^= tk[i-1]
   5.249 +                tt = tk[KC / 2 - 1]
   5.250 +                tk[KC / 2] ^= (S[ tt        & 0xFF] & 0xFF)       ^ \
   5.251 +                              (S[(tt >>  8) & 0xFF] & 0xFF) <<  8 ^ \
   5.252 +                              (S[(tt >> 16) & 0xFF] & 0xFF) << 16 ^ \
   5.253 +                              (S[(tt >> 24) & 0xFF] & 0xFF) << 24
   5.254 +                for i in xrange(KC / 2 + 1, KC):
   5.255 +                    tk[i] ^= tk[i-1]
   5.256 +            # copy values into round key arrays
   5.257 +            j = 0
   5.258 +            while j < KC and t < ROUND_KEY_COUNT:
   5.259 +                Ke[t / BC][t % BC] = tk[j]
   5.260 +                Kd[ROUNDS - (t / BC)][t % BC] = tk[j]
   5.261 +                j += 1
   5.262 +                t += 1
   5.263 +        # inverse MixColumn where needed
   5.264 +        for r in xrange(1, ROUNDS):
   5.265 +            for j in xrange(BC):
   5.266 +                tt = Kd[r][j]
   5.267 +                Kd[r][j] = U1[(tt >> 24) & 0xFF] ^ \
   5.268 +                           U2[(tt >> 16) & 0xFF] ^ \
   5.269 +                           U3[(tt >>  8) & 0xFF] ^ \
   5.270 +                           U4[ tt        & 0xFF]
   5.271 +        self.Ke = Ke
   5.272 +        self.Kd = Kd
   5.273 +
   5.274 +    def encrypt(self, plaintext):
   5.275 +        if len(plaintext) != self.block_size:
   5.276 +            raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(plaintext)))
   5.277 +        Ke = self.Ke
   5.278 +
   5.279 +        BC = self.block_size / 4
   5.280 +        ROUNDS = len(Ke) - 1
   5.281 +        if BC == 4:
   5.282 +            SC = 0
   5.283 +        elif BC == 6:
   5.284 +            SC = 1
   5.285 +        else:
   5.286 +            SC = 2
   5.287 +        s1 = shifts[SC][1][0]
   5.288 +        s2 = shifts[SC][2][0]
   5.289 +        s3 = shifts[SC][3][0]
   5.290 +        a = [0] * BC
   5.291 +        # temporary work array
   5.292 +        t = []
   5.293 +        # plaintext to ints + key
   5.294 +        for i in xrange(BC):
   5.295 +            t.append((ord(plaintext[i * 4    ]) << 24 |
   5.296 +                      ord(plaintext[i * 4 + 1]) << 16 |
   5.297 +                      ord(plaintext[i * 4 + 2]) <<  8 |
   5.298 +                      ord(plaintext[i * 4 + 3])        ) ^ Ke[0][i])
   5.299 +        # apply round transforms
   5.300 +        for r in xrange(1, ROUNDS):
   5.301 +            for i in xrange(BC):
   5.302 +                a[i] = (T1[(t[ i           ] >> 24) & 0xFF] ^
   5.303 +                        T2[(t[(i + s1) % BC] >> 16) & 0xFF] ^
   5.304 +                        T3[(t[(i + s2) % BC] >>  8) & 0xFF] ^
   5.305 +                        T4[ t[(i + s3) % BC]        & 0xFF]  ) ^ Ke[r][i]
   5.306 +            t = copy.copy(a)
   5.307 +        # last round is special
   5.308 +        result = []
   5.309 +        for i in xrange(BC):
   5.310 +            tt = Ke[ROUNDS][i]
   5.311 +            result.append((S[(t[ i           ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
   5.312 +            result.append((S[(t[(i + s1) % BC] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
   5.313 +            result.append((S[(t[(i + s2) % BC] >>  8) & 0xFF] ^ (tt >>  8)) & 0xFF)
   5.314 +            result.append((S[ t[(i + s3) % BC]        & 0xFF] ^  tt       ) & 0xFF)
   5.315 +        return string.join(map(chr, result), '')
   5.316 +
   5.317 +    def decrypt(self, ciphertext):
   5.318 +        if len(ciphertext) != self.block_size:
   5.319 +            raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(ciphertext)))
   5.320 +        Kd = self.Kd
   5.321 +
   5.322 +        BC = self.block_size / 4
   5.323 +        ROUNDS = len(Kd) - 1
   5.324 +        if BC == 4:
   5.325 +            SC = 0
   5.326 +        elif BC == 6:
   5.327 +            SC = 1
   5.328 +        else:
   5.329 +            SC = 2
   5.330 +        s1 = shifts[SC][1][1]
   5.331 +        s2 = shifts[SC][2][1]
   5.332 +        s3 = shifts[SC][3][1]
   5.333 +        a = [0] * BC
   5.334 +        # temporary work array
   5.335 +        t = [0] * BC
   5.336 +        # ciphertext to ints + key
   5.337 +        for i in xrange(BC):
   5.338 +            t[i] = (ord(ciphertext[i * 4    ]) << 24 |
   5.339 +                    ord(ciphertext[i * 4 + 1]) << 16 |
   5.340 +                    ord(ciphertext[i * 4 + 2]) <<  8 |
   5.341 +                    ord(ciphertext[i * 4 + 3])        ) ^ Kd[0][i]
   5.342 +        # apply round transforms
   5.343 +        for r in xrange(1, ROUNDS):
   5.344 +            for i in xrange(BC):
   5.345 +                a[i] = (T5[(t[ i           ] >> 24) & 0xFF] ^
   5.346 +                        T6[(t[(i + s1) % BC] >> 16) & 0xFF] ^
   5.347 +                        T7[(t[(i + s2) % BC] >>  8) & 0xFF] ^
   5.348 +                        T8[ t[(i + s3) % BC]        & 0xFF]  ) ^ Kd[r][i]
   5.349 +            t = copy.copy(a)
   5.350 +        # last round is special
   5.351 +        result = []
   5.352 +        for i in xrange(BC):
   5.353 +            tt = Kd[ROUNDS][i]
   5.354 +            result.append((Si[(t[ i           ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
   5.355 +            result.append((Si[(t[(i + s1) % BC] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
   5.356 +            result.append((Si[(t[(i + s2) % BC] >>  8) & 0xFF] ^ (tt >>  8)) & 0xFF)
   5.357 +            result.append((Si[ t[(i + s3) % BC]        & 0xFF] ^  tt       ) & 0xFF)
   5.358 +        return string.join(map(chr, result), '')
   5.359 +
   5.360 +def encrypt(key, block):
   5.361 +    return rijndael(key, len(block)).encrypt(block)
   5.362 +
   5.363 +def decrypt(key, block):
   5.364 +    return rijndael(key, len(block)).decrypt(block)
   5.365 +
   5.366 +def test():
   5.367 +    def t(kl, bl):
   5.368 +        b = 'b' * bl
   5.369 +        r = rijndael('a' * kl, bl)
   5.370 +        assert r.decrypt(r.encrypt(b)) == b
   5.371 +    t(16, 16)
   5.372 +    t(16, 24)
   5.373 +    t(16, 32)
   5.374 +    t(24, 16)
   5.375 +    t(24, 24)
   5.376 +    t(24, 32)
   5.377 +    t(32, 16)
   5.378 +    t(32, 24)
   5.379 +    t(32, 32)
     6.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     6.2 +++ b/viff/test/test_aes.py	Thu Jan 29 16:19:49 2009 +0100
     6.3 @@ -0,0 +1,139 @@
     6.4 +# Copyright 2009 VIFF Development Team.
     6.5 +#
     6.6 +# This file is part of VIFF, the Virtual Ideal Functionality Framework.
     6.7 +#
     6.8 +# VIFF is free software: you can redistribute it and/or modify it
     6.9 +# under the terms of the GNU Lesser General Public License (LGPL) as
    6.10 +# published by the Free Software Foundation, either version 3 of the
    6.11 +# License, or (at your option) any later version.
    6.12 +#
    6.13 +# VIFF is distributed in the hope that it will be useful, but WITHOUT
    6.14 +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
    6.15 +# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
    6.16 +# Public License for more details.
    6.17 +#
    6.18 +# You should have received a copy of the GNU Lesser General Public
    6.19 +# License along with VIFF. If not, see <http://www.gnu.org/licenses/>.
    6.20 +
    6.21 +"""Tests for viff.aes."""
    6.22 +
    6.23 +
    6.24 +from viff.test.util import RuntimeTestCase, protocol
    6.25 +
    6.26 +from viff.field import GF256
    6.27 +from viff.runtime import gather_shares, Share
    6.28 +from viff.aes import bit_decompose, AES
    6.29 +
    6.30 +from viff.test.rijndael import S, rijndael
    6.31 +
    6.32 +
    6.33 +__doctest__ = ["viff.aes"]
    6.34 +
    6.35 +
    6.36 +class BitDecompositionTestCase(RuntimeTestCase):
    6.37 +    """Test GF256 bit decomposition."""
    6.38 +
    6.39 +    def verify(self, runtime, results, expected_results):
    6.40 +        self.assert_type(results, list)
    6.41 +        opened_results = []
    6.42 +
    6.43 +        for result, expected in zip(results, expected_results):
    6.44 +            self.assert_type(result, Share)
    6.45 +            opened = runtime.open(result)
    6.46 +            opened.addCallback(self.assertEquals, expected)
    6.47 +            opened_results.append(opened)
    6.48 +        
    6.49 +        return gather_shares(opened_results)
    6.50 +
    6.51 +    @protocol
    6.52 +    def test_bit_decomposition(self, runtime):
    6.53 +        share = Share(runtime, GF256, GF256(99))
    6.54 +        return self.verify(runtime, bit_decompose(share),
    6.55 +                           [1,1,0,0,0,1,1,0])
    6.56 +
    6.57 +
    6.58 +class AESTestCase(RuntimeTestCase):
    6.59 +    def verify(self, runtime, results, expected_results):
    6.60 +        self.assert_type(results, list)
    6.61 +        opened_results = []
    6.62 +        
    6.63 +        for result_row, expected_row in zip(results, expected_results):
    6.64 +            self.assert_type(result_row, list)
    6.65 +            self.assertEquals(len(result_row), len(expected_row))
    6.66 +
    6.67 +            for result, expected in zip(result_row, expected_row):
    6.68 +                self.assert_type(result, Share)
    6.69 +                opened = runtime.open(result)
    6.70 +                opened.addCallback(self.assertEquals, expected)
    6.71 +                opened_results.append(opened)
    6.72 +
    6.73 +        return gather_shares(opened_results)
    6.74 +
    6.75 +    def _test_byte_sub(self, runtime, aes):
    6.76 +        results = []
    6.77 +        expected_results = []
    6.78 +
    6.79 +        for i in range(4):
    6.80 +            results.append([])
    6.81 +            expected_results.append([])
    6.82 +
    6.83 +            for j in range(4):
    6.84 +                b = 60 * i + j
    6.85 +                results[i].append(Share(runtime, GF256, GF256(b)))
    6.86 +                expected_results[i].append(S[b])
    6.87 +
    6.88 +        aes.byte_sub(results)
    6.89 +        self.verify(runtime, results, expected_results)
    6.90 +
    6.91 +    @protocol
    6.92 +    def test_byte_sub_with_masking(self, runtime):
    6.93 +        self._test_byte_sub(runtime, AES(runtime, 128, 
    6.94 +                                         use_exponentiation=False))
    6.95 +
    6.96 +    @protocol
    6.97 +    def test_byte_sub_with_exponentiation(self, runtime):
    6.98 +        self._test_byte_sub(runtime, AES(runtime, 128, 
    6.99 +                                         use_exponentiation=True))
   6.100 +
   6.101 +    @protocol
   6.102 +    def test_key_expansion(self, runtime):
   6.103 +        aes = AES(runtime, 256)
   6.104 +        key = []
   6.105 +        ascii_key = []
   6.106 +
   6.107 +        for i in xrange(8):
   6.108 +            key.append([])
   6.109 +
   6.110 +            for j in xrange(4):
   6.111 +                b = 15 * i + j
   6.112 +                key[i].append(Share(runtime, GF256, GF256(b)))
   6.113 +                ascii_key.append(chr(b))
   6.114 +
   6.115 +        result = aes.key_expansion(key)
   6.116 +
   6.117 +        r = rijndael(ascii_key)
   6.118 +        expected_result = []
   6.119 +
   6.120 +        for round_key in r.Ke:
   6.121 +            for word in round_key:
   6.122 +                split_word = []
   6.123 +                expected_result.append(split_word)
   6.124 +
   6.125 +                for j in xrange(4):
   6.126 +                    split_word.insert(0, word % 256)
   6.127 +                    word /= 256
   6.128 +
   6.129 +        self.verify(runtime, result, expected_result)
   6.130 +
   6.131 +    @protocol
   6.132 +    def test_encrypt(self, runtime):
   6.133 +        cleartext = "Encrypt this!!!!"
   6.134 +        key = "Supposed to be secret!?!"
   6.135 +
   6.136 +        aes = AES(runtime, 192)
   6.137 +        r = rijndael(key)
   6.138 +
   6.139 +        result = aes.encrypt(cleartext, key)
   6.140 +        expected = [ord(c) for c in r.encrypt(cleartext)]
   6.141 +
   6.142 +        return self.verify(runtime, [result], [expected])