• There are two other popular recurrent layer: LSTM and GRU.
  • SimpleRNN is difficult to learn long-term dependencies.
    • This is due to the vanishing gradient problem.
  • Long Short-Term Memory (LSTM)
    • proposed by Hochreiter and Schmidhuber in 1997
    • It adds a way to carry information across many timesteps.

Details of LSTM

 

LSTM diagram
Computations involved in LSTM


class LSTM(nn.Module):
    
    def __init__(selfinput_sizehidden_sizesequence_lengthnum_layersdevice):
        super(LSTM, self).__init__()
        self.device = device
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size*sequence_length, 1)
        
    def forward(selfx):
        h0 = torch.zeros(self.num_layers, x.size()[0], self.hidden_size).to(self.device)
        c0 = torch.zeros(self.num_layers, x.size()[0], self.hidden_size).to(self.device) # cell state가 추가되었다.
        out, _ = self.lstm(x, (h0, c0)) # output, (hn, cn): cell state와 hidden state만 반환 (순서쌍 형태로)
        out = out.reshape(out.shape[0], -1# <- state 추가
        out = self.fc(out)
        return out

GRU &amp;nbsp;diagram
Computations involved in GRU


class GRU(nn.Module):
    
    def __init__(selfinput_sizehidden_sizesequence_lengthnum_layersdevice):
        super(GRU, self).__init__()
        self.device = device
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size*sequence_length, 1)
        
    def forward(selfx):
        h0 = torch.zeros(self.num_layers, x.size()[0], self.hidden_size).to(self.device)
        out, _ = self.gru(x, h0)
        out = out.reshape(out.shape[0], -1# <- state 추가
        out = self.fc(out)
        return out
 

GRU 모델을 이용한 예측 그래프

'파이토치' 카테고리의 다른 글

[파이토치] Unsupervised Learning  (0) 2022.02.15
[파이토치] Transfer Learning  (0) 2022.02.14
[파이토치] RNN  (0) 2022.02.14
[파이토치] CNN  (0) 2022.02.14
[파이토치] Cross-Validation  (0) 2022.02.14

+ Recent posts