詹惠儿

2018-12-19   阅读量: 708

数据分析师 Python数据分析

如何用Logistic回归识别手写数字?(1)

扫码加入数据分析学习群

Logistic回归是一种非常常用的统计方法,它允许我们从一组自变量中预测二进制输出。现在,我们将了解如何在PyTorch中实现这一点,PyTorch是一个非常流行的深度学习库,由Facebook开发。
现在,我们将看到如何使用PyTorch中的Logistic回归对MNIST数据集中的手写数字进行分类。首先,您需要将PyTorch安装到Python环境中。最简单的方法是使用pipconda工具。访问pytorch.org并安装您想要使用的Python解释器版本和包管理器。
安装PyTorch后,现在让我们看一下代码。写下面给出的三行来导入重新安装的库函数和对象。

import torch

import torch.nn as nn

import torchvision.datasets as dsets

import torchvision.transforms as transforms

from torch.autograd import Variable

这里,torch.nn模块包含模型所需的代码,torchvision.datasets包含MNIST数据集。它包含我们将在这里使用的手写数字的数据集。该torchvision.transforms模块包含的对象转换成其他各种方法。在这里,我们将使用它从图像转换为PyTorch张量。此外,torch.autograd模块包含Variable类以及其他类,我们将在定义我们的张量时使用它。

接下来,我们将下载数据集并将其加载到内存中。

# MNIST Dataset (Images and Labels)

train_dataset = dsets.MNIST(root ='./data',

train = True,

transform = transforms.ToTensor(),

download = True)

test_dataset = dsets.MNIST(root ='./data',

train = False,

transform = transforms.ToTensor())

# Dataset Loader (Input Pipline)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,

batch_size = batch_size,

shuffle = True)

test_loader = torch.utils.data.DataLoader(dataset = test_dataset,

batch_size = batch_size,

shuffle = False)

添加CDA认证专家【维克多阿涛】,微信号:【cdashijiazhuang】,提供数据分析指导及CDA考试秘籍。已助千人通过CDA数字化人才认证。欢迎交流,共同成长!
0.0000 0 4 关注作者 收藏

评论(0)


暂无数据

推荐课程

推荐帖子