詹惠儿

2018-12-19   阅读量: 620

大数据 Python编程

如何用Logistic回归识别手写数字?(2)

扫码加入数据分析学习群

现在,我们将定义我们的超参数

# Hyper Parameters

input_size = 784

num_classes = 10

num_epochs = 5

batch_size = 100

learning_rate = 0.001

在我们的数据集中,图像大小为28 * 28。因此,我们的输入大小是784.此外,这里有10位数字,因此,我们可以有10个不同的输出。因此,我们将num_classes设置为10.此外,我们将在整个数据集上训练五次。最后,我们将分别训练小批量的100张图像,以防止因内存溢出而导致程序崩溃。

在此之后,我们将定义我们的模型如下。在这里,我们将我们的模型初始化为torch.nn.Module的子类,然后定义前向传递。在我们编写的代码中,softmax在每次正向传递期间内部计算,因此我们不需要在forward()函数内指定它。

class LogisticRegression(nn.Module):

def __init__(self, input_size, num_classes):

super(LogisticRegression, self).__init__()

self.linear = nn.Linear(input_size, num_classes)

def forward(self, x):

out = self.linear(x)

return out

定义了我们的类之后,现在我们实例化了一个对象

model = LogisticRegression(input_size, num_classes)

接下来,我们设置损失函数和优化器。在这里,我们将使用交叉熵损失,对于优化器,我们将使用随机梯度下降算法,其学习率为0.001,如上面的超参数中所定义。

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

添加CDA认证专家【维克多阿涛】,微信号:【cdashijiazhuang】,提供数据分析指导及CDA考试秘籍。已助千人通过CDA数字化人才认证。欢迎交流,共同成长!
0.0000 0 1 关注作者 收藏

评论(0)


暂无数据

推荐课程

推荐帖子