RNN 源码解析(tensorflow)

本文将基于tensorflow r1.5的源码,从组成RNN的细胞开始,到如何构建一个RNN,学习这整个的流程。

rnn_cell的接口声明

RNNCell类

RNNCell类继承了base_layer.Layer类,是RNN cell的一个抽象类的表示。

__call__()方法会调用父类base_layer.Layer类的__call__()方法,父类base_layer.Layer类的__call__()方法会调用call方法,所以所有继承了RNNCell的类,只需要重写自身的call方法,就可以实现__call__()的调用。

  • __call__()的输入主要是inputs, state, 输出是outputs, new_sate,通过调用这个方法,实现序列采样的前进。

  • 下面介绍几种常见的RNN cell的具体实现。

BasicRNNCell类

  • BasicRNNCell类是RNNCell类最基本的实现。

  • 构造方法中传入参数主要是num_units(神经元个数)。

  • 该类的call()方法的具体实现如下:

1
2
3
4
5
6
7
8
def call(self, inputs, state):
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""

gate_inputs = math_ops.matmul(
array_ops.concat([inputs, state], 1), self._kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
output = self._activation(gate_inputs)
return output, output

通过方法的注释可以直观的看到,每次调用call()方法时,也就是序列采样时,output = new_state = act(W input + U state + B).在计算时,将input和state串联起来一起计算,节省时间。

  • 这里要注意的是,inputs的shape是(batch_size, seq_length), output和state是同样的, shape为(batch_size, num_units), 参数W和U串联在一起的self._kernel的shape是(seq_length + num_units, num_units).

BasicLSTMCell类

  • BasicLSTMCell类是LSTM cell的基本实现。

  • 构造方法中传入参数主要是num_units(神经元个数)。

  • 该类的call()方法的具体实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def call(self, inputs, state):
"""Long short-term memory cell (LSTM).

Args:
inputs: `2-D` tensor with shape `[batch_size, input_size]`.
state: An `LSTMStateTuple` of state tensors, each shaped
`[batch_size, self.state_size]`, if `state_is_tuple` has been set to
`True`. Otherwise, a `Tensor` shaped
`[batch_size, 2 * self.state_size]`.

Returns:
A pair containing the new hidden state, and the new state (either a
`LSTMStateTuple` or a concatenated state, depending on
`state_is_tuple`).
"""
sigmoid = math_ops.sigmoid
one = constant_op.constant(1, dtype=dtypes.int32)
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)

gate_inputs = math_ops.matmul(
array_ops.concat([inputs, h], 1), self._kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)

# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(
value=gate_inputs, num_or_size_splits=4, axis=one)

forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
# Note that using `add` and `multiply` instead of `+` and `*` gives a
# performance improvement. So using those at the cost of readability.
add = math_ops.add
multiply = math_ops.multiply
new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),
multiply(sigmoid(i), self._activation(j)))
new_h = multiply(self._activation(new_c), sigmoid(o))

if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
  • 可以看到,函数输入中的state是一个tuple, 这是因为LSTM网络比一般的RNN网络多出了一个cell_state.

  • 接着计算gate_inputs, 也就是4个门分别的输入,这里是把4个门的输入的参数串联到一个参数矩阵,一次计算得到一个输入矩阵,包含了4个门的分别的不同的输入,然后把矩阵切分开来,得到了i, j, f, o这4个门的输入。

  • 根据公式计算c和h,
    c = forget_gate * c + input_gate * activation(input),
    h = activation(c) * output_gate.

  • 由于BasicLSTMCell不包含投影层,所以output就是h, new_state是(c, h).

LSTMCell类

  • 与BasicLSTMCell相比,LSTMCell类提供了可选的peep-hole链接、cell clipping以及投影层。

  • 该类的call()方法的具体实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def call(self, inputs, state):
"""Run one step of LSTM.

Args:
inputs: input Tensor, 2D, `[batch, num_units].
state: if `state_is_tuple` is False, this must be a state Tensor,
`2-D, [batch, state_size]`. If `state_is_tuple` is True, this must be a
tuple of state Tensors, both `2-D`, with column sizes `c_state` and
`m_state`.

Returns:
A tuple containing:

- A `2-D, [batch, output_dim]`, Tensor representing the output of the
LSTM after reading `inputs` when previous state was `state`.
Here output_dim is:
num_proj if num_proj was set,
num_units otherwise.
- Tensor(s) representing the new state of LSTM after reading `inputs` when
the previous state was `state`. Same type and shape(s) as `state`.

Raises:
ValueError: If input size cannot be inferred from inputs via
static shape inference.
"""
num_proj = self._num_units if self._num_proj is None else self._num_proj
sigmoid = math_ops.sigmoid

if self._state_is_tuple:
(c_prev, m_prev) = state
else:
c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])

input_size = inputs.get_shape().with_rank(2)[1]
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")

# i = input_gate, j = new_input, f = forget_gate, o = output_gate
lstm_matrix = math_ops.matmul(
array_ops.concat([inputs, m_prev], 1), self._kernel)
lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias)

