diamond_full diamond diamond_half diamond_euro search-icon menu chat-icon close-icon envelope-icon smartphone-call-icon

Blog & News

23. September, 2022

Differential Cryptanalysis of a Single-Round SPN

Part two of the differential cryptanalysis series

preview-image for Logo

The design of block ciphers is usually seen as a specialist topic. Consequently, knowledge is mostly preserved in academic papers and there are only few introductory tutorials. We aim to fill this gap between the IT security practitioner and the block cipher designer. In this blog series we introduce differential cryptanalysis with a hands-on approach. Our target group are IT security practitioners and programmers without a deep knowledge of math.

The following list shows the topics of all scheduled blog posts. It will be updated with the corresponding links once new posts are being released.


Introduction

Differential cryptanalysis is a type of cryptographic attack on the design of a cipher. It was introduced by Biham and Shamir in the late 80s. However, later it was revealed that differential cryptanalysis was already known to IBM when they designed DES – the first public standardized algorithm for encrypting data.

Differential cryptanalysis is a chosen plaintext attack, i.e., the attacker can query encryptions of a plaintext and has to recover the key. This is formally modeled by querying a so-called chosen-plaintext oracle. In practice it does not matter how the oracle works exactly, but it somehow has to provide the ciphertext, given a plaintext. This setting is part of the attack scenario. The main idea of differential cryptanalysis is to trace differences of plaintexts through the encryption algorithm. In the ideal case, for a given difference of inputs each difference of outputs has the same probability of occurrence. However, due to mathematical reasons that ideal situation can never be the case. If the deviations from that ideal probability are too large they can be exploited and the key can be recovered.

In this blog post we introduce a toy cipher on which we demonstrate the principles of differential cryptanalysis. We implemented our attack in Python in order to raise the understanding of differential cryptanalysis.

Our Toy Cipher

In this section we describe the design of our toy cipher which we attack below using differential cryptanalysis. Further, we give a short security discussion.

Design

The toy cipher analyzed in this post follows the Even-Mansourscheme which we have already discussed in the first post of this series . In some sense this is the simplest cipher which can be proven secure in some formal model.

The block size of our toy cipher is 4 bits. The key consists of 8 bits. This key is split into two halves, k0 and k1 of 4 bits each. Additionally the cipher contains an S-box. The S-box is basically a table lookup that maps a 4 bit input to a 4 bit output.

Encrypting an input block of 4 bit length requires the following steps:

  1. Xor the plaintext input with the first key k0.
  2. Feed the output of the previous step into the S-box.
  3. Xor the output of the previous step with k1.

The output of the last step is the cipher text.

The image below shows the steps of the encryption routine as a circuit. As usual, the data flow is from top to the bottom.

The schematic design of our toy cipher
The schematic design of our toy cipher

Creating the concrete values for the S-box was done using a dice. This generally leads to an insecure S-box, which is perfect for illustrating differential cryptanalysis. In general, for modern ciphers the values of the S-boxes are chosen very carefully in order to impede attacks such as the differential cryptanalysis we show in this post.

Below is the Python code which implements encryption with our toy cipher.

sbox = [12, 2, 13, 14, 3, 10, 0, 9, 5, 8, 15, 11, 4, 7, 1, 6]


def round_function(input, key):
    return sbox[key ^ input]


def encrypt(input, key0, key1):
    return round_function(input, key0) ^ key1

Security Discussion

A naive security analysis shows that it is possible to brute-force the key of the cipher. As the key consists of only 8 bits there are 256 possible keys in total.

This attack can be improved. If we have some plaintext/ciphertext pairs, then it is possible to speed up the encryption to only 16 guesses. If we guess k1, we can compute the value of k0 using the plaintext/ciphertext pair. For this, the given ciphertext can be traced back through the cipher: it is xored with the guessed k1 and the S-box is reversed. The result of this operation xor the given plaintext gives a guess for the key k0. This guess for the combined key (k0, k1) can be validated using an additional plaintext/ciphertext pair. Thus, it is necessary to only test each of the 16 possible 4 bit partial keys k1.

