viff

changeset 1069:53e67a17c67d

Test for AES ByteSub.
author Marcel Keller <mkeller@cs.au.dk>
date Mon, 22 Dec 2008 15:39:41 +0100
parents 33129a6f532a
children d2d9d638364b
files viff/test/rijndael.py viff/test/test_aes.py
diffstat 2 files changed, 464 insertions(+), 0 deletions(-) [+]
line diff
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/viff/test/rijndael.py	Mon Dec 22 15:39:41 2008 +0100
     1.3 @@ -0,0 +1,376 @@
     1.4 +"""
     1.5 +A pure python (slow) implementation of rijndael with a decent interface
     1.6 +
     1.7 +To include -
     1.8 +
     1.9 +from rijndael import rijndael
    1.10 +
    1.11 +To do a key setup -
    1.12 +
    1.13 +r = rijndael(key, block_size = 16)
    1.14 +
    1.15 +key must be a string of length 16, 24, or 32
    1.16 +blocksize must be 16, 24, or 32. Default is 16
    1.17 +
    1.18 +To use -
    1.19 +
    1.20 +ciphertext = r.encrypt(plaintext)
    1.21 +plaintext = r.decrypt(ciphertext)
    1.22 +
    1.23 +If any strings are of the wrong length a ValueError is thrown
    1.24 +"""
    1.25 +
    1.26 +# ported from the Java reference code by Bram Cohen, April 2001
    1.27 +# this code is public domain, unless someone makes 
    1.28 +# an intellectual property claim against the reference 
    1.29 +# code, in which case it can be made public domain by 
    1.30 +# deleting all the comments and renaming all the variables
    1.31 +
    1.32 +import copy
    1.33 +import string
    1.34 +
    1.35 +shifts = [[[0, 0], [1, 3], [2, 2], [3, 1]],
    1.36 +          [[0, 0], [1, 5], [2, 4], [3, 3]],
    1.37 +          [[0, 0], [1, 7], [3, 5], [4, 4]]]
    1.38 +
    1.39 +# [keysize][block_size]
    1.40 +num_rounds = {16: {16: 10, 24: 12, 32: 14}, 24: {16: 12, 24: 12, 32: 14}, 32: {16: 14, 24: 14, 32: 14}}
    1.41 +
    1.42 +A = [[1, 1, 1, 1, 1, 0, 0, 0],
    1.43 +     [0, 1, 1, 1, 1, 1, 0, 0],
    1.44 +     [0, 0, 1, 1, 1, 1, 1, 0],
    1.45 +     [0, 0, 0, 1, 1, 1, 1, 1],
    1.46 +     [1, 0, 0, 0, 1, 1, 1, 1],
    1.47 +     [1, 1, 0, 0, 0, 1, 1, 1],
    1.48 +     [1, 1, 1, 0, 0, 0, 1, 1],
    1.49 +     [1, 1, 1, 1, 0, 0, 0, 1]]
    1.50 +
    1.51 +# produce log and alog tables, needed for multiplying in the
    1.52 +# field GF(2^m) (generator = 3)
    1.53 +alog = [1]
    1.54 +for i in xrange(255):
    1.55 +    j = (alog[-1] << 1) ^ alog[-1]
    1.56 +    if j & 0x100 != 0:
    1.57 +        j ^= 0x11B
    1.58 +    alog.append(j)
    1.59 +
    1.60 +log = [0] * 256
    1.61 +for i in xrange(1, 255):
    1.62 +    log[alog[i]] = i
    1.63 +
    1.64 +# multiply two elements of GF(2^m)
    1.65 +def mul(a, b):
    1.66 +    if a == 0 or b == 0:
    1.67 +        return 0
    1.68 +    return alog[(log[a & 0xFF] + log[b & 0xFF]) % 255]
    1.69 +
    1.70 +# substitution box based on F^{-1}(x)
    1.71 +box = [[0] * 8 for i in xrange(256)]
    1.72 +box[1][7] = 1
    1.73 +for i in xrange(2, 256):
    1.74 +    j = alog[255 - log[i]]
    1.75 +    for t in xrange(8):
    1.76 +        box[i][t] = (j >> (7 - t)) & 0x01
    1.77 +
    1.78 +B = [0, 1, 1, 0, 0, 0, 1, 1]
    1.79 +
    1.80 +# affine transform:  box[i] <- B + A*box[i]
    1.81 +cox = [[0] * 8 for i in xrange(256)]
    1.82 +for i in xrange(256):
    1.83 +    for t in xrange(8):
    1.84 +        cox[i][t] = B[t]
    1.85 +        for j in xrange(8):
    1.86 +            cox[i][t] ^= A[t][j] * box[i][j]
    1.87 +
    1.88 +# S-boxes and inverse S-boxes
    1.89 +S =  [0] * 256
    1.90 +Si = [0] * 256
    1.91 +for i in xrange(256):
    1.92 +    S[i] = cox[i][0] << 7
    1.93 +    for t in xrange(1, 8):
    1.94 +        S[i] ^= cox[i][t] << (7-t)
    1.95 +    Si[S[i] & 0xFF] = i
    1.96 +
    1.97 +# T-boxes
    1.98 +G = [[2, 1, 1, 3],
    1.99 +    [3, 2, 1, 1],
   1.100 +    [1, 3, 2, 1],
   1.101 +    [1, 1, 3, 2]]
   1.102 +
   1.103 +AA = [[0] * 8 for i in xrange(4)]
   1.104 +
   1.105 +for i in xrange(4):
   1.106 +    for j in xrange(4):
   1.107 +        AA[i][j] = G[i][j]
   1.108 +        AA[i][i+4] = 1
   1.109 +
   1.110 +for i in xrange(4):
   1.111 +    pivot = AA[i][i]
   1.112 +    if pivot == 0:
   1.113 +        t = i + 1
   1.114 +        while AA[t][i] == 0 and t < 4:
   1.115 +            t += 1
   1.116 +            assert t != 4, 'G matrix must be invertible'
   1.117 +            for j in xrange(8):
   1.118 +                AA[i][j], AA[t][j] = AA[t][j], AA[i][j]
   1.119 +            pivot = AA[i][i]
   1.120 +    for j in xrange(8):
   1.121 +        if AA[i][j] != 0:
   1.122 +            AA[i][j] = alog[(255 + log[AA[i][j] & 0xFF] - log[pivot & 0xFF]) % 255]
   1.123 +    for t in xrange(4):
   1.124 +        if i != t:
   1.125 +            for j in xrange(i+1, 8):
   1.126 +                AA[t][j] ^= mul(AA[i][j], AA[t][i])
   1.127 +            AA[t][i] = 0
   1.128 +
   1.129 +iG = [[0] * 4 for i in xrange(4)]
   1.130 +
   1.131 +for i in xrange(4):
   1.132 +    for j in xrange(4):
   1.133 +        iG[i][j] = AA[i][j + 4]
   1.134 +
   1.135 +def mul4(a, bs):
   1.136 +    if a == 0:
   1.137 +        return 0
   1.138 +    r = 0
   1.139 +    for b in bs:
   1.140 +        r <<= 8
   1.141 +        if b != 0:
   1.142 +            r = r | mul(a, b)
   1.143 +    return r
   1.144 +
   1.145 +T1 = []
   1.146 +T2 = []
   1.147 +T3 = []
   1.148 +T4 = []
   1.149 +T5 = []
   1.150 +T6 = []
   1.151 +T7 = []
   1.152 +T8 = []
   1.153 +U1 = []
   1.154 +U2 = []
   1.155 +U3 = []
   1.156 +U4 = []
   1.157 +
   1.158 +for t in xrange(256):
   1.159 +    s = S[t]
   1.160 +    T1.append(mul4(s, G[0]))
   1.161 +    T2.append(mul4(s, G[1]))
   1.162 +    T3.append(mul4(s, G[2]))
   1.163 +    T4.append(mul4(s, G[3]))
   1.164 +
   1.165 +    s = Si[t]
   1.166 +    T5.append(mul4(s, iG[0]))
   1.167 +    T6.append(mul4(s, iG[1]))
   1.168 +    T7.append(mul4(s, iG[2]))
   1.169 +    T8.append(mul4(s, iG[3]))
   1.170 +
   1.171 +    U1.append(mul4(t, iG[0]))
   1.172 +    U2.append(mul4(t, iG[1]))
   1.173 +    U3.append(mul4(t, iG[2]))
   1.174 +    U4.append(mul4(t, iG[3]))
   1.175 +
   1.176 +# round constants
   1.177 +rcon = [1]
   1.178 +r = 1
   1.179 +for t in xrange(1, 30):
   1.180 +    r = mul(2, r)
   1.181 +    rcon.append(r)
   1.182 +
   1.183 +del A
   1.184 +del AA
   1.185 +del pivot
   1.186 +del B
   1.187 +del G
   1.188 +del box
   1.189 +del log
   1.190 +del alog
   1.191 +del i
   1.192 +del j
   1.193 +del r
   1.194 +del s
   1.195 +del t
   1.196 +del mul
   1.197 +del mul4
   1.198 +del cox
   1.199 +del iG
   1.200 +
   1.201 +class rijndael:
   1.202 +    def __init__(self, key, block_size = 16):
   1.203 +        if block_size != 16 and block_size != 24 and block_size != 32:
   1.204 +            raise ValueError('Invalid block size: ' + str(block_size))
   1.205 +        if len(key) != 16 and len(key) != 24 and len(key) != 32:
   1.206 +            raise ValueError('Invalid key size: ' + str(len(key)))
   1.207 +        self.block_size = block_size
   1.208 +
   1.209 +        ROUNDS = num_rounds[len(key)][block_size]
   1.210 +        BC = block_size / 4
   1.211 +        # encryption round keys
   1.212 +        Ke = [[0] * BC for i in xrange(ROUNDS + 1)]
   1.213 +        # decryption round keys
   1.214 +        Kd = [[0] * BC for i in xrange(ROUNDS + 1)]
   1.215 +        ROUND_KEY_COUNT = (ROUNDS + 1) * BC
   1.216 +        KC = len(key) / 4
   1.217 +
   1.218 +        # copy user material bytes into temporary ints
   1.219 +        tk = []
   1.220 +        for i in xrange(0, KC):
   1.221 +            tk.append((ord(key[i * 4]) << 24) | (ord(key[i * 4 + 1]) << 16) |
   1.222 +                (ord(key[i * 4 + 2]) << 8) | ord(key[i * 4 + 3]))
   1.223 +
   1.224 +        # copy values into round key arrays
   1.225 +        t = 0
   1.226 +        j = 0
   1.227 +        while j < KC and t < ROUND_KEY_COUNT:
   1.228 +            Ke[t / BC][t % BC] = tk[j]
   1.229 +            Kd[ROUNDS - (t / BC)][t % BC] = tk[j]
   1.230 +            j += 1
   1.231 +            t += 1
   1.232 +        tt = 0
   1.233 +        rconpointer = 0
   1.234 +        while t < ROUND_KEY_COUNT:
   1.235 +            # extrapolate using phi (the round key evolution function)
   1.236 +            tt = tk[KC - 1]
   1.237 +            tk[0] ^= (S[(tt >> 16) & 0xFF] & 0xFF) << 24 ^  \
   1.238 +                     (S[(tt >>  8) & 0xFF] & 0xFF) << 16 ^  \
   1.239 +                     (S[ tt        & 0xFF] & 0xFF) <<  8 ^  \
   1.240 +                     (S[(tt >> 24) & 0xFF] & 0xFF)       ^  \
   1.241 +                     (rcon[rconpointer]    & 0xFF) << 24
   1.242 +            rconpointer += 1
   1.243 +            if KC != 8:
   1.244 +                for i in xrange(1, KC):
   1.245 +                    tk[i] ^= tk[i-1]
   1.246 +            else:
   1.247 +                for i in xrange(1, KC / 2):
   1.248 +                    tk[i] ^= tk[i-1]
   1.249 +                tt = tk[KC / 2 - 1]
   1.250 +                tk[KC / 2] ^= (S[ tt        & 0xFF] & 0xFF)       ^ \
   1.251 +                              (S[(tt >>  8) & 0xFF] & 0xFF) <<  8 ^ \
   1.252 +                              (S[(tt >> 16) & 0xFF] & 0xFF) << 16 ^ \
   1.253 +                              (S[(tt >> 24) & 0xFF] & 0xFF) << 24
   1.254 +                for i in xrange(KC / 2 + 1, KC):
   1.255 +                    tk[i] ^= tk[i-1]
   1.256 +            # copy values into round key arrays
   1.257 +            j = 0
   1.258 +            while j < KC and t < ROUND_KEY_COUNT:
   1.259 +                Ke[t / BC][t % BC] = tk[j]
   1.260 +                Kd[ROUNDS - (t / BC)][t % BC] = tk[j]
   1.261 +                j += 1
   1.262 +                t += 1
   1.263 +        # inverse MixColumn where needed
   1.264 +        for r in xrange(1, ROUNDS):
   1.265 +            for j in xrange(BC):
   1.266 +                tt = Kd[r][j]
   1.267 +                Kd[r][j] = U1[(tt >> 24) & 0xFF] ^ \
   1.268 +                           U2[(tt >> 16) & 0xFF] ^ \
   1.269 +                           U3[(tt >>  8) & 0xFF] ^ \
   1.270 +                           U4[ tt        & 0xFF]
   1.271 +        self.Ke = Ke
   1.272 +        self.Kd = Kd
   1.273 +
   1.274 +    def encrypt(self, plaintext):
   1.275 +        if len(plaintext) != self.block_size:
   1.276 +            raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(plaintext)))
   1.277 +        Ke = self.Ke
   1.278 +
   1.279 +        BC = self.block_size / 4
   1.280 +        ROUNDS = len(Ke) - 1
   1.281 +        if BC == 4:
   1.282 +            SC = 0
   1.283 +        elif BC == 6:
   1.284 +            SC = 1
   1.285 +        else:
   1.286 +            SC = 2
   1.287 +        s1 = shifts[SC][1][0]
   1.288 +        s2 = shifts[SC][2][0]
   1.289 +        s3 = shifts[SC][3][0]
   1.290 +        a = [0] * BC
   1.291 +        # temporary work array
   1.292 +        t = []
   1.293 +        # plaintext to ints + key
   1.294 +        for i in xrange(BC):
   1.295 +            t.append((ord(plaintext[i * 4    ]) << 24 |
   1.296 +                      ord(plaintext[i * 4 + 1]) << 16 |
   1.297 +                      ord(plaintext[i * 4 + 2]) <<  8 |
   1.298 +                      ord(plaintext[i * 4 + 3])        ) ^ Ke[0][i])
   1.299 +        # apply round transforms
   1.300 +        for r in xrange(1, ROUNDS):
   1.301 +            for i in xrange(BC):
   1.302 +                a[i] = (T1[(t[ i           ] >> 24) & 0xFF] ^
   1.303 +                        T2[(t[(i + s1) % BC] >> 16) & 0xFF] ^
   1.304 +                        T3[(t[(i + s2) % BC] >>  8) & 0xFF] ^
   1.305 +                        T4[ t[(i + s3) % BC]        & 0xFF]  ) ^ Ke[r][i]
   1.306 +            t = copy.copy(a)
   1.307 +        # last round is special
   1.308 +        result = []
   1.309 +        for i in xrange(BC):
   1.310 +            tt = Ke[ROUNDS][i]
   1.311 +            result.append((S[(t[ i           ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
   1.312 +            result.append((S[(t[(i + s1) % BC] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
   1.313 +            result.append((S[(t[(i + s2) % BC] >>  8) & 0xFF] ^ (tt >>  8)) & 0xFF)
   1.314 +            result.append((S[ t[(i + s3) % BC]        & 0xFF] ^  tt       ) & 0xFF)
   1.315 +        return string.join(map(chr, result), '')
   1.316 +
   1.317 +    def decrypt(self, ciphertext):
   1.318 +        if len(ciphertext) != self.block_size:
   1.319 +            raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(ciphertext)))
   1.320 +        Kd = self.Kd
   1.321 +
   1.322 +        BC = self.block_size / 4
   1.323 +        ROUNDS = len(Kd) - 1
   1.324 +        if BC == 4:
   1.325 +            SC = 0
   1.326 +        elif BC == 6:
   1.327 +            SC = 1
   1.328 +        else:
   1.329 +            SC = 2
   1.330 +        s1 = shifts[SC][1][1]
   1.331 +        s2 = shifts[SC][2][1]
   1.332 +        s3 = shifts[SC][3][1]
   1.333 +        a = [0] * BC
   1.334 +        # temporary work array
   1.335 +        t = [0] * BC
   1.336 +        # ciphertext to ints + key
   1.337 +        for i in xrange(BC):
   1.338 +            t[i] = (ord(ciphertext[i * 4    ]) << 24 |
   1.339 +                    ord(ciphertext[i * 4 + 1]) << 16 |
   1.340 +                    ord(ciphertext[i * 4 + 2]) <<  8 |
   1.341 +                    ord(ciphertext[i * 4 + 3])        ) ^ Kd[0][i]
   1.342 +        # apply round transforms
   1.343 +        for r in xrange(1, ROUNDS):
   1.344 +            for i in xrange(BC):
   1.345 +                a[i] = (T5[(t[ i           ] >> 24) & 0xFF] ^
   1.346 +                        T6[(t[(i + s1) % BC] >> 16) & 0xFF] ^
   1.347 +                        T7[(t[(i + s2) % BC] >>  8) & 0xFF] ^
   1.348 +                        T8[ t[(i + s3) % BC]        & 0xFF]  ) ^ Kd[r][i]
   1.349 +            t = copy.copy(a)
   1.350 +        # last round is special
   1.351 +        result = []
   1.352 +        for i in xrange(BC):
   1.353 +            tt = Kd[ROUNDS][i]
   1.354 +            result.append((Si[(t[ i           ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
   1.355 +            result.append((Si[(t[(i + s1) % BC] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
   1.356 +            result.append((Si[(t[(i + s2) % BC] >>  8) & 0xFF] ^ (tt >>  8)) & 0xFF)
   1.357 +            result.append((Si[ t[(i + s3) % BC]        & 0xFF] ^  tt       ) & 0xFF)
   1.358 +        return string.join(map(chr, result), '')
   1.359 +
   1.360 +def encrypt(key, block):
   1.361 +    return rijndael(key, len(block)).encrypt(block)
   1.362 +
   1.363 +def decrypt(key, block):
   1.364 +    return rijndael(key, len(block)).decrypt(block)
   1.365 +
   1.366 +def test():
   1.367 +    def t(kl, bl):
   1.368 +        b = 'b' * bl
   1.369 +        r = rijndael('a' * kl, bl)
   1.370 +        assert r.decrypt(r.encrypt(b)) == b
   1.371 +    t(16, 16)
   1.372 +    t(16, 24)
   1.373 +    t(16, 32)
   1.374 +    t(24, 16)
   1.375 +    t(24, 24)
   1.376 +    t(24, 32)
   1.377 +    t(32, 16)
   1.378 +    t(32, 24)
   1.379 +    t(32, 32)
     2.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     2.2 +++ b/viff/test/test_aes.py	Mon Dec 22 15:39:41 2008 +0100
     2.3 @@ -0,0 +1,88 @@
     2.4 +# Copyright 2007, 2008 VIFF Development Team.
     2.5 +#
     2.6 +# This file is part of VIFF, the Virtual Ideal Functionality Framework.
     2.7 +#
     2.8 +# VIFF is free software: you can redistribute it and/or modify it
     2.9 +# under the terms of the GNU Lesser General Public License (LGPL) as
    2.10 +# published by the Free Software Foundation, either version 3 of the
    2.11 +# License, or (at your option) any later version.
    2.12 +#
    2.13 +# VIFF is distributed in the hope that it will be useful, but WITHOUT
    2.14 +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
    2.15 +# or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
    2.16 +# Public License for more details.
    2.17 +#
    2.18 +# You should have received a copy of the GNU Lesser General Public
    2.19 +# License along with VIFF. If not, see <http://www.gnu.org/licenses/>.
    2.20 +
    2.21 +"""Tests for viff.aes."""
    2.22 +
    2.23 +
    2.24 +from viff.test.util import RuntimeTestCase, protocol
    2.25 +
    2.26 +from viff.field import GF256
    2.27 +from viff.runtime import gather_shares, Share
    2.28 +from viff.aes import bit_decompose, AES
    2.29 +
    2.30 +from viff.test.rijndael import S
    2.31 +
    2.32 +
    2.33 +__doctest__ = ["viff.aes"]
    2.34 +
    2.35 +
    2.36 +class BitDecompositionTestCase(RuntimeTestCase):
    2.37 +    """Test GF256 bit decomposition."""
    2.38 +
    2.39 +    def verify(self, runtime, results, expected_results):
    2.40 +        self.assert_type(results, list)
    2.41 +        opened_results = []
    2.42 +
    2.43 +        for result, expected in zip(results, expected_results):
    2.44 +            self.assert_type(result, Share)
    2.45 +            opened = runtime.open(result)
    2.46 +            opened.addCallback(self.assertEquals, expected)
    2.47 +            opened_results.append(opened)
    2.48 +        
    2.49 +        return gather_shares(opened_results)
    2.50 +
    2.51 +    @protocol
    2.52 +    def test_bit_decomposition(self, runtime):
    2.53 +        share = Share(runtime, GF256, GF256(99))
    2.54 +        return self.verify(runtime, bit_decompose(share),
    2.55 +                           [1,1,0,0,0,1,1,0])
    2.56 +
    2.57 +
    2.58 +class AESTestCase(RuntimeTestCase):
    2.59 +    def verify(self, runtime, results, expected_results):
    2.60 +        self.assert_type(results, list)
    2.61 +        opened_results = []
    2.62 +        
    2.63 +        for result_row, expected_row in zip(results, expected_results):
    2.64 +            self.assert_type(result_row, list)
    2.65 +            self.assertEquals(len(result_row), len(expected_row))
    2.66 +
    2.67 +            for result, expected in zip(result_row, expected_row):
    2.68 +                self.assert_type(result, Share)
    2.69 +                opened = runtime.open(result)
    2.70 +                opened.addCallback(self.assertEquals, expected)
    2.71 +                opened_results.append(opened)
    2.72 +
    2.73 +        return gather_shares(opened_results)
    2.74 +
    2.75 +    @protocol
    2.76 +    def test_byte_sub(self, runtime):
    2.77 +        aes = AES(runtime, 128)
    2.78 +        results = []
    2.79 +        expected_results = []
    2.80 +
    2.81 +        for i in range(4):
    2.82 +            results.append([])
    2.83 +            expected_results.append([])
    2.84 +
    2.85 +            for j in range(4):
    2.86 +                b = 60 * i + j
    2.87 +                results[i].append(Share(runtime, GF256, b))
    2.88 +                expected_results[i].append(S[b])
    2.89 +
    2.90 +        aes.byte_sub(results)
    2.91 +        self.verify(runtime, results, expected_results)