如何更好地理解神经网络的正向传播?我们需要从「矩阵乘法」入手
moboyou 2025-05-22 12:11 22 浏览
图:pixabay
原文来源:medium
作者:Matt Ross
「机器人圈」编译:嗯~阿童木呀、多啦A亮
介绍
我为什么要写这篇文章呢?主要是因为我在构建神经网络的过程中遇到了一个令人沮丧的bug,最终迫使我进入该系统,并且真正了解了神经网络的核心的线性代数。我发现我已经做得很好了,而这只需要确保两个相乘矩阵的内部维度相匹配,而当发生bug时,我只是将各自将矩阵转置于不同的位置,直到事情解决。但是其中有一个隐藏的事实,那就是我并没有真正了解矩阵乘法运算的每一步。
我们将通过遍历正向传播的每一个步骤来计算一个相当简单的神经网络的成本函数。当然,如果你想知道我的矩阵乘法那些由于疏忽所引起的错误,那就是我把偏差单元(1的向量)加在一个列上,而它的正确位置应该是在一个行上。我这样做是因为在这一步骤之前,我并没有真正了解矩阵乘法的全部输出,所以没有意识到我必须做出改变。首先,我将介绍一个在神经网络正向传播中所发生事情的高级背景原因,然后我们将仔细研究一个特定的范例,并使用指数和代码让事情更清晰。
因此,神经网络对复杂关系进行建模是不可思议的。我们即将谈论的只是网络的前馈传播部分。现在,神经网络的输入单元可以是任何东西。例如,它们可以是表示一堆手写数字的20像素×20像素图像的灰度强度(介于0和1之间)。在这种情况下,你将拥有400个输入单元。现在我们有2个输入单元,加上我们的+1偏差单元(为什么要有偏差单元?答案在这里
https://www.quora.com/What-is-bias-in-artificial-neural-network)。正向传播本质上是从一个例子(例如手写数字的那些图像中的一个)中获取每个输入,然后将输入值乘以单元/节点之间的每个连接的权重(见图5),然后将所有连接的所有乘积相加到你正在计算激活的节点中,然后将获取总和(z),这一点可通过sigmoid函数实现(见下文)。
图1 Sigmoid函数
因此你就会获得隐藏层的每一个单元的激活。然后,你可以使用相同的方法来计算下一层,但是这次你将使用隐藏层的激活作为“输入”值。你将所有a^2激活(即隐藏层)单元乘以第二组权重Theta2,将连接到单个最终输出单元的每个乘积相加,并将该乘积通过Sigmoid函数加以计算,以获得最终的输出激活a^3。g(z)是Sigmoid函数,z是x输入(或隐藏层中的激活)和权重θ(由图5中的正常神经网络图中的单个箭头表示)的乘积。
图2 应用Sigmoid函数的假设函数
一旦你拥有这一切,你想计算网络的成本(图4)。你的成本函数本质上将计算出给出范例的输出假设h(x)和实际y值之间的成本/差异。所以,在我继续使用的范例中,y是由输入所表示的实际数字的值。如果在网络中有一个“4”馈送图像,则y就是值“4”。由于存在多个输出单元,所以成本函数将h(x)与输出相比较,相对于列向量,其中第4行为1,其余的都为0。意思是表示“4”输出的输出单元为真,其余为假。对于1,2或n的输出,请参见下文。
图3 我们的样本数据y值表示为逻辑真/假列向量
图4 多级Logistic回归成本函数
上述成本函数J(theta)中的两个Sigma将总结出你通过网络(m)和每个单个输出级(K)提供的每个示例的成本。现在,你可以通过单独进行每个计算来实现这一点,但是事实证明,人类已经定义的矩阵乘法的方法使得它能够完美地同时执行所有这些正向传播计算。我们那些专攻数值计算的朋友已经优化了矩阵乘法函数,使得神经网络可以极大地输出假设。要编写我们的代码,以便我们能够同时进行所有的计算,而不是说在所有输入示例中的for循环中运行所有内容,这是一个称为向量化代码的过程。这在神经网络中是非常重要的,因为它们的计算成本已经足够高了,我们不需要任何循环来减慢我们的运算速度。
我们的网络范例
在我们的网络中,我们将有四级,即1,2,3,4,并将遍历这个计算的每个步骤。我们假设拥有的是一个已训练的网络,并已通过反向传播训练了Theta参数/权重。它将是一个3层网络(2个输入单元,2个隐藏层单元和4个输出单元)。网络和参数(又名权重)可以表示如下:
图5 具有权重的神经网络
在进一步深入之前,如果你不知道矩阵乘法是如何工作的,那就花费7分钟来看看可汗学院(Khan Academy)(
https://www.khanacademy.org/math/precalculus/precalc-matrices/multiplying-matrices-by-matrices/v/matrix-multiplication-intro),然后再看一两个范例,确保你对它的工作原理有一个直观的认识。再次强调,在进一步深入研究之前了解这一点很重要。
那我们就从所有的数据开始。我们的3个示例数据和相应的y输出值。这些数据不代表任何东西,它们只是数字,用来显示我们将要做的计算:
图6 数据
当然,如上所述,由于有4个输出单元,我们的数据必须表示为一个逻辑向量矩阵,三个示例输出中的每一个都要如此。我使用的是MATLAB,以便将我们的y向量转换成一个逻辑向量矩阵:
yv=[1:4] == y; %creating logical vectors of y values
图7 样本矩阵输出y数据变为逻辑向量
另外,请注意,我们的X数据没有足够的特征。在图5的神经网络中,当我们计算权重/参数和输入值的乘积时,我们就有了那个虚线偏差单元x(0)。这意味着我们需要将偏差单元添加到数据中,也就是说我们在矩阵的开头添加一个列:
X = [ones(m,1),X];
图8 附有偏差的数据,偏差即由图5中的神经网络虚线单位/节点所指代
数据X被定义为第一个输入层a^1的第一个激活值,所以如果你在代码(第3行)中看到了一个a^1,它只是指初始输入数据。网络中每个连接/箭头的权值或参数如下所示:
图9 我们的神经网络的第一组权重/参数,其指数与图5神经网络图的箭头匹配。
下面是我们将用于计算逻辑成本函数的完整代码,我们已经解决了第2行和第9行,但是我们将在本代码的其余部分慢慢地分解矩阵乘法和重要的矩阵操作:
1: m = size(X, 1);
2: X = [ones(m,1),X];
3: a1 = X;
4: z2 = Theta1*a1';
5: a2 = sigmoid(z2);
6: a2 = [ones(1,m);a2];
7: z3 = Theta2*a2;
8: a3 = sigmoid(z3);
9: yv=[1:4] == y;
10: J = (1/m) * (sum(-yv’ .* log(a3) — ((1 — yv’) .* log(1 — a3))));
11: J = sum(J);
首先,我们来进行正向传播的第一步,第4行代码。将每个范例的输入值乘以与其相应的权重。我总是想象输入值在图5的网络中沿着箭头流动,乘以权重,然后在激活单元/节点等待其他箭头进行乘法运算。然后,特定单元的整个激活值首先由这些箭头(输入x权重)的总和组成,然后该和通过Sigmoid函数进行操作(参见上面的图1)。
所以在这里很容易犯你的第一个矩阵乘法错误。由于我们的附有偏差的单元被添加到X(在这里也称为a^1)是一个3x3矩阵,而我们的Theta1是一个2x3矩阵。由于Theta:2x3和X:3x3的两个内部维度是相同的,因此把两个矩阵相乘就变得很简单了,其结果应该是正确的,且会给出我们的2x3合成矩阵?对不起,错了!
z2 = Theta1 * a1; %WRONG! THIS DOESN'T GIVE US WHAT WE WANT
尽管运行这个计算将输出一个我们期望和需要将其用于下一步的正确维度的矩阵,但是所有计算的值将是错误的,因此所有的计算都将从这里开始表现为错误的。另外,由于没有计算机错误,所以很难判断为什么网络成本函数计算出了错误的成本,如果你注意到了,请记住,当进行矩阵乘法时,得到的矩阵的每个元素ab是第1矩阵中的行a和第二个矩阵中的列b的点积和。如果我们使用上面的代码来计算z^2,则得到的矩阵中的第一个元素将由我们的第一行Theta的[0.1 0.3 . 0.5]与整列偏差单元相乘得到,[1.000;1.000; 1.000],这对我们没有用。这意味着我们需要将范例的输入数据矩阵进行转置,使得矩阵将每个theta与每个输入正确相乘:
z2 = Theta1*a1';
矩阵乘法的运算如下:
图10 矩阵乘法的指数符号表示。列中的结果元素表示单个示例,并且行是隐藏层中的不同激活单元。每个示例中 2个隐藏层导致两个值(或行)。
然后,我们将上述z^2矩阵中的6个元素中的每个单元应用于Sigmoid函数:
a2 = sigmoid(z2);
这为我们提供了三个示例中每两个隐藏单元的隐藏层激活值的2×3矩阵:
图11 隐藏单元的激活值
因为这是作为矩阵乘法完成的,所以我们能够同时计算隐藏层的激活值,而不是在所有这些例子中使用for循环,当使用更大的数据集时,计算变得极其昂贵。更不用说再需要反向传播了。
现在我们具有第二层激活单元的值,它们作为输入到下一层和最后一层,即输出层。该层对于第2层和第3层之间的图5中的每个箭头,都有一组新的权重/参数Theta2,我们继续重复上面的步骤。将连接到每个激活节点的权重的激活值(输入)乘以连接到每个激活节点的产品,然后通过sigmoid函数运行每个激活节点和以获得最终输出。我们的a^2作为我们的输入数据,我们的权重/参数如下:
图12 带指数的Theta2权重/参数。每行表示对每个输出单元贡献的权重。
我们要做以下计算:
z3 = Theta2*a2;
但是在我们这样做之前,我们必须再次添加我们的偏差单元到我们的数据,在这种情况下,隐藏层激活a^2。如果你在图5中再次注意到,隐藏层中的虚线圆圈(0),仅在下一次计算时才添加的偏置单元。因此,我们把它添加到上面图11所示的激活矩阵中。
介绍我犯过的错误就是我写这篇文章的动机。要正向传播激活值,我们将Theta中的一行的每个元素与a2中的每个元素相乘,并且这些乘积的总和将给出所得到的z^3矩阵的单个元素。通常,数据结构的方式是将偏差单元添加为列,但是如果你这样做(我愚蠢地做了),这将会给我们一个错误的结果。所以我们将偏置单位作为一行添加到a^2中。
a2 = [ones(1,m);a2];
图13 将偏移行添加到a^2激活中
在我们运行矩阵乘法以计算z^3之前,请注意,在z^2之前,你必须转置输入数据a 1,使其对于矩阵乘法“正确排列”,以计算出我们想要的结果。这里,我们的矩阵按照我们想要的方式排列,所以没有转置a^2矩阵。这是另一个常见的错误,如果你不了解这个核心的计算,那么很容易犯这个错误(我过去对此非常内疚)。现在我们可以在4x3和3x3矩阵上运行矩阵乘法,得到3个例子中的每一个的4×3矩阵输出假设:
z3 = Theta2*a2;
图14 矩阵乘法的指数符号表示。列中的合成元素代表单个示例,并且行是输出层的不同激活单元,共有四个输出单元。在分类问题中,这意味着四个类/类别。还需要注意的是,每个元素中的所有a的[m]上标指数是示例编号。
然后我们对z^2矩阵中的12个元素中的每一个元素使用sigmoid 函数:
a3 = sigmoid(z3);
这为我们每个输出单元/类提供了一个4x3矩阵的输出层激活(类似于假设):
图15 每个示例网络的每个输出单元的激活值。如果你在所有的示例中做一个循环,这将是一个列向量,而不是一个矩阵。
从这里,你只是计算成本函数。唯一需要注意的是,你必须转置y向量的矩阵,以确保你在成本函数中正在进行的元素操作与每个示例和输出单元完全对齐。
图16 逻辑y向量矩阵的转置
然后我们把它们放在一起来计算成本函数:
图4 多级Logistic回归成本函数
J = (1/m) * (sum(-yv’ .* log(a3) — ((1 — yv’) .* log(1 — a3))));
J = sum(J);
这就是我们的成本,应该注意的是以计算所有类以及所有示例的双倍总和。这就是所有人。矩阵乘法可以使这个代码非常整齐和高效,不需要让循环减慢,但是你必须知道矩阵乘法中发生了什么,以便你可以适当地调整矩阵,无论是乘法顺序,必要时进行转置,并将偏差单元添加到矩阵的正确区域。一旦你把它打破了,掌握得就更加直观,我强烈推荐,如果你仍然不确定,慢慢地像这样通过一个示例,它总是归结为一些非常简单的基本原理。
我希望这对于正向传播所需的线性代数的掌握,同时去神秘化是非常有帮助的。
原文链接:https://medium.com/@
matt.as.ross/under-the-hood-of-neural-network-forward-propagation-the-dreaded-matrix-multiplication-a5360b33426
- 上一篇:BP神经网络在射频识别定位系统中的应用研究
- 下一篇:【电机控制】FOC电机控制
相关推荐
- Node.js 获取文件信息及路径(node.js怎么获取当前文件路径)
-
获取文件信息每个文件都有一组细节,我们可以使用Node.js进行检查。特别是使用fs模块提供的stat()方法。constfs=require('fs');fs.stat(...
- 深入剖析JavaScript中深浅拷贝(js实现深浅拷贝)
-
大家好,我是Echa。最近有一位00后的小妹妹粉丝私信小编说自己很喜欢编程,目前在某公司实习前端开发工作,说到现在为止还没有搞懂JavaScript中深拷贝和浅拷贝这个问题,同时也在网上看了很多关于深...
- 为什么高手写 JS 总是又快又好?这10个技巧你要知道
-
大家好,很高兴又见面了,我是"高级前端进阶",由我带着大家一起关注前端前沿、深入前端底层技术,大家一起进步,也欢迎大家关注、点赞、收藏、转发!JavaScript是前端开发的重要语言...
- IT技术栈:Javascript神器,URL.createObjectURL()
-
URL.createObjectURL()是JavaScript中的一个方法,用于创建一个特殊的URL,该URL可以用于将不支持直接加载的数据(如二进制数据或Blob对象)嵌入到we...
- 如何在 Linux 中创建和管理组?(linux如何建立组)
-
在Linux中,组是用户账户的集合,用于统一管理权限。每个用户至少属于一个主组(PrimaryGroup),还可以加入多个附加组(SupplementaryGroup)。组的权限设置决定了用户对文...
- 付费文库内容无法复制,不用任何工具,学会这4种方法轻松复制
-
关注职场办公,分享实用干货,洞察科技资讯,这里是「职场科技范」。我们在搜索资料的时候,看到非常有用的文库,但往往都是付费的,只能看不能复制。今天就来教大家,学会下面这4种方法,轻松复制文库内容。一、内...
- node.js v24.0.0 正式发布!10大重磅更新助力开发者,性能大幅提升
-
近日,Node.js官方团队正式发布了Node.jsv24.0.0版本,这是一个具有里程碑意义的重大更新。作为"Current"版本,它将在未来六个月内引领Node.js...
- 我理解的网站产品经理之四:网站产品前端姿势
-
来人人都是产品经理【起点学院】,BAT实战派产品总监手把手系统带你学产品、学运营。2016年了,嗨,大家新年好。作为一个网页的产品经理,网页的前端知识可谓是不能不知,本文主讲网站产品的前端姿势。通常,...
- 五一我要看七天小说!免费开源的轻量化书库talebook搭建流程。
-
这次来分享一个简单阅读项目:TaleBook,项目曾用名calibre-webserver。TaleBook是一个基于Calibre的简单的个人图书管理系统,支持在线阅读。不过鉴于各种规章制度,仅...
- “5 分钟 CMake 使用指南,解决我的 C++ 打包问题!”
-
在软件开发的世界里,构建系统扮演着至关重要的角色,它不仅决定了项目的构建效率,还直接影响到团队协作的流畅度。对于许多C++开发者而言,CMake因其强大的功能和广泛的兼容性成为了构建自动化流程的...
- 大佬级鬼才终于把JavaScript整理成了修仙小说,让学习变简单
-
这是一本讲解JavaScript编程语言的技术书籍,只不过,本书采用了一种全新的写作手法。如果你厌倦了厚厚的、如同字典般的编程书籍,不妨尝试一下新的口味,话不多说,直接上干货!目录截图:内容展示:以上...
- JavaScript基础知识点总结(javascript基础入门教程)
-
//逗比小憨憨/*第一章*HTML引用js方法:*1,外部引用:HTML外部引用js:<scriptsrc="js/day1.js"></script>*2,...
- 在Node.js中处理Zip文件(node运行js文件)
-
作者:疯狂的技术宅转发链接:https://mp.weixin.qq.com/s/edJd9-t1AyTGRcha_1k6RA前言Zip文件是常用的压缩文件格式。在本文中,我将演示如何用adm-...
- Python 标准库中鲜为人知的宝藏 | Node.js 22.8.0 发布
-
Python标准库中鲜为人知的宝藏Python标准库功能强大,但有些模块却鲜为人知。本文将介绍一些有趣且实用的模块,助你提升代码效率和功能。数据结构:超越列表和字典除了常用的列表和字典,coll...
- 小程序,wxml页面里如何写JS代码?WXS如何模块化?
-
这篇接着上篇小程序,跳转页面的两种方式及其页面传参数继续讲,小程序wxml页面里如何写JS代码?wxs如何模块化?第一个问题:wxml页面要想类似HTML页面中写js代码,必须在页面中使用wxs标...
- 一周热门
- 最近发表
-
- Node.js 获取文件信息及路径(node.js怎么获取当前文件路径)
- 深入剖析JavaScript中深浅拷贝(js实现深浅拷贝)
- 为什么高手写 JS 总是又快又好?这10个技巧你要知道
- IT技术栈:Javascript神器,URL.createObjectURL()
- 如何在 Linux 中创建和管理组?(linux如何建立组)
- 付费文库内容无法复制,不用任何工具,学会这4种方法轻松复制
- node.js v24.0.0 正式发布!10大重磅更新助力开发者,性能大幅提升
- 我理解的网站产品经理之四:网站产品前端姿势
- 五一我要看七天小说!免费开源的轻量化书库talebook搭建流程。
- “5 分钟 CMake 使用指南,解决我的 C++ 打包问题!”
- 标签列表
-
- 外键约束 oracle (36)
- oracle的row number (32)
- 唯一索引 oracle (34)
- oracle in 表变量 (28)
- oracle导出dmp导出 (28)
- oracle两个表 (20)
- oracle 数据库 字符集 (20)
- oracle安装补丁 (19)
- matlab化简多项式 (20)
- 多线程的创建方式 (29)
- 多线程 python (30)
- java多线程并发处理 (32)
- 宏程序代码一览表 (35)
- c++需要学多久 (25)
- css class选择器用法 (25)
- css样式引入 (30)
- html5和css3新特性 (19)
- css教程文字移动 (33)
- php简单源码 (36)
- php个人中心源码 (25)
- 网站管理平台php源码 (19)
- php小说爬取源码 (23)
- github好玩的php项目 (18)
- 云电脑app源码 (22)
- js创建txt文件 (18)