#!/usr/bin/env python

# Copyright (C) 2021 Dr. Henning Kopp, SCHUTZWERK GmbH
#
# Permission is hereby granted, free of charge, to any person
# obtaining a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# Differential Cryptanalysis Toy Implementation
# This code contains the differential cryptanalysis of a two round SPN.

# An accompanying blog post with explanation of the code can be found
# at https://www.schutzwerk.com/en/43/posts/differential_cryptanalysis_3/

# "First, we want to establish the idea that a computer language is
# not just a way of getting a computer to perform operations but
# rather that it is a novel formal medium for expressing ideas about
# methodology. Thus, programs must be written for people to read, and
# only incidentally for machines to execute."
# - Harold Abelson, Structure and Interpretation of Computer Programs

# Two Round SPN
# block size = 9 bit, sbox size = 3 bits, keylength = 27 bits
# see 02_differential.jpg for the specification

# xor k0
# sbox, sbox, sbox
# permutation: 0, 3, 6, 1, 4, 7, 2, 5, 8
# xor k1
# sbox, sbox, sbox
# xor k2

import random

s = [3, 6, 4, 5, 1, 7, 2, 0]  # chosen by fair dice roll
s_rev = [s.index(0),
         s.index(1),
         s.index(2),
         s.index(3),
         s.index(4),
         s.index(5),
         s.index(6),
         s.index(7)]

# not the bit length of the sbox, but its range of possible input values
SBOX_RANGE = len(s)


def sbox(x):
    return s[x]


def inv_sbox(x):
    return s_rev[x]


p = [0, 3, 6, 1, 4, 7, 2, 5, 8]


def pbox(x):
    y = 0
    for i in range(len(p)):
        if (x & (1 << i)) != 0:
            y = y ^ (1 << p[i])
    return y


def demux(x):
    """
    Takes 9 Bit value to three 3 Bit values.
    demux(1) = [1,0,0]
    demux(8) = [0,1,0]
    """
    y = []
    for i in range(3):
        y.append((x >> (i * 3)) & 0x7)
    return y


def mux(x):
    """
    Takes three 3 Bit values to a 9 Bit value.
    The inverse of demux.
    mux(demux(13)) = 13
    """
    y = 0
    for i in range(3):
        y = y ^ (x[i] << (i * 3))
    return y


def round_function(input, round_key):
    return mux([sbox(k ^ p) for k, p in zip(demux(round_key), demux(input))])


def round_keys(key):
    """
    Derives the list of round keys from the main key.
    Note: rounds are enumerated starting with zero.
    """
    y = []
    for i in range(3):
        y.append((key >> (i * 9)) & 0x1ff)
    return y


def encrypt(input, key):
    if key > 2**27-1:
        raise Exception("Key too long")
    if input > 511:
        raise Exception("Plaintext block too long")
    roundkeys = round_keys(key)
    output = round_function(input, roundkeys[0])
    output = pbox(output)
    output = round_function(output, roundkeys[1])
    output = mux([(k ^ p) for k, p in zip(demux(roundkeys[2]), demux(output))])
    # the last line is similar to
    # output = output ^ roundkeys[2]
    # but we avoid endianness errors
    return output


def decrypt(input, key):
    if key > 2**27-1:
        raise Exception("Key too long")
    if input > 511:
        raise Exception("Plaintext block too long")
    roundkeys = round_keys(key)
    output = mux([(k ^ p) for k, p in zip(demux(roundkeys[2]), demux(input))])
    output = mux([inv_sbox(x) for x in demux(output)])
    output = mux([(k ^ p) for k, p in zip(demux(roundkeys[1]), demux(output))])
    output = pbox(output)  # in our case pbox = inv_pbox
    output = mux([inv_sbox(x) for x in demux(output)])
    output = mux([(k ^ p) for k, p in zip(demux(roundkeys[0]), demux(output))])
    return output


def get_difference_distribution_table():
    print("[*] Computing difference distribution table.")
    diff_dist_table = [[0 for x in range(SBOX_RANGE)]
                       for y in range(SBOX_RANGE)]
    for in_diff in range(SBOX_RANGE):
        for input0 in range(SBOX_RANGE):
            input1 = input0 ^ in_diff
            out_diff = sbox(input0) ^ sbox(input1)
            diff_dist_table[in_diff][out_diff] = diff_dist_table[in_diff][out_diff] + 1
    return diff_dist_table


def matrix_pretty_print(matrix):
    # https://stackoverflow.com/questions/13214809/pretty-print-2d-python-list
    s = [[str(e) for e in row] for row in matrix]
    lens = [max(map(len, col)) for col in zip(*s)]
    fmt = '  '.join('{{:{}}}'.format(x) for x in lens)
    table = [fmt.format(*row) for row in s]
    print('\n'.join(table))


diff_dist_table = get_difference_distribution_table()
matrix_pretty_print(diff_dist_table)

# 8  0  0  0  0  0  0  0
# 0  2  2  0  0  2  2  0
# 0  0  0  4  0  0  0  4
# 0  2  2  0  0  2  2  0
# 0  2  2  0  0  2  2  0
# 0  0  0  0  4  0  0  4
# 0  2  2  0  0  2  2  0
# 0  0  0  4  4  0  0  0

# So, the nice differentials per round are as follows:
# input diff of 2 (010) leads to output diff of 3 (011) or 7 (111) with 50% probability (4/8)
# input diff of 5 (101) leads to output diff of 4 (100) or 7 (111) with 50% probability (4/8)
# input diff of 7 (111) leads to output diff of 3 (011) or 4 (100)  with 50% probability (4/8)

