登录
首页大数据时代深度学习中如何实现Keras模型加载和保存?
深度学习中如何实现Keras模型加载和保存?
2020-07-22
收藏

Keras 是源于 Theano 或 者TensorFlow 的一个深度学习框架,它的设计来源于Torch,编程语言使用的是 Python ,是一个拥有强大功能、内容抽象,而且高度模块化的神经网络库。

今天小编给大家分享的就是Keras 模型的保存与加载,希望对大家学习和使用Keras 有所帮助。

一、Keras模型保存和加载的基础介绍

Keras模型保存和加载一般是保存成hdf5格式。Keras模型主要有两种,序贯模型即Sequential、以及函数式模型Model,相对来说函数模型Model使用范围更广,序贯模型Sequential可看作是函数模型的一种特殊情况。

两类模型有一些方法是相同的:

model.summary():打印模型概况

model.get_config():返回包含模型配置信息的Python字典。

model.get_layer():依据层名或下标获得层对象

model.get_weights():返回模型权重张量的列表,类型为numpy array

model.set_weights():从numpy array里将权重载入给模型,要求数组具有与model.get_weights()相同的形状。

model.to_json:返回代表模型的JSON字符串,仅包含网络结构,不包含权值。

model.to_yaml:与model.to_json类似,同样可以从产生的YAML字符串中重构模型

model.save_weights(filepath):将模型权重保存到指定路径,文件类型是HDF5(后缀是.h5)

model.load_weights(filepath, by_name=False):从HDF5文件中加载权重到当前模型中, 默认情况下模型的结构将保持不变。如果想将权重载入不同的模型(有些层相同)中,则设置by_name=True,只有名字匹配的层才会载入权重

二、Keras模型保存和加载方式

1.保存所有状态

(1)保存模型和模型图


# 保存模型 model.save(file_path)
model_name = '{}/{}_{}_{}_v2.h5'.format(params['model_dir'],params['filters'],params['pool_size_1'],params['pool_size_2'])
model.save(model_name)

# 保存模型图
from keras.utils import plot_model
# 需要安装pip install pydot
model_plot = '{}/{}_{}_{}_v2.png'.format(params['model_dir'],params['filters'],params['pool_size_1'],params['pool_size_2'])
plot_model(model, to_file=model_plot)


(2)加载模型


from keras.models import load_model

model_path = '../docs/keras/100_2_3_v2.h5'
model = load_model(model_path)


利弊分析:

a.模型保存和加载就只需一行代码,写起来简单快捷

b.既能保存模型的结构和参数,又能保存训练配置等信息。方便我们从上次训练中断的地方再次进行训练优化。

c.占用空间过大,上传或者同步费时。

2.只保存模型结构和模型参数

(1)保存模型


import yaml
import json

# 保存模型结构到yaml文件或者json文件
yaml_string = model.to_yaml()
open('../docs/keras/model_architecture.yaml', 'w').write(yaml_string)
# json_string = model.to_json()
# open('../docs/keras/model_architecture.json', 'w').write(json_string)

# 保存模型参数到h5文件
model.save_weights('../docs/keras/model_weights.h5')


(2)加载模型


import yaml
import json
from keras.models import model_from_json
from keras.models import model_from_yaml

# 加载模型结构
model = model_from_yaml(open('../docs/keras/model_architecture.yaml').read())
# model = model_from_json(open('../docs/keras/model_architecture.json').read())

# 加载模型参数
model.load_weights('../docs/keras/model_weights.h5')


利弊分析:

a.能够节省硬盘空间,便于同步和协作

b.会丢失训练的一部分配置信息

数据分析咨询请扫描二维码

客服在线
立即咨询