詹惠儿

2018-12-19   阅读量: 1673

数据分析师 Python数据分析

如何保存机器学习模型?

扫码加入数据分析学习群

在机器学习中,在使用scikit学习库时,我们需要将训练好的模型保存在文件中并恢复它们以便重复使用它来将模型与其他模型进行比较,以便在新数据上测试模型。保存数据称为Serializaion,而恢复数据称为反序列化
此外,我们处理不同类型和大小的数据。一些数据集很容易训练,即它们花费的时间较少,但是大小很大(超过1GB)的数据集即使使用GPU也可能需要很长时间才能在本地机器上进行训练。当我们在某个不同的项目中或稍后需要相同的训练数据时,为了避免浪费训练时间,存储经过训练的模型,以便将来可以随时使用。
我们可以通过以下方式在scikit中保存模型:

  1. Pickle string:pickle模块实现了一个基本但强大的算法,用于序列化和反序列化Python对象结构。Pickle型号提供以下功能 -

pickle.dump要序列化对象层次结构,只需使用dump()。

pickle.load要反序列化数据流,可以调用loads()函数。

示例:让我们在虹膜数据集上应用K最近邻,然后保存模型。

import numpy as np

# Load dataset

from sklearn.datasets import load_iris

iris = load_iris()

X = iris.data

y = iris.target

# Split dataset into train and test

X_train, X_test, y_train, y_test = \

train_test_split(X, y, test_size = 0.3,

random_state = 2018)

# import KNeighborsClassifier model

from sklearn.neighbors import KNeighborsClassifier as KNN

knn = KNN(n_neighbors = 3)

# train model

knn.fit(X_train, y_train)

使用pickle将模型保存为字符串 -

import pickle

# Save the trained model as a pickle string.

saved_model = pickle.dumps(knn)

# Load the pickled model

knn_from_pickle = pickle.loads(saved_model)

# Use the loaded pickled model to make predictions

knn_from_pickle.predict(X_test)

添加CDA认证专家【维克多阿涛】,微信号:【cdashijiazhuang】,提供数据分析指导及CDA考试秘籍。已助千人通过CDA数字化人才认证。欢迎交流,共同成长!
0.0000 0 1 关注作者 收藏

评论(0)


暂无数据

推荐课程

推荐帖子