Mersenne twister 算法介绍 概述 $Mersenne\ twister$ 算法是一种周期非常长的伪随机数算法,它和传统的 $LFSR$ 算法有一些相似之处,因而在泄露足够多的伪随机数后也很容易受到攻击。该算法的主体分为初始序列生成、旋转、输出三步。
初始序列生成 和 $LFSR$ 一样,我们需要初始化一个序列 $state$ ,这个序列的长度直接关系到生成器的周期。与之不同的是,这里我们只需要一个初始种子 $seed$ 和一个固定算法 $F_1$。
在 $MT19937$ 中,该算法被描述成如下:
1 2 3 4 5 6 7 8 9 def gen_state (seed ): state=[seed]+[0 ]*(n-1 ) for i in range (1 ,n): state[i] = a * (state[i-1 ] ^ (state[i-1 ] >> 30 )) + i state[i] &= 0xffffffff return state
通过这个算法,我们可以得到 $n$ 个32比特的数。
当然了,我们也可以人为地修改里面的参数来达到某些特定目的。
旋转算法 梅森旋转算法能够构造大周期伪随机数生成器的原因就在这里,不过此处并不想分析缘由(其实是自己也看不懂仍然只是给出代码形式。这一步类似于 $LFSR$ 中由前 $x$ 比特生成下一比特的过程,只不过这里的“比特”不是 0 或 1 ,而是一个数。
1 2 3 4 5 6 7 8 9 10 def twist (state ): for i in range (0 , n): y = (state[i] & highbit) | (state[(i + 1 ) % n] & lowbit) state[i] = (y>>1 ) ^ state[(i + o) % n] if y % 2 != 0 : state[i] = state[i] ^ mask
稍微解释一下这一步的操作。首先是 $state[i]$ 的最高位以及 $state[i+1]$ 的低31位合并为一个新的数 $y$ ,根据 $y$ 的奇偶性得到两种操作,若 $y$ 是偶数,则 $state[i]=(y>>1)\bigoplus state[i+o]$ 否则 $state[i]=(y>>1)\bigoplus state[i+o] \bigoplus mask$ 。
通过矩阵(以下默认所有矩阵是在GF(2)上进行的) ,我们可以使得格式更加统一 $$ Y=[y_{31},y_{30}…,y_0] ,\ y_i\in{0,1}\ \ y二进制数组成的向量 $$
$$ STATE_i\ \ state_i的二进制数向量 $$
$$ mask=[a_{31},a_{30}…,a_0] ,\ a_i\in{0,1} \ \ mask二进制数组成的向量 $$
$$ Y=YA $$
$$ STATE_i=Y+STATE_{i+o} $$
其中 $$ A=\left[ \begin{matrix} 0 & 1 & 0 & \cdots& 0\\ 0 & 0 & 1 &\cdots & 0\\ \vdots & \vdots & \vdots & \ddots & \vdots\\ 0 & 0 & 0 & \cdots & 1\\ a_{31} & a_{30} & a_{29} & \cdots & a_0\\ \end{matrix} \right]\\ $$ 可以看出,两种模式得到的结果是一样的。
输出算法 这很有趣,那我们来看看具体操作吧
1 2 3 4 5 6 7 8 9 10 11 def next (state,i ): x = state[i] x ^= x >> 11 x ^= (x << 7 ) & b x ^= (x << 15 ) & c x ^= x >> 18 if i==n-1 : twist(state) return x
这个也很简单,看得出来,输出结果是 $state[i]$ 映射。
当我们遍历完一组 $state$ 后,需要再次旋转获得一组新的 $state$ ,否则的话后面输出的会和之前的重复。
完整代码 下面我们通过定义一个 $MT19937$ 类来结束这一部分
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 class MT19937 : def __init__ (self,seed ): self.n = 624 self.o = 397 self.a = 0x6c078965 self.b = 2636928640 self.c = 4022730752 self.highbit = 0x80000000 self.lowbit = 0x7fffffff self.mask = 0x9908b0df self.flag = 0 self.state=[seed]+[0 ]*(self.n-1 ) self.gen_state() def gen_state (self ): for i in range (1 ,self.n): self.state[i] = self.a * (self.state[i-1 ] ^ (self.state[i-1 ] >> 30 )) + i self.state[i] &= 0xffffffff self.twist() def twist (self ): for i in range (0 , self.n): y = (self.state[i] & self.highbit) + (self.state[(i + 1 ) % (self.n)] & self.lowbit) self.state[i] = (y>>1 ) ^ self.state[(i + self.o) % (self.n)] if y % 2 != 0 : self.state[i] = self.state[i] ^ self.mask def Next (self ): tmp=self.state[self.flag] tmp ^= (tmp >> 11 ) tmp ^= (tmp << 7 ) & self.b tmp ^= (tmp << 15 ) & self.c tmp ^= (tmp >> 18 ) self.flag+=1 if self.flag==self.n: self.twist() self.flag=0 return tmp def getrandomint (self ): return self .Next()
很好,现在你已经获得了 $python$ 中 $random$ 库的秘密了
破解算法 破解状态 由输出的随机数破解当前状态其实就是求解 $Next$ 的逆函数。那么我们来看, $Next$ 函数的基本形式是$tmp=tmp\bigoplus(tmp >> x)\ and\ y$,如果我们把 $tmp$ 看成是它的二进制数向量,那么容易得到, $tmp>>x$ 实际上就是 $tmp$ 右乘矩阵 $X$ $$ X=\left[ \begin{matrix} 0 & \cdots & 0 & 1 & \cdots & 0\\ \vdots & \ddots & \vdots & \vdots & \ddots & \vdots\\ 0 & \cdots & 0 & 0 & \cdots & 1\\ \vdots & \ddots & \vdots & \vdots & \ddots & \vdots\\ 0 & \dots & 0 & 0 & \dots & 0\\ \end{matrix} \right]\\ $$ 同样地,对于左移(<<),我们可以构造 $X^{‘}$ $$ X^{‘}=\left[ \begin{matrix} 0 & \cdots & 0 & 0 & \cdots & 0\\ \vdots & \ddots & \vdots & \vdots & \ddots & \vdots\\ 1 & \cdots & 0 & 0 & \cdots & 0\\ \vdots & \ddots & \vdots & \vdots & \ddots & \vdots\\ 0 & \dots & 1 & 0 & \dots & 0\\ \end{matrix} \right]\\ $$ 再进一步, $tmp\bigoplus(tmp >> x)$ 就是 $tmp\times(I+X)$ 。另一方面,运算 $a\ and\ b$ 相当于 $a\times(b\times I)$ ,因此,整一步操作相当于 $tmp$ 右乘一个矩阵 $L=(I+X)\times(b\times I )$ ,而整个 $Next$ 函数相当于执行了这样的操作3次,可以理解为 $tmp$ 右乘一个更大的矩阵 $F$ 。那么我们能够从输出的随机数逆推得到 $state$ 的充要条件就是 $F$ 可逆。事实上,这是成立的。
最后,我们需要注意我们得到的随机数在输出时是否进行了旋转。换句话说,我们得到的 $n$ 个连续输出未必就是一组 $state$ 产生的,但是如果我们希望通过 $n$ 个连续输出得到 $seed$ ,那么它们最好是由同一组 $state$ 输出的。
下面是恢复 $state$ 的代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 d= b= c= prns=[] def get_I (d ): I=[] for i in range (d): temp=[0 for _ in range (d)] temp[i]=1 I.append(temp) return matrix(GF(2 ),I) def get_right (l ): temp=[] for i in range (l): temp.append([0 for _ in range (d)]) for i in range (d-l): k=[0 for _ in range (d)] k[i]=1 temp.append(k) return matrix(GF(2 ),temp).transpose() def get_left (l ): temp=[] for i in range (l): temp.append([0 for _ in range (d)]) for i in range (d-l): k=[0 for _ in range (d)] k[i]=1 temp.append(k) return matrix(GF(2 ),temp) def get_and (a ): a=bin (a)[2 :].zfill(d) m=[] for i in range (d): k=[0 for _ in range (d)] k[i]=int (a[i]) m.append(k) return matrix(GF(2 ),m) def bin_state (state ): new_state=[] for i in state: tmp=bin (i)[2 :].zfill(32 ) new_state.append([int (tmp[_]) for _ in range (len (tmp))]) return new_state def dec_state (state ): new_state=[] for i in state: tmp='0b' for j in i: tmp+=str (j) new_state.append((int (tmp,2 ))) return new_state U=get_right(11 ) S=get_left(7 ) T=get_left(15 ) L=get_right(18 ) B=get_and(b) C=get_and(c) I=get_I(d) F=(I+U)*(I+S*B)*(I+T*C)*(I+L) prns=bin_state(prns) state=[] for i in range (len (prns)): X=matrix(GF(2 ),prns[i]) O=F.solve_left(X) state.append(O[0 ]) cur_state=dec_state(state) print (cur_state)
恢复seed 恢复 $seed$ 相当于恢复初始 $state$ ,那么我们需要求解 $twist$ 的逆函数。前文已经提到过,由 $state$ 生成 $state^{\ast}$ 的公式为 $$ state_i^{\ast}=(y\bigoplus state_{i+o)}\times A\ \ 上标\ast表示这是覆盖以后的值 $$ 矩阵 $A$ 是可逆的,于是我们求出 $(y\bigoplus state_{i+o})$ ,现在如果我们已知 $state_{i+o}$ ,那么 $y$ 就是可以知道的,也就是 $state_i$ 的最高位和 $state_{i+1}$ 的低位。此时只需要相邻两个 $state^{\ast}$ 我们就可以推知完整的一个 $state$ 。
那么问题来了,$state_{i+o}$ 在执行完 $twist$ 函数后被覆盖,原来的 $state_{i+o}$ 也是需要我们复原的,而我们复原又需要用到它,这就陷入循环了。事实并非如此。因为存在一部分的 $state$ ,它们执行 $twist$ 函数时用到的 $state_{(i+o)}$ 是被覆盖后的值(它们已经通过上一步被我们知道了),对于这些 $state$ 我们是可以利用上面的方法复原的,而后我们再利用这些复原值来反推剩下的。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 n=233 o=130 mask=0x9908f23f cur_state=[] def get_A (mask ): mask=bin (mask)[2 :].zfill(32 ) M=[int (mask[_]) for _ in range (len (mask))] A=[[0 for _ in range (32 )] for __ in range (31 )] for i in range (31 ): A[i][i+1 ]=1 A.append(M) return matrix(GF(2 ),A) A=get_A(mask) cur_state=bin_state(cur_state) for i in range (n-1 ,-1 ,-1 ): tmp=matrix(GF(2 ),cur_state[i])+matrix(GF(2 ),cur_state[(i+o)%n]) tmp=tmp*A.inverse() s=[] s.append(tmp[0 ][0 ]) tmp=matrix(GF(2 ),cur_state[i-1 ])+matrix(GF(2 ),cur_state[(i-1 +o)%n]) tmp=tmp*A.inverse() for j in range (31 ): s.append(tmp[0 ][j+1 ]) cur_state[i]=s print (dec_state(cur_state))
对比我们的计算结果和标准答案,发现只有第一个( $seed$ )是不一样的。这就很有趣,后面所有数都一样唯独第一个不同,说明我们的算法没什么太大的问题,而且,根据这两个序列所生成的新序列以及之后的所有序列都是一样的(可以自己试一下)。也即是说,我们现在已经可以预测后面的结果,但是不能复原之前的随机数。
那么接下来我们看到由 $seed$ 生成第一序列的函数。
1 2 state[i] = a * (state[i-1 ] ^ (state[i-1 ] >> 30 )) + i state[i] &= 0xffffffff
这是一个模运算。$and\ 0xffffffff$ 实际上就是取后32位,即取模 $2^{32}$ 的余数。然后我们已知 $state[1]$ ,要求 $state[0]$ ,带入上式 $$ (state[0] \bigoplus (state[0] >> 30))=state[1]*a^{-1}\ mod\ 2^{32} $$ 而 $(state[0] \bigoplus (state[0] >> 30))$ 我们之前讨论过是 $state[0]\times L$ 的形式,矩阵 $L$ 也是容易求的,因此恢复 $seed$ 只需要在之前的基础上再进一步即可。
由此我们还得到了一种攻击方法,如果我们已知初时序列中某个 $state_i$ 的值和 $i$ 的值,那么我们可以推算出整个初始序列,只需要把下面的脚本改成循环就彳亍了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 d= a= state=[] a=a.inverse_mod(2 **32 ) state[0 ]=(state[1 ]*a)%(2 **32 ) state[0 ]=bin_state(state[0 ]) I=get_I(d) R=get_right(30 ) L=(I+R) s=matrix(GF(2 ),state[0 ]) s=s*L.inverse() state[0 ]=dec_state(state[0 ]) print (state)print (standard)
对比发现二者一致。说明我们已经完全复原了 $MT19937$ 。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 d=32 n=624 o=397 mask=0x9908b0df b=2636928640 c=4022730752 prns= def get_I (d ): I=[] for i in range (d): temp=[0 for _ in range (d)] temp[i]=1 I.append(temp) return matrix(GF(2 ),I) def get_right (l ): temp=[] for i in range (l): temp.append([0 for _ in range (d)]) for i in range (d-l): k=[0 for _ in range (d)] k[i]=1 temp.append(k) return matrix(GF(2 ),temp).transpose() def get_left (l ): temp=[] for i in range (l): temp.append([0 for _ in range (d)]) for i in range (d-l): k=[0 for _ in range (d)] k[i]=1 temp.append(k) return matrix(GF(2 ),temp) def get_and (a ): a=bin (a)[2 :].zfill(d) m=[] for i in range (d): k=[0 for _ in range (d)] k[i]=int (a[i]) m.append(k) return matrix(GF(2 ),m) def get_A (mask ): mask=bin (mask)[2 :].zfill(32 ) M=[int (mask[_]) for _ in range (len (mask))] A=[[0 for _ in range (32 )] for __ in range (31 )] for i in range (31 ): A[i][i+1 ]=1 A.append(M) return matrix(GF(2 ),A) def bin_state (state ): new_state=[] for i in state: tmp=bin (i)[2 :].zfill(32 ) new_state.append([int (tmp[_]) for _ in range (len (tmp))]) return new_state def dec_state (state ): new_state=[] for i in state: tmp='0b' for j in i: tmp+=str (j) new_state.append((int (tmp,2 ))) return new_state def last (tmp,i ): a=1812433253 n=2 **32 tmp-=i a=inverse_mod(a,n) tmp=(tmp*a)%n return (tmp>>30 )^^tmp U=get_right(11 ) S=get_left(7 ) T=get_left(15 ) L=get_right(18 ) B=get_and(b) C=get_and(c) I=get_I(d) A=get_A(mask) F=(I+U)*(I+S*B)*(I+T*C)*(I+L) prns=bin_state(prns) for i in range (len (prns)): X=matrix(GF(2 ),prns[i]) O=F.solve_left(X) rand[i]=O[0 ] cur_state=dec_state(prns) print ('current state : \n' ,cur_state)cur_state=bin_state(cur_state) for i in range (n-1 ,-1 ,-1 ): tmp=matrix(GF(2 ),cur_state[i])+matrix(GF(2 ),cur_state[(i+o)%n]) tmp=tmp*A.inverse() s=[] s.append(tmp[0 ][0 ]) tmp=matrix(GF(2 ),cur_state[i-1 ])+matrix(GF(2 ),cur_state[(i-1 +o)%n]) tmp=tmp*A.inverse() for j in range (31 ): s.append(tmp[0 ][j+1 ]) cur_state[i]=s state=dec_state(cur_state) print ('state before recover the seed : \n' ,state)for i in range (len (state)-1 ,0 ,-1 ): state[i-1 ]=last(state[i],i) print ('state with the seed recovered : \n' ,state)
$python$ 脚本复原 上面这么多都是用 $sagemath$ 写的,如果没有环境用起来比较难受,下面给出一种 $python$ 脚本。
$python$ 解决这个问题的思路没有用到矩阵,可能会复杂一点,这也是为什么没有第一时间写 $python$ 脚本处理,当然直接在 $python$ 里面引入矩阵运算也可以直接套用。
首先我们看到,从随机数恢复种子这一过程中我们要反复用到异或和位移的复合运算,所以我们需要解决的首要问题就是它。
例如我们复原 $Next$ 函数时用到了 $tmp\ast=tmp\bigoplus(tmp << x)\ and\ y$ ,除了使用代数方法,我们也可以尝试理解这一过程。
显然如果我们知道 $k=(tmp << x)\ and\ y$ 就会好做很多。于是我们看到 $k$ , $k$ 的低 $x$ 位显然为 0,因此 $tmp$ 和 $tmp\ast$ 的低 $x$ 位是相同的,我们不费吹灰之力复原了 $tmp$ 的低 $x$ 位 $l$ 。
进一步, $k$ 的低 $2x$ 位是 $l<<x\ and\ y$ ,里面的所有参数都是已知的。接下来我们只需要反复迭代这一过程即可。对于左移运算也是同理。
综合来看,每次操作能为我们恢复 $tmp$ 的 $x$ 位,要想完全恢复 $tmp$ ,我们至少需要操作 $(32/x)+1$ 次。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 def inverse_right (prn, shift, bits=32 ): tmp = prn for i in range (bits // shift+1 ): tmp = prn ^ tmp >> shift return tmp def inverse_right_mask (prn, shift, mask, bits=32 ): tmp = prn for i in range (bits // shift+1 ): tmp = prn ^ tmp >> shift & mask return tmp def inverse_left (prn, shift, bits=32 ): tmp = prn for i in range (bits // shift+1 ): tmp = prn ^ tmp << shift return tmp def inverse_left_mask (prn, shift, mask, bits=32 ): tmp = prn for i in range (bits // shift+1 ): tmp = prn ^ tmp << shift & mask return tmp def inverse_Next (prns ): for i in prns: i = inverse_right(i,18 ) i = inverse_left_mask(i,15c) i = inverse_left_mask(i,7 ,b) i = inverse_right(i,11 ) i=i&0xffffffff return prns
$ps$ :不知道为啥循环次数少一点也可以恢复。
然后是根据 $cur_state$ 恢复 $state_0$ 。
这一步和 $sagemath$ 的方法类似,唯一不同的是此处没有矩阵 $A$ 了。这里的问题在于如何确定我们的结果异或了 $mask$ ,而$mask$ 的比特长度是我们的突破口。
$twist$ 函数在之前被描述成 “首先是 $state[i]$ 的最高位以及 $state[i+1]$ 的低31位合并为一个新的数 $y$ ,根据 $y$ 的奇偶性得到两种操作,若 $y$ 是偶数,则 $state[i]=(y>>1)\bigoplus state[i+o]$ 否则 $state[i]=(y>>1)\bigoplus state[i+o] \bigoplus mask$ 。”
我们先对 $state[i]$ 异或上 $state[i+o]$ ,得到 $(y>>1)$ 或者 $(y>>1)\bigoplus mask$ ,而 $(y>>1)$ 的比特长度肯定是31。如果我们有 $mask$ 的比特长度为32,那么异或了 $mask$ 的结果一定也是32位,而另一种没有异或的一定只有31位。那么我们根据这个差别可以判断出旋转时是否用到了 $mask$ 。这样我们就做了矩阵 $A$ 的工作,只是没有它那么简洁。(如果 $mask$ 不是32位,那就阿巴阿巴了)
剩下的步骤就完全照搬之前的就行。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 def inverse_twist (cur_state ): high = 0x80000000 low = 0x7fffffff mask = for i in range (623 ,-1 ,-1 ): tmp = cur_state[i]^cur_state[(i+o)%n] if tmp & high == high: tmp ^= mask tmp <<= 1 tmp |= 1 else : tmp <<=1 res = tmp&high tmp = cur_state[i-1 ]^cur_state[(i+o)%n] if tmp & high == high: tmp ^= mask tmp <<= 1 tmp |= 1 else : tmp <<=1 res |= (tmp)&low cur_state[i] = res return cur_state
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 def inverse_right (prn, shift, bits=32 ): tmp = prn for i in range (bits // shift+1 ): tmp = prn ^ tmp >> shift return tmp def inverse_right_mask (prn, shift, mask, bits=32 ): tmp = prn for i in range (bits // shift+1 ): tmp = prn ^ tmp >> shift & mask return tmp def inverse_left (prn, shift, bits=32 ): tmp = prn for i in range (bits // shift+1 ): tmp = prn ^ tmp << shift return tmp def inverse_left_mask (prn, shift, mask, bits=32 ): tmp = prn for i in range (bits // shift+1 ): tmp = prn ^ tmp << shift & mask return tmp def inverse_Next (prns, b=2636928640 , c=4022730752 ): for i in range (len (prns)): prns[i] = inverse_right(prns[i], 18 ) prns[i] = inverse_left_mask(prns[i], 15 , c) prns[i] = inverse_left_mask(prns[i], 7 , b) prns[i] = inverse_right(prns[i], 11 ) prns[i] = prns[i] & 0xffffffff return prns def inverse_twist (cur_state,mask=0x9908b0df ,n=624 ,o=397 ): high = 0x80000000 low = 0x7fffffff for i in range (623 ,-1 ,-1 ): tmp = cur_state[i]^cur_state[(i+o)%n] if tmp & high == high: tmp ^= mask tmp <<= 1 tmp |= 1 else : tmp <<=1 res = tmp&high tmp = cur_state[i-1 ]^cur_state[(i-1 +o)%n] if tmp & high == high: tmp ^= mask tmp <<= 1 tmp |= 1 else : tmp <<=1 res |= (tmp)&low cur_state[i] = res return cur_state def recover_last (state,i,a=0x6c078965 ): n=2 **32 state-=i a=inverse(a,n) state=(state*a)%n return (state>>30 )^state def main (): prns=[] cur_state=inverse_Next(prns) print ("cur_state:\n" ,cur_state) state=inverse_twist(cur_state) print ("state without seed recovery:\n" ,state) seed=recover_last(state[1 ],1 ) print ("seed:\n" ,seed) main()