Consequently, for a differential attack to be considered effective, it has to require less than 16 guesses. Otherwise, a brute force attack is more efficient.

Differential Attack

The main idea of differential cryptanalysis is to trace not single values but differences through each stage of the cipher.

In our case, we define a difference of two plaintexts p and p' as p⊕p', where ⊕ denotes the bitwise xor operation. There are other more exotic definitions of differences possible, which are more suitable for other ciphers. However, in our cipher defining the difference classically as bitwise xor is sufficient for a successful attack.

Notice that encrypting a value p for a fixed key yields some value c of ciphertext. As the encryption algorithm is deterministic, encrypting p again under the same key yields the same c. Consequently, two inputs with a difference of zero yield two outputs with a difference of zero. This behavior is independent of the key.

Let us now take a look at two inputs p and p' with a fixed input difference p⊕p' for a fixed key k. What are the output differences? Encrypting p yields some ciphertext c. Encrypting p' yields some other ciphertext c'. As c and c' should have no relationship, each output difference c⊕c' should have the same probability of occurrence.

Next, we trace the operation of xoring through the cipher in order to check whether for a fixed input difference each output difference has the same probability. When encrypting p, we first xor the key k_0 into the plaintext p. An important observation is, that this xor does not affect the differences at all. If we have p⊕p' as input difference, then after xoring with the key k_0, we get the same output difference as (p⊕k_0)⊕(p'⊕k_0) = p⊕p'⊕k_0⊕k_0 = p⊕p' due to the commutativity of the xor operation. The same holds for xoring with k_1. Consequently the relation of input differences and output differences is independent of the used key. In particular, the operation of the encryption on the differences depends only on the S-box. This action of the encryption on the differences is called the differential characteristic of the S-box. Formally, a differential characteristic is usually defined as an input difference, an output difference, and the conditional probability of occurrence of that output difference, given the input difference.

Next, for each fixed input difference, we iterate over all possible pairs of plaintext with that difference and count the frequency of occurrence of the output differences. As there are 4 bit of input this results in 2^4*2^4 = 256 operations.

Even for more realistic encryption algorithms, this can usually be done in a preprocessing phase. The resulting table is called difference distribution table, as it shows the distribution of differences.

def get_difference_distribution_table():
    print("[*] Computing difference distribution table.")
    diff_dist_table = [[0 for x in range(16)] for y in range(16)]
    for in_diff in range(16):
        for input0 in range(16):
            input1 = input0 ^ in_diff
            out_diff = sbox[input0] ^ sbox[input1]
            diff_dist_table[in_diff][out_diff] = diff_dist_table[in_diff][out_diff] + 1
    return diff_dist_table


def matrix_pretty_print(matrix):
    # see https://stackoverflow.com/questions/13214809/pretty-print-2d-python-list
    s = [[str(e) for e in row] for row in matrix]
    lens = [max(map(len, col)) for col in zip(*s)]
    fmt = '  '.join('{{:{}}}'.format(x) for x in lens)
    table = [fmt.format(*row) for row in s]
    print('\n'.join(table))


diff_dist_table = get_difference_distribution_table()
matrix_pretty_print(diff_dist_table)

The resulting difference distribution table is shown below. The first 16 means that an input difference of 0 leads to an output difference of 16 for all 16 possible inputs.

The entry in row a and column b means that an input difference of a-1 leads to an output difference of b-1 for that many differences, as we index starting with zero.

As a concrete example, the value of 4 in row 2 and column 4 means that an input difference of 1 leads to an output difference of 3 for 4 of the 16 tested differences.

Remember from above, that for a perfect cipher, each output difference should have the same probability of occurrence for a fixed input difference. If that was the case, all values in the table would be 1 even those in the first row and column. However, as this cannot be the case, as an input difference of 0 always leads to an output difference of 0, a perfect cipher in this sense cannot exist.

 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15
