我們現(xiàn)在準備好從頭開始實施 RNN。特別是,我們將訓練此 RNN 作為字符級語言模型(參見 第 9.4 節(jié)),并按照第 9.2 節(jié)中概述的數(shù)據(jù)處理步驟,在由 HG Wells 的《時間機器》的整個文本組成的語料庫上對其進行訓練. 我們首先加載數(shù)據(jù)集。
%matplotlib inline
import math
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
%matplotlib inline
import math
import tensorflow as tf
from d2l import tensorflow as d2l
9.5.1. 循環(huán)神經(jīng)網(wǎng)絡(luò)模型
我們首先定義一個類來實現(xiàn) RNN 模型(第 9.4.2 節(jié))。請注意,隱藏單元的數(shù)量num_hiddens
是一個可調(diào)的超參數(shù)。
class RNNScratch(d2l.Module): #@save
"""The RNN model implemented from scratch."""
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
self.W_xh = nn.Parameter(
torch.randn(num_inputs, num_hiddens) * sigma)
self.W_hh = nn.Parameter(
torch.randn(num_hiddens, num_hiddens) * sigma)
self.b_h = nn.Parameter(torch.zeros(num_hiddens))
class RNNScratch(d2l.Module): #@save
"""The RNN model implemented from scratch."""
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
self.W_xh = np.random.randn(num_inputs, num_hiddens) * sigma
self.W_hh = np.random.randn(
num_hiddens, num_hiddens) * sigma
self.b_h = np.zeros(num_hiddens)
class RNNScratch(nn.Module): #@save
"""The RNN model implemented from scratch."""
num_inputs: int
num_hiddens: int
sigma: float = 0.01
def setup(self):
self.W_xh = self.param('W_xh', nn.initializers.normal(self.sigma),
(self.num_inputs, self.num_hiddens))
self.W_hh = self.param('W_hh', nn.initializers.normal(self.sigma),
(self.num_hiddens, self.num_hiddens))
self.b_h = self.param('b_h', nn.initializers.zeros, (self.num_hiddens))
class RNNScratch(d2l.Module): #@save
"""The RNN model implemented from scratch."""
def __init__(self, num_inputs, num_hiddens, sigma=0.01):
super().__init__()
self.save_hyperparameters()
self.W_xh = tf.Variable(tf.random.normal(
(num_inputs, num_hiddens)) * sigma)
self.W_hh = tf.Variable(tf.random.normal(
(num_hiddens, num_hiddens)) * sigma)
self.b_h = tf.Variable(tf.zeros(num_hiddens))
下面的方法forward
定義了如何計算任何時間步的輸出和隱藏狀態(tài),給定當前輸入和模型在前一個時間步的狀態(tài)。請注意,RNN 模型循環(huán)遍歷 的最外層維度inputs
,一次更新隱藏狀態(tài)。這里的模型使用了tanh激活函數(shù)(第 5.1.2.3 節(jié))。
@d2l.add_to_class(RNNScratch) #@save
def forward(self, inputs, state=None):
if state is None:
# Initial state with shape: (batch_size, num_hiddens)
state = torch.zeros((inputs.shape[1], self.num_hiddens),
device=inputs.device)
else:
state, = state
outputs = []
for X in inputs: # Shape of inputs: (num_steps, batch_size, num_inputs)
state = torch.tanh(torch.matmul(X, self.W_xh) +
torch.matmul(state, self.W_hh) + self.b_h)
outputs.append(state)
return outputs, state
@d2l.add_to_class(RNNScratch) #@save
def forward(self, inputs, state=None):
if state is None:
# Initial state with shape: (batch_size, num_hiddens)
state = np.zeros((inputs.shape[1], self.num_hiddens),
ctx=inputs.ctx)
else:
state, = state
outputs = []
for X in inputs: # Shape of inputs: (num_steps, batch_size, num_inputs)
state = np.tanh(np.dot(X, self.W_xh) +
np.dot(state, self.W_hh) + self.b_h)
outputs.append(state)
return outputs, state
@d2l.add_to_class(RNNScratch) #@save
def __call__(self, inputs, state=None):
if state is not None:
state, = state
outputs = []
for X in inputs: # Shape of inputs: (num_steps, batch_size, num_inputs)
state = jnp.tanh(jnp.matmul(X, self.W_xh) + (
jnp.matmul(state, self.W_hh) if state is not None else 0)
+ self.b_h)
outputs.append(state)
return outputs, state
@d2l.add_to_class(RNNScratch) #@save
def forward(self, inputs, state=None):
if state is None:
# Initial state with shape: (batch_size, num_hiddens)
state = tf.zeros((inputs.shape[1], self.num_hiddens))
else:
state, = state
state = tf.reshape(state, (-1, self.num_hiddens))
outputs = []
for X in inputs: # Shape of inputs: (num_steps, batch_size, num_inputs)
state = tf.tanh(tf.matmul(X, self.W_xh) +
tf.matmul(state, self.W_hh) + self.b_h)
outputs.append(state)
return outputs, state
我們可以將一小批輸入序列輸入 RNN 模型,如下所示。
batch_size, num_inputs, num_hiddens, num_steps = 2, 16, 32, 100
rnn = RNNScratch(num_inputs, num_hiddens)
X = jnp.ones((num_steps, batch_size, num_inputs))
(output
評論
查看更多