脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|shell|

服务器之家 - 脚本之家 - Python - Pytorch实现常用乘法算子TensorRT的示例代码

Pytorch实现常用乘法算子TensorRT的示例代码

2023-02-24 11:47极智视界 Python

pytorch 用于训练,TensorRT用于推理是很多AI应用开发的标配。大家往往更加熟悉 pytorch 的算子,而不太熟悉TensorRT的算子。本文介绍了Pytorch中常用乘法的TensorRT实现,感兴趣的可以了解一下

本文介绍一下 Pytorch 中常用乘法的 TensorRT 实现。

pytorch 用于训练,TensorRT 用于推理是很多 AI 应用开发的标配。大家往往更加熟悉 pytorch 的算子,而不太熟悉 TensorRT 的算子,这里拿比较常用的乘法运算在两种框架下的实现做一个对比,可能会有更加直观一些的认识。

1.乘法运算总览

先把 pytorch 中的一些常用的乘法运算进行一个总览:

  • torch.mm:用于两个矩阵 (不包括向量) 的乘法,如维度 (m, n) 的矩阵乘以维度 (n, p) 的矩阵;
  • torch.bmm:用于带 batch 的三维向量的乘法,如维度 (b, m, n) 的矩阵乘以维度 (b, n, p) 的矩阵;
  • torch.mul:用于同维度矩阵的逐像素点相乘,也即点乘,如维度 (m, n) 的矩阵点乘维度 (m, n) 的矩阵。该方法支持广播,也即支持矩阵和元素点乘;
  • torch.mv:用于矩阵和向量的乘法,矩阵在前,向量在后,如维度 (m, n) 的矩阵乘以维度为 (n) 的向量,输出维度为 (m);
  • torch.matmul:用于两个张量相乘,或矩阵与向量乘法,作用包含 torch.mm、torch.bmm、torch.mv;
  • @:作用相当于 torch.matmul;
  • *:作用相当于 torch.mul;

如上进行了一些具体罗列,可以归纳出,常用的乘法无非两种:矩阵乘 和 点乘,所以下面分这两类进行介绍。

2.乘法算子实现

2.1矩阵乘算子实现

先来看看矩阵乘法的 pytorch 的实现 (以下实现在终端):

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
>>> import torch
>>> # torch.mm
>>> a = torch.randn(66, 99)
>>> b = torch.randn(99, 88)
>>> c = torch.mm(a, b)
>>> c.shape
torch.size([66, 88])
>>>
>>> # torch.bmm
>>> a = torch.randn(3, 66, 99)
>>> b = torch.randn(3, 99, 77)
>>> c = torch.bmm(a, b)
>>> c.shape
torch.size([3, 66, 77])
>>>
>>> # torch.mv
>>> a = torch.randn(66, 99)
>>> b = torch.randn(99)
>>> c = torch.mv(a, b)
>>> c.shape
torch.size([66])
>>>
>>> # torch.matmul
>>> a = torch.randn(32, 3, 66, 99)
>>> b = torch.randn(32, 3, 99, 55)
>>> c = torch.matmul(a, b)
>>> c.shape
torch.size([32, 3, 66, 55])
>>>
>>> # @
>>> d = a @ b
>>> d.shape
torch.size([32, 3, 66, 55])

来看 TensorRT 的实现,以上乘法都可使用 addMatrixMultiply 方法覆盖,对应 torch.matmul,先来看该方法的定义:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
//!
//! \brief Add a MatrixMultiply layer to the network.
//!
//! \param input0 The first input tensor (commonly A).
//! \param op0 The operation to apply to input0.
//! \param input1 The second input tensor (commonly B).
//! \param op1 The operation to apply to input1.
//!
//! \see IMatrixMultiplyLayer
//!
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new matrix multiply layer, or nullptr if it could not be created.
//!
IMatrixMultiplyLayer* addMatrixMultiply(
  ITensor& input0, MatrixOperation op0, ITensor& input1, MatrixOperation op1) noexcept
{
  return mImpl->addMatrixMultiply(input0, op0, input1, op1);
}

可以看到这个方法有四个传参,对应两个张量和其 operation。来看这个算子在 TensorRT 中怎么添加:

?
1
2
3
4
5
6
7
8
9
10
// 构造张量 Tensor0
nvinfer1::IConstantLayer *Constant_layer0 = m_network->addConstant(tensorShape0, value0);
// 构造张量 Tensor1
nvinfer1::IConstantLayer *Constant_layer1 = m_network->addConstant(tensorShape1, value1);
 
// 添加矩阵乘法
nvinfer1::IMatrixMultiplyLayer *Matmul_layer = m_network->addMatrixMultiply(Constant_layer0->getOutput(0), matrix0Type, Constant_layer1->getOutput(0), matrix2Type);
 
// 获取输出
matmulOutput = Matmul_layer->getOputput(0);

2.2点乘算子实现

再来看看点乘的 pytorch 的实现 (以下实现在终端):

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
>>> import torch
>>> # torch.mul
>>> a = torch.randn(66, 99)
>>> b = torch.randn(66, 99)
>>> c = torch.mul(a, b)
>>> c.shape
torch.size([66, 99])
>>> d = 0.125
>>> e = torch.mul(a, d)
>>> e.shape
torch.size([66, 99])
>>> # *
>>> f = a * b
>>> f.shape
torch.size([66, 99])

