脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服务器之家 - 脚本之家 - Python - PyTorch中的squeeze()和unsqueeze()解析与应用案例

PyTorch中的squeeze()和unsqueeze()解析与应用案例

2022-11-03 10:54易烊千蝈 Python

这篇文章主要介绍了PyTorch中的squeeze()和unsqueeze()解析与应用案例,文章内容介绍详细,需要的小伙伴可以参考一下,希望对你有所帮助

附上官网地址:

https://pytorch.org/docs/stable/index.html

1.torch.squeeze

PyTorch中的squeeze()和unsqueeze()解析与应用案例

squeeze的用法主要就是对数据的维度进行压缩或者解压。

先看torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的数去掉第一个维数为一的维度之后就变成(3)行。squeeze(a)就是将a中所有为1的维度删掉。不为1的维度没有影响。a.squeeze(N) 就是去掉a中指定的维数为一的维度。还有一种形式就是b=torch.squeeze(a,N) a中去掉指定的定的维数为一的维度。

换言之:

表示若第arg维的维度值为1,则去掉该维度,否则tensor不变。(即若tensor.shape()[arg] == 1,则去掉该维度)

例如:

一个维度为2x1x2x1x2的tensor,不用去想它长什么样儿,squeeze(0)就是不变,squeeze(1)就是变成2x2x1x2。(0是从最左边的维度算起的)

>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])

 

2.torch.unsqueeze

PyTorch中的squeeze()和unsqueeze()解析与应用案例

torch.unsqueeze()这个函数主要是对数据维度进行扩充。给指定位置加上维数为一的维度,比如原本有个三行的数据(3),在0的位置加了一维就变成一行三列(1,3)。a.squeeze(N) 就是在a中指定位置N加上一个维数为1的维度。还有一种形式就是b=torch.squeeze(a,N) a就是在a中指定位置N加上一个维数为1的维度。

>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1,  2,  3,  4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4]])

 

3.例子

给一个使用上述两个函数,并进行一次卷积的例子:

from torchvision.transforms import  ToTensor
import torch as t
from torch import nnimport cv2
import numpy as np
import cv2
to_tensor = ToTensor()
# 加载图像
lena = cv2.imread('lena.jpg', cv2.IMREAD_GRAYSCALE)
cv2.imshow('lena', lena)
# input = to_tensor(lena) 将ndarray转换为tensor,自动将[0,255]归一化至[0,1]。
input = to_tensor(lena).unsqueeze(0)
# 初始化卷积参数
kernel = t.ones(1, 1, 3, 3)/-9
kernel[:, :, 1, 1] = 1
conv = nn.Conv2d(1, 1, 3, 1, padding=1, bias=False)
conv.weight.data = kernel.view(1, 1, 3, 3)
# 输出
out = conv(input)
out = out.squeeze(0)
print(out.shape)
out = out.unsqueeze(3)
print(out.shape)
out = out.squeeze(0)
print(out.shape)
out = out.detach().numpy()# 缩放到0~最大值
cv2.normalize(out, out, 1.0, 0, cv2.NORM_INF)
cv2.imshow("lena-result", out)
cv2.waitKey()

结果图如下:

PyTorch中的squeeze()和unsqueeze()解析与应用案例

到此这篇关于PyTorch中的squeeze()和unsqueeze()解析与应用案例的文章就介绍到这了,更多相关squeeze()和unsqueeze()解析内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!

references:
[1] 陈云.深度学习框架之PyTorch入门与实践.北京:电子工业出版社,2018.

原文链接:https://blog.csdn.net/weixin_39490300/article/details/123464027

延伸 · 阅读

精彩推荐
  • Pythonpython基于SMTP发送QQ邮件

    python基于SMTP发送QQ邮件

    这篇文章主要为大家详细介绍了python基于SMTP发送QQ邮件,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下...

    我的名字三个字11162021-09-25
  • PythonPython实现感知器模型、两层神经网络

    Python实现感知器模型、两层神经网络

    这篇文章主要为大家详细介绍了Python实现感知器模型、两层神经网络,具有一定的参考价值,感兴趣的小伙伴们可以参考一下 ...

    O天涯海阁O4532020-12-26
  • PythonPython实现利用163邮箱远程关电脑脚本

    Python实现利用163邮箱远程关电脑脚本

    这篇文章主要为大家详细介绍了Python实现利用163邮箱远程关电脑脚本,具有一定的参考价值,感兴趣的小伙伴们可以参考一下...

    阿有耳9732021-01-16
  • Pythonpython spilt()分隔字符串的实现示例

    python spilt()分隔字符串的实现示例

    split() 方法可以实现将一个字符串按照指定的分隔符切分成多个子串,本文介绍了spilt的具体使用,感兴趣的可以了解一下...

    胡小牧5302021-11-09
  • Python分享4个方便且好用的Python自动化脚本

    分享4个方便且好用的Python自动化脚本

    自动化测试是把以人为驱动的测试行为转化为机器执行的一种过程,直白的就是为了节省人力、时间或硬件资源,提高测试效率,这篇文章主要给大家分享介绍...

    shunshunss7152022-09-19
  • Pythonpython实现单向链表详解

    python实现单向链表详解

    这篇文章主要介绍了python实现单向链表详解,分享了相关代码示例,每一步操作前都有简单分析,小编觉得还是挺不错的,具有一定借鉴价值,需要的朋友...

    过分了4602021-01-13
  • PythonPython爬虫基础之selenium库的用法总结

    Python爬虫基础之selenium库的用法总结

    今天带大家来学习selenium库的使用方法及相关知识总结,文中非常详细的介绍了selenium库,对正在学习python的小伙伴很有帮助,需要的朋友可以参考下...

    一腔诗意醉了酒7042021-11-12
  • Pythonpython实现扫描日志关键字的示例

    python实现扫描日志关键字的示例

    下面小编就为大家分享一篇python实现扫描日志关键字的示例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    大繁至简8272021-02-07