16 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 4 2 0 0 2 0 4 0 0 0 2 2 0
0 4 0 6 0 2 0 0 0 0 2 0 2 0 0 0
0 0 4 0 0 0 2 2 0 0 4 0 0 0 2 2
0 2 0 0 0 0 0 2 2 0 0 0 0 4 2 4
0 2 2 0 2 0 2 0 0 2 2 0 2 0 2 0
0 0 0 0 4 0 0 0 0 0 0 4 4 0 4 0
0 0 2 2 0 2 0 2 2 2 0 0 0 2 0 2
0 2 2 0 0 2 0 2 0 2 2 0 0 2 0 2
0 2 0 0 4 0 4 2 2 0 0 0 0 0 2 0
0 0 2 2 2 0 2 0 2 2 0 0 2 0 2 0
0 0 0 0 0 4 0 4 0 0 0 4 0 4 0 0
0 0 4 0 0 2 2 0 4 0 0 0 2 0 0 2
0 0 0 0 0 0 4 0 0 0 0 8 0 0 0 4
0 4 0 0 2 2 0 0 0 4 0 0 2 2 0 0
0 0 0 2 0 2 0 0 4 0 6 0 2 0 0 0

There are some especially nasty values in the table. For example an input difference of 13 leads to an output difference of 11 with a probability of 1/2, or in 8 of the 16 possible cases. This is the differential characteristic which we are going to use in the remainder of this blog post.

Note that using this knowledge we can already create a distinguisher for the cipher. This is an algorithm that can distinguish our toy encryption algorithm from an ideal cipher. The distinguisher queries a chosen-plaintext oracle with two plaintexts with a difference of 13. If the output difference is 11, then there is a high probability of dealing with the toy cipher instead of an ideal cipher. Often, designing a distinguisher is a first step in attacking an algorithm. Even though distinguishing attacks are not relevant in practice they can sometimes be extended to allow for key-recovery.

In a next step, we expand our attack beyond a distinguishing attack to achieve key recovery.

Key Recovery through Differential Cryptanalysis

In this section we expand our attack to recover the whole key.

Recall from the security discussion above, that the length of the key is 8 bits, thus brute-forcing naively needs 2^8=256 encryption operations. However, we can improve that by only brute-forcing the first half of the key k_0 and compute the remaining half k_1 using basic algebra as we know one plaintext/ciphertext pair. Key guesses can then be validated using a second known plaintext/ciphertext pair. Consequently, brute forcing like this needs 2^4=16 encryption operations.

For our attack using differential cryptanalysis to make sense it needs to execute less than 16 encryption operations. The main idea is to use a differential characteristic to narrow down the key space. As differential cryptanalysis is a chosen plaintext attack, we can access an encryption oracle to get a pair of plaintexts and corresponding ciphertexts satisfying a special difference. The remaining key space is then searched for the correct key. In the following, we provide a walk through of the attack.

As we have seen above, the differential characteristic (13,11) holds with a probability of 1/2. This is the characteristic we chose for performing our attack. Still as part of the preprocessing phase, all possible input values to the S-box for which the differential characteristic (13,11) holds are computed. Note that there are 8 of these values, as can be seen in the difference distribution table at the chosen differential characteristic.

Thus, for that differential there are 8 possible intermediate values in the encryption algorithm, i.e., values that are input to the S-box but after the xor operation with the key k_0. And when given an input difference of 13 the corresponding output difference is 11 with probability 1/2. On the other hand, if we would have chosen a differential with low probability of occurrence, then there are only few intermediate values (and thus a small space of possible keys) but the probability of findings pairs where the differential characteristic holds is low.

def gen_possible_intermediate_values(input_diff, output_diff):
    good_pairs = []
    for input0 in range(16):
        input1 = input0 ^ input_diff
        if sbox[input0] ^ sbox[input1] == output_diff:
            good_pairs.append([input0, input1])
    return good_pairs


intermediate_values = gen_possible_intermediate_values(13, 11)
print("[*] Possible intermediate values: " + str(intermediate_values))

