如何更好地理解神经网络的正向传播?我们需要从「矩阵乘法」入手
moboyou 2025-05-15 20:00 24 浏览
图:pixabay
原文来源:medium
作者:Matt Ross
「机器人圈」编译:嗯~阿童木呀、多啦A亮
介绍
我为什么要写这篇文章呢?主要是因为我在构建神经网络的过程中遇到了一个令人沮丧的bug,最终迫使我进入该系统,并且真正了解了神经网络的核心的线性代数。我发现我已经做得很好了,而这只需要确保两个相乘矩阵的内部维度相匹配,而当发生bug时,我只是将各自将矩阵转置于不同的位置,直到事情解决。但是其中有一个隐藏的事实,那就是我并没有真正了解矩阵乘法运算的每一步。
我们将通过遍历正向传播的每一个步骤来计算一个相当简单的神经网络的成本函数。当然,如果你想知道我的矩阵乘法那些由于疏忽所引起的错误,那就是我把偏差单元(1的向量)加在一个列上,而它的正确位置应该是在一个行上。我这样做是因为在这一步骤之前,我并没有真正了解矩阵乘法的全部输出,所以没有意识到我必须做出改变。首先,我将介绍一个在神经网络正向传播中所发生事情的高级背景原因,然后我们将仔细研究一个特定的范例,并使用指数和代码让事情更清晰。
因此,神经网络对复杂关系进行建模是不可思议的。我们即将谈论的只是网络的前馈传播部分。现在,神经网络的输入单元可以是任何东西。例如,它们可以是表示一堆手写数字的20像素×20像素图像的灰度强度(介于0和1之间)。在这种情况下,你将拥有400个输入单元。现在我们有2个输入单元,加上我们的+1偏差单元(为什么要有偏差单元?答案在这里
https://www.quora.com/What-is-bias-in-artificial-neural-network)。正向传播本质上是从一个例子(例如手写数字的那些图像中的一个)中获取每个输入,然后将输入值乘以单元/节点之间的每个连接的权重(见图5),然后将所有连接的所有乘积相加到你正在计算激活的节点中,然后将获取总和(z),这一点可通过sigmoid函数实现(见下文)。
图1 Sigmoid函数
因此你就会获得隐藏层的每一个单元的激活。然后,你可以使用相同的方法来计算下一层,但是这次你将使用隐藏层的激活作为“输入”值。你将所有a^2激活(即隐藏层)单元乘以第二组权重Theta2,将连接到单个最终输出单元的每个乘积相加,并将该乘积通过Sigmoid函数加以计算,以获得最终的输出激活a^3。g(z)是Sigmoid函数,z是x输入(或隐藏层中的激活)和权重θ(由图5中的正常神经网络图中的单个箭头表示)的乘积。
图2 应用Sigmoid函数的假设函数
一旦你拥有这一切,你想计算网络的成本(图4)。你的成本函数本质上将计算出给出范例的输出假设h(x)和实际y值之间的成本/差异。所以,在我继续使用的范例中,y是由输入所表示的实际数字的值。如果在网络中有一个“4”馈送图像,则y就是值“4”。由于存在多个输出单元,所以成本函数将h(x)与输出相比较,相对于列向量,其中第4行为1,其余的都为0。意思是表示“4”输出的输出单元为真,其余为假。对于1,2或n的输出,请参见下文。
图3 我们的样本数据y值表示为逻辑真/假列向量
图4 多级Logistic回归成本函数
上述成本函数J(theta)中的两个Sigma将总结出你通过网络(m)和每个单个输出级(K)提供的每个示例的成本。现在,你可以通过单独进行每个计算来实现这一点,但是事实证明,人类已经定义的矩阵乘法的方法使得它能够完美地同时执行所有这些正向传播计算。我们那些专攻数值计算的朋友已经优化了矩阵乘法函数,使得神经网络可以极大地输出假设。要编写我们的代码,以便我们能够同时进行所有的计算,而不是说在所有输入示例中的for循环中运行所有内容,这是一个称为向量化代码的过程。这在神经网络中是非常重要的,因为它们的计算成本已经足够高了,我们不需要任何循环来减慢我们的运算速度。
我们的网络范例
在我们的网络中,我们将有四级,即1,2,3,4,并将遍历这个计算的每个步骤。我们假设拥有的是一个已训练的网络,并已通过反向传播训练了Theta参数/权重。它将是一个3层网络(2个输入单元,2个隐藏层单元和4个输出单元)。网络和参数(又名权重)可以表示如下:
图5 具有权重的神经网络
在进一步深入之前,如果你不知道矩阵乘法是如何工作的,那就花费7分钟来看看可汗学院(Khan Academy)(
https://www.khanacademy.org/math/precalculus/precalc-matrices/multiplying-matrices-by-matrices/v/matrix-multiplication-intro),然后再看一两个范例,确保你对它的工作原理有一个直观的认识。再次强调,在进一步深入研究之前了解这一点很重要。
那我们就从所有的数据开始。我们的3个示例数据和相应的y输出值。这些数据不代表任何东西,它们只是数字,用来显示我们将要做的计算:
图6 数据
当然,如上所述,由于有4个输出单元,我们的数据必须表示为一个逻辑向量矩阵,三个示例输出中的每一个都要如此。我使用的是MATLAB,以便将我们的y向量转换成一个逻辑向量矩阵:
yv=[1:4] == y; %creating logical vectors of y values
图7 样本矩阵输出y数据变为逻辑向量
另外,请注意,我们的X数据没有足够的特征。在图5的神经网络中,当我们计算权重/参数和输入值的乘积时,我们就有了那个虚线偏差单元x(0)。这意味着我们需要将偏差单元添加到数据中,也就是说我们在矩阵的开头添加一个列:
X = [ones(m,1),X];
图8 附有偏差的数据,偏差即由图5中的神经网络虚线单位/节点所指代
数据X被定义为第一个输入层a^1的第一个激活值,所以如果你在代码(第3行)中看到了一个a^1,它只是指初始输入数据。网络中每个连接/箭头的权值或参数如下所示:
图9 我们的神经网络的第一组权重/参数,其指数与图5神经网络图的箭头匹配。
下面是我们将用于计算逻辑成本函数的完整代码,我们已经解决了第2行和第9行,但是我们将在本代码的其余部分慢慢地分解矩阵乘法和重要的矩阵操作:
1: m = size(X, 1);
2: X = [ones(m,1),X];
3: a1 = X;
4: z2 = Theta1*a1';
5: a2 = sigmoid(z2);
6: a2 = [ones(1,m);a2];
7: z3 = Theta2*a2;
8: a3 = sigmoid(z3);
9: yv=[1:4] == y;
10: J = (1/m) * (sum(-yv’ .* log(a3) — ((1 — yv’) .* log(1 — a3))));
11: J = sum(J);
首先,我们来进行正向传播的第一步,第4行代码。将每个范例的输入值乘以与其相应的权重。我总是想象输入值在图5的网络中沿着箭头流动,乘以权重,然后在激活单元/节点等待其他箭头进行乘法运算。然后,特定单元的整个激活值首先由这些箭头(输入x权重)的总和组成,然后该和通过Sigmoid函数进行操作(参见上面的图1)。
所以在这里很容易犯你的第一个矩阵乘法错误。由于我们的附有偏差的单元被添加到X(在这里也称为a^1)是一个3x3矩阵,而我们的Theta1是一个2x3矩阵。由于Theta:2x3和X:3x3的两个内部维度是相同的,因此把两个矩阵相乘就变得很简单了,其结果应该是正确的,且会给出我们的2x3合成矩阵?对不起,错了!
z2 = Theta1 * a1; %WRONG! THIS DOESN'T GIVE US WHAT WE WANT
尽管运行这个计算将输出一个我们期望和需要将其用于下一步的正确维度的矩阵,但是所有计算的值将是错误的,因此所有的计算都将从这里开始表现为错误的。另外,由于没有计算机错误,所以很难判断为什么网络成本函数计算出了错误的成本,如果你注意到了,请记住,当进行矩阵乘法时,得到的矩阵的每个元素ab是第1矩阵中的行a和第二个矩阵中的列b的点积和。如果我们使用上面的代码来计算z^2,则得到的矩阵中的第一个元素将由我们的第一行Theta的[0.1 0.3 . 0.5]与整列偏差单元相乘得到,[1.000;1.000; 1.000],这对我们没有用。这意味着我们需要将范例的输入数据矩阵进行转置,使得矩阵将每个theta与每个输入正确相乘:
z2 = Theta1*a1';
矩阵乘法的运算如下:
图10 矩阵乘法的指数符号表示。列中的结果元素表示单个示例,并且行是隐藏层中的不同激活单元。每个示例中 2个隐藏层导致两个值(或行)。
然后,我们将上述z^2矩阵中的6个元素中的每个单元应用于Sigmoid函数:
a2 = sigmoid(z2);
这为我们提供了三个示例中每两个隐藏单元的隐藏层激活值的2×3矩阵:
图11 隐藏单元的激活值
因为这是作为矩阵乘法完成的,所以我们能够同时计算隐藏层的激活值,而不是在所有这些例子中使用for循环,当使用更大的数据集时,计算变得极其昂贵。更不用说再需要反向传播了。
现在我们具有第二层激活单元的值,它们作为输入到下一层和最后一层,即输出层。该层对于第2层和第3层之间的图5中的每个箭头,都有一组新的权重/参数Theta2,我们继续重复上面的步骤。将连接到每个激活节点的权重的激活值(输入)乘以连接到每个激活节点的产品,然后通过sigmoid函数运行每个激活节点和以获得最终输出。我们的a^2作为我们的输入数据,我们的权重/参数如下:
图12 带指数的Theta2权重/参数。每行表示对每个输出单元贡献的权重。
我们要做以下计算:
z3 = Theta2*a2;
但是在我们这样做之前,我们必须再次添加我们的偏差单元到我们的数据,在这种情况下,隐藏层激活a^2。如果你在图5中再次注意到,隐藏层中的虚线圆圈(0),仅在下一次计算时才添加的偏置单元。因此,我们把它添加到上面图11所示的激活矩阵中。
介绍我犯过的错误就是我写这篇文章的动机。要正向传播激活值,我们将Theta中的一行的每个元素与a2中的每个元素相乘,并且这些乘积的总和将给出所得到的z^3矩阵的单个元素。通常,数据结构的方式是将偏差单元添加为列,但是如果你这样做(我愚蠢地做了),这将会给我们一个错误的结果。所以我们将偏置单位作为一行添加到a^2中。
a2 = [ones(1,m);a2];
图13 将偏移行添加到a^2激活中
在我们运行矩阵乘法以计算z^3之前,请注意,在z^2之前,你必须转置输入数据a 1,使其对于矩阵乘法“正确排列”,以计算出我们想要的结果。这里,我们的矩阵按照我们想要的方式排列,所以没有转置a^2矩阵。这是另一个常见的错误,如果你不了解这个核心的计算,那么很容易犯这个错误(我过去对此非常内疚)。现在我们可以在4x3和3x3矩阵上运行矩阵乘法,得到3个例子中的每一个的4×3矩阵输出假设:
z3 = Theta2*a2;
图14 矩阵乘法的指数符号表示。列中的合成元素代表单个示例,并且行是输出层的不同激活单元,共有四个输出单元。在分类问题中,这意味着四个类/类别。还需要注意的是,每个元素中的所有a的[m]上标指数是示例编号。
然后我们对z^2矩阵中的12个元素中的每一个元素使用sigmoid 函数:
a3 = sigmoid(z3);
这为我们每个输出单元/类提供了一个4x3矩阵的输出层激活(类似于假设):
图15 每个示例网络的每个输出单元的激活值。如果你在所有的示例中做一个循环,这将是一个列向量,而不是一个矩阵。
从这里,你只是计算成本函数。唯一需要注意的是,你必须转置y向量的矩阵,以确保你在成本函数中正在进行的元素操作与每个示例和输出单元完全对齐。
图16 逻辑y向量矩阵的转置
然后我们把它们放在一起来计算成本函数:
图4 多级Logistic回归成本函数
J = (1/m) * (sum(-yv’ .* log(a3) — ((1 — yv’) .* log(1 — a3))));
J = sum(J);
这就是我们的成本,应该注意的是以计算所有类以及所有示例的双倍总和。这就是所有人。矩阵乘法可以使这个代码非常整齐和高效,不需要让循环减慢,但是你必须知道矩阵乘法中发生了什么,以便你可以适当地调整矩阵,无论是乘法顺序,必要时进行转置,并将偏差单元添加到矩阵的正确区域。一旦你把它打破了,掌握得就更加直观,我强烈推荐,如果你仍然不确定,慢慢地像这样通过一个示例,它总是归结为一些非常简单的基本原理。
我希望这对于正向传播所需的线性代数的掌握,同时去神秘化是非常有帮助的。
原文链接:https://medium.com/@
matt.as.ross/under-the-hood-of-neural-network-forward-propagation-the-dreaded-matrix-multiplication-a5360b33426
- 上一篇:电力系统EI会议·权威期刊推荐!
- 下一篇:我拿导弹公式算桃花,结果把自己炸成了烟花
相关推荐
- 原神:“天理”是什么?至今还有很多玩家没搞明白
-
原神已经更新到层岩巨渊,关于提瓦特的秘密却越来越多。然而,直到今天还有很多玩家不明白天理以及天理维系者的关系。这并不怪大家,因为剧情里根本没提,只能靠玩家去猜。天理是什么?在看完渊下宫的剧情之后,不少...
- 《原神》爆火3年仍无竞品:它的“致命武器”竟不是开放世界?
-
#原神的最大特点是什么?#《原神》爆火3年仍无竞品:它的“致命武器”竟不是开放世界?【独家观察】2023年8月,《原神》4.0版本“枫丹”上线首日登顶68国畅销榜,这个现象级产品再次向行业抛出灵魂拷问...
- 原神:每个人都是氪金大佬?除非在梦里!或许还有一种方法
-
游戏中的笔杆王者,每日靠玩游戏过日子,玩网游也有20个年头,我有自己独特的见解,作为一个10年不脱坑的老玩家,如果文章写的有什么问题,请重喷!如果大家觉得好,请转发加点赞!非常感谢!原神每个人都是氪金...
- 原神:丝柯克,又传新消息!入池时间,武器确定!第八元素无了!
-
原神:丝柯克,又传新消息!入池时间,武器确定!第八元素无了!新角色丝柯克已经确定会在5.7版本入池!作为公子的师傅,早在公子14岁时,丝柯克就已经是一位畅行于深渊的剑客了,如今成为执行官的公子,却只希...
- 原神服务端架构搭建工具+环境配置资料
-
我是艾西,今天给大家分享一份详细的原神服务端结构资料教程,从服务端的获取到端口的使用以及安卓和ios的DAIL签名等一文让你明白怎么架设原神服务端,哪些工具资料又代表着什么意思(保姆级教学)Grass...
- 在原神里钓鱼,有人竟然用上了深度强化学习,还把它开源了
-
机器之心报道机器之心编辑部还愁在《原神》里钓不到鱼吗?这有一份迟到的提瓦特钓鱼指南。在游戏圈,你可以没有玩过,但一定听过《原神》。虽然这是一款口碑两极分化的游戏,但不得不承认《原神》是当前最为火热的游...
- BetterGI:让原神游戏更便捷的自动化工具
-
技术背景BetterGI是一个基于计算机视觉技术的项目,旨在让原神游戏变得更加便捷。它利用视觉算法和模拟操作,实现了多种游戏内的自动化功能,帮助玩家节省时间和精力。实现步骤系统要求操作系统:Wind...
- 原神:2.6服务端泄露?9999纠缠之缘秒到账,米哈游跻身全球15强
-
首先,恭喜米哈游凭借原神跻身全球应用开发商第14名,这个榜单记录了全球开发商在iOS&GooglePlay综合收入前52名的数据,第一腾讯第二网易,字节跳动第7,米哈游则排14名。值得一提的是,这个...
- 观鸣潮1.3前瞻有感,《原神》是屎山代码?七个问题拷打米哈游!
-
哈喽大家好啊。前天看完鸣潮1.3直播后,感触良多啊。虽然我对1.3的前瞻内容觉得中规中矩,没有太满意,但是对面策划的态度让我看到了差距。所以今天来拷打一下原神。就是对比隔壁策划面对玩家的反馈,所回答的...
- PHP中的九大缓存技术(php中的九大缓存技术是什么)
-
1、全页面静态化缓存也就是将页面全部生成html静态页面,用户访问时直接访问的静态页面,而不会去走php服务器解析的流程。此种方式,在CMS系统中比较常见,比如dedecms;一种比较常用的实现方式是...
- 使用PhpStorm将代码同步到开发环境
-
配置步骤1、选择Tools>Deployment>Configuration:2、选择SFTP:3、输入servername:4、配置SSHconfiguration:5、配置...
- PM小技术:使用SAE发布在线Axure文档
-
俗话说,不会写代码的产品经理不是好的射鸡湿。关于产品经理与技术之间的微妙关系,扯开了讲可是长篇大论,比如知乎上这个问题:IT行业产品经理(尤其是创业的)需要懂技术吗?懂到什么程度?,以及这个:产品经...
- PHP新手如何提高代码质量(php代码教程)
-
1.不要使用相对路径常常会看到:require_once('../../lib/some_>该方法有很多缺点:它首先查找指定的php包含路径,然后查找当前目录.因此会检查过多路径.如果该脚本...
- PHP代码中常用的优化策略(php性能优化及安全策略)
-
1、如果能将类的方法定义成static,就尽量定义成static,它的速度会提升将近4倍。2、$row['id']的速度是$row[id]的7倍。3、echo比print快,并...
- PHP 没你想的那么差(php ml)
-
PHP现在名声很糟糕,因为它曾经是“可怕”的。本文试着回答一些常见的关于PHP的断言,目的是向非技术人员解释,PHP并不像许多人所说的那么糟糕。它是不是鼓励糟糕的实践?不再是了。过去,许多开发者...
- 一周热门
- 最近发表
- 标签列表
-
- 外键约束 oracle (36)
- oracle的row number (32)
- 唯一索引 oracle (34)
- oracle in 表变量 (28)
- oracle导出dmp导出 (28)
- oracle两个表 (20)
- oracle 数据库 字符集 (20)
- oracle安装补丁 (19)
- matlab化简多项式 (20)
- 多线程的创建方式 (29)
- 多线程 python (30)
- java多线程并发处理 (32)
- 宏程序代码一览表 (35)
- c++需要学多久 (25)
- c语言编程小知识大全 (17)
- css class选择器用法 (25)
- css样式引入 (30)
- html5和css3新特性 (19)
- css教程文字移动 (33)
- php简单源码 (36)
- php个人中心源码 (25)
- 网站管理平台php源码 (19)
- php小说爬取源码 (23)
- github好玩的php项目 (18)
- 云电脑app源码 (22)