脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|shell|

服务器之家 - 脚本之家 - Python - 一个超强 Pytorch 操作!!!

一个超强 Pytorch 操作!!!

2024-01-02 14:28DOWHAT小壮 Python

Pytorch 同样提供了许多用于数据处理和转换的函数。今儿来看下,最重要的几个必会函数。

哈喽,我是小壮!

这几天关于深度学习的内容,已经分享了一些。

另外,类似于numpy、pandas常用数据处理函数,在Pytorch中也是同样的重要,同样的有趣!!

Pytorch同样提供了许多用于数据处理和转换的函数。

今儿来看下,最重要的几个必会函数。

一个超强 Pytorch 操作!!!

torch.Tensor

torch.Tensor 是PyTorch中最基本的数据结构,用于表示张量(tensor)。张量是多维数组,可以包含数字、布尔值等。你可以使用torch.Tensor的构造函数创建张量,也可以通过其他函数创建。

import torch

# 创建一个空的张量
empty_tensor = torch.Tensor()

# 从列表创建张量
data = [1, 2, 3, 4]
tensor_from_list = torch.Tensor(data)

torch.from_numpy

用于将NumPy数组转换为PyTorch张量。

import numpy as np

numpy_array = np.array([1, 2, 3, 4])
torch_tensor = torch.from_numpy(numpy_array)

torch.Tensor.item

用于从只包含一个元素的张量中提取Python数值。适用于标量张量。

scalar_tensor = torch.tensor(5)
scalar_value = scalar_tensor.item()

torch.Tensor.view

用于改变张量的形状。

original_tensor = torch.randn(2, 3)  # 2x3的随机张量
reshaped_tensor = original_tensor.view(3, 2)  # 将形状改变为3x2

torch.Tensor.to

用于将张量转换到指定的设备(如CPU或GPU)。

cpu_tensor = torch.randn(3)
gpu_tensor = cpu_tensor.to("cuda")  # 将张量移动到GPU

torch.Tensor.numpy

将张量转换为NumPy数组。

pytorch_tensor = torch.tensor([1, 2, 3])
numpy_array = pytorch_tensor.numpy()

torch.nn.functional.one_hot

用于对整数张量进行独热编码。

import torch.nn.functional as F

integer_tensor = torch.tensor([0, 2, 1])
one_hot_encoded = F.one_hot(integer_tensor)

torch.utils.data.Dataset和torch.utils.data.DataLoader

用于加载和处理数据集。这两个类通常与自定义的数据集类一起使用。

from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]

dataset = CustomDataset([1, 2, 3, 4, 5])
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

以上这些是PyTorch中一些重要的数据转换函数,进行了简单的使用。

它们对于处理和准备深度学习任务中的数据非常非常有帮助。

一个案例

接下来,我们制作一个图像分割的案例。

在这个案例中,我们将使用PyTorch和torchvision库进行图像分割,使用预训练的DeepLabV3模型和PASCAL VOC数据集。

在整个的代码中,涉及到上面所学到的内容,调整大小、裁剪、标准化等。

import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import matplotlib.pyplot as plt

# 下载示例图像
!wget -O example_image.jpg https://pytorch.org/assets/deeplab/deeplab1.jpg

# 定义图像转换
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 调整大小
    transforms.ToTensor(),           # 转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])

# 加载并转换图像
image_path = 'example_image.jpg'
image = Image.open(image_path).convert("RGB")
input_tensor = transform(image).unsqueeze(0)  # 添加批次维度

# 加载预训练的DeepLabV3模型
model = models.segmentation.deeplabv3_resnet101(pretrained=True)
model.eval()

# 进行图像分割
with torch.no_grad():
    output = model(input_tensor)['out'][0]
    output_predictions = output.argmax(0)

# 将预测结果转换为彩色图像
def decode_segmap(image, nc=21):
    label_colors = np.array([(0, 0, 0),  # 0: 背景
                             (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),  # 1-5: 物体
                             (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0),  # 6-9: 道路
                             (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),  # 10-13: 面部
                             (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0),  # 14-17: 植物
                             (0, 192, 0), (128, 192, 0), (0, 64, 128)])  # 18-20: 建筑

    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)

    for l in range(0, nc):
        idx = image == l
        r[idx] = label_colors[l, 0]
        g[idx] = label_colors[l, 1]
        b[idx] = label_colors[l, 2]

    rgb = np.stack([r, g, b], axis=2)
    return rgb

# 将预测结果转换为彩色图像
output_rgb = decode_segmap(output_predictions.numpy())

# 可视化原始图像和分割结果
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title('Original Image')

plt.subplot(1, 2, 2)
plt.imshow(output_rgb)
plt.title('Segmentation Result')

plt.show()

在这个案例中,我们首先定义了一系列图像转换函数,包括调整大小、转换为张量和标准化。这些转换确保输入图像满足模型的需求。

然后,加载了一个示例图像并应用了这些转换。

接下来,我们使用了torchvision中预训练的DeepLabV3模型来进行图像分割。对于输出,我们提取了预测结果的最大值索引,以获得每个像素的预测类别。

最后,我们将预测结果转换为彩色图像,并可视化原始图像和分割结果。

一个超强 Pytorch 操作!!!

这个案例强调了图像转换函数在图像分割任务中的重要作用,确保输入图像符合模型的输入要求,并且输出结果易于可视化。

原文地址:https://mp.weixin.qq.com/s?__biz=MzkyNzM4NzE0OA==&mid=2247483944&idx=1&sn=98b0c870b1a7bc18b1bad2ffa18e7287

延伸 · 阅读

精彩推荐
  • Python利用django-suit模板添加自定义的菜单、页面及设置访问权限

    利用django-suit模板添加自定义的菜单、页面及设置访问权限

    这篇文章主要给大家介绍了关于利用django-suit模板添加自定义的菜单、页面及设置访问权限的相关资料,文中通过示例代码介绍的非常详细,对大家的学习或...

    喂-不吃素的熊宝宝10882021-03-16
  • PythonDjango drf请求模块源码解析

    Django drf请求模块源码解析

    APIView中的dispatch是整个请求生命过程的核心方法,包含了请求模块,权限验证,异常模块和响应模块,我们先来介绍请求模块,对Django drf请求模块源码相关...

    Silent丿丶黑羽11812021-11-26
  • Python让你分分钟学会python条件语句

    让你分分钟学会python条件语句

    学好Python和条件语句,将方便有效提高工作效率,这篇文章主要给大家介绍了关于python条件语句的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以...

    振华OPPO8992021-12-24
  • Python成功解决python.exe无法定位程序输入点

    成功解决python.exe无法定位程序输入点

    成功解决python.exe无法找到程序入口 无法定位程序输入点_model_builder_test.py 无法定位输入点...

    丶七年先生5352023-08-28
  • Python对python numpy数组中冒号的使用方法详解

    对python numpy数组中冒号的使用方法详解

    下面小编就为大家分享一篇对python numpy数组中冒号的使用方法详解,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    pnnngchg9652021-02-01
  • Python解析django的csrf跨站请求伪造

    解析django的csrf跨站请求伪造

    本文主要介绍了解析django的csrf跨站请求伪造,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着...

    等日落3852022-08-18
  • Pythonpython实现多个视频文件合成画中画效果

    python实现多个视频文件合成画中画效果

    这篇文章主要为大家详细介绍了python实现多个视频文件合成画中画效果,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考...

    ~狼6022021-12-25
  • PythonNumpy的各种下标操作的示例代码

    Numpy的各种下标操作的示例代码

    本文主要介绍了Numpy的各种下标操作的示例代码,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下...

    DechinPhy7122022-10-12