AES Encryption 2 - Python

Published on 2023-09-04 by Kartikay Bagla


Now that we know the basics on AES, it's time to implement the concepts in python. By the end of this, we should be able to take in any arbitrary file as input along with a key and return an encrypted file which should be able to decrypted back to the original.

This is part 2 of the series. Read [part 1 here]({% link Programming/_posts/2023-09-02-AES Encryption 1 - Concepts.md %}).

The full code is available on github.

Python Implementation

Very simple, I used numpy instead of python arrays in the hope that it'll be fast (but fast is relative). Basically the AES algo you read above is for a 16 byte chunk, so let's call it encrypt_chunk for now. (This function is the same as the one we discussed above.)

The base code then becomes:

import numpy as np

# encrypt file
f = file(INPUT_FILE, "rb")
FILE_DATA = f.read()
FILE_SIZE = len(f)
f.close()

PADDING_REQUIRED = FILE_SIZE % 16
for i in range(0, PADDING_REQUIRED):
    FILE_DATA.add_byte_to_end(0)

OUTPUT = []
N_ITER = FILE_DATA.size / 16
for i in range(0, N_ITER):
    chunk = np.array(FILE_DATA[i*16: i*16+1], dtype="uint8").reshape(4, 4)
    ITER_OUT = encrypt_chunk(chunk, KEY)
    for j in range(0, 16):
         OUTPUT.push(ITER_OUT[j])

OUTPUT.add_byte_to_end(PADDING_REQUIRED)

w = open(OUTPUT_FILE, "wb")
w.write(OUTPUT)
w.close()

Note that we treat each chunk as a 4x4 numpy array where each element is a byte (or 8 bits).

Encrypt Chunk

For details on each step, refer to the previous entry in the series.

def encrypt_chunk(input_block, key):
    state = input_block
    round_keys = get_round_keys(key)

    # initial round key addition
    state = add_round_key(state, round_keys[0])

    for round_key in round_keys[1:-1]:
        state = sub_bytes(state)
        state = shift_rows(state)
        state = mix_columns(state)
        state = add_round_key(state, round_key)
    
    # final round
    state = sub_bytes(state)
    state = shift_rows(state)
    state = add_round_key(state, round_keys[-1])

    return state


def decrypt_chunk(input_block, key):
    state = input_block
    round_keys = KeyExpansion(key)

    # first round
    add_round_key(state, round_keys[-1])

    # rounds 2 to n-1
    for i in range(num_rounds - 2, 0, -1):
        inv_shift_rows(state)
        inv_sub_bytes(state)
        add_round_key(state, round_keys[i])
        inv_mix_columns(state)

    # last round
    inv_shift_rows(state)
    inv_sub_bytes(state)
    add_round_key(state, round_keys[0])

    return state

Since KeyExpansion depends on SubBytes, we'll get to it later. inv_* is the inverted version of the function.

Add Round Key

Just a bitwise XOR, so easy enough.

def add_round_key(state, key):
    return state ^ key

Since a ^ b = c => c ^ b = a, we don't need a separate inverse function.

SubBytes

Essentially you need to replace each value in the state with a given dictionary. For this we use a simple dictionary like dict = {0: 99, 1: 124, ..., 255: 22}. Now since the input to this is a numpy array, instead of using for loops to replace values, we can just vectorize the get function and make it faster.

import numpy as np
from .constants import S_BOX_MAPPING  # dict where mapping is stored

S_BOX_MAPPING_NUMPY_FN = np.vectorize(S_BOX_MAPPING.get)

def sub_bytes(state):
    return S_BOX_MAPPING_NUMPY_FN(state)

For the inverse of this function, just invert the dictionary. For convenience, we have also declared it as a constant.

import numpy as np
from .constants import INV_S_BOX_MAPPING  # dict where mapping is stored

INV_S_BOX_MAPPING_NUMPY_FN = np.vectorize(INV_S_BOX_MAPPING.get)

def inv_sub_bytes(state):
    return INV_S_BOX_MAPPING_NUMPY_FN(state)

ShiftRows

Seems easy enough. I just add some list comprehension to make it cooler (and maybe faster.)

import numpy as np

def shift_rows(state):
    return np.array(
        [np.roll(row, -i) for i, row in enumerate(state)],
        dtype=state.dtype,
    )

def inv_shift_rows(state):
    return np.array(
        [np.roll(row, i) for i, row in enumerate(state)], dtype=state.dtype
    )

We use -i because we shift to the left when encrypting, and +i to shift to the right when decrypting.

MixColumns

Most complex bit of the entire thing. Now before I realized lookup tables exist for galois multiplication (the finite field maths I was talking about in the previous article), I tried to code up that multiplication myself. After failing, I just used wikipedia's given code and converted it to python. Don't ask me how it works and why it works. Essentially it does this: $$ \begin{bmatrix} b_{0,j} \ b_{1,j} \ b_{2,j} \ b_{3,j} \end{bmatrix} = \begin{bmatrix} 2 & 3 & 1 & 1 \ 1 & 2 & 3 & 1 \ 1 & 1 & 2 & 3 \ 3 & 1 & 1 & 2 \end{bmatrix} \begin{bmatrix} a_{0,j} \ a_{1,j} \ a_{2,j} \ a_{3,j} \end{bmatrix} $$

But replace addition with XOR and multiplication with galois multiplication.

