脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服务器之家 - 脚本之家 - Python - Pytorch中torch.stack()函数的深入解析

Pytorch中torch.stack()函数的深入解析

2022-08-31 10:50cv_lhp Python

在pytorch中常见的拼接函数主要是两个,分别是:stack()和cat(),下面这篇文章主要给大家介绍了关于Pytorch中torch.stack()函数的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下

一. torch.stack()函数解析

1. 函数说明:

1.1 官网:torch.stack(),函数定义及参数说明如下图所示:

Pytorch中torch.stack()函数的深入解析

1.2 函数功能

沿一个新维度对输入一系列张量进行连接,序列中所有张量应为相同形状,stack 函数返回的结果会新增一个维度。也即是把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度上面进行堆叠。

1.3 参数列表

  • tensors :为一系列输入张量,类型为turple和List
  • dim :新增维度的(下标)位置,当dim = -1时默认最后一个维度;范围必须介于 0 到输入张量的维数之间,默认是dim=0,在第0维进行连接
  • 返回值:输出新增维度后的张量

2. 代码举例

2.1 dim = 0 : 在第0维进行连接,相当于在行上进行组合(输入张量为一维,输出张量为两维)

?
1
2
3
4
5
6
7
8
import torch
#二维输入张量a,b
a = torch.tensor([1, 2, 3])
b = torch.tensor([11, 22, 33])
c = torch.stack([a, b],dim=0)#在第0维进行连接,相当于在行上进行组合(输入张量为一维,输出张量为两维)
print(a)
print(b)
print(c)

输出结果如下:
tensor([1, 2, 3])
tensor([11, 22, 33])
tensor([[ 1,  2,  3],
        [11, 22, 33]])

2.2 dim = 1 :在第1维进行连接,相当于在对应行上面对列元素进行组合(输入张量为一维,输出张量为两维)

?
1
2
3
4
5
6
7
8
import torch
#二维输入张量a,b
a = torch.tensor([1, 2, 3])
b = torch.tensor([11, 22, 33])
c = torch.stack([a, b],dim=1)#在第1维进行连接,相当于在对应行上面对列元素进行组合(输入张量为一维,输出张量为两维)
print(a)
print(b)
print(c)

输出结果如下:
tensor([1, 2, 3])
tensor([11, 22, 33])
tensor([[ 1, 11],
        [ 2, 22],
        [ 3, 33]])

2.3 dim=0:表示在第0维进行连接,相当于在通道维度上进行组合(输入张量为两维,输出张量为三维),注意:此处输入张量维度为二维,因此dim最大只能为2。

?
1
2
3
4
5
6
7
8
import torch
#二维输入张量a,b
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b],dim=0)#在第0维进行连接,相当于在通道维度上进行组合(输入张量为两维,输出张量为三维)
print(a)
print(b)
print(c)

输出结果如下所示:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[11, 22, 33],
         [44, 55, 66],
         [77, 88, 99]]])

2.4 dim=1:表示在第1维进行连接,相当于对相应通道中每个行进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。

?
1
2
3
4
5
6
7
8
import torch
#二维输入张量a,b
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 1)#在第1维进行连接,相当于对相应通道中每个行进行组合
print(a)
print(b)
print(c)

输出结果如下所示:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
tensor([[[ 1,  2,  3],
         [11, 22, 33]],

        [[ 4,  5,  6],
         [44, 55, 66]],

        [[ 7,  8,  9],
         [77, 88, 99]]])

2.5 dim=2:表示在第2维进行连接,相当于对相应行中每个列元素进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。

?
1
2
3
4
5
6
7
8
import torch
#二维输入张量a,b
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 2)#在第2维进行连接,相当于对相应行中每个列元素进行组合
print(a)
print(b)
print(c)

输出结果如下所示:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
tensor([[[ 1, 11],
         [ 2, 22],
         [ 3, 33]],

        [[ 4, 44],
         [ 5, 55],
         [ 6, 66]],

        [[ 7, 77],
         [ 8, 88],
         [ 9, 99]]])

