leu=lambda b:sum(b[i]<<(i*8)for i in range(len(b))) unleu=lambda n,l:bytes(n>>i&255 for i in range(0,l,8)) p = 2**255 - 19 def decU(b): # Ignore unused bits, which is apparently bit 8 b = b[:-1] + bytes([b[-1] & 127]) return leu(b) def encU(u): u = u % p return unleu(u, 255) def decScalar(k): k = bytearray(k) k[0] &= 248 k[31] &= 127 k[31] |= 64 return leu(k) def mask(x): for i in range(255): x |= (x >> i) | x << i return x & ((1<<255)-1) def cswap(swap, x2, x3): dummy = mask(swap) & (x2 ^ x3) x2 = x2 ^ dummy x3 = x3 ^ dummy return x2, x3 def x25519(k, u): k = decScalar(k) u = decU(u) x1 = u x2 = 1 z2 = 0 x3 = u z3 = 1 swap = 0 for t in range(254, -1, -1): # [254, 0] kt = (k >> t) & 1 swap ^= kt x2, x3 = cswap(swap, x2, x3) z2, z3 = cswap(swap, z2, z3) swap = kt A = (x2 + z2) % p AA = pow(A, 2, p) B = (x2 - z2) % p BB = pow(B, 2, p) E = (AA - BB) % p C = (x3 + z3) % p D = (x3 - z3) % p DA = (D*A) % p CB = (C*B) % p x3 = pow((DA + CB) % p, 2, p) z3 = (x1 * pow((DA - CB) % p, 2, p)) % p x2 = (AA * BB) % p z2 = (E * ((AA + (121665 * E) % p) % p) % p) % p x2, x3 = cswap(swap, x2, x3) z2, z3 = cswap(swap, z2, z3) return encU((x2 * pow(z2, p - 2, p)) % p) if __name__ == '__main__': scalar = bytes.fromhex('0900000000000000000000000000000000000000000000000000000000000000') ucoord = bytes.fromhex('0900000000000000000000000000000000000000000000000000000000000000') for _ in range(1000): print(_)#debg scalar, ucoord = x25519(scalar, ucoord), scalar print(scalar.hex() == '684cf59ba83309552800ef566f2f4d3c1c3887c49360e3875f2eb94d99532c51')