# We stitch the differentials together as in 02_differential_char.jpg
# and get a differential for all rounds.
# An input difference of 16 (000 010 000) leads to an output difference
# of 511 (111 111 111) with probability of 6,25% (first round 50%,
# second round 12,5%)
# This is under the (wrong) assumption that the probabilities are
# independent. Even though this assumption is wrong it is standard in
# differential cryptanalysis.
# We only need to go to the second-but-last round, and have
# 000 010 000 -> 010 010 010 with 50% probability.


def gen_plain_cipher_pairs(input_diff, num):
    # Generate num plaintext, ciphertext pairs with fixed input difference.
    # Remember, this is a chosen plaintext attack
    # random key which we want to recover
    key = random.randint(0, 2**27-1)
    print(f"[*] Real key: {key}")
    print(f"[*] Corresponding round keys: {round_keys(key)}")

    pairs = []
    for input0 in random.sample(range(2**9-1), num):
        input1 = input0 ^ input_diff
        output0 = encrypt(input0, key)
        output1 = encrypt(input1, key)
        pairs.append(((input0, input1), (output0, output1)))
    return pairs


plain_cipher_pairs = gen_plain_cipher_pairs(16, 500)
# We use the characteristic for the whole cipher.
# Thus the input difference is 16.
# As the characteristic 16 -> 511 holds with a probability of 6.25%
# we take 500 plain_cipher_pairs. There should be some good pairs in them.
# I found out, that it works without specially selecting good pairs.
# This may be due to the fact that the last round again lowers the
# probabilities of the characteristic holding.

# We know that for a plaintext difference of 000 010 000, the
# difference before the last sbox is 010 010 010 with 50%
# probability. That means we can now try all 2**9 values for the
# last roundkey and check if the difference before the last sbox is
# 010 010 010. If it is, it is probably the correct round key.

# In my tests i had around 64 keys for the last round which had the same
# probability. Why is this the case? There is also a differential
# characteristic for a single sbox 5 -> 7 or 110 -> 010 with 50% probability.

# The effort for key-recover is thus
#  2**9 * number of pairs  for recovery of 64 possible values for k3
#  + 2**9 * 64 for bruteforcing the remaining key space.

# Why is bruteforcing the remaining keyspace only 2**9 and not 2**18?
# Well, if we have a k3 and guess a k2, then we can compute (not
# guess!) a k1. Thus, we only have to try all 2**9 values of k2,
# instead of all 2**18 values for k2 and k1.

# This is an effort of around 33.280 = 2**15 instead of
# 2**27 = 134.217.728 for a complete bruteforce.


def recover_probable_last_roundkey():
    print("[*] Brute-Forcing key of last round.")
    # count for each key, how often we had the correct difference
    count = [0 for i in range(2**9)]
    for k3 in range(2**9):  # bruteforce nine bits.
        for _, outputs in plain_cipher_pairs:
            out1 = outputs[0]
            out2 = outputs[1]
            mid1 = mux([inv_sbox(k ^ o)
                        for k, o in zip(demux(k3), demux(out1))])
            mid2 = mux([inv_sbox(k ^ o)
                        for k, o in zip(demux(k3), demux(out2))])
            if mid1 ^ mid2 == int('010010010', 2):
                count[k3] += 1
    # Multiple keys are printed. in my test i had 64
    # keys which occur 250 times.
    most_probable_keys = [key for key in range(2**9)
                          if count[key] == max(count)]
    print(f"[*] Most probable keys in last round: {most_probable_keys}")
    print(f"[*] How often these keys yielded the correct difference: {max(count)}")
    return most_probable_keys


last_round_keys = recover_probable_last_roundkey()

# After we have the key(s) of the last round, we can brute-force the remainder of the key.

def validate_key(guessed_key):
    """Checks a key against the known plaintext-ciphertext pair and returns True if the key is correct."""
    for ((input0, input1), (output0, output1)) in plain_cipher_pairs:
        if encrypt(input0, guessed_key) != output0:
            return False
        if encrypt(input1, guessed_key) != output1:
            return False
    return True


def bruteforce_remaining_keyspace(last_round_keys):
    print("[*] Brute-Forcing remaining key space.")
    # Note that we need to check with multiple plaintext-ciphertext
    # pairs. Checking with only one pair yields false positives for
    # the key.

    # Additionally, note that we do not need to bruteforce k1 and
    # k2. If we know k2, we can compute k1 by using a plaintext
    # ciphertext pair. Thereby, the effort of
    # bruteforcing the remaining keyspace is 2**9 instead of 2**18.

    ((input0, input1), (output0, output1)) = plain_cipher_pairs[0]

    for last_round_key in last_round_keys:
        print(f"Trying last round key of {last_round_key}")
        for k2 in range(2**9):
            # Compute k1
            key = (last_round_key << 18) | (k2 << 9) | 0
            k1 = decrypt(output0, key) ^ input0

            # Verify the key with other plaintext-ciphertext pairs
            key = (last_round_key << 18) | (k2 << 9) | k1
            if validate_key(key):
                print(f"Recovered key --> {key}")

        # Code for a full bruteforce of k1 and k2 for comparison purposes
        # for k1 in range(2**9):
        #     for k2 in range(2**9):
        #         key = (last_round_key << 18) | (k2 << 9) | k1
        #         if validate_key(key):
        #             print(f"Recovered key --> {key}")


bruteforce_remaining_keyspace(last_round_keys)
