文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

PyTorch中使用回调和日志记录来监控模型训练?

2024-11-29 18:48

关注

理解回调和日志记录

回调和日志记录是PyTorch中有效管理和监控机器学习模型训练过程的基本工具。

1.回调

在编程中,回调是一个作为参数传递给另一个函数的函数。这允许回调函数在调用函数的特定点执行。在PyTorch中,回调用于在训练循环的指定阶段执行操作,例如一个时期的结束或处理一个批次之后。这些阶段可以是:

回调执行的常见操作包括:

2.回调的好处

3.日志记录

日志记录是指记录软件执行过程中发生的事件。PyTorch日志记录对于监控各种指标至关重要,以理解模型随时间的性能。存储训练指标,如:

4.为什么日志记录很重要?

日志记录提供了模型训练历程的历史记录。它允许您:

在PyTorch中实现回调和日志记录

让我们逐步了解如何在PyTorch中实现一个简单的回调和日志记录系统。

步骤1:定义一个回调类

首先,我们定义一个回调类,它将在每个时期的结束时打印一条消息。

class PrintCallback:
    def on_epoch_end(self, epoch, logs):
        print(f"Epoch {epoch}: loss = {logs['loss']:.4f}, accuracy = {logs['accuracy']:.4f}")

步骤2:修改训练循环

接下来,我们修改训练循环以接受我们的回调,并在每个时期的结束时调用它。

def train_model(model, dataloader, criterion, optimizer, epochs, callbacks):
    for epoch in range(epochs):
        for batch in dataloader:
            # Training process happens here
            pass
        logs = {'loss': 0.001, 'accuracy': 0.999}  # Example metrics after an epoch
        for callback in callbacks:
            callback.on_epoch_end(epoch, logs)

步骤3:实现日志记录

对于日志记录,我们将使用Python内置的日志模块来记录训练进度。

import logging
logging.basicConfig(level=logging.INFO)

def log_metrics(epoch, logs):
    logging.info(f"Epoch {epoch}: loss = {logs['loss']:.4f}, accuracy = {logs['accuracy']:.4f}")

步骤4:将所有内容整合在一起

最后,我们创建我们的回调实例,设置记录器,并开始训练过程。

print_callback = PrintCallback()
train_model(model, dataloader, criterion, optimizer, epochs=10, callbacks=[print_callback])

在PyTorch中实现回调和日志记录

示例1:合成数据集

让我们创建一个代表我们机器人绘画的随机数字的简单数据集。我们将使用PyTorch创建随机数据点。

import torch

# Generate random data points
data = torch.rand(100, 3)  # 100 paintings, 3 colors each
labels = torch.randint(0, 2, (100,))  # Randomly label them as good (1) or bad (0)

步骤1:定义一个简单模型

现在,我们将定义一个简单的模型,尝试学习对绘画进行分类。

from torch import nn

# A simple neural network with one layer
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer = nn.Linear(3, 2)
    def forward(self, x):
        return self.layer(x)
model = SimpleModel()

步骤2:设置训练

我们将准备训练模型所需的一切。


# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# DataLoader to handle our dataset
from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=10)

步骤3:实现一个回调

我们将创建一个回调,它在每个时期后打印损失。

class PrintLossCallback:
    def on_epoch_end(self, epoch, loss):
        print(f"Epoch {epoch}: loss = {loss:.4f}")

步骤4:使用回调训练

现在,我们将训练模型并使用我们的回调。

def train(model, dataloader, criterion, optimizer, epochs, callback):
    for epoch in range(epochs):
        total_loss = 0
        for inputs, targets in dataloader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        callback.on_epoch_end(epoch, total_loss / len(dataloader))

# Create an instance of our callback
print_loss_callback = PrintLossCallback()
# Start training
train(model, dataloader, criterion, optimizer, epochs=5, callback=print_loss_callback)

输出:

Epoch 0: loss = 0.6927
Epoch 1: loss = 0.6909
Epoch 2: loss = 0.6899
Epoch 3: loss = 0.6891
Epoch 4: loss = 0.6885

步骤5:可视化训练

我们可以绘制随时间变化的损失,以可视化我们机器人的进步。

import matplotlib.pyplot as plt

losses = []  # Store the losses here
class PlotLossCallback:
    def on_epoch_end(self, epoch, loss):
        losses.append(loss)
        plt.plot(losses)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.show()
# Update our training function to use the plotting callback
plot_loss_callback = PlotLossCallback()
train(model, dataloader, criterion, optimizer, epochs=5, callback=plot_loss_callback)

输出:

示例2:公共数据集

对于第二个示例,我们将使用在线可用的真实数据集。我们将直接使用URL加载著名的鸢尾花数据集。

步骤1:加载数据集

我们将使用pandas从URL加载数据集。

import pandas as pd

# Load the Iris dataset
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
iris_data = pd.read_csv(url, header=None)

步骤2:预处理数据

我们需要将数据转换为PyTorch可以理解的格式。

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

# Encode the labels
encoder = LabelEncoder()
iris_labels = encoder.fit_transform(iris_data[4])
# Split the data
train_data, test_data, train_labels, test_labels = train_test_split(
    iris_data.iloc[:, :4].values, iris_labels, test_size=0.2, random_state=42
)
# Convert to PyTorch tensors
train_data = torch.tensor(train_data, dtype=torch.float32)
test_data = torch.tensor(test_data, dtype=torch.float32)
train_labels = torch.tensor(train_labels, dtype=torch.long)
test_labels = torch.tensor(test_labels, dtype=torch.long)
# Create DataLoaders
train_dataset = TensorDataset(train_data, train_labels)
test_dataset = TensorDataset(test_data, test_labels)
train_loader = DataLoader(train_dataset, batch_size=10)
test_loader = DataLoader(test_dataset, batch_size=10)

步骤3:为鸢尾花数据集定义一个模型

我们将为鸢尾花数据集创建一个合适的模型。

class IrisModel(nn.Module):
    def __init__(self):
        super(IrisModel, self).__init__()
        self.layer1 = nn.Linear(4, 10)
        self.layer2 = nn.Linear(10, 3)
    def forward(self, x):
        x = torch.relu(self.layer1(x))
        return self.layer2(x)
iris_model = IrisModel()

步骤4:训练模型

我们将按照之前的步骤训练这个模型。

# Assume the same training function and callbacks as before
train(iris_model, train_loader, criterion, optimizer, epochs=5, callback=plot_loss_callback)

输出:

步骤5:评估模型

最后,我们将检查我们的模型在测试数据上的表现如何。

def evaluate(model, test_loader):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    with torch.no_grad():  # No need to track gradients
        for inputs, targets in test_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == targets).sum().item()
    accuracy = correct / len(test_loader.dataset)
    print(f"Accuracy: {accuracy:.4f}")

evaluate(iris_model, test_loader)

输出:

Accuracy: 0.3333

结论

您可以通过设置回调和日志记录来进行必要的调整,获得对模型训练过程的洞察,并确保其高效学习。请记住,如果您的模型提供明确反馈,您通往训练有素的机器学习模型的道路将更加顺利。本文提供了适合初学者的代码示例和解释,让您基本掌握PyTorch中的回调和日志记录。不要犹豫尝试提供的代码。记住,实践是掌握这些主题的关键。

来源:小白玩转Python内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