Running the code yields the possible intermediate values [0, 13], [2, 15], [4, 9], [6, 11], [9, 4], [11, 6], [13, 0], [15, 2]. Note that there are duplicates due to symmetry. Without the duplicates we have the following pairs of input: [0, 13], [2, 15], [4, 9], [6, 11]. Each of these pairs has an xor difference of 13 (i.e., 0⊕13 = 2⊕15 = ... = 13), and leads to an output difference of 11 when processed with the S-box. Further, there are no more input pairs with this property. Consequently, if there is a pair of inputs p and p' with p⊕p' = 13 and corresponding outputs c and c' with c⊕c' = 11, then the values directly before the S-box are among these 8 values.

As differential cryptanalysis is a chosen plaintext attack, the attacker can query an oracle to generate pairs of plaintext with a fixed difference and get the corresponding ciphertexts. In the code below three pairs of plaintext with a fixed difference of 13 and their corresponding ciphertexts are created. As not all plaintexts with a fixed difference of 13 lead to a pair of ciphertexts with a difference of 11 (only with probability of 1/2), we are asking the oracle for multiple pairs. The number three has been chosen experimentally. Of course, more plaintext pairs increase the chance of a ciphertext pair with the sought difference of 11.

def gen_plain_cipher_pairs(input_diff, num):
    # Generate num plaintext, ciphertext pairs with fixed input difference.
    # Remember, this is a chosen plaintext attack
    # random key which we want to recover
    key = (random.randint(0, 15), random.randint(0, 15))
    print("[*] Real key: %s %s" % (key[0], key[1]))
    pairs = []
    for input0 in random.sample(range(16), num):
        input1 = input0 ^ input_diff
        output0 = encrypt(input0, key[0], key[1])
        output1 = encrypt(input1, key[0], key[1])
        pairs.append(((input0, input1), (output0, output1)))
    return pairs


plain_cipher_pairs = gen_plain_cipher_pairs(13, 3)

Next, the ciphertext pairs are sieved and the good pairs are kept. These are the pairs, where the differential characteristic holds, i.e., the difference of the plaintexts is 13 and the output difference is 11.

def find_good_pair(plain_cipher_pairs, output_diff):
    print("[*] Searching for good pairs.")
    for ((input0, input1), (output0, output1)) in plain_cipher_pairs:
        if output0 ^ output1 == output_diff:
            return ((input0, input1), (output0, output1))
    raise Exception("No good pair found.")


((good_p0, good_p1), (good_c0, good_c1)) = find_good_pair(plain_cipher_pairs, 11)
print("[*] Found a good pair: " + str(((good_p0, good_p1), (good_c0, good_c1))))

If there are x plaintext/ciphertext pairs with an input difference of 13 then approximately x/2 of these are good pairs, i.e., they have the output difference 11. For such a good pair, the 8 possible intermediate values before and after the S-box are known, as these have been computed previously. As we know the plaintext and the possible inputs to the S-box the possible key k_0 can be computed using xor. Thus, each of these 8 intermediate values results in a guess for the key k_0. Given k_0 and the plaintext/ciphertext pair, k_1 can be computed and we receive the full key.

If a key is guessed, it can be validated using another plaintext/ciphertext pair. Note that a good pair is not necessary for validating a guess for the key. Any pair works.

def validate_key(guessed_k0, guessed_k1):
    """Checks a key against known plaintext/ciphertext pair and returns True if the key is correct."""
    for ((input0, input1), (output0, output1)) in plain_cipher_pairs:
        if encrypt(input0, guessed_k0, guessed_k1) != output0:
            return False
        if encrypt(input1, guessed_k0, guessed_k1) != output1:
            return False
    return True

All that is left to do is to compute the possible keys given the possible intermediate values before the S-box and check the guessed keys. The attacker still has to brute force multiple keys, but there are only 8 possible values instead of 16 values for a full brute-force attack.

