博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
BatchNormalization总结
阅读量:4154 次
发布时间:2019-05-25

本文共 1407 字,大约阅读时间需要 4 分钟。

目录

1、Internal Covariate Shift现象

网络在训练过程中,中间层的权重在不断发生变化,导致该层输出数据的分布发生改变,这种数据分布的改变称为’Internal Covariate Shift’。

2、BatchNormalization算法

为了解决Internal Covariate Shift现象,Sergey Ioffe and Christian Szegedy提出了BatchNormalization算法,This operation simply zero-centers and normalizes each input, then scales and shifts the result using two new parameter vectors per layer: one for scaling, the other for shifting,具体步骤如下所示:

在这里插入图片描述

  • μ B是输入均值向量,该向量元素个数等于输入个数(或者输入通道数),该值是在一个batch上计算得到的;
  • σ B是输入标准差向量,该向量元素个数等于输入个数(或者输入通道数),该值是在一个batch上计算得到的;
  • mB是batch size大小;
  • 在这里插入图片描述是一个批次中第i个输入实例的零中心和归一化后的值;
  • γ为缩放因子向量,该向量元素个数等于输入个数(或者输入通道数),在训练过程中不断学习更新;
  • β为平移因子向量,该向量元素个数等于输入个数(或者输入通道数),在训练过程中不断学习更新;
  • z(i)是一个批次中第i个输入实例经过缩放和平移后的结果。

在推理过程中,输入的batch size可能很小,甚至是1,因此求得的均值和标准差不具有代表性,实际此时输入的均值 μ B和标准差σ B实际采用的是整个训练集的均值和标准差,而这两个值在训练过程中是随着训练的推进逐渐迭代更新的,而不是训练完后再在整个数据集上计算均值和标准差,这两个在训练过程中不断迭代更新的均值和标准差分别称为moving_mean, moving_variance,是由训练集决定的,不需训练。

3、BatchNormalization层的参数量

由以上所述可知,一个BatchNormalization层涉及到的参数包含在四个矩阵(向量)中,分别是gamma, beta, moving_mean, moving_variance,这四个矩阵的元素数都是相同的,即等于输入个数(或者输入通道数),其中一半是可训练的,而另一半是不需要训练的。

4、BatchNormalization层的计算量

按下式计算:

FLOPs = 2*C*W*H

C是BatchNormalization层的输入通道数,W和H分别是BatchNormalization层输入特征图的宽和高分辨率,2指的是包括乘法和加法。

实际应用中往往将BatchNormalization层合并到卷积层,而不单独考虑其计算量。

5、BatchNormalization层的优点

  • BatchNormalization提高了模型的泛化能力,采用BatchNormalization层后,可以考虑去掉较早的dropout或l2等正则化方法;
  • 提高训练速度,采用BatchNormalization层后。可以设置一个较大的初始学习率,而且在此基础上的训练速度也会大大提高。

转载地址:http://harti.baihongyu.com/

你可能感兴趣的文章
异常 Java学习Day_15
查看>>
Mysql初始化的命令
查看>>
MySQL关键字的些许问题
查看>>
浅谈HTML
查看>>
css基础
查看>>
Servlet进阶和JSP基础
查看>>
servlet中的cookie和session
查看>>
过滤器及JSP九大隐式对象
查看>>
软件(项目)的分层
查看>>
【Python】学习笔记——-7.0、面向对象编程
查看>>
【Python】学习笔记——-7.2、访问限制
查看>>
【Python】学习笔记——-7.3、继承和多态
查看>>
【Python】学习笔记——-7.5、实例属性和类属性
查看>>
git中文安装教程
查看>>
虚拟机 CentOS7/RedHat7/OracleLinux7 配置静态IP地址 Ping 物理机和互联网
查看>>
Jackson Tree Model Example
查看>>
常用js收集
查看>>
如何防止sql注入
查看>>
springmvc传值
查看>>
在Eclipse中查看Android源码
查看>>