脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服务器之家 - 脚本之家 - Python - 深入学习PyTorch中LSTM的输入和输出

深入学习PyTorch中LSTM的输入和输出

2022-07-26 19:39Cyril_KI Python

这篇文章主要介绍了深入学习PyTorch中LSTM的输入和输出,文章围绕主题展开学习的内容介绍,具有一定的参考价值,需要的朋友可以参考一下,希望对你的学习有所帮助

LSTM参数

官方文档给出的解释为:

深入学习PyTorch中LSTM的输入和输出

总共有七个参数,其中只有前三个是必须的。由于大家普遍使用PyTorch的DataLoader来形成批量数据,因此batch_first也比较重要。LSTM的两个常见的应用场景为文本处理和时序预测,因此下面对每个参数我都会从这两个方面来进行具体解释。

  • input_size:在文本处理中,由于一个单词没法参与运算,因此我们得通过Word2Vec来对单词进行嵌入表示,将每一个单词表示成一个向量,此时input_size=embedding_size。比如每个句子中有五个单词,每个单词用一个100维向量来表示,那么这里input_size=100;在时间序列预测中,比如需要预测负荷,每一个负荷都是一个单独的值,都可以直接参与运算,因此并不需要将每一个负荷表示成一个向量,此时input_size=1。 但如果我们使用多变量进行预测,比如我们利用前24小时每一时刻的[负荷、风速、温度、压强、湿度、天气、节假日信息]来预测下一时刻的负荷,那么此时input_size=7
  • hidden_size:隐藏层节点个数。可以随意设置。
  • num_layers:层数。nn.LSTMCell与nn.LSTM相比,num_layers默认为1。
  • batch_first:默认为False,意义见后文。

Inputs

关于LSTM的输入,官方文档给出的定义为:

深入学习PyTorch中LSTM的输入和输出

可以看到,输入由两部分组成:input、(初始的隐状态h_0,初始的单元状态c_0)

其中input:

input(seq_len, batch_size, input_size)
  • seq_len:在文本处理中,如果一个句子有7个单词,则seq_len=7;在时间序列预测中,假设我们用前24个小时的负荷来预测下一时刻负荷,则seq_len=24。
  • batch_size:一次性输入LSTM中的样本个数。在文本处理中,可以一次性输入很多个句子;在时间序列预测中,也可以一次性输入很多条数据。
  • input_size

(h_0, c_0):

h_0(num_directions * num_layers, batch_size, hidden_size)
c_0(num_directions * num_layers, batch_size, hidden_size)

h_0和c_0的shape一致。

  • num_directions:如果是双向LSTM,则num_directions=2;否则num_directions=1。num_layers:
  • batch_size:
  • hidden_size:

 Outputs

关于LSTM的输出,官方文档给出的定义为:

深入学习PyTorch中LSTM的输入和输出

可以看到,输出也由两部分组成:otput、(隐状态h_n,单元状态c_n)

其中output的shape为:

output(seq_len, batch_size, num_directions * hidden_size)

h_n和c_n的shape保持不变,参数解释见前文。

batch_first

如果在初始化LSTM时令batch_first=True,那么input和output的shape将由:

input(seq_len, batch_size, input_size)
output(seq_len, batch_size, num_directions * hidden_size)

变为:

input(batch_size, seq_len, input_size)
output(batch_size, seq_len, num_directions * hidden_size)

即batch_size提前。

案例

简单搭建一个LSTM如下所示:

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size
        self.num_directions = 1 # 单向LSTM
        self.batch_size = batch_size
        self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
        self.linear = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input_seq):
        batch_size, seq_len = input_seq[0], input_seq[1]
        h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
        c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
        # output(batch_size, seq_len, num_directions * hidden_size)
        output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
        pred = self.linear(output)  # (5, 30, 1)
        pred = pred[:, -1, :]  # (5, 1)
        return pred

其中定义模型的代码为:

self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
self.linear = nn.Linear(self.hidden_size, self.output_size)

我们加上具体的数字:

self.lstm = nn.LSTM(self.input_size=1, self.hidden_size=64, self.num_layers=5, batch_first=True)
self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)

再看前向传播:

def forward(self, input_seq):
    batch_size, seq_len = input_seq[0], input_seq[1]
    h_0 = torch.randn(self.num_directions * self.num_layers, batch_size, self.hidden_size).to(device)
    c_0 = torch.randn(self.num_directions * self.num_layers, batch_size, self.hidden_size).to(device)
    # input(batch_size, seq_len, input_size)
    # output(batch_size, seq_len, num_directions * hidden_size)
    output, _ = self.lstm(input_seq, (h_0, c_0))  # output(5, 30, 64)
    pred = self.linear(output) # (5, 30, 1)
    pred = pred[:, -1, :]  # (5, 1)
    return pred

假设用前30个预测下一个,则seq_len=30,batch_size=5,由于设置了batch_first=True,因此,输入到LSTM中的input的shape应该为:

input(batch_size, seq_len, input_size) = input(5, 30, 1)

经过DataLoader处理后的input_seq为:

input_seq(batch_size, seq_len, input_size) = input_seq(5, 30, 1)

然后将input_seq送入LSTM:

output, _ = self.lstm(input_seq, (h_0, c_0))  # output(5, 30, 64)

根据前文,output的shape为:

output(batch_size, seq_len, num_directions * hidden_size) = output(5, 30, 64)

全连接层的定义为:

self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)

然后将output送入全连接层:

pred = self.linear(output)  # pred(5, 30, 1)

得到的预测值shape为(5, 30, 1),由于输出是输入右移,我们只需要取pred第二维度(time)中的最后一个数据:

pred = pred[:, -1, :]  # (5, 1)

这样,我们就得到了预测值,然后与label求loss,然后再反向更新参数即可。

到此这篇关于深入学习PyTorch中LSTM的输入和输出的文章就介绍到这了,更多相关PyTorch LSTM内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!

原文地址:https://blog.csdn.net/Cyril_KI/article/details/122557880

延伸 · 阅读

精彩推荐
  • Pythonpython urllib.request模块的使用详解

    python urllib.request模块的使用详解

    这篇文章主要介绍了python urllib.request模块的使用详解,帮助大家更好的理解和学习使用python,感兴趣的朋友可以了解下...

    可爱的黑精灵13662021-09-24
  • Python详解Python sys.argv使用方法

    详解Python sys.argv使用方法

    在本文中我们给大家详细讲解了关于Python sys.argv使用方法以及注意事项,有此需要的读者们跟着学习下。...

    Python教程网9382021-06-25
  • PythonPython入门_浅谈逻辑判断与运算符

    Python入门_浅谈逻辑判断与运算符

    下面小编就为大家带来一篇Python入门_浅谈逻辑判断与运算符。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧 ...

    脚本之家2972020-11-07
  • Pythonpython用opencv批量截取图像指定区域的方法

    python用opencv批量截取图像指定区域的方法

    今天小编就为大家分享一篇python用opencv批量截取图像指定区域的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    大力挥拳6342021-05-22
  • PythonPycharm以root权限运行脚本的方法

    Pycharm以root权限运行脚本的方法

    今天小编就为大家分享一篇Pycharm以root权限运行脚本的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    ShichimiyaSatone6572021-05-18
  • Python使用Python的turtle模块画图的方法

    使用Python的turtle模块画图的方法

    这篇文章主要介绍了使用Python的turtle模块画图的方法,涉及turtle简介,运动命令,画笔控制命令的分享,以及具体操作的步骤,具有一定参考价值,需要的...

    Zoctopus·Lian10902020-12-17
  • PythonPython matplotlib底层原理解析

    Python matplotlib底层原理解析

    这篇文章主要介绍了Python matplotlib底层原理,下面文章围绕Python matplotlib底层原理的相关资料展开详细内容,具有一定的参考价值,需要的朋友可以参考下...

    盆友圈的小可爱5512022-03-10
  • PythonPyCharm 2021.2 (Professional)调试远程服务器程序的操作技巧

    PyCharm 2021.2 (Professional)调试远程服务器程序的操作技巧

    本文给大家分享用 PyCharm 2021 调试远程服务器程序的过程,通过图文并茂的形式给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需...

    微拂素罗衫10452021-12-22