def recover_key():
    print("[*] Brute-Forcing remaining key space")
    for (p0, p1) in intermediate_values:
        guessed_k0 = p0 ^ good_p0
        guessed_k1 = sbox[p0] ^ good_c0
        if validate_key(guessed_k0, guessed_k1):
            print("Recovered key --> %s %s" % (guessed_k0, guessed_k1))
        else:
            print("                  %s %s" % (guessed_k0, guessed_k1))


recover_key()

References

Full code

The full code is shown below, and can be downloaded here .

# Differential Cryptanalysis Toy Implementation

# Encryption is as follows: xor key0, then substitute, then xor key1
# So we have the simplest type of a sp-network (without the permutation)
# key sizes: key0=4bit, key1=4 bit, so key is 8 bit
# block length is 4 bit
# sbox width is 4 bit

import random

sbox = [12, 2, 13, 14, 3, 10, 0, 9, 5, 8, 15, 11, 4, 7, 1, 6]  # chosen by fair dice roll

# Note: fixed point of sbox[11]=11

def round_function(input, key):
    return sbox[key ^ input]


def encrypt(input, key0, key1):
    return round_function(input, key0) ^ key1


def get_difference_distribution_table():
    print("[*] Computing difference distribution table.")
    diff_dist_table = [[0 for x in range(16)] for y in range(16)]
    for in_diff in range(16):
        for input0 in range(16):
            input1 = input0 ^ in_diff
            out_diff = sbox[input0] ^ sbox[input1]
            diff_dist_table[in_diff][out_diff] = diff_dist_table[in_diff][out_diff] + 1
    return diff_dist_table


def matrix_pretty_print(matrix):
    # https://stackoverflow.com/questions/13214809/pretty-print-2d-python-list
    s = [[str(e) for e in row] for row in matrix]
    lens = [max(map(len, col)) for col in zip(*s)]
    fmt = '  '.join('{{:{}}}'.format(x) for x in lens)
    table = [fmt.format(*row) for row in s]
    print('\n'.join(table))


diff_dist_table = get_difference_distribution_table()
matrix_pretty_print(diff_dist_table)

# 16  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
# 0   0  0  4  2  0  0  2  0  4  0  0  0  2  2  0
# 0   4  0  6  0  2  0  0  0  0  2  0  2  0  0  0
# 0   0  4  0  0  0  2  2  0  0  4  0  0  0  2  2
# 0   2  0  0  0  0  0  2  2  0  0  0  0  4  2  4
# 0   2  2  0  2  0  2  0  0  2  2  0  2  0  2  0
# 0   0  0  0  4  0  0  0  0  0  0  4  4  0  4  0
# 0   0  2  2  0  2  0  2  2  2  0  0  0  2  0  2
# 0   2  2  0  0  2  0  2  0  2  2  0  0  2  0  2
# 0   2  0  0  4  0  4  2  2  0  0  0  0  0  2  0
# 0   0  2  2  2  0  2  0  2  2  0  0  2  0  2  0
# 0   0  0  0  0  4  0  4  0  0  0  4  0  4  0  0
# 0   0  4  0  0  2  2  0  4  0  0  0  2  0  0  2
# 0   0  0  0  0  0  4  0  0  0  0  8  0  0  0  4
# 0   4  0  0  2  2  0  0  0  4  0  0  2  2  0  0
# 0   0  0  2  0  2  0  0  4  0  6  0  2  0  0  0

# We see that an input difference of 13 leads to an output difference of 11
# with probability 1/2 (8/16)
# So we already built a distinguisher for the cipher.
print("[*] Choosing differential characteristic 13 -> 11")
# How? Well, we query a chosen-plaintext oracle with two plaintexts with
# difference 13. If the output difference is 11, then we probably deal
# with the cipher, instead of a random oracle.

# Next, we want to recover the key.
# Note that the key length is 8 bits, thus brute-forcing naively needs
# 2^8 steps. However, we brute-force only the first half of the key and
# compute the remaining half using basic algebra. Key guesses can then
# be validated using some known plaintext-ciphertext pair.
# Consequently, brute forcing needs 2^4=16 steps.