def galois_multiplication(a, b):
    """Galois multiplication of two 8 bit numbers."""
    p = 0
    for i in range(8):
        if b & 1:
            p ^= a
        hi_bit_set = a & 0x80
        a <<= 1
        if hi_bit_set:
            a ^= 0x1b
        b >>= 1
    return p

gmul = galois_multiplication

def mix_columns(state):
    out = np.zeros_like(state)
    for i in range(len(out)):
        out[i, 0] = np.array(
            gmul(2, state[i, 0]) ^ gmul(3, state[i, 1]) ^ state[i, 2] ^ state[i, 3]
        ).astype(np.uint8)
        out[i, 1] = np.array(
            state[i, 0] ^ gmul(2, state[i, 1]) ^ gmul(3, state[i, 2]) ^ state[i, 3]
        ).astype(np.uint8)
        out[i, 2] = np.array(
            state[i, 0] ^ state[i, 1] ^ gmul(2, state[i, 2]) ^ gmul(3, state[i, 3])
        ).astype(np.uint8)
        out[i, 3] = np.array(
            gmul(3, state[i, 0]) ^ state[i, 1] ^ state[i, 2] ^ gmul(2, state[i, 3])
        ).astype(np.uint8)
    return out

For inverse, we just calculate the inverse of the matrix we multiplied with and get: $$ \begin{bmatrix}b_0\b_1\b_2\b_3\end{bmatrix} = \begin{bmatrix} 14&11&13&9 \ 9&14&11&13 \ 13&9&14&11 \ 11&13&9&14 \end{bmatrix} \begin{bmatrix}d_0\d_1\d_2\d_3\end{bmatrix} $$

def inv_mix_columns(state):
    out = np.zeros_like(state)
    for i in range(len(out)):
        out[i, 0] = np.array(
            gmul(0x0E, state[i, 0])
            ^ gmul(0x0B, state[i, 1])
            ^ gmul(0x0D, state[i, 2])
            ^ gmul(0x09, state[i, 3])
        ).astype(np.uint8)
        out[i, 1] = np.array(
            gmul(0x09, state[i, 0])
            ^ gmul(0x0E, state[i, 1])
            ^ gmul(0x0B, state[i, 2])
            ^ gmul(0x0D, state[i, 3])
        ).astype(np.uint8)
        out[i, 2] = np.array(
            gmul(0x0D, state[i, 0])
            ^ gmul(0x09, state[i, 1])
            ^ gmul(0x0E, state[i, 2])
            ^ gmul(0x0B, state[i, 3])
        ).astype(np.uint8)
        out[i, 3] = np.array(
            gmul(0x0B, state[i, 0])
            ^ gmul(0x0D, state[i, 1])
            ^ gmul(0x09, state[i, 2])
            ^ gmul(0x0E, state[i, 3])
        ).astype(np.uint8)
    return out

Key Expansion

Should be pretty simple compared to MixColumns.

import numpy as np

def _rot_word(word):
    return np.roll(word, -1)

def get_round_keys(input_key, N, R):
    # N = size of key in 32 bit words.
    # R = number of rounds.
    # check if key is in shape N, 4 i.e. an array of 32 bit words where 
    # each word is a 4 byte array also
    assert input_key.shape == (N, 4)
    W = np.zeros((4 * R, 4), dtype=np.uint8)
    W[:N] = input_key
    for i in range(N, 4 * R - 1):
        if i % N == 0:
            W[i] = (
                W[i - N] ^ sub_bytes(_rot_word(W[i - 1]), inplace=False) ^ RCON[i // N]
            )
        elif N > 6 and i % N == 4:
            W[i] = W[i - N] ^ sub_bytes(W[i - 1], inplace=False)
        else:
            W[i] = W[i - N] ^ W[i - 1]
    return W

Cleanup for the final code

The code is up on github. I just made a few changes to it:

  • Added docstrings.
  • Renamed a few variables to be more reader friendly.
  • Added the option to modify state inplace (might be faster in some cases).
  • Improved the encrypt/decrypt file functionality.
  • Used an Enum to store AES-128/192/256 and use relevant constants.
  • Added typer for a fancy CLI interface.

How to use the code

RTFM. (please google this if you don't know what this means.)

Also note that you can simply save your key as a .txt file with the approriate number of characters. Just remember that each ASCII character is one byte. So, for AES-128 you need to have a file with 16 characters only.

My takeaways from Python code

TESTS

WRITE TESTS. A LOT OF DEV PAIN COULD HAVE BEEN AVOIDED HERE IF I WROTE TESTS.

Encryption is pointless if you can't decrypt. So after writing all of encryption how do I know that it is correctly encrypted. So I wrote decrypt and much to my suprise (but probably no one else's), the decrypted output was a messy blob and not the clean original input I was expecting.

So I spent a lot of time debugging which function was messing up. If I had just written tests for each function while writing it, I wouldn't have gone through this pain.

SLOWWWW

This code took 12 seconds to encrypt a 120kb file. Considering AES powers the web (HTTPS uses TLS which uses AES), AES should be pretty fast, yet I'm getting 10KBps which is like dialup levels of speed. So there's definitely room for improvement.

Next Steps

I could try to optimise it in python only using multithreaded or more optimised functions and data storage techniques. But considering how low level this algorithm really is, I doubt python is the right tool to use for this.

I am learning rust on the side and this seems like the perfect project to practice and improve my skills. So it was decided that I would use rust to speed up this program as I go along.

Continue reading this in part 3. Coming out soon.