from pwn import *
def findall(bytes: list[int], byte: int):
pos = enumerate(bytes)
pos = filter(lambda e: e[1] == byte, pos)
pos = map(lambda e: e[0], pos)
return list(pos)
checkTable = [
0x00, 0x00, 0x00, 0x00, 0x00, 0xa8, 0x00, 0x64, 0x00, 0x00, 0xcd, 0x00, 0x00, 0x00, 0x80, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00,
0xf1, 0x00, 0x00, 0x00, 0x63, 0x00, 0x11, 0x20, 0xa2, 0x29, 0xd7, 0x00, 0x00, 0x49, 0x00, 0xf1,
0x00, 0x46, 0x00, 0x00, 0x00, 0x00, 0x00, 0x8b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x87, 0x00, 0x00, 0x00, 0x00, 0xf8, 0x49, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x4f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xa8, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x20, 0x00, 0x07, 0x6e, 0x00, 0x00, 0xa7, 0x00, 0x00, 0x00, 0x00, 0xc2, 0x3c, 0x00, 0x00,
0x57, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xc4, 0x00, 0x00, 0x00, 0x00, 0x00,
0x94, 0x00, 0x00, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x9c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x95, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x22, 0x00, 0x29, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x5a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x41,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x9f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x86, 0x64,
0x00, 0x48, 0x00, 0x00, 0x00, 0xeb, 0x00, 0x00, 0x00, 0x1f, 0x00, 0xed, 0x00, 0x00, 0x00, 0x00,
]
trace = open("clockstep.txt").read().splitlines()
shufLoop = 0x801b
shufSkip = 0x8064
i = 0
bits = ""
while i < len(trace):
line = trace[i]
i += 1
if f"{shufLoop:04x}" not in line:
continue
while i < len(trace):
line = trace[i]
i += 1
if f"D:f0" not in line:
continue
print(i, trace[i], trace[i+1])
dest = int(trace[i+1][2:][:4], 16)
i += 1
if dest == shufSkip:
bits += "0"
else:
bits += "1"
break
secure = b""
for i in range(0, len(bits), 8):
secure += p8(int(bits[i:][:8], 2))
assert len(bits) == 0x30 * 8
print(secure)
checkXor = 0x8073
i = 0
input = b""
while i < len(trace):
line = trace[i]
i += 1
if f"{checkXor:04x}" not in line:
continue
input += p8(int(trace[i].split("D:")[1][:2], 16))
i += 1
print(input)
secure = b""
checkLoad = 0x8075
checkTableAddr = 0x8191
i = 0
j = 0
while i < len(trace):
line = trace[i]
i += 1
if f"{checkLoad:04x}" not in line:
continue
index = int(trace[i + 2][2:][:4], 16) - checkTableAddr
assert 0 <= index and index <= 255
print(trace[i + 2], trace[i + 6])
check = int(trace[i + 6].split("D:")[1][:2], 16)
i += 2
assert index in findall(checkTable, check)
secure += p8(index ^ input[j])
j += 1
print(secure)