2.6 dim=3:表示在第3维进行连接,相当于对相应行中每个列元素进行组合(输入维度大小为3维,因此dim=3最后一维始终代表为列),注意:此处输入张量维度为三维,因此dim最大只能为3。

?
1
2
3
4
5
6
7
8
import torch
#三维输入张量a,b
a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])
b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])
c = torch.stack([a, b], 3)#表示在第3维进行连接,相当于对相应行中每个列元素进行组合(最后一维是第三维,始终代表为列)
print(a)
print(b)
print(c)

输出结果如下所示:
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]])
tensor([[[ 11,  22,  33],
         [ 44,  55,  66],
         [ 77,  88,  99]],

        [[110, 220, 330],
         [440, 550, 660],
         [770, 880, 990]]])
tensor([[[[  1,  11],
          [  2,  22],
          [  3,  33]],

         [[  4,  44],
          [  5,  55],
          [  6,  66]],

         [[  7,  77],
          [  8,  88],
          [  9,  99]]],


        [[[ 10, 110],
          [ 20, 220],
          [ 30, 330]],

         [[ 40, 440],
          [ 50, 550],
          [ 60, 660]],

         [[ 70, 770],
          [ 80, 880],
          [ 90, 990]]]])

2.7 dim=4 (错误维度:因为此处输入张量维度为三维,所以dim最大只能为3,此处维度为4,因此会报错)

?
1
2
3
4
5
6
7
8
import torch
#三维输入张量a,b
a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])
b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])
c = torch.stack([a, b], 4)
print(a)
print(b)
print(c)

输出错误:
IndexError: Dimension out of range (expected to be in range of [-4, 3], but got 4)

总结

到此这篇关于Pytorch中torch.stack()函数的文章就介绍到这了,更多相关Pytorch torch.stack()函数内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!

原文链接:https://blog.csdn.net/flyingluohaipeng/article/details/125034358

延伸 · 阅读

精彩推荐
  • Pythonpython通过pil模块将raw图片转换成png图片的方法

    python通过pil模块将raw图片转换成png图片的方法

    这篇文章主要介绍了python通过pil模块将raw图片转换成png图片的方法,实例分析了Python中pil模块的使用技巧,并Image.fromstring函数进行了较为详尽的分析说明,需要...

    疯狂一夏12162019-12-03
  • PythonPython实现八皇后问题示例代码

    Python实现八皇后问题示例代码

    这篇文章主要给大家介绍了关于利用Python实现八皇后问题的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值...

    马一特4482021-04-27
  • Pythonpython中多个装饰器的执行顺序详解

    python中多个装饰器的执行顺序详解

    装饰器是程序开发中经常会用到的一个功能,也是python语言开发的基础知识。这篇文章主要介绍了python中多个装饰器的执行顺序详解,小编觉得挺不错的,...

    wyzane8412021-04-05
  • Python详解Python GUI编程之PyQt5入门到实战

    详解Python GUI编程之PyQt5入门到实战

    这篇文章主要介绍了详解Python GUI编程之PyQt5入门到实战,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友...

    Erics-19965612021-08-12
  • Pythonpython 爬虫请求模块requests详解

    python 爬虫请求模块requests详解

    这篇文章主要介绍了python 爬虫请求模块requests详解,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下...

    码农一号已就位12592021-08-08
  • Python对python的输出和输出格式详解

    对python的输出和输出格式详解

    今天小编就为大家分享一篇对python的输出和输出格式详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    石头里蹦出的猴子5432021-04-26
  • PythonPython做文本按行去重的实现方法

    Python做文本按行去重的实现方法

    每行在promotion后面包含一些数字,如果这些数字是相同的,则认为是相同的行,对于相同的行,只保留一行。接下来通过本文给大家介绍Python做文本按行去...

    aaa1111sss9082020-09-10
  • Python对python调用RPC接口的实例详解

    对python调用RPC接口的实例详解

    今天小编就为大家分享一篇对python调用RPC接口的实例详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    天枢10742021-05-11