国科网

2025-06-03 09:27:29  星期二
立足国科融媒,服务先进科技
机器学习 | PyTorch简明教程下篇

点赞

0
发布时间:2023年11月02日 浏览量:212次 所属栏目:人工智能 发布者:田佳恬

接着上篇《PyTorch简明教程上篇》,继续学习多层感知机,卷积神经网络和LSTMNet。

1、多层感知机

多层感知机通过在网络中加入一个或多个隐藏层来克服线性模型的限制,是一个简单的神经网络,也是深度学习的重要基础,具体图如下:

import numpy as np
import torch
from torch.autograd import Variable
from torch import optim
from data_util import load_mnist

def build_model(input_dim, output_dim):
    return torch.nn.Sequential(
        torch.nn.Linear(input_dim, 512, bias=False),
        torch.nn.ReLU(),
        torch.nn.Dropout(0.2),
        torch.nn.Linear(512, 512, bias=False),
        torch.nn.ReLU(),
        torch.nn.Dropout(0.2),
        torch.nn.Linear(512, output_dim, bias=False),
    )

def train(model, loss, optimizer, x_val, y_val):
    model.train()
    optimizer.zero_grad()
    fx = model.forward(x_val)
    output = loss.forward(fx, y_val)
    output.backward()
    optimizer.step()
    return output.item()

def predict(model, x_val):
    model.eval()
    output = model.forward(x_val)
    return output.data.numpy().argmax(axis=1)

def main():
    torch.manual_seed(42)
    trX, teX, trY, teY = load_mnist(notallow=False)
    trX = torch.from_numpy(trX).float()
    teX = torch.from_numpy(teX).float()
    trY = torch.tensor(trY)

    n_examples, n_features = trX.size()
    n_classes = 10
    model = build_model(n_features, n_classes)
    loss = torch.nn.CrossEntropyLoss(reductinotallow='mean')
    optimizer = optim.Adam(model.parameters())
    batch_size = 100

    for i in range(100):
        cost = 0.
        num_batches = n_examples // batch_size
        for k in range(num_batches):
            start, end = k * batch_size, (k + 1) * batch_size
            cost += train(model, loss, optimizer,
                          trX[start:end], trY[start:end])
        predY = predict(model, teX)
        print("Epoch %d, cost = %f, acc = %.2f%%"
              % (i + 1, cost / num_batches, 100. * np.mean(predY == teY)))

if __name__ == "__main__":
    main()

分享说明:转发分享请注明出处。

    相关图讯
    网站简介  |   联系我们  |   广告服务  |   监督电话
    本网站由国科网运营维护 国科网讯(北京)技术有限公司版权所有  咨询电话:010-88516927
    地址:北京市海淀区阜石路甲69号院1号楼1层一单元114
    ICP备案号:京ICP备15066964号-8   违法和不良信息举报电话:010-67196565
    12300电信用户申诉受理中心   网络违法犯罪举报网站   中国互联网举报中心   12321网络不良与垃圾信息举报中心   12318全国文化市场举报网站
    代理域名注册服务机构:阿里巴巴云计算(北京)有限公司