Now that we know the basics on AES, it's time to implement the concepts in python. By the end of this, we should be able to take in any arbitrary file as input along with a key and return an encrypted file which should be able to decrypted back to the original.
This is part 2 of the series. Read [part 1 here]({% link Programming/_posts/2023-09-02-AES Encryption 1 - Concepts.md %}).
The full code is available on github.
Python Implementation
Very simple, I used numpy instead of python arrays in the hope that it'll be fast (but fast is relative). Basically the AES algo you read above is for a 16 byte chunk, so let's call it encrypt_chunk
for now. (This function is the same as the one we discussed above.)
The base code then becomes:
import numpy as np
# encrypt file
f = file(INPUT_FILE, "rb")
FILE_DATA = f.read()
FILE_SIZE = len(f)
f.close()
PADDING_REQUIRED = FILE_SIZE % 16
for i in range(0, PADDING_REQUIRED):
FILE_DATA.add_byte_to_end(0)
OUTPUT = []
N_ITER = FILE_DATA.size / 16
for i in range(0, N_ITER):
chunk = np.array(FILE_DATA[i*16: i*16+1], dtype="uint8").reshape(4, 4)
ITER_OUT = encrypt_chunk(chunk, KEY)
for j in range(0, 16):
OUTPUT.push(ITER_OUT[j])
OUTPUT.add_byte_to_end(PADDING_REQUIRED)
w = open(OUTPUT_FILE, "wb")
w.write(OUTPUT)
w.close()
Note that we treat each chunk as a 4x4 numpy array where each element is a byte (or 8 bits).
Encrypt Chunk
For details on each step, refer to the previous entry in the series.
def encrypt_chunk(input_block, key):
state = input_block
round_keys = get_round_keys(key)
# initial round key addition
state = add_round_key(state, round_keys[0])
for round_key in round_keys[1:-1]:
state = sub_bytes(state)
state = shift_rows(state)
state = mix_columns(state)
state = add_round_key(state, round_key)
# final round
state = sub_bytes(state)
state = shift_rows(state)
state = add_round_key(state, round_keys[-1])
return state
def decrypt_chunk(input_block, key):
state = input_block
round_keys = KeyExpansion(key)
# first round
add_round_key(state, round_keys[-1])
# rounds 2 to n-1
for i in range(num_rounds - 2, 0, -1):
inv_shift_rows(state)
inv_sub_bytes(state)
add_round_key(state, round_keys[i])
inv_mix_columns(state)
# last round
inv_shift_rows(state)
inv_sub_bytes(state)
add_round_key(state, round_keys[0])
return state
Since KeyExpansion
depends on SubBytes
, we'll get to it later.
inv_*
is the inverted version of the function.
Add Round Key
Just a bitwise XOR, so easy enough.
def add_round_key(state, key):
return state ^ key
Since a ^ b = c => c ^ b = a
, we don't need a separate inverse function.
SubBytes
Essentially you need to replace each value in the state with a given dictionary. For this we use a simple dictionary like dict = {0: 99, 1: 124, ..., 255: 22}
.
Now since the input to this is a numpy array, instead of using for loops to replace values, we can just vectorize the get function and make it faster.
import numpy as np
from .constants import S_BOX_MAPPING # dict where mapping is stored
S_BOX_MAPPING_NUMPY_FN = np.vectorize(S_BOX_MAPPING.get)
def sub_bytes(state):
return S_BOX_MAPPING_NUMPY_FN(state)
For the inverse of this function, just invert the dictionary. For convenience, we have also declared it as a constant.
import numpy as np
from .constants import INV_S_BOX_MAPPING # dict where mapping is stored
INV_S_BOX_MAPPING_NUMPY_FN = np.vectorize(INV_S_BOX_MAPPING.get)
def inv_sub_bytes(state):
return INV_S_BOX_MAPPING_NUMPY_FN(state)
ShiftRows
Seems easy enough. I just add some list comprehension to make it cooler (and maybe faster.)
import numpy as np
def shift_rows(state):
return np.array(
[np.roll(row, -i) for i, row in enumerate(state)],
dtype=state.dtype,
)
def inv_shift_rows(state):
return np.array(
[np.roll(row, i) for i, row in enumerate(state)], dtype=state.dtype
)
We use -i
because we shift to the left when encrypting, and +i
to shift to the right when decrypting.
MixColumns
Most complex bit of the entire thing. Now before I realized lookup tables exist for galois multiplication (the finite field maths I was talking about in the previous article), I tried to code up that multiplication myself. After failing, I just used wikipedia's given code and converted it to python. Don't ask me how it works and why it works. Essentially it does this: $$ \begin{bmatrix} b_{0,j} \ b_{1,j} \ b_{2,j} \ b_{3,j} \end{bmatrix} = \begin{bmatrix} 2 & 3 & 1 & 1 \ 1 & 2 & 3 & 1 \ 1 & 1 & 2 & 3 \ 3 & 1 & 1 & 2 \end{bmatrix} \begin{bmatrix} a_{0,j} \ a_{1,j} \ a_{2,j} \ a_{3,j} \end{bmatrix} $$
But replace addition with XOR and multiplication with galois multiplication.
def galois_multiplication(a, b):
"""Galois multiplication of two 8 bit numbers."""
p = 0
for i in range(8):
if b & 1:
p ^= a
hi_bit_set = a & 0x80
a <<= 1
if hi_bit_set:
a ^= 0x1b
b >>= 1
return p
gmul = galois_multiplication
def mix_columns(state):
out = np.zeros_like(state)
for i in range(len(out)):
out[i, 0] = np.array(
gmul(2, state[i, 0]) ^ gmul(3, state[i, 1]) ^ state[i, 2] ^ state[i, 3]
).astype(np.uint8)
out[i, 1] = np.array(
state[i, 0] ^ gmul(2, state[i, 1]) ^ gmul(3, state[i, 2]) ^ state[i, 3]
).astype(np.uint8)
out[i, 2] = np.array(
state[i, 0] ^ state[i, 1] ^ gmul(2, state[i, 2]) ^ gmul(3, state[i, 3])
).astype(np.uint8)
out[i, 3] = np.array(
gmul(3, state[i, 0]) ^ state[i, 1] ^ state[i, 2] ^ gmul(2, state[i, 3])
).astype(np.uint8)
return out
For inverse, we just calculate the inverse of the matrix we multiplied with and get: $$ \begin{bmatrix}b_0\b_1\b_2\b_3\end{bmatrix} = \begin{bmatrix} 14&11&13&9 \ 9&14&11&13 \ 13&9&14&11 \ 11&13&9&14 \end{bmatrix} \begin{bmatrix}d_0\d_1\d_2\d_3\end{bmatrix} $$
def inv_mix_columns(state):
out = np.zeros_like(state)
for i in range(len(out)):
out[i, 0] = np.array(
gmul(0x0E, state[i, 0])
^ gmul(0x0B, state[i, 1])
^ gmul(0x0D, state[i, 2])
^ gmul(0x09, state[i, 3])
).astype(np.uint8)
out[i, 1] = np.array(
gmul(0x09, state[i, 0])
^ gmul(0x0E, state[i, 1])
^ gmul(0x0B, state[i, 2])
^ gmul(0x0D, state[i, 3])
).astype(np.uint8)
out[i, 2] = np.array(
gmul(0x0D, state[i, 0])
^ gmul(0x09, state[i, 1])
^ gmul(0x0E, state[i, 2])
^ gmul(0x0B, state[i, 3])
).astype(np.uint8)
out[i, 3] = np.array(
gmul(0x0B, state[i, 0])
^ gmul(0x0D, state[i, 1])
^ gmul(0x09, state[i, 2])
^ gmul(0x0E, state[i, 3])
).astype(np.uint8)
return out
Key Expansion
Should be pretty simple compared to MixColumns.
import numpy as np
def _rot_word(word):
return np.roll(word, -1)
def get_round_keys(input_key, N, R):
# N = size of key in 32 bit words.
# R = number of rounds.
# check if key is in shape N, 4 i.e. an array of 32 bit words where
# each word is a 4 byte array also
assert input_key.shape == (N, 4)
W = np.zeros((4 * R, 4), dtype=np.uint8)
W[:N] = input_key
for i in range(N, 4 * R - 1):
if i % N == 0:
W[i] = (
W[i - N] ^ sub_bytes(_rot_word(W[i - 1]), inplace=False) ^ RCON[i // N]
)
elif N > 6 and i % N == 4:
W[i] = W[i - N] ^ sub_bytes(W[i - 1], inplace=False)
else:
W[i] = W[i - N] ^ W[i - 1]
return W
Cleanup for the final code
The code is up on github. I just made a few changes to it:
- Added docstrings.
- Renamed a few variables to be more reader friendly.
- Added the option to modify state inplace (might be faster in some cases).
- Improved the encrypt/decrypt file functionality.
- Used an Enum to store AES-128/192/256 and use relevant constants.
- Added
typer
for a fancy CLI interface.
How to use the code
RTFM. (please google this if you don't know what this means.)
Also note that you can simply save your key as a .txt file with the approriate number of characters. Just remember that each ASCII character is one byte. So, for AES-128 you need to have a file with 16 characters only.
My takeaways from Python code
TESTS
WRITE TESTS. A LOT OF DEV PAIN COULD HAVE BEEN AVOIDED HERE IF I WROTE TESTS.
Encryption is pointless if you can't decrypt. So after writing all of encryption how do I know that it is correctly encrypted. So I wrote decrypt and much to my suprise (but probably no one else's), the decrypted output was a messy blob and not the clean original input I was expecting.
So I spent a lot of time debugging which function was messing up. If I had just written tests for each function while writing it, I wouldn't have gone through this pain.
SLOWWWW
This code took 12 seconds to encrypt a 120kb file. Considering AES powers the web (HTTPS uses TLS which uses AES), AES should be pretty fast, yet I'm getting 10KBps which is like dialup levels of speed. So there's definitely room for improvement.
Next Steps
I could try to optimise it in python only using multithreaded or more optimised functions and data storage techniques. But considering how low level this algorithm really is, I doubt python is the right tool to use for this.
I am learning rust on the side and this seems like the perfect project to practice and improve my skills. So it was decided that I would use rust to speed up this program as I go along.
Continue reading this in part 3. Coming out soon.