来看 TensorRT 的实现,以上乘法都可使用 addScale 方法覆盖,这在图像预处理中十分常用,先来看该方法的定义:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
//!
//! \brief Add a Scale layer to the network.
//!
//! \param input The input tensor to the layer.
//!              This tensor is required to have a minimum of 3 dimensions in implicit batch mode
//!              and a minimum of 4 dimensions in explicit batch mode.
//! \param mode The scaling mode.
//! \param shift The shift value.
//! \param scale The scale value.
//! \param power The power value.
//!
//! If the weights are available, then the size of weights are dependent on the ScaleMode.
//! For ::kUNIFORM, the number of weights equals 1.
//! For ::kCHANNEL, the number of weights equals the channel dimension.
//! For ::kELEMENTWISE, the number of weights equals the product of the last three dimensions of the input.
//!
//! \see addScaleNd
//! \see IScaleLayer
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new Scale layer, or nullptr if it could not be created.
//!
IScaleLayer* addScale(ITensor& input, ScaleMode mode, Weights shift, Weights scale, Weights power) noexcept
{
  return mImpl->addScale(input, mode, shift, scale, power);
}

 可以看到有三个模式:

  • kUNIFORM:weights 为一个值,对应张量乘一个元素;
  • kCHANNEL:weights 维度和输入张量通道的 c 维度对应,可以做一些以通道为基准的预处理;
  • kELEMENTWISE:weights 维度和输入张量的 c、h、w 对应,不考虑 batch,所以是输入的后三维;

再来看这个算子在 TensorRT 中怎么添加:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// 构造张量 input
nvinfer1::IConstantLayer *Constant_layer = m_network->addConstant(tensorShape, value);
 
// scalemode选择,kUNIFORM、kCHANNEL、kELEMENTWISE
scalemode = kUNIFORM;
 
// 构建 Weights 类型的 shift、scale、power,其中 volume 为元素数量
nvinfer1::Weights scaleShift{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scaleScale{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scalePower{nvinfer1::DataType::kFLOAT, nullptr, volume };
 
// !! 注意这里还需要对 shift、scale、power 的 values 进行赋值,若只是乘法只需要对 scale 进行赋值就行
 
// 添加张量乘法
nvinfer1::IScaleLayer *Scale_layer = m_network->addScale(Constant_layer->getOutput(0), scalemode, scaleShift, scaleScale, scalePower);
 
// 获取输出
scaleOutput = Scale_layer->getOputput(0);

有一点你可能会比较疑惑,既然是点乘,那么输入只需要两个张量就可以了,为啥这里有 input、shift、scale、power 四个张量这么多呢。解释一下,input 不用说,就是输入张量,而 shift 表示加法参数、scale 表示乘法参数、power 表示指数参数,说到这里,你应该能发现,这个函数除了我们上面讲的点乘外还有其他更加丰富的运算功能。

到此这篇关于Pytorch实现常用乘法算子TensorRT的示例代码的文章就介绍到这了,更多相关Pytorch乘法算子TensorRT内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!

原文链接:https://blog.csdn.net/weixin_42405819/article/details/125070931

延伸 · 阅读

精彩推荐
  • Pythonpython实现删除列表中某个元素的3种方法

    python实现删除列表中某个元素的3种方法

    这篇文章主要介绍了python实现删除列表中某个元素的3种方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的...

    小呆丶9652020-04-21
  • Python教大家玩转Python字符串处理的七种技巧

    教大家玩转Python字符串处理的七种技巧

    这篇文章主要给大家介绍了关于学会Python字符串处理的七种技巧,其中包括字符串的连接和合并、字符串的切片和相乘、字符串的分割、字符串的开头和结...

    菜鸟学Python4812020-09-27
  • PythonPython实现比较两个列表(list)范围

    Python实现比较两个列表(list)范围

    这篇文章主要介绍了Python实现比较两个列表(list)范围,本文根据一道题目实现解决代码,本文分别给出题目和解答源码,需要的朋友可以参考下...

    脚本之家3392020-07-15
  • PythonPython-typing: 类型标注与支持 Any类型详解

    Python-typing: 类型标注与支持 Any类型详解

    这篇文章主要介绍了Python-typing: 类型标注与支持 Any类型详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    G____G10842021-10-27
  • PythonPython可视化学习之seaborn绘制线型回归曲线

    Python可视化学习之seaborn绘制线型回归曲线

    这篇文章主要为大家介绍了如何利用seaborn绘制变量之间线型回归(linear regression)曲线,2文中涉及如下两个重要函数:seaborn.regplot和seaborn.lmplot,感兴趣的小...

    qq_214782615502022-10-13
  • Pythonpython模拟登陆网站的示例

    python模拟登陆网站的示例

    这篇文章主要介绍了python模拟登陆网站的示例,帮助大家更好的理解和学习使用python,感兴趣的朋友可以了解下...

    可爱的黑精灵11242021-09-23
  • PythonPython中函数eval和ast.literal_eval的区别详解

    Python中函数eval和ast.literal_eval的区别详解

    eval函数在Python中做数据类型的转换还是很有用的。它的作用就是把数据还原成它本身或者是能够转化成的数据类型。那么eval和ast.literal_val()的区别是什么呢...

    杰瑞268272020-12-01
  • PythonKeras搭建M2Det目标检测平台示例

    Keras搭建M2Det目标检测平台示例

    这篇文章主要为大家介绍了Keras搭建M2Det目标检测平台实现的源码示例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪...

    Bubbliiiing3912022-12-24