i, j, f, o = array_ops.split(
value=lstm_matrix, num_or_size_splits=4, axis=1)
# Diagonal connections
if self._use_peepholes:
c = (sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
else:
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
self._activation(j))

if self._cell_clip is not None:
# pylint: disable=invalid-unary-operand-type
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
# pylint: enable=invalid-unary-operand-type
if self._use_peepholes:
m = sigmoid(o + self._w_o_diag * c) * self._activation(c)
else:
m = sigmoid(o) * self._activation(c)

if self._num_proj is not None:
m = math_ops.matmul(m, self._proj_kernel)

if self._proj_clip is not None:
# pylint: disable=invalid-unary-operand-type
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
# pylint: enable=invalid-unary-operand-type

new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else
array_ops.concat([c, m], 1))
return m, new_state
  • 和BasicLSTMCell类似,这里同样是获取state,其中包含了c_prev和m_prev, 分别表示了上一步中的cell和hidden的状态。

  • 接着计算4个门的输入,和BasicLSTMCell类似,不同点是,在这个地方,为4个门的输入张量加上了一个bias。

  • 下面是计算c和h,如果配置了使用peepholes connections,那么会在计算forget, input和output这三个门的输入上,分别串联前一个时刻的c. 如果配置了num_proj,那么会将m(hidden state)投影到配置的num_proj的维度上。

构建RNN的几种方法

  • 现在我们已经对RNN的cell有所了解了,那么该如何创建一个完整的RNN呢?

  • 我们已经知道,对RNN的cell的call()函数的循环调用,可以实现序列的采样,那么最直观的想法就是我们构建一个循环,不停的调用call()函数, 这样就得到了一个RNN。基本思想如下:

1
2
3
4
5
6
state = cell.zero_state(...)
outputs = []
for input_ in inputs:
output, state = cell(input_, state)
outputs.append(output)
return (outputs, state)
  • 在tensorflow中,已经有一些帮助方法,可以帮我们快速的构建一个RNN,比如:

  • tf.nn.static_rnn()

  • tf.nn.dynamic_rnn()

  • tf.nn.bidirectional_dynamic_rnn()

  • tf.nn.raw_rnn()

下面我们来一一学习这些帮助方法。

tf.nn.static_rnn()

tf.nn.static_rnn()的主要输入参数有5个:
cell是RNN指定的细胞;

inputs是输入,类型必须是collections.Sequence(除了string), shape为(B, T, D), 如果只有一个序列,那么可以为1维,即(D);

initial_state是给定的RNN的其实状态,如不指定,则以RNNCell.zero_state(…)来初始化,且必须要指定dtype;

sequence_length是序列的长度,如果指定了序列的长度,那么会启动动态计算,即在一个batch内,超过sequence_length的部分将不予以计算。

函数内部,首先取出inputs中第一个输入,然后检查所有的batch的size是否一致,最后得到batch_size。

tf.nn.dynamic_rnn()

tf.nn.dynamic_rnn()与tf.nn.static_rnn()有些区别,首先在图的定义上,dynamic_rnn()构造的是一个可以循环执行的图,序列的长度表示为循环的次数;而static_rnn()构造的是一个RNN的展开图,所以对于不同的batch,dynamic_rnn()可以允许不同的batch内有不同的序列长度,而statci_rnn()由于展开图的长度就是序列的长度,所以在不同的batch内必须要有相同的长度。后续会写一篇文章来详细的说明tf.nn.static_rnn()和tf.nn.dynamic_rnn()的异同及性能。

一般推荐使用dynamic_rnn()取代static_rnn().

tf.nn.bidirectional_dynamic_rnn()

bidirectional_dynamic_rnn()是dynamic_rnn()的双向版本。

为什么要有双向的RNN呢?本人认为这是RNN的结构造成的,RNN与一般的全连接神经网络的一个主要的不同是,RNN可以在不同的输入中共享参数,但是RNN是序列结构,所以只能是后面的输入共享前面的输入的信息,前面的输入不能得到后面的信息,所以的在实际的应用中(如machine translation),序列的采样仅仅从左到右或者从右到左都是会丢失一些信息,这时候就需要双向的RNN,tensorflow提供了这个帮助方法tf.nn.bidirectional_dynamic_rnn()。

序列前向输入RNN的结果为(output_fw, output_state_fw),序列后向输入RNN的结果为(output_fw, output_state_fw), 总的输出为((output_fw, output_state_fw), (output_fw, output_state_fw)).

tf.nn.raw_rnn()

tf.nn.raw_rnn()是tf.nn.dynamic_rnn()更底层的函数。dynamic_rnn()有一些限制,tf.nn.raw_rnn()可以提供更底层的控制,以便于实现一些特殊的需求,如seq2seq模型的解码等等。

函数的输入中,比较重要的是loop_fn, loop_fn也是一个函数,输入为(time, cell_output, cell_state, loop_state), 输出为(finished, next_input, next_cell_state, emit_output, next_loop_state).

loop_fn会在循环每一次结束之后调用,使我们更多的控制循环的进行,如更改循环的输出,下一个循环的输入,细胞的状态。

结束

至此,你已经大体的了解了tensorflow中RNN的具体实现。