Category: Cryptography
Difficulty: Medium (469 points)
Author: Ellen Dasnois
Description
Hi! I came up with a new way to securely generate random numbers. Can you help me pick the constants for my generator?
Challenge files
Observations
This challenge requires us to choose the $a$ and $b$ value of the PRNG class in such a way that we can recover the state. The server puts the following constraints on a and b: $a, b \in [2, p-2]$ and prime.
The PRNG has a value $p$ it uses to modulo the $a^{state}$ and $b^{state}$ output. It is stated in a comment in the source file that $p$ is a safe prime, which means that $p = 2q + 1$, with $q$ also being a prime.
This is the function responsible for choosing the next state, and the output value.
def __next__(self):
self.state = pow(self.a, self.state, self.p)
return pow(self.b, self.state, self.p) & ((1 << 1024) - 1)
So written mathematically (with $s$ as state and $o$ as output):
$s_{n+1} = a^{s_n} \mod p$
$o = b^{s_{n+1}} \mod p$, keeping only the lowest 1024 bits, since p is 1040 bits this means we lose 16 bits of data.
Solution
We will choose $a = 2$ and $b = q = (p - 1) / 2$.
We choose these values so that $a$ is the modular inverse of $-b$:
$$a \times (-b)= 2 \times \dfrac{1 - p}{2} = 1 - p \equiv 1 \pmod p$$
Which means $-b \equiv a^{-1} \pmod p \Leftrightarrow b \equiv -a^{-1} \pmod p$. Substituting this in the output equation we get $o \equiv (-a^{-1})^{s_{n+1}} \equiv (-1)^{s_{n+1}} (a^{s_{n+1}})^{-1} \pmod p$. And since $a^{s_{n+1}}$ is just the next state we get $ o \equiv \pm (s_{n+2})^{-1} \pmod p \Leftrightarrow s_{n+2} \equiv \pm o^{-1} \pmod p$. Now we can retrieve the seed from the output.
Finally we still need to get around losing the last 16 bits, but since this is only 65536 combinations we can brute-force this.
Solution script
from pwn import *
import multiprocessing
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad
# p from random_powers.py
p = 10850103099166047071520367708533912921292365462873655355082196283822802916279324461777317757320676515143805886859143391447139191952634561495685600717746893237201688818135469940378899370206984224326957101503205179052452733794736957891466434300200845471180922874825803848268956280610617869537754920453304412776538419
q = (p - 1) // 2
a = 2
b = q
r = process(['python3', 'random_powers.py'])
r.sendlineafter(b"a = ", str(a).encode())
r.sendlineafter(b"b = ", str(b).encode())
def get_random_number():
""" Function to request a random number """
r.sendlineafter(b"> ", b"1")
return int(r.recvline().decode().strip())
# Get 2 random numbers we can use for the attack
out1 = int(get_random_number())
out2 = int(get_random_number())
# Some bits are removed, so we need to brute-force them, only 65536 possibilities
mask = (1 << 1024) - 1
def check_guess(args):
"""
Worker function to test a guess for the missing top 16 bits.
"""
guess, out1, out2 = args
out1_full = (guess << 1024) | out1
if out1_full >= p:
return None
candidate1 = pow(out1_full, -1, p)
candidate2 = (-candidate1) % p
# Verify by checking if it produces the correct second output
for candidate in (candidate1, candidate2):
if (pow(b, candidate, p) & mask) == out2:
return candidate
return None
# Use multiple processes to speed this up
guesses = [(i, out1, out2) for i in range(1 << 16)]
info(f"Starting brute-force with {len(guesses)} guesses using {multiprocessing.cpu_count()} processes...")
found_s2 = None
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
for result in pool.imap_unordered(check_guess, guesses, chunksize=2048):
if result is not None:
found_s2 = result
pool.terminate()
info(f"Found s2: {found_s2}")
break
# Get the ciphertext
r.sendlineafter(b"> ", b"2")
ciphertext = bytes.fromhex(r.recvline().strip().decode())
info(f"Ciphertext: {ciphertext.hex()}")
# Predict the 3rd state and its output
s3 = pow(a, found_s2, p)
out3 = pow(b, s3, p) & mask
# Reconstruct the AES key
key = (out3 & ((1 << 128) - 1)).to_bytes(16, "big")
cipher = AES.new(key, AES.MODE_ECB)
# Decrypt the flag
flag = unpad(cipher.decrypt(ciphertext), 16)
info(f"Flag: {flag.decode()}")