脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|shell|

服务器之家 - 脚本之家 - Python - 突破Pytorch核心点,CNN !!!

突破Pytorch核心点,CNN !!!

2024-01-03 13:59DOWHAT小壮 Python

创建卷积神经网络(CNN),很多初学者不太熟悉,今儿咱们来大概说说,给一个完整的案例进行说明。CNN 用于图像分类、目标检测、图像生成等任务。它的关键思想是通过卷积层和池化层来自动提取图像的特征,并通过全连接层进

哈喽,我是小壮!

创建卷积神经网络(CNN),很多初学者不太熟悉,今儿咱们来大概说说,给一个完整的案例进行说明。

CNN 用于图像分类、目标检测、图像生成等任务。它的关键思想是通过卷积层和池化层来自动提取图像的特征,并通过全连接层进行分类。

原理

1.卷积层(Convolutional Layer):

卷积层使用卷积操作从输入图像中提取特征。卷积操作涉及一个可学习的卷积核(filter/kernel),该核在输入图像上滑动,并计算滑动窗口下的点积。这有助于提取局部特征,使网络对平移不变性更强。

公式:

突破Pytorch核心点,CNN !!!

其中,x是输入,w是卷积核,b是偏置。

2.池化层(Pooling Layer):

池化层用于减小数据的空间维度,减少计算量,并提取最显著的特征。最大池化是常用的一种方式,在每个窗口中选择最大的值。

公式(最大池化):

突破Pytorch核心点,CNN !!!

3.全连接层(Fully Connected Layer):

全连接层用于将卷积和池化层提取的特征映射到输出类别。它连接到前一层的所有神经元。

实战步骤和详解

1.步骤

  • 导入必要的库和模块。
  • 定义网络结构:使用nn.Module定义一个继承自它的自定义神经网络类,定义卷积层、激活函数、池化层和全连接层。
  • 定义损失函数和优化器。
  • 加载和预处理数据。
  • 训练网络:使用训练数据迭代训练网络参数。
  • 测试网络:使用测试数据评估模型性能。

2.代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 定义卷积神经网络类
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # 卷积层1
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # 卷积层2
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        # 全连接层
        self.fc1 = nn.Linear(32 * 7 * 7, 10)  # 输入大小根据数据调整

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(-1, 32 * 7 * 7)
        x = self.fc1(x)
        return x

# 定义损失函数和优化器
net = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

# 加载和预处理数据
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

# 训练网络
num_epochs = 5
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = net(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')

# 测试网络
net.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print('Accuracy on the test set: {}%'.format(100 * accuracy))

这个示例展示了一个简单的CNN模型,使用MNIST数据集进行训练和测试。

接下来,咱们添加可视化步骤,更直观地了解模型的性能和训练过程。

可视化

1.导入matplotlib

import matplotlib.pyplot as plt

2.在训练过程中记录损失和准确率:

在训练循环中,记录每个epoch的损失和准确率。

# 在训练循环中添加以下代码
train_loss_list = []
accuracy_list = []

for epoch in range(num_epochs):
    running_loss = 0.0
    correct = 0
    total = 0

    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = net(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')

    epoch_loss = running_loss / len(train_loader)
    accuracy = correct / total

    train_loss_list.append(epoch_loss)
    accuracy_list.append(accuracy)

3.可视化损失和准确率:

# 在训练循环后,添加以下代码
plt.figure(figsize=(12, 4))

# 可视化损失
plt.subplot(1, 2, 1)
plt.plot(range(1, num_epochs + 1), train_loss_list, label='Training Loss')
plt.title('Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# 可视化准确率
plt.subplot(1, 2, 2)
plt.plot(range(1, num_epochs + 1), accuracy_list, label='Accuracy')
plt.title('Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

这样,咱们就可以在训练过程结束后看到训练损失和准确率的变化。

导入代码后,大家可以根据需要调整可视化的内容和格式。

原文地址:https://mp.weixin.qq.com/s?__biz=MzkyNzM4NzE0OA==&mid=2247483954&idx=1&sn=d4448703baa7b133db44b4b7c3b5e51b

延伸 · 阅读

精彩推荐
  • PythonPython实现i人事自动打卡的示例代码

    Python实现i人事自动打卡的示例代码

    这篇文章主要介绍了Python实现i人事自动打卡的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友...

    Alliot14622020-05-01
  • Pythonpython基于右递归解决八皇后问题的方法

    python基于右递归解决八皇后问题的方法

    这篇文章主要介绍了python基于右递归解决八皇后问题的方法,实例分析了右递归算法的相关使用技巧,需要的朋友可以参考下 ...

    小萝莉6152020-07-07
  • Pythonpython DataFrame的shift()方法的使用

    python DataFrame的shift()方法的使用

    在python数据分析中,可以使用shift()方法对DataFrame对象的数据进行位置的前滞、后滞移动,本文主要介绍了python DataFrame的shift()方法的使用,感兴趣的可以了...

    侯小啾4002022-10-29
  • PythonPython Merge函数原理及用法解析

    Python Merge函数原理及用法解析

    这篇文章主要介绍了Python Merge函数原理及用法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参...

    秋天中的一片叶22402020-09-17
  • PythonDjango REST framework 限流功能的使用

    Django REST framework 限流功能的使用

    DRF常用功能的案例基本用法都有讲解,关于限流(Throttling)这个功能其实在真实的业务场景中能真正用到的其实不算多。今天说这个话题其实一方面是讨论...

    火腿蛋炒饭4212021-12-08
  • Python在Python编程过程中用单元测试法调试代码的介绍

    在Python编程过程中用单元测试法调试代码的介绍

    这篇文章主要介绍了在Python编程过程中用单元测试法调试代码的介绍,包括使用断言等,有助于debug时的效率提升,需要的朋友可以参考下 ...

    Jeff Knupp2352020-05-28
  • Pythonpython实现查询IP地址所在地

    python实现查询IP地址所在地

    本文给大家分享的是使用Python实现根据ip138的API查询IP的地理位置的代码,非常的实用,推荐给大家,有需要的小伙伴可以参考下。 ...

    老徐的私房菜24652020-05-25
  • Python超详细,教你用python语言实现QQ机器人制作教程

    超详细,教你用python语言实现QQ机器人制作教程

    这篇文章主要介绍了如何python语言实现QQ机器人,用图文详细的描述了其中的操作步骤,非常的简单易上手,有需要的朋友可以参考下...

    ……快乐的√45602021-12-26