在机器学习中,在使用scikit学习库时,我们需要将训练好的模型保存在文件中并恢复它们以便重复使用它来将模型与其他模型进行比较,以便在新数据上测试模型。保存数据称为Serializaion,而恢复数据称为反序列化。
此外,我们处理不同类型和大小的数据。一些数据集很容易训练,即它们花费的时间较少,但是大小很大(超过1GB)的数据集即使使用GPU也可能需要很长时间才能在本地机器上进行训练。当我们在某个不同的项目中或稍后需要相同的训练数据时,为了避免浪费训练时间,存储经过训练的模型,以便将来可以随时使用。
我们可以通过以下方式在scikit中保存模型:
- Pickle string:pickle模块实现了一个基本但强大的算法,用于序列化和反序列化Python对象结构。Pickle型号提供以下功能 -
pickle.dump
要序列化对象层次结构,只需使用dump()。
pickle.load
要反序列化数据流,可以调用loads()函数。
示例:让我们在虹膜数据集上应用K最近邻,然后保存模型。
import numpy as np
# Load dataset
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
# Split dataset into train and test
X_train, X_test, y_train, y_test = \
train_test_split(X, y, test_size = 0.3,
random_state = 2018)
# import KNeighborsClassifier model
from sklearn.neighbors import KNeighborsClassifier as KNN
knn = KNN(n_neighbors = 3)
# train model
knn.fit(X_train, y_train)
使用pickle将模型保存为字符串 -
import pickle
# Save the trained model as a pickle string.
saved_model = pickle.dumps(knn)
# Load the pickled model
knn_from_pickle = pickle.loads(saved_model)
# Use the loaded pickled model to make predictions
knn_from_pickle.predict(X_test)








暂无数据