最近火到出圈的一篇数学论文,到底说了什么?为什么能掀起波澜?
moboyou 2025-05-18 14:31 2 浏览
四月,arXiv上出现了一篇题为《KAN: Kolmogorov-Arnold Networks》的论文。该论文获得约5000个赞,对于一篇学术论文来说,可谓是相当火爆。随附的GitHub库已有7600多个星标,且数字还在持续增长。
Kolmogorov-Arnold 网络(KAN)是一种全新的神经网络构建块。它比多层感知器(MLP)更具表达力、更不易过拟合且更易于解释。多层感知器在深度学习模型中无处不在。例如,我们知道它们被用于GPT-2、3以及(可能的)4等模型的Transformer模块之间。对MLP的改进将对机器学习世界产生广泛的影响。
MLP
MLP实际上是一种非常古老的架构,可以追溯到50年代。其设计初衷是模仿大脑结构;由许多互联的神经元组成,这些神经元将信息向前传递,因此得名前馈网络(feed-forward network)。
MLP通常通过类似上图的示意图来展示。对于外行来说,这很有用,但在我看来,它并没有传达出真正正在发生的事情的深刻理解。用数学来表示它要容易得多。
假设有一些输入x和一些输出y。一个两层的MLP将如下所示:
其中W是可学习权重的矩阵,b是偏差向量。函数f是一个非线性函数。看到这些方程,很明显,一个MLP是一系列带有非线性间隔的线性回归模型。这是一个非常基本的设置。
尽管基本,但它表达力极强。有数学保证,MLP是通用逼近器,即:它们可以逼近任何函数,类似于所有函数都可以用泰勒级数来表示。
为了训练模型的权重,我们使用了反向传播(backpropagation),这要归功于自动微分(autodiff)。我不会在这里深入讨论,但重要的是要注意自动微分可以对任何可微函数起作用,这在后面会很重要。
MLP的问题
MLP在广泛的用例中被使用,但存在一些严重的缺点。
- 因为它们作为模型极其灵活,可以很好地适应任何数据。结果,它们很可能过拟合。
- 模型中往往包含大量的权重,解释这些权重以从数据中得出结论变得非常困难。我们常说深度学习模型是“黑盒”。
- 拥有大量的权重还意味着它们的训练可能会很长,GPT-3的大部分参数都在MLP层中。
Kolmogorov-Arnold 网络
Kolmogorov-Arnold 表示定理
Kolmogorov-Arnold 表示定理的目标类似于支撑MLP的通用逼近定理,但前提不同。它本质上说,任何多变量函数都可以用1维非线性函数的加法来表示。例如:向量v=(x1, x2)的除法运算可以用对数和指数代替:
为什么这会有用呢?这究竟实现了什么?
这为我们提供了一种不同但简单的范式来开始构建神经网络架构。作者声称,这种架构比使用多层感知器(MLP)更易于解释、更高效地使用参数,并且具有更好的泛化能力。在MLP中,非线性函数是固定的,在训练过程中从未改变。而在KAN中,不再有权重矩阵或偏差,只有适应数据的一维非线性函数。然后将这些非线性函数相加。我们可以堆叠越来越多的层来创建更复杂的函数。
B样条(B-splines)
在KAN中表示非线性的方式中有一点重要的是需要注意的。与MLP中明确定义的非线性函数(如ReLU()、Tanh()、silu()等)不同,KAN的作者使用样条。这些基本上是分段多项式。它们源自计算机图形领域,在该领域中,过度参数化并不是一个问题。
样条解决了在多个点之间平滑插值的问题。如果你熟悉机器学习理论,你会知道要在n个数据点之间完美插值,需要一个n-1阶的多项式。问题是高阶多项式可能变得非常曲折,看起来不平滑。
- 10个数据点被一个9阶多项式完美拟合
通过将分段多项式函数适应于数据点之间的部分,样条解决了这个问题。这里我们使用三次样条。
- 三次样条插值更好,但不能泛化
对于三次样条(样条的一种类型),为了确保平滑,需要在数据点(或结点)的位置对一阶和二阶导数设置约束。数据点两侧的曲线必须在数据点处具有匹配的一阶导数和二阶导数。
KAN使用的是B样条,另一种类型的样条,具有局部性(移动一个点不会影响曲线的整体形状)和匹配的二阶导数(也称为C2连续性)的特性。这样做的代价是实际上不会通过这些点(除了在极端情况下)。
- 3条B样条对应5个数据点。注意曲线是如何不通过数据点的。
在机器学习中,特别是在应用于物理学时,不经过每一个数据点是可以接受的,因为我们预计测量会有噪声。
这就是在KAN的计算图的每一个边缘发生的事情。一维数据用一组B样条进行拟合。
进入KAN
因此,现在我们在计算图的每个边缘都有一个分段的参数曲线。在每个节点,这些曲线被求和:我们之前看到,可以通过这种方式逼近任何函数。
为了训练这样的模型,我们可以使用标准的反向传播。在这种情况下,作者使用的是LBFGS(Limited-memory
Broyden-Fletcher-Goldfarb-Shanno),这是一种二阶优化方法(与Adam这种一阶方法相比)。另一个需要注意的细节是:在每个代表一维函数的边上,有一个B样条,但作者还增加了一个非线性函数:silu函数。
对此的解释不是很清楚,但很可能是由于梯度消失(这是我的猜测)。
我们来试用一下
我打算使用作者提供的代码,它运行得非常出色,有许多示例可以帮助我们更好地理解它。
他们使用由以下函数生成的合成数据:
定义模型
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
这里定义了三个参数:
- 宽度,其定义方式与多层感知器(MLP)类似:一个列表,其中每个元素对应一个层,元素值是该层的宽度。在这种情况下,有三层;输入维度为2,有5个隐藏维度,输出维度为1
- 网格与B样条相关,它描述了数据点之间的网格可以有多细致。增加这个参数可以创建更多曲折的函数。
- k是B样条的多项式阶数,一般来说,三次曲线是个不错的选择,因为三次曲线对样条有很好的属性。
- seed,随机种子:样条的权重用高斯噪声随机初始化(就像在常规MLP中一样)。
训练
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.0)
该库的API非常直观,我们可以看到我们正在使用LBFGS优化器,训练20步。接下来的两个参数与网络的正则化相关。
训练后的下一步是修剪模型,这会移除低于相关性阈值的边和节点,完成后建议重新训练一下。然后将每个样条边转换为符号函数(log、exp、sin等)。这可以手动或自动完成。库提供了一个极好的工具,借助model.plot()方法可以看到模型内部的情况。
# Code to fit symbolic functions to the fitted splines
if mode == "manual":
# manual mode
model.fix_symbolic(0, 0, 0, "sin")
model.fix_symbolic(0, 1, 0, "x^2")
model.fix_symbolic(1, 0, 0, "exp")
elif mode == "auto":
# automatic mode
lib = ["x", "x^2", "x^3", "x^4", "exp", "log", "sqrt", "sin", "abs"]
model.auto_symbolic(lib=lib)
一旦在每个边上设置了符号函数,就会进行最终的再训练,以确保每个边的仿射参数是合理的。
整个训练过程在下面的图表中总结。
- 使用KAN进行符号回归的示例。图片来自论文。
完整的训练代码如下所示:
# Define the model
model = KAN(width=[2, 5, 1], grid=5, k=3, seed=0)
# First training
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.0)
# Prune edges that have low importance
model = model.prune()
# Retrain the pruned model with no regularisation
model.train(dataset, opt="LBFGS", steps=50)
# Find the symbolic functions
model.auto_symbolic(lib=["x", "x^2", "x^3", "x^4", "exp", "log", "sqrt", "sin", "abs"])
# Find the afine parameters of the fitted functions without regularisation
model.train(dataset, opt="LBFGS", steps=50)
# Display the resultant equation
model.symbolic_formula()[0][0] # Print the resultant symbolic function
一些思考
模型中有相当多的超参数可以调整。这些可以产生非常不同的结果。例如,在上面的示例中:将隐藏神经元的数量从5改为6意味着KAN找不到正确的函数。
在机器学习中,“超参数”(hyperparameters)是指那些在学习过程开始之前需要设置的参数。这些参数控制着训练过程的各个方面,但它们并不是通过训练数据自动学习得到的。超参数的设置对模型的性能和效率有着重要的影响。
- 由KAN[2,6,1]找到的结果函数
这种变化性是预期的,因为这种架构是全新的。花了几十年时间,人们才找到了调整MLP超参数(如学习率、批大小、初始化等)的最佳方式。
结论
MLP已经存在很长时间了,早该升级了。我们知道这种改变是可能的,大约6年前,LSTMs在序列建模中无处不在,后来被transformers作为标准的语言模型架构构建块所取代。如果MLP也能发生这种变化,那将是令人兴奋的。另一方面,这种架构仍然不稳定,而且运行效果并不是非常出色。时间将告诉我们,否能找到一种方法来绕过这种不稳定性并释放KAN的真正潜力,或者KAN是否会被遗忘,成为机器学习的一个小知识点。
我对这种新架构感到非常兴奋,但我也持怀疑态度。
相关推荐
- 人工智能所有必要的数学概念:机器学习和深度学习
-
人工智能和数学之间的这种联系的快速概述是:缺乏数学技能的人工智能专家相当于缺乏说服力的政治家。每个人都有一个需要关注的领域!我不会进一步详细说明理解数学对AI的重要性,而是直奔本文的要点。为AI...
- 「数学」微分方程第一步,吃透概念-复数,多项式方程及矩阵理论
-
最近我开启了“量子力学之路”系列,旨在从数理角度从零解释量子力学。正如我在系列的第一篇文章量子力学之路——坚实的数理基础至关重要,没有捷径可走中提到的那样,学习量子力学有一些先决条件,而一些先决条件并...
- 量子计算(七):量子系统
-
量子系统前言对于一个非物理专业的人而言,量子力学概念晦涩难懂。鉴于此,本文仅介绍量子力学的一些基础概念加之部分数学的相关知识,甚至不涉及薛定谔方程,就足够开始量子计算机的应用。这如同不需去了解CPU的...
- 什么是正定矩阵?它的几何解释有助于我们直观地理解它。
-
正定矩阵定义为每个特征值为正的对称矩阵。好吧,但你可能想知道,“我们为什么要定义这样的东西?它在某种程度上有用吗?为什么特征值的符号很重要?”这很好,但是你能提供更多的想法来支持它吗?正定矩阵的几何解...
- 实对称矩阵的几个性质
-
实对称矩阵是一种非常重要的矩阵,这里列出它的几个重要性质,以供参考:证明过程中用到的方法就是取转置和共轭,以及两个复数乘积的共轭等于两个复数共轭的乘积的性质。因为A是对称阵,所以A可以相似对角化,A=...
- 三分钟秒懂矩阵的所有概念
-
(1)矩阵矩阵就像是一幅由许多小格子组成的画,每个格子都是一个颜色或图案。比如,一个17x11的矩阵画就是一个17行11列的画,每个小格子都有不同的颜色或图案。(2)矩阵的秩秩就像是画中的“独立颜色数...
- 大一新生开发的小工具火了!可视化Python编程体验了解一下
-
鱼羊发自凹非寺量子位报道|公众号QbitAI普普通通黑底白字地敲代码太枯燥?那么,把Python脚本可视化怎么样?就像这样,从输入图片、调整尺寸到双边滤波,每一步都能看得清清楚楚明明白白。...
- Python 数据分析——SciPy 线性代数-linalg
-
NumPy和SciPy都提供了线性代数函数库linalg,SciPy的线性代数库比NumPy更加全面。一、解线性方程组numpy.linalg.solve(A,b)和scipy.linalg.sol...
- 广义切比雪夫滤波器函数综合
-
主要分享《通信系统微波滤波器——基础、设计与应用》书籍中相关章节的个人理解与感悟,如有错误欢迎批评指正!这一节主要计算广义切比雪夫滤波器的多项式函数。如果一个二端口网络是无耗并且互易的,则S参数矩阵可...
- 基于基扩展模型的LTE-R信道估计算法
-
邓玲,陈忠辉,赵宜升(福州大学物理与信息工程学院,福建福州350108)摘要:针对LTER通信系统,对快时变信道估计问题进行了研究。采用基扩展模型对高速铁路通信环境的快时变信道进行拟合,将信道冲...
- 一种基于相干波束形成的零陷加宽算法
-
摘要:针对干扰信号和期望信号相干导致“干扰欠相消”以及由于干扰扰动而无法去除的问题,提出了一种基于前后向空间平滑的零陷加宽算法。该算法首先通过前后空间平滑方法去相干,并利用最佳下降的递推方法求得最...
- [高等数学] 矩阵的奇异值分解的详细证明及计算实例
-
[高等数学]矩阵的奇异值分解的详细证明及计算实例目录1定义及介绍2详细证明3计算实例4程序正文1定义及介绍奇异值分解(SingularValueDecomposition,SVD)是...
- 运动控制功能开了挂的S7-200 SMART V3-凸轮功能
-
1、S7-200SMARTV3凸轮功能限制2、组态凸轮电子凸轮根据预定义的电子凸轮表,使用脉冲串控制从轴与主轴同步凸轮表是一份数据表,用于指定跟随主轴移动的从轴的位置。水平轴代表主轴相位,而垂直轴...
- 平均7倍实测加速,MIT提出高效、硬件友好的三维深度学习方法
-
机器之心发布机器之心编辑部随着三维深度学习越来越成为近期研究的热点,基于栅格化的数据处理方法也越来越受欢迎。但这种处理方法往往受限于高分辨下巨大的内存和计算开销,因此麻省理工学院HANLab的研...
- Python数学建模系列(四):数值逼近
-
若文中数学公式显示有问题可查看文章原文Python数学建模系列(四):数值逼近菜鸟学习记:第四十二天1.一维插值插值:求过已知有限个数据点的近似函数。插值函数经过样本点,拟合函数一般基于最小二乘法...
- 一周热门
- 最近发表
- 标签列表
-
- curseforge官网网址 (16)
- 外键约束 oracle (36)
- oracle的row number (32)
- 唯一索引 oracle (34)
- oracle in 表变量 (28)
- oracle导出dmp导出 (28)
- oracle 数据导出导入 (16)
- oracle两个表 (20)
- oracle 数据库 使用 (12)
- 启动oracle的监听服务 (13)
- oracle 数据库 字符集 (20)
- powerdesigner oracle (13)
- oracle修改端口 (15)
- 左连接 oracle (15)
- oracle 标准版 (13)
- oracle 转义字符 (14)
- asp 连接 oracle (12)
- oracle安装补丁 (19)
- matlab三维图 (12)
- matlab归一化 (16)
- matlab求解方程 (13)
- matlab脚本 (14)
- matlab多项式拟合 (13)
- matlab阶跃函数 (14)
- 三次样条插值matlab (14)