# Now, we use differential cryptanalysis and need less then 16
# steps. As differential cryptanalysis is a chosen-plaintext attack, we
# can access an encryption oracle.

# Now, let us compute all possible intermediate values for which the
# differential characteristic 13 -> 11 holds. This can be done in a
# pre-processing phase. Note that there are 8 intermediate values, as
# that is the probability of the differential characteristic. Thus, we
# have many intermediate values, but it is easy to find a
# plaintext-ciphertext pair for which the characteristic holds. On the
# other hand, if the probability of the differential is low, then there
# are only few possible intermediate values, but it is hard to find a
# plaintext-ciphertext pair for which the differential characteristic
# holds.

def gen_possible_intermediate_values(input_diff, output_diff):
    good_pairs = []
    for input0 in range(16):
        input1 = input0 ^ input_diff
        if sbox[input0] ^ sbox[input1] == output_diff:
            good_pairs.append([input0, input1])
    return good_pairs


intermediate_values = gen_possible_intermediate_values(13, 11)
print("[*] Possible intermediate values: " + str(intermediate_values))


def gen_plain_cipher_pairs(input_diff, num):
    # Generate num plaintext, ciphertext pairs with fixed input difference.
    # Remember, this is a chosen plaintext attack
    # random key which we want to recover
    key = (random.randint(0, 15), random.randint(0, 15))
    print("[*] Real key: %s %s" % (key[0], key[1]))
    pairs = []
    for input0 in random.sample(range(16), num):
        input1 = input0 ^ input_diff
        output0 = encrypt(input0, key[0], key[1])
        output1 = encrypt(input1, key[0], key[1])
        pairs.append(((input0, input1), (output0, output1)))
    return pairs


plain_cipher_pairs = gen_plain_cipher_pairs(13, 3)
# We are using three pairs. This should be enough, but of course more is better.

# Next, we want to only take a look at the good plaintext-ciphertext
# pairs. These are those pairs, where the differential characteristic
# holds.


def find_good_pair(plain_cipher_pairs, output_diff):
    print("[*] Searching for good pairs.")
    for ((input0, input1), (output0, output1)) in plain_cipher_pairs:
        if output0 ^ output1 == output_diff:
            return ((input0, input1), (output0, output1))
    raise Exception("No good pair found.")

# If we have num plaintext-ciphertext pairs with the input difference
# 13, then approximately num/2 of these are good pairs, i.e., they
# have the output difference 11.


((good_p0, good_p1), (good_c0, good_c1)) = find_good_pair(plain_cipher_pairs, 11)

print("[*] Found a good pair: " + str(((good_p0, good_p1), (good_c0, good_c1))))

# For such a good pair, we know the 8 possible intermediate values
# before and after the sbox. Each of these intermediate values gives us
# a guess for the key.

# If we have guessed a key, we can validate it using the other (even
# bad) plaintext-ciphertext pair or some other known
# plaintext-ciphertext pair.


def validate_key(guessed_k0, guessed_k1):
    """Checks a key against known plaintext-ciphertext pair and returns True if the key is correct."""
    for ((input0, input1), (output0, output1)) in plain_cipher_pairs:
        if encrypt(input0, guessed_k0, guessed_k1) != output0:
            return False
        if encrypt(input1, guessed_k0, guessed_k1) != output1:
            return False
    return True


# All that is left is compute the possible keys, given the possible
# intermediate values before the sbox and check the keys. Note that
# we are still bruteforcing, but we are only bruteforcing 8 values,
# instead of 16.
def recover_key():
    print("[*] Brute-Forcing remaining key space")
    for (p0, p1) in intermediate_values:
        guessed_k0 = p0 ^ good_p0
        guessed_k1 = sbox[p0] ^ good_c0
        if validate_key(guessed_k0, guessed_k1):
            print("Recovered key --> %s %s" % (guessed_k0, guessed_k1))
        else:
            print("                  %s %s" % (guessed_k0, guessed_k1))


recover_key()

~ Dr. Henning Kopp