京公网安备 11010802034615号
经营许可证编号:京B2-20210330
这篇文章主要介绍了Python基于numpy灵活定义神经网络结构的方法,结合实例形式分析了神经网络结构的原理及Python具体实现方法,涉及Python使用numpy扩展进行数学运算的相关操作技巧,需要的朋友可以参考下
本文实例讲述了Python基于numpy灵活定义神经网络结构的方法。分享给大家供大家参考,具体如下:
用numpy可以灵活定义神经网络结构,还可以应用numpy强大的矩阵运算功能!
一、用法
1). 定义一个三层神经网络:

说明:
输入层节点数目:3
隐藏层节点数目:4
输出层节点数目:2
2).定义一个五层神经网络:
'''示例二'''
nn = NeuralNetworks([3,5,7,4,2]) # 定义神经网络
nn.fit(X,y) # 拟合
print(nn.predict(X)) #预测
说明:
输入层节点数目:3
隐藏层1节点数目:5
隐藏层2节点数目:7
隐藏层3节点数目:4
输出层节点数目:2
二、实现
如下实现方式为本人(@hhh5460)原创。 要点: dtype=object
import numpy as np
class NeuralNetworks(object):
''''''
def __init__(self, n_layers=None, active_type=None, n_iter=10000, error=0.05, alpha=0.5, lamda=0.4):
'''搭建神经网络框架'''
# 各层节点数目 (向量)
self.n = np.array(n_layers) # 'n_layers必须为list类型,如:[3,4,2] 或 n_layers=[3,4,2]'
self.size = self.n.size # 层的总数
# 层 (向量)
self.z = np.empty(self.size, dtype=object) # 先占位(置空),dtype=object !如下皆然
self.a = np.empty(self.size, dtype=object)
self.data_a = np.empty(self.size, dtype=object)
# 偏置 (向量)
self.b = np.empty(self.size, dtype=object)
self.delta_b = np.empty(self.size, dtype=object)
# 权 (矩阵)
self.w = np.empty(self.size, dtype=object)
self.delta_w = np.empty(self.size, dtype=object)
# 填充
for i in range(self.size):
self.a[i] = np.zeros(self.n[i]) # 全零
self.z[i] = np.zeros(self.n[i]) # 全零
self.data_a[i] = np.zeros(self.n[i]) # 全零
if i < self.size - 1:
self.b[i] = np.ones(self.n[i+1]) # 全一
self.delta_b[i] = np.zeros(self.n[i+1]) # 全零
mu, sigma = 0, 0.1 # 均值、方差
self.w[i] = np.random.normal(mu, sigma, (self.n[i], self.n[i+1])) # # 正态分布随机化
self.delta_w[i] = np.zeros((self.n[i], self.n[i+1])) # 全零
下面完整代码是我学习斯坦福机器学习教程,完全自己敲出来的:
import numpy as np
'''
参考:http://ufldl.stanford.edu/wiki/index.php/%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C
'''
class NeuralNetworks(object):
''''''
def __init__(self, n_layers=None, active_type=None, n_iter=10000, error=0.05, alpha=0.5, lamda=0.4):
'''搭建神经网络框架'''
self.n_iter = n_iter # 迭代次数
self.error = error # 允许最大误差
self.alpha = alpha # 学习速率
self.lamda = lamda # 衰减因子 # 此处故意拼写错误!
if n_layers is None:
raise '各层的节点数目必须设置!'
elif not isinstance(n_layers, list):
raise 'n_layers必须为list类型,如:[3,4,2] 或 n_layers=[3,4,2]'
# 节点数目 (向量)
self.n = np.array(n_layers)
self.size = self.n.size # 层的总数
# 层 (向量)
self.a = np.empty(self.size, dtype=object) # 先占位(置空),dtype=object !如下皆然
self.z = np.empty(self.size, dtype=object)
# 偏置 (向量)
self.b = np.empty(self.size, dtype=object)
self.delta_b = np.empty(self.size, dtype=object)
# 权 (矩阵)
self.w = np.empty(self.size, dtype=object)
self.delta_w = np.empty(self.size, dtype=object)
# 残差 (向量)
self.data_a = np.empty(self.size, dtype=object)
# 填充
for i in range(self.size):
self.a[i] = np.zeros(self.n[i]) # 全零
self.z[i] = np.zeros(self.n[i]) # 全零
self.data_a[i] = np.zeros(self.n[i]) # 全零
if i < self.size - 1:
self.b[i] = np.ones(self.n[i+1]) # 全一
self.delta_b[i] = np.zeros(self.n[i+1]) # 全零
mu, sigma = 0, 0.1 # 均值、方差
self.w[i] = np.random.normal(mu, sigma, (self.n[i], self.n[i+1])) # # 正态分布随机化
self.delta_w[i] = np.zeros((self.n[i], self.n[i+1])) # 全零
# 激活函数
self.active_functions = {
'sigmoid': self.sigmoid,
'tanh': self.tanh,
'radb': self.radb,
'line': self.line,
}
# 激活函数的导函数
self.derivative_functions = {
'sigmoid': self.sigmoid_d,
'tanh': self.tanh_d,
'radb': self.radb_d,
'line': self.line_d,
}
if active_type is None:
self.active_type = ['sigmoid'] * (self.size - 1) # 默认激活函数类型
else:
self.active_type = active_type
def sigmoid(self, z):
if np.max(z) > 600:
z[z.argmax()] = 600
return 1.0 / (1.0 + np.exp(-z))
def tanh(self, z):
return (np.exp(z) - np.exp(-z)) / (np.exp(z) + np.exp(-z))
def radb(self, z):
return np.exp(-z * z)
def line(self, z):
return z
def sigmoid_d(self, z):
return z * (1.0 - z)
def tanh_d(self, z):
return 1.0 - z * z
def radb_d(self, z):
return -2.0 * z * np.exp(-z * z)
def line_d(self, z):
return np.ones(z.size) # 全一
def forward(self, x):
'''正向传播(在线)'''
# 用样本 x 走一遍,刷新所有 z, a
self.a[0] = x
for i in range(self.size - 1):
self.z[i+1] = np.dot(self.a[i], self.w[i]) + self.b[i]
self.a[i+1] = self.active_functions[self.active_type[i]](self.z[i+1]) # 加了激活函数
def err(self, X, Y):
'''误差'''
last = self.size-1
err = 0.0
for x, y in zip(X, Y):
self.forward(x)
err += 0.5 * np.sum((self.a[last] - y)**2)
err /= X.shape[0]
err += sum([np.sum(w) for w in self.w[:last]**2])
return err
def backward(self, y):
'''反向传播(在线)'''
last = self.size - 1
# 用样本 y 走一遍,刷新所有delta_w, delta_b
self.data_a[last] = -(y - self.a[last]) * self.derivative_functions[self.active_type[last-1]](self.z[last]) # 加了激活函数的导函数
for i in range(last-1, 1, -1):
self.data_a[i] = np.dot(self.w[i], self.data_a[i+1]) * self.derivative_functions[self.active_type[i-1]](self.z[i]) # 加了激活函数的导函数
# 计算偏导
p_w = np.outer(self.a[i], self.data_a[i+1]) # 外积!感谢 numpy 的强大!
p_b = self.data_a[i+1]
# 更新 delta_w, delta_w
self.delta_w[i] = self.delta_w[i] + p_w
self.delta_b[i] = self.delta_b[i] + p_b
def update(self, n_samples):
'''更新权重参数'''
last = self.size - 1
for i in range(last):
self.w[i] -= self.alpha * ((1/n_samples) * self.delta_w[i] + self.lamda * self.w[i])
self.b[i] -= self.alpha * ((1/n_samples) * self.delta_b[i])
def fit(self, X, Y):
'''拟合'''
for i in range(self.n_iter):
# 用所有样本,依次
for x, y in zip(X, Y):
self.forward(x) # 前向,更新 a, z;
self.backward(y) # 后向,更新 delta_w, delta_b
# 然后,更新 w, b
self.update(len(X))
# 计算误差
err = self.err(X, Y)
if err < self.error:
break
# 整千次显示误差(否则太无聊!)
if i % 1000 == 0:
print('iter: {}, error: {}'.format(i, err))
def predict(self, X):
'''预测'''
last = self.size - 1
res = []
for x in X:
self.forward(x)
res.append(self.a[last])
return np.array(res)
if __name__ == '__main__':
nn = NeuralNetworks([2,3,4,3,1], n_iter=5000, alpha=0.4, lamda=0.3, error=0.06) # 定义神经网络
X = np.array([[0.,0.], # 准备数据
[0.,1.],
[1.,0.],
[1.,1.]])
y = np.array([0,1,1,0])
nn.fit(X,y) # 拟合
print(nn.predict(X)) # 预测
数据分析咨询请扫描二维码
若不方便扫码,搜微信号:CDAshujufenxi
在数据驱动运营的时代,指标是连接业务目标与实际行动的核心桥梁,是企业解读业务现状、发现问题、预判趋势的“量化标尺”。一套 ...
2026-05-08在存量竞争日趋激烈的商业时代,“以客户为中心”早已从口号落地为企业运营的核心逻辑。而客户画像作为打通“了解客户”与“服务 ...
2026-05-08 很多数据分析师每天与Excel打交道,但当被问到“什么是表格结构数据”“它和表结构数据有什么区别”“表格结构数据有哪些核 ...
2026-05-08在数据分析、计量研究等场景中,回归分析是探究变量间量化关系的核心方法,无论是简单的一元线性回归,还是复杂的多元线性回归、 ...
2026-05-07在数据分析、计量研究等场景中,回归分析是探究变量间量化关系的核心方法,无论是简单的一元线性回归,还是复杂的多元线性回归、 ...
2026-05-07 很多数据分析师画过趋势图、做过业绩预测,但当被问到“这个月销售额增长20%,到底是长期趋势自然增长,还是促销活动的短期 ...
2026-05-07在数字化时代,商业竞争的核心已从“经验驱动”转向“数据驱动”,越来越多的企业意识到,商业分析不是简单的数据统计与报表呈现 ...
2026-05-06在Excel数据透视表的实操中,“引用”是连接透视表与公式、辅助数据的核心操作,而相对引用作为最基础、最常用的引用方式,其设 ...
2026-05-06 很多数据分析师做过按月份的销售额趋势图,画过按天的流量折线图,但当被问到“时间序列和普通数据有什么本质区别”“季节性 ...
2026-05-06在Excel数据分析中,数据透视表是汇总、整理海量数据的高效工具,而公式则是实现数据二次计算、逻辑判断的核心功能。实际操作中 ...
2026-04-30Excel透视图是数据分析中不可或缺的工具,它能将透视表中的数据快速可视化,帮助我们直观捕捉数据规律、呈现分析结果。但在实际 ...
2026-04-30 很多数据分析师能熟练地计算指标、搭建标签体系,但当被问到“画像到底在解决什么问题”“画像和标签是什么关系”“画像如何 ...
2026-04-30在中介效应分析中,人口统计学变量(如年龄、性别、学历、收入、职业等)是常见的控制变量或调节变量,其处理方式直接影响分析结 ...
2026-04-29在SQL数据库实操中,日期数据的存储与显示是高频需求,而“数字日期”(如20240520、20241231、45321)是很多开发者、数据分析师 ...
2026-04-29 很多分析师在设计标签时思路清晰,但真到落地环节却面临“数据在手,不知如何转化为可用标签”的困境:或因加工方式选择不当 ...
2026-04-29在手游行业竞争日趋白热化的当下,“流量为王”早已升级为“留存为王”,而付费用户留存率更是衡量一款手游盈利能力、运营质量的 ...
2026-04-28在日常MySQL数据库运维与开发中,经常会遇到“同一台服务器上,两个不同数据库(以下简称“源库”“目标库”)的表数据需要保持 ...
2026-04-28 很多分析师每天和数据打交道,但当被问到“标签是什么”“标签和指标有什么区别”“标签体系如何设计”时,却常常答不上来。 ...
2026-04-28箱线图(Box Plot)作为一种经典的数据可视化工具,广泛应用于统计学、数据分析、科研实证等领域,核心价值在于直观呈现数据的集 ...
2026-04-27实证分析是社会科学、自然科学、经济管理等领域开展研究的核心范式,其核心逻辑是通过对多维度数据的收集、分析与解读,揭示变量 ...
2026-04-27