京公网安备 11010802034615号
经营许可证编号:京B2-20210330
文章来源:DeepHub IMBA
作者: P**nHub兄弟网站
学习如何通过剪枝来使你的模型变得更小
剪枝是一种模型优化技术,这种技术可以消除权重张量中不必要的值。这将会得到更小的模型,并且模型精度非常接近标准模型。
在本文中,我们将通过一个例子来观察剪枝技术对最终模型大小和预测误差的影响。
我们的第一步导入一些工具、包:
最后,初始化TensorBoard,这样就可以将模型可视化:
import os import zipfile import tensorflow as tf import tensorflow_model_optimization as tfmot from tensorflow.keras.models import load_model from tensorflow import keras %load_ext tensorboard
在这个实验中,我们将使用scikit-learn生成一个回归数据集。之后,我们将数据集分解为训练集和测试集:
from sklearn.datasets import make_friedman1 X, y = make_friedman1(n_samples=10000, n_features=10, random_state=0) from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
我们将创建一个简单的神经网络来预测目标变量y,然后检查均值平方误差。在此之后,我们将把它与修剪过的整个模型进行比较,然后只与修剪过的Dense层进行比较。
接下来,在30个训练轮次之后,一旦模型停止改进,我们就使用回调来停止训练它。
early_stop = keras.callbacks.EarlyStopping(monitor=’val_loss’, patience=30)
我们打印出模型概述,以便与运用剪枝技术的模型概述进行比较。
model = setup_model() model.summary()
让我们编译模型并训练它。
tf.keras.utils.plot_model( model, to_file=”model.png”, show_shapes=True, show_layer_names=True, rankdir=”TB”, expand_nested=True, dpi=96, )
现在检查一下均方误差。我们可以继续到下一节,看看当我们修剪整个模型时,这个误差是如何变化的。
from sklearn.metrics import mean_squared_error predictions = model.predict(X_test) print(‘Without Pruning MSE %.4f’ % mean_squared_error(y_test,predictions.reshape(3300,))) Without Pruning MSE 0.0201
当把模型部署到资源受限的边缘设备(如手机)时,剪枝等优化模型技术尤其重要。
我们将上面的MSE与修剪整个模型得到的MSE进行比较。第一步是定义剪枝参数。权重剪枝是基于数量级的。这意味着在训练过程中一些权重被转换为零。模型变得稀疏,这样就更容易压缩。由于可以跳过零,稀疏模型还可以加快推理速度。
预期的参数是剪枝计划、块大小和块池类型。
from tensorflow_model_optimization.sparsity.keras import ConstantSparsity
pruning_params = {
'pruning_schedule': ConstantSparsity(0.5, 0),
'block_size': (1, 1),
'block_pooling_type': 'AVG'
}
现在,我们可以应用我们的剪枝参数来修剪整个模型。
from tensorflow_model_optimization.sparsity.keras import prune_low_magnitude model_to_prune = prune_low_magnitude( keras.Sequential([ tf.keras.layers.Dense(128, activation='relu', input_shape=(X_train.shape[1],)), tf.keras.layers.Dense(1, activation='relu') ]), **pruning_params)
我们检查模型概述。将其与未剪枝模型的模型进行比较。从下图中我们可以看到整个模型已经被剪枝 —— 我们将很快看到剪枝一个稠密层后模型概述的区别。
model_to_prune.summary()
在TF中,我们必须先编译模型,然后才能将其用于训练集和测试集。
model_to_prune.compile(optimizer=’adam’, loss=tf.keras.losses.mean_squared_error, metrics=[‘mae’, ‘mse’])
由于我们正在使用剪枝技术,所以除了早期停止回调函数之外,我们还必须定义两个剪枝回调函数。我们定义一个记录模型的文件夹,然后创建一个带有回调函数的列表。
tfmot.sparsity.keras.UpdatePruningStep()
使用优化器步骤更新剪枝包装器。如果未能指定剪枝包装器,将会导致错误。
tfmot.sparsity.keras.PruningSummaries()
将剪枝概述添加到Tensorboard。
log_dir = ‘.models’ callbacks = [ tfmot.sparsity.keras.UpdatePruningStep(), # Log sparsity and other metrics in Tensorboard. tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir), keras.callbacks.EarlyStopping(monitor=’val_loss’, patience=10) ]
有了这些,我们现在就可以将模型与训练集相匹配了。
model_to_prune.fit(X_train,y_train,epochs=100,validation_split=0.2,callbacks=callbacks,verbose=0)
在检查这个模型的均方误差时,我们注意到它比未剪枝模型的均方误差略高。
prune_predictions = model_to_prune.predict(X_test) print(‘Whole Model Pruned MSE %.4f’ % mean_squared_error(y_test,prune_predictions.reshape(3300,))) Whole Model Pruned MSE 0.1830
现在让我们实现相同的模型,但这一次,我们将只剪枝稠密层。请注意在剪枝计划中使用多项式衰退函数。
from tensorflow_model_optimization.sparsity.keras import PolynomialDecay
layer_pruning_params = {
'pruning_schedule': PolynomialDecay(initial_sparsity=0.2,
final_sparsity=0.8, begin_step=1000, end_step=2000),
'block_size': (2, 3),
'block_pooling_type': 'MAX'
}
model_layer_prunning = keras.Sequential([
prune_low_magnitude(tf.keras.layers.Dense(128, activation='relu',input_shape=(X_train.shape[1],)),
**layer_pruning_params),
tf.keras.layers.Dense(1, activation='relu')
])
从概述中我们可以看到只有第一个稠密层将被剪枝。
model_layer_prunning.summary()
然后我们编译并拟合模型。
model_layer_prunning.compile(optimizer=’adam’, loss=tf.keras.losses.mean_squared_error, metrics=[‘mae’, ‘mse’]) model_layer_prunning.fit(X_train,y_train,epochs=300,validation_split=0.1,callbacks=callbacks,verbose=0)
现在,让我们检查均方误差。
layer_prune_predictions = model_layer_prunning.predict(X_test) print(‘Layer Prunned MSE %.4f’ % mean_squared_error(y_test,layer_prune_predictions.reshape(3300,))) Layer Prunned MSE 0.1388
由于我们使用了不同的剪枝参数,所以我们无法将这里获得的MSE与之前的MSE进行比较。如果您想比较它们,那么请确保剪枝参数是相同的。在测试时,对于这个特定情况,layer_pruning_params给出的错误比pruning_params要低。比较从不同的剪枝参数获得的MSE是有用的,这样你就可以选择一个不会使模型性能变差的MSE。
现在让我们比较一下有剪枝和没有剪枝模型的大小。我们从训练和保存模型权重开始,以便以后使用。
def train_save_weights():
model = setup_model()
model.compile(optimizer='adam',
loss=tf.keras.losses.mean_squared_error,
metrics=['mae', 'mse'])
model.fit(X_train,y_train,epochs=300,validation_split=0.2,callbacks=callbacks,verbose=0)
model.save_weights('.models/friedman_model_weights.h5')
train_save_weights()
我们将建立我们的基础模型,并加载保存的权重。然后我们对整个模型进行剪枝。我们编译、拟合模型,并在Tensorboard上将结果可视化。
base_model = setup_model()
base_model.load_weights('.models/friedman_model_weights.h5') # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)
model_for_pruning.compile(
loss=tf.keras.losses.mean_squared_error,
optimizer='adam',
metrics=['mae', 'mse']
)
model_for_pruning.fit(
X_train,
y_train,
callbacks=callbacks,
epochs=300,
validation_split = 0.2,
verbose=0
)
%tensorboard --logdir={log_dir}
以下是TensorBoard的剪枝概述的快照。
在TensorBoard上也可以看到其它剪枝模型概述
现在让我们定义一个计算模型大小函数
def get_gzipped_model_size(model,mode_name,zip_name): # Returns size of gzipped model, in bytes. model.save(mode_name, include_optimizer=False) with zipfile.ZipFile(zip_name, 'w', compression=zipfile.ZIP_DEFLATED) as f: f.write(mode_name) return os.path.getsize(zip_name)
现在我们定义导出模型,然后计算大小。
对于剪枝过的模型,tfmot.sparsity.keras.strip_pruning()用来恢复带有稀疏权重的原始模型。请注意剥离模型和未剥离模型在尺寸上的差异。
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning,'.models/model_for_pruning.h5','.models/model_for_pruning.zip')))
print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export,'.models/model_for_export.h5','.models/model_for_export.zip')))
Size of gzipped pruned model without stripping: 6101.00 bytes Size of gzipped pruned model with stripping: 5140.00 bytes
对这两个模型进行预测,我们发现它们具有相同的均方误差。
model_for_prunning_predictions = model_for_pruning.predict(X_test)
print('Model for Prunning Error %.4f' % mean_squared_error(y_test,model_for_prunning_predictions.reshape(3300,)))
model_for_export_predictions = model_for_export.predict(X_test)
print('Model for Export Error %.4f' % mean_squared_error(y_test,model_for_export_predictions.reshape(3300,)))
Model for Prunning Error 0.0264 Model for Export Error 0.0264
您可以继续测试不同的剪枝计划如何影响模型的大小。显然这里的观察结果不具有普遍性。也可以尝试不同的剪枝参数,并了解它们如何影响您的模型大小、预测误差/精度,这将取决于您要解决的问题。
为了进一步优化模型,您可以将其量化。如果您想了解更多,请查看下面的回购和参考资料。
作者:Derrick Mwiti
deephub翻译组:钱三一
数据分析咨询请扫描二维码
若不方便扫码,搜微信号:CDAshujufenxi
在数字化时代,数据分析已成为企业决策、业务优化、增长突破的核心支撑,从数据仓库搭建(如维度表与事实表的设计)、数据采集清 ...
2026-03-16在数据仓库建设、数据分析(尤其是用户行为分析、业务指标分析)的实践中,维度表与事实表是两大核心组件,二者相互依存、缺一不 ...
2026-03-16数据是CDA(Certified Data Analyst)数据分析师开展一切工作的核心载体,而数据读取作为数据生命周期的关键环节,是连接原始数 ...
2026-03-16在用户行为分析实践中,很多从业者会陷入一个核心误区:过度关注“当前数据的分析结果”,却忽视了结果的“泛化能力”——即分析 ...
2026-03-13在数字经济时代,用户的每一次点击、浏览、停留、转化,都在传递着真实的需求信号。用户行为分析,本质上是通过收集、整理、挖掘 ...
2026-03-13在金融、零售、互联网等数据密集型行业,量化策略已成为企业挖掘商业价值、提升决策效率、控制经营风险的核心工具。而CDA(Certi ...
2026-03-13在机器学习建模体系中,随机森林作为集成学习的经典算法,凭借高精度、抗过拟合、适配多场景、可解释性强的核心优势,成为分类、 ...
2026-03-12在机器学习建模过程中,“哪些特征对预测结果影响最大?”“如何筛选核心特征、剔除冗余信息?”是从业者最常面临的核心问题。随 ...
2026-03-12在数字化转型深度渗透的今天,企业管理已从“经验驱动”全面转向“数据驱动”,数据思维成为企业高质量发展的核心竞争力,而CDA ...
2026-03-12在数字经济飞速发展的今天,数据分析已从“辅助工具”升级为“核心竞争力”,渗透到商业、科技、民生、金融等各个领域。无论是全 ...
2026-03-11上市公司财务报表是反映企业经营状况、盈利能力、偿债能力的核心数据载体,是投资者决策、研究者分析、从业者复盘的重要依据。16 ...
2026-03-11数字化浪潮下,数据已成为企业生存发展的核心资产,而数据思维,正是CDA(Certified Data Analyst)数据分析师解锁数据价值、赋 ...
2026-03-11线性回归是数据分析中最常用的预测与关联分析方法,广泛应用于销售额预测、风险评估、趋势分析等场景(如前文销售额预测中的多元 ...
2026-03-10在SQL Server安装与配置的实操中,“服务名无效”是最令初学者头疼的高频问题之一。无论是在命令行执行net start启动服务、通过S ...
2026-03-10在数据驱动业务的当下,CDA(Certified Data Analyst)数据分析师的核心价值,不仅在于解读数据,更在于搭建一套科学、可落地的 ...
2026-03-10在企业经营决策中,销售额预测是核心环节之一——无论是库存备货、营销预算制定、产能规划,还是战略布局,都需要基于精准的销售 ...
2026-03-09金融数据分析的核心价值,是通过挖掘数据规律、识别风险、捕捉机会,为投资决策、风险控制、业务优化提供精准支撑——而这一切的 ...
2026-03-09在数据驱动决策的时代,CDA(Certified Data Analyst)数据分析师的核心工作,是通过数据解读业务、支撑决策,而指标与指标体系 ...
2026-03-09在数据处理的全流程中,数据呈现与数据分析是两个紧密关联却截然不同的核心环节。无论是科研数据整理、企业业务复盘,还是日常数 ...
2026-03-06在数据分析、数据预处理场景中,dat文件是一种常见的二进制或文本格式数据文件,广泛应用于科研数据、工程数据、传感器数据等领 ...
2026-03-06