行业报告 AI展会 数据标注 标注供求
数据标注数据集
主页 > 机器学习 正文

如何在Keras中创建自定义损失函数?

如何在Keras中创建自定义损失函数?
Dhruv Deshmukh 发布在 Unsplash 上的照片

我们使用损失函数来计算一个给定的算法与它所训练的数据的匹配程度。损失计算是基于预测值和实际值之间的差异来做的。如果预测值与实际值相差甚远,损失函数将得到一个非常大的数值。

Keras 是一个创建神经网络的库,它是开源的,用 Python 语言编写。Keras 不支持低级计算,但它运行在诸如 Theano 和 TensorFlow 之类的库上。

在本教程中,我们将使用 TensorFlow 作为 Keras backend。backend 是一个 Keras 库,用于执行计算,如张量积、卷积和其他类似的活动。

如何在Keras中创建自定义损失函数?
Karim MANJRA 发布在 Unsplash 上的照片

keras 中常用的损失函数

如上所述,我们可以创建一个我们自己的自定义损失函数;但是在这之前,讨论现有的 Keras 损失函数是很好的。下面是两个最常用的:

  • 均方误差

均方误差(MSE)测量误差平方的平均值。它是预测值和实际值之间的平均平方差。

  • 平均绝对误差

平均绝对误差(MAE)是两个连续变量之间差的度量,通常用 x 和 y 表示。平均绝对误差是绝对误差 e=y-x 的平均值,其中 y 是预测值,x 是实际值。

什么是自定义损失函数?

对于不同的损失函数,计算损失的公式有不同的定义。在某些情况下,我们可能需要使用 Keras 没有提供的损失计算公式。在这种情况下,我们可以考虑定义和使用我们自己的损失函数。这种用户定义的损失函数称为自定义损失函数。

Keras 中的自定义损失函数可以以我们想要的方式提高机器学习模型的性能,并且对于更有效地解决特定问题非常有用。例如,假设我们正在构建一个股票投资组合优化模型。在这种情况下,设计一个定制损失函数将有助于实现对在错误方向上预测价格变动的巨大惩罚。

我们可以通过编写一个返回标量并接受两个参数(即真值和预测值)的函数,在 Keras 中创建一个自定义损失函数。然后,我们将自定义损失函数传递给 model.compile 作为参数,就像处理任何其他损失函数一样。

实现自定义损失函数

现在让我们为我们的 Keras 模型实现一个自定义的损失函数。首先,我们需要定义我们的 Keras 模型。我们的模型实例名是 keras_model,我们使用 keras 的 sequential()函数来创建模型。

我们有三个层,都是形状为 64、64 和 1 的密集层。我们有一个为 1 的输入形状,我们使用 ReLU 激活函数(校正线性单位)。

如何在Keras中创建自定义损失函数?

一旦定义了模型,我们就需要定义我们的自定义损失函数。其实现如下所示。我们将实际值和预测值传递给这个函数。

注意,我们将实际值和预测值的差除以 10,这是损失函数的自定义部分。在缺省损失函数中,实际值和预测值的差值不除以 10。

记住,这完全取决于你的特定用例需要编写什么样的自定义损失函数。在这里我们除以 10,这意味着我们希望在计算过程中降低损失的大小。

在 MSE 的默认情况下,损失的大小将是此自定义实现的 10 倍。因此,当我们的损失值变得非常大并且计算变得非常昂贵时,我们可以使用这种定制的损失函数。

在这里,我们从这个函数返回一个标量自定义损失值。

如何在Keras中创建自定义损失函数?

定义 keras 的自定义损失函数

要进一步使用自定义损失函数,我们需要定义优化器。我们将在这里使用 RMSProp 优化器。RMSprop 代表均方根传播。RMSprop 优化器类似于具有动量的梯度下降。常用的优化器被命名为 rmsprop、Adam 和 sgd。

我们需要将自定义的损失函数和优化器传递给在模型实例上调用的 compile 方法。然后我们打印模型以确保编译时没有错误。

如何在Keras中创建自定义损失函数?

Keras 模型优化器和编译模型

现在是时候训练这个模型,看看它是否正常工作了。为此,我们在模型上使用拟合方法,传递自变量 x 和因变量 y 以及 epochs=100。

这里的目的是确保模型训练没有任何错误,并且随着 epoch 数的增加,损失逐渐减少。你可以查看下图中的模型训练的结果:

如何在Keras中创建自定义损失函数?

epoch=100 的 Keras 模型训练

结语

在本文中,我们了解了什么是自定义损失函数,以及如何在 Keras 模型中定义一个损失函数。然后,我们使用自定义损失函数编译了 Keras 模型。最后,我们成功地训练了模型,实现了自定义损失功能。

微信公众号

声明:本站部分作品是由网友自主投稿和发布、编辑整理上传,对此类作品本站仅提供交流平台,转载的目的在于传递更多信息及用于网络分享,并不代表本站赞同其观点和对其真实性负责,不为其版权负责。如果您发现网站上有侵犯您的知识产权的作品,请与我们取得联系,我们会及时修改或删除。

网友评论:

发表评论
请自觉遵守互联网相关的政策法规,严禁发布色情、暴力、反动的言论。
评价:
表情:
用户名: 验证码:点击我更换图片
SEM推广服务

Copyright©2005-2028 Sykv.com 可思数据 版权所有    京ICP备14056871号

关于我们   免责声明   广告合作   版权声明   联系我们   原创投稿   网站地图  

可思数据 数据标注

扫码入群
扫码关注

微信公众号

返回顶部