RNN 源码解析(tensorflow)
本文将基于tensorflow r1.5的源码,从组成RNN的细胞开始,到如何构建一个RNN,学习这整个的流程。
rnn_cell的接口声明
- tensorflow在tensorflow/python/ops/rnn_cell.py中声明了一个RNNCell基类和几个基本的RNN cell类,这些类的实现在tensorflow/python/ops/rnn_cell_impl.py中。接下来,我们具体来看看这些类的实现。
RNNCell类
RNNCell类继承了base_layer.Layer类,是RNN cell的一个抽象类的表示。
__call__()方法会调用父类base_layer.Layer类的__call__()方法,父类base_layer.Layer类的__call__()方法会调用call方法,所以所有继承了RNNCell的类,只需要重写自身的call方法,就可以实现__call__()的调用。
__call__()的输入主要是inputs, state, 输出是outputs, new_sate,通过调用这个方法,实现序列采样的前进。
下面介绍几种常见的RNN cell的具体实现。
BasicRNNCell类
BasicRNNCell类是RNNCell类最基本的实现。
构造方法中传入参数主要是num_units(神经元个数)。
该类的call()方法的具体实现如下:
1 | def call(self, inputs, state): |
通过方法的注释可以直观的看到,每次调用call()方法时,也就是序列采样时,output = new_state = act(W input + U state + B).在计算时,将input和state串联起来一起计算,节省时间。
- 这里要注意的是,inputs的shape是(batch_size, seq_length), output和state是同样的, shape为(batch_size, num_units), 参数W和U串联在一起的self._kernel的shape是(seq_length + num_units, num_units).
BasicLSTMCell类
BasicLSTMCell类是LSTM cell的基本实现。
构造方法中传入参数主要是num_units(神经元个数)。
该类的call()方法的具体实现如下:
1 | def call(self, inputs, state): |
可以看到,函数输入中的state是一个tuple, 这是因为LSTM网络比一般的RNN网络多出了一个cell_state.
接着计算gate_inputs, 也就是4个门分别的输入,这里是把4个门的输入的参数串联到一个参数矩阵,一次计算得到一个输入矩阵,包含了4个门的分别的不同的输入,然后把矩阵切分开来,得到了i, j, f, o这4个门的输入。
根据公式计算c和h,
c = forget_gate * c + input_gate * activation(input),
h = activation(c) * output_gate.由于BasicLSTMCell不包含投影层,所以output就是h, new_state是(c, h).
LSTMCell类
与BasicLSTMCell相比,LSTMCell类提供了可选的peep-hole链接、cell clipping以及投影层。
该类的call()方法的具体实现如下:
1 | def call(self, inputs, state): |
和BasicLSTMCell类似,这里同样是获取state,其中包含了c_prev和m_prev, 分别表示了上一步中的cell和hidden的状态。
接着计算4个门的输入,和BasicLSTMCell类似,不同点是,在这个地方,为4个门的输入张量加上了一个bias。
下面是计算c和h,如果配置了使用peepholes connections,那么会在计算forget, input和output这三个门的输入上,分别串联前一个时刻的c. 如果配置了num_proj,那么会将m(hidden state)投影到配置的num_proj的维度上。
构建RNN的几种方法
现在我们已经对RNN的cell有所了解了,那么该如何创建一个完整的RNN呢?
我们已经知道,对RNN的cell的call()函数的循环调用,可以实现序列的采样,那么最直观的想法就是我们构建一个循环,不停的调用call()函数, 这样就得到了一个RNN。基本思想如下:
1 | state = cell.zero_state(...) |
在tensorflow中,已经有一些帮助方法,可以帮我们快速的构建一个RNN,比如:
tf.nn.static_rnn()
tf.nn.dynamic_rnn()
tf.nn.bidirectional_dynamic_rnn()
tf.nn.raw_rnn()
下面我们来一一学习这些帮助方法。
tf.nn.static_rnn()
tf.nn.static_rnn()的主要输入参数有5个:
cell是RNN指定的细胞;
inputs是输入,类型必须是collections.Sequence(除了string), shape为(B, T, D), 如果只有一个序列,那么可以为1维,即(D);
initial_state是给定的RNN的其实状态,如不指定,则以RNNCell.zero_state(…)来初始化,且必须要指定dtype;
sequence_length是序列的长度,如果指定了序列的长度,那么会启动动态计算,即在一个batch内,超过sequence_length的部分将不予以计算。
函数内部,首先取出inputs中第一个输入,然后检查所有的batch的size是否一致,最后得到batch_size。
tf.nn.dynamic_rnn()
tf.nn.dynamic_rnn()与tf.nn.static_rnn()有些区别,首先在图的定义上,dynamic_rnn()构造的是一个可以循环执行的图,序列的长度表示为循环的次数;而static_rnn()构造的是一个RNN的展开图,所以对于不同的batch,dynamic_rnn()可以允许不同的batch内有不同的序列长度,而statci_rnn()由于展开图的长度就是序列的长度,所以在不同的batch内必须要有相同的长度。后续会写一篇文章来详细的说明tf.nn.static_rnn()和tf.nn.dynamic_rnn()的异同及性能。
一般推荐使用dynamic_rnn()取代static_rnn().
tf.nn.bidirectional_dynamic_rnn()
bidirectional_dynamic_rnn()是dynamic_rnn()的双向版本。
为什么要有双向的RNN呢?本人认为这是RNN的结构造成的,RNN与一般的全连接神经网络的一个主要的不同是,RNN可以在不同的输入中共享参数,但是RNN是序列结构,所以只能是后面的输入共享前面的输入的信息,前面的输入不能得到后面的信息,所以的在实际的应用中(如machine translation),序列的采样仅仅从左到右或者从右到左都是会丢失一些信息,这时候就需要双向的RNN,tensorflow提供了这个帮助方法tf.nn.bidirectional_dynamic_rnn()。
序列前向输入RNN的结果为(output_fw, output_state_fw),序列后向输入RNN的结果为(output_fw, output_state_fw), 总的输出为((output_fw, output_state_fw), (output_fw, output_state_fw)).
tf.nn.raw_rnn()
tf.nn.raw_rnn()是tf.nn.dynamic_rnn()更底层的函数。dynamic_rnn()有一些限制,tf.nn.raw_rnn()可以提供更底层的控制,以便于实现一些特殊的需求,如seq2seq模型的解码等等。
函数的输入中,比较重要的是loop_fn, loop_fn也是一个函数,输入为(time, cell_output, cell_state, loop_state), 输出为(finished, next_input, next_cell_state, emit_output, next_loop_state).
loop_fn会在循环每一次结束之后调用,使我们更多的控制循环的进行,如更改循环的输出,下一个循环的输入,细胞的状态。
结束
至此,你已经大体的了解了tensorflow中RNN的具体实现。