首页 基于神经网络的工具变量生成与反事实推理方法及装置

基于神经网络的工具变量生成与反事实推理方法及装置

举报
开通vip

基于神经网络的工具变量生成与反事实推理方法及装置(19)中华人民共和国国家知识产权局(12)发明专利申请(10)申请公布号CN112633503A(43)申请公布日2021.04.09(21)申请号202011493947.2(22)申请日2020.12.16(71)申请人浙江大学地址310058浙江省杭州市西湖区余杭塘路866号(72)发明人况琨 袁俊坤 吴飞 林兰芬 (74)专利代理机构杭州求是专利事务所有限公司33200代理人傅朝栋 张法高(51)Int.Cl.G06N5/04(2006.01)G06N3/04(2006.01)权利要求书4页说明书12页附...

基于神经网络的工具变量生成与反事实推理方法及装置
(19)中华人民共和国国家知识产权局(12)发明专利MATCH_ word word文档格式规范word作业纸小票打印word模板word简历模板免费word简历 _1713972055153_2(10)申请公布号CN112633503A(43)申请公布日2021.04.09(21)申请号202011493947.2(22)申请日2020.12.16(71)申请人浙江大学地址310058浙江省杭州市西湖区余杭塘路866号(72)发明人况琨 袁俊坤 吴飞 林兰芬 (74)专利代理机构杭州求是专利事务所有限公司33200代理人傅朝栋 张法高(51)Int.Cl.G06N5/04(2006.01)G06N3/04(2006.01)权利要求书4页说明书12页附图2页(54)发明名称基于神经网络的工具变量生成与反事实推理方法及装置(57)摘要本发明公开了一种基于神经网络的工具变量生成与反事实推理方法及装置。针对之前的基于工具变量的反事实推理(如手写数字识别)方法需要预先定义和可获取的工具变量的问 快递公司问题件快递公司问题件货款处理关于圆的周长面积重点题型关于解方程组的题及答案关于南海问题 ,本发明直接从可观测变量中学习和解耦出工具变量,大大提升了因果推断效率,节省了时间和成本。本发明首次自动地从可观测变量中提取出工具变量,在算法和运用上有独创性和独特性。将本发明应用于现有的基于工具变量的反事实预测方法,与使用真实工具变量的方法相比性能因果推断有明显提升。本发明着重于从可观测变量中解耦出工具变量的表征,解决了基于工具变量的反事实预测技术需要预先使用先验知识和高昂成本获取工具变量数据的难题,提升了手写数字识别等领域精度。CN112633503ACN112633503A权 利 要 求 书1/4页1.一种基于神经网络的工具变量生成与反事实推理方法,其特征在于,包括如下步骤:S1:获取手写数字图片数据作为干预,获取手写数字图片的标签数据作为结果,将手写数字图片和标签构建成反事实预测数据集;S2:使用互信息约束的方法,对工具变量和其他协变量的表征设置约束,用于进行初步的表征学习;S3:基于两阶段反事实预测技术设置额外约束,用于对初步学习到的解耦表征进一步优化;S4:基于所述的反事实预测数据集,通过交替优化S2和S3中设置的约束,获得优化后的工具变量和其他协变量的表征模型;S5:针对待识别的手写数字图片,利用优化后的表征模型,得到工具变量和其他协变量的表征,并将其应用于基于工具变量的反事实预测模型中,输出手写数字图片中手写数字的识别结果。2.如权利要求1所述的基于神经网络的工具变量生成与反事实推理方法,其特征在于,步骤S1中,所述反事实预测数据集表示为其中vi,xi,yi分别为第i个样本的可观测变量、干预和结果,其中样本的可观测变量以该样本对应的手写数字图片本身代替,N为样本总数。3.如权利要求2所述的基于神经网络的工具变量生成与反事实推理方法,其特征在于,所述的步骤S2具体包括以下子步骤:S201:基于神经网络构建以可观测变量V为输入以工具变量Z为输出的第一表征模型φZ(.),同时基于神经网络构建以可观测变量V为输入以其他协变量C为输出的第二表征模型φC(.);S202:基于神经网络构建以工具变量Z为输入以干预变量X为输出的第一约束网络fZX(.),设定第一约束网络的损失函数为:ZXZ其中:为第一约束网络f(·)中以φ(vi)为输入去预测xi时得到的变ZZ分分布;φ(vi)为第一表征模型φ(·)中输入vi时得到的输出结果;log表示对数似然函数;另外,针对第一约束网络设定互信息最大化损失函数为:S203:基于神经网络构建以工具变量Z为输入以结果变量Y为输出的第二约束网络fZY(·),设定第二约束网络的损失函数为:ZYZ其中:为第二约束网络f(·)中以φ(vi)为输入去预测yi时得到的变2CN112633503A权 利 要 求 书2/4页分分布;另外,针对第二约束网络设定互信息最大化损失函数为:其中:ωij为由第i个样本的干预xi和第j个样本的干预xj之间距离决定的权重;S204:基于神经网络构建以其他协变量C为输入以干预变量X为输出的第三约束网络fCX(·),设定第三约束网络的损失函数为:CXC其中:为第三约束网络f(·)中以φ(vi)为输入去预测xi时得到的变CC分分布;φ(vi)表示第二表征模型φ(·)中输入vi时得到的输出结果;另外,针对第三约束网络设定互信息最大化损失函数为:S205:基于神经网络构建以其他协变量C为输入以结果变量Y为输出的第四约束网络fCY(·),设定第四约束网络的损失函数为:CYC其中:为第四约束网络f(·)中以φ(vi)为输入去预测yi时得到的变分分布;另外,针对第四约束网络设定互信息最大化损失函数为:S206:基于神经网络构建以工具变量Z为输入以其他协变量C为输出的第五约束网络fZC(·),设定第五约束网络的损失函数为:ZCZC其中:为第五约束网络f(·)中以φ(vi)为输入去预测φ(vi)时得到的变分分布;另外,针对第五约束网络设定互信息最大化损失函数为:4.如权利要求3所述的基于神经网络的工具变量生成与反事实推理方法,其特征在于,步骤S203中,所述权重ωij通过RBF核函数计算,公式如下:3CN112633503A权 利 要 求 书3/4页其中σ是一个用于调节的超参数。5.如权利要求4所述的基于神经网络的工具变量生成与反事实推理方法,其特征在于,所述的步骤S3具体包括以下子步骤:ZCS301:基于神经网络构建以工具变量Z的表征φ(vi)和其他协变量C的表征φ(vi)为输入以干预变量X为输出的第一阶段回归网络fX(·),并设定第一阶段回归网络的损失函数为:其中l(·)表示计算平方误差;CS302:基于神经网络构建以和其他协变量C的表征φ(vi)为输入以结果变量Y为输出的第二阶段回归网络fY(·),并设定第二阶段回归网络的损失函数为:其中:femb(·)为用于扩充干预变量维度的映射网络,表示第一阶段回归网络fX(·)输出的干预变量X估计值,6.如权利要求5所述的基于神经网络的工具变量生成与反事实推理方法,其特征在于,所述的步骤S4具体包括以下步骤:S401:将所有五个约束网络的损失函数进行整合得到综合损失函数:利用所述反事实预测数据集对五个约束网络进行训练,通过最小化所述综合损失函数分别优化各约束网络中的网络参数;S402:将所有五个约束网络的互信息最大化损失函数进行整合得到综合互信息损失函数:其中:α、β、∈、η是权重超参数;利用所述反事实预测数据集对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化所述综合互信息损失函数分别优化两个表征模型中的网络参数;S403:利用所述反事实预测数据集对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化所述第一阶段回归网络的损失函数优化第一阶段回归网络以及两个表征模型中的网络参数;S404:利用所述反事实预测数据集继续对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化所述第二阶段回归网络的损失函数优化第二阶段回归网络、映射网络以及两个表征模型中的网络参数;S405:不断迭代重复S401~S405,使被用于交替训练对应的网络参4CN112633503A权 利 要 求 书4/4页数,直至迭代终止,得到参数优化后的第一表征模型φ′Z(·)和第二表征模型φ′C(·)。7.如权利要求5所述的基于神经网络的工具变量生成与反事实推理方法,其特征在于,所述的步骤S5具体包括以下步骤:S51:针对待识别的目标手写数字图片,将目标手写数字图片作为可观测变量,输入参数优化后的第一表征模型φ′Z(·)和第二表征模型φ′C(·)中,得到工具变量Z的表征和其他协变量C的表征;S52:将目标手写数字图片以及S51中得到的工具变量Z的表征和其他协变量C的表征,一并输入经过训练的基于工具变量的反事实预测模型中,输出目标手写数字图片中手写数字的识别结果。8.如权利要求1所述的基于神经网络的工具变量生成与反事实推理方法,其特征在于,所述基于工具变量的反事实预测模型为2SLS、Deep IV、Kernel IV或DeepGMM模型。9.一种基于深度网络的工具变量解耦与手写数字识别装置,其特征在于,包括存储器和处理器;所述存储器,用于存储计算机程序;所述处理器,用于当执行所述计算机程序时,实现如权利要求1~8任一项所述的基于神经网络的工具变量生成与反事实推理方法。5CN112633503A说 明 书1/12页基于神经网络的工具变量生成与反事实推理方法及装置技术领域[0001]本发明涉及因果推断领域,尤其涉及一种自动的工具变量解耦方法,实现可直接从可观测变量中提取出工具变量的反事实预测方法,从而提升手写数字识别的效率和精度。背景技术[0002]因果推断致力于对干预产生的反事实结果进行估计,辅助决策者进行选择,以达到使得结果最优化的目标。因果推断的黄金方法是使用随机控制实验随机分配干预值进行因果推断,但是此类方法的成本过高甚至无法实现。一些方法通过加权、匹配的方式来对影响因果推断的混淆变量进行约束的目的,但是此类方法仅仅只能在混淆完全可观测的情形下使用,当混淆存在不可观测的情况下该类方法仍然存在较大缺陷。[0003]工具变量提出用来解决不可观测的混淆问题,它和干预变量相关同时和结果变量条件独立。当下的基于工具变量的因果推断方法都需要一个预先定义的工具变量,但是这在现实情况下往往并不实用。如何直接从所有可观测变量中解耦出工具变量,并自动地进行因果推断是一个亟待解决的问题。[0004]手写数字识别作为因果推断的一个典型应用领域,其存在同样的技术问题。针对手写数字的识别,如何通过自动的工具变量解耦,获取仅仅和标签条件相关的工具变量信息,从而辅助手写数字识别以达到最大的精度,是本发明需要解决的主要技术问题。发明内容[0005]本发明的目的是解决当下基于工具变量的因果推断技术手写数字识别需要预先定义的工具变量这个问题,提出一种基于神经网络的工具变量生成与反事实推理方法及装置,它能够直接从可观测的变量中解耦出工具变量,实现自动工具变量解耦和因果推断从而提升手写数字识别的效率和精度。[0006]本发明具体采用的技术 方案 气瓶 现场处置方案 .pdf气瓶 现场处置方案 .doc见习基地管理方案.doc关于群访事件的化解方案建筑工地扬尘治理专项方案下载 如下:[0007]一种基于神经网络的工具变量生成与反事实推理方法,其包括如下步骤:[0008]S1:获取手写数字图片数据作为干预,获取手写数字图片的标签数据作为结果,将手写数字图片和标签构建成反事实预测数据集;[0009]S2:使用互信息约束的方法,对工具变量和其他协变量的表征设置约束,用于进行初步的表征学习;[0010]S3:基于两阶段反事实预测技术设置额外约束,用于对初步学习到的解耦表征进一步优化;[0011]S4:基于所述的反事实预测数据集,通过交替优化S2和S3中设置的约束,获得优化后的工具变量和其他协变量的表征模型;[0012]S5:针对待识别的手写数字图片,利用优化后的表征模型,得到工具变量和其他协变量的表征,并将其应用于基于工具变量的反事实预测模型中,输出手写数字图片中手写6CN112633503A说 明 书2/12页数字的识别结果。[0013]作为优选,步骤S1中,所述反事实预测数据集表示为其中vi,xi,yi分别为第i个样本的可观测变量、干预和结果,其中样本的可观测变量以该样本对应的手写数字图片本身代替,N为样本总数。[0014]进一步的,所述的步骤S2具体包括以下子步骤:[0015]S201:基于神经网络构建以可观测变量V为输入以工具变量Z为输出的第一表征模型φZ(·),同时基于神经网络构建以可观测变量V为输入以其他协变量C为输出的第二表征模型φC(·);[0016]S202:基于神经网络构建以工具变量Z为输入以干预变量X为输出的第一约束网络fZX(·),设定第一约束网络的损失函数为:[0017][0018]ZXZ其中:为第一约束网络f(·)中以φ(vi)为输入去预测xi时得到ZZ的变分分布;φ(vi)为第一表征模型φ(·)中输入vi时得到的输出结果;log表示对数似然函数;[0019]另外,针对第一约束网络设定互信息最大化损失函数为:[0020][0021]S203:基于神经网络构建以工具变量Z为输入以结果变量Y为输出的第二约束网络fZY(·),设定第二约束网络的损失函数为:[0022][0023]ZYZ其中:为第二约束网络f(·)中以φ(vi)为输入去预测yi时得到的变分分布;[0024]另外,针对第二约束网络设定互信息最大化损失函数为:[0025][0026]其中:ωij为由第i个样本的干预xi和第j个样本的干预xj之间距离决定的权重;[0027]S204:基于神经网络构建以其他协变量C为输入以干预变量X为输出的第三约束网络fCX(·),设定第三约束网络的损失函数为:[0028][0029]CXC其中:为第三约束网络f(·)中以φ(vi)为输入去预测xi时得到7CN112633503A说 明 书3/12页CC的变分分布;φ(vi)表示第二表征模型φ(·)中输入vi时得到的输出结果;[0030]另外,针对第三约束网络设定互信息最大化损失函数为:[0031][0032]S205:基于神经网络构建以其他协变量C为输入以结果变量Y为输出的第四约束网络fCY(·),设定第四约束网络的损失函数为:[0033][0034]CYC其中:为第四约束网络f(·)中以φ(vi)为输入去预测yi时得到的变分分布;[0035]另外,针对第四约束网络设定互信息最大化损失函数为:[0036][0037]S206:基于神经网络构建以工具变量Z为输入以其他协变量C为输出的第五约束网络fZC(·),设定第五约束网络的损失函数为:[0038][0039]ZCZC其中:为第五约束网络f(·)中以φ(vi)为输入去预测φ(vi)时得到的变分分布;[0040]另外,针对第五约束网络设定互信息最大化损失函数为:[0041][0042]进一步的,步骤S203中,所述权重ωij通过RBF核函数计算,公式如下:[0043][0044]其中σ是一个用于调节的超参数。[0045]进一步的,所述的步骤S3具体包括以下子步骤:[0046]ZCS301:基于神经网络构建以工具变量Z的表征φ(vi)和其他协变量C的表征φX(vi)为输入以干预变量X为输出的第一阶段回归网络f(·),并设定第一阶段回归网络的损失函数为:[0047][0048]其中l(·)表示计算平方误差;8CN112633503A说 明 书4/12页[0049]CS302:基于神经网络构建以和其他协变量C的表征φ(vi)为输入以结果变量Y为输出的第二阶段回归网络fY(·),并设定第二阶段回归网络的损失函数为:[0050][0051]其中:femb(·)为用于扩充干预变量维度的映射网络,表示第一阶段回归网络fX(·)输出的干预变量X估计值,[0052]进一步的,所述的步骤S4具体包括以下步骤:[0053]S401:将所有五个约束网络的损失函数进行整合得到综合损失函数:[0054][0055]利用所述反事实预测数据集对五个约束网络进行训练,通过最小化所述综合损失函数分别优化各约束网络中的网络参数;[0056]S402:将所有五个约束网络的互信息最大化损失函数进行整合得到综合互信息损失函数:[0057][0058]其中:α、β、∈、η是权重超参数;[0059]利用所述反事实预测数据集对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化所述综合互信息损失函数分别优化两个表征模型中的网络参数;[0060]S403:利用所述反事实预测数据集对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化所述第一阶段回归网络的损失函数优化第一阶段回归网络以及两个表征模型中的网络参数;[0061]S404:利用所述反事实预测数据集继续对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化所述第二阶段回归网络的损失函数优化第二阶段回归网络、映射网络以及两个表征模型中的网络参数;[0062]S405:不断迭代重复S401~S405,使被用于交替训练对应的网络参数,直至迭代终止,得到参数优化后的第一表征模型φ′Z(·)和第二表征模型φ′C(·)。[0063]进一步的,所述的步骤S5具体包括以下步骤:[0064]S51:针对待识别的目标手写数字图片,将目标手写数字图片作为可观测变量,输入参数优化后的第一表征模型φ′Z(·)和第二表征模型φ′C(·)中,得到工具变量Z的表征和其他协变量C的表征;[0065]S52:将目标手写数字图片以及S51中得到的工具变量Z的表征和其他协变量C的表征,一并输入经过训练的基于工具变量的反事实预测模型中,输出目标手写数字图片中手写数字的识别结果。[0066]进一步的,所述基于工具变量的反事实预测模型为2SLS、Deep IV、Kernel IV或DeepGMM模型。[0067]另一方面,本发明提供了一种基于深度网络的工具变量解耦与手写数字识别装9CN112633503A说 明 书5/12页置,其包括存储器和处理器;[0068]所述存储器,用于存储计算机程序;[0069]所述处理器,用于当执行所述计算机程序时,实现如前述任一方案所述的基于神经网络的工具变量生成与反事实推理方法。[0070]本发明使用表征学习技术进行自动的工具变量解耦。针对之前的基于工具变量的反事实预测方法需要预先定义和可获取的工具变量的问题,本发明直接从可观测变量中学习和解耦出工具变量,大大提升了因果推断的效率,节省了大量时间和成本。本发明首次自动地从可观测变量中提取出工具变量,在算法和运用上有自己的独创性和独特性。将本发明应用于现有的基于工具变量的反事实预测方法,并自动地进行它的性能与假设使用真实工具变量的该方法相比因果推断,可以达到相当、甚至更好的性能表现。附图说明[0071]图1为基于神经网络的工具变量生成与反事实推理方法流程图[0072]图2为基于神经网络的工具变量生成与反事实推理结构示意图。具体实施方式[0073]下面结合附图和具体实施方式对本发明做进一步阐述和说明。[0074]如图1所示,一种基于神经网络的工具变量生成与反事实推理方法,该实施方式中的反事实推理用于实现手写数字识别,其包括如下步骤:[0075]S1:获取手写数字图片数据作为干预,获取手写数字图片的标签数据作为结果,将手写数字图片和标签构建成反事实预测数据集;[0076]S2:使用互信息约束的方法,对工具变量和其他协变量的表征设置约束,用于进行初步的表征学习;[0077]S3:基于两阶段反事实预测技术设置额外约束,用于对初步学习到的解耦表征进一步优化;[0078]S4:基于所述的反事实预测数据集,通过交替优化S2和S3中设置的约束,获得优化后的工具变量和其他协变量的表征模型;[0079]S5:针对待识别的手写数字图片,利用优化后的表征模型,得到工具变量和其他协变量的表征,并将其应用于基于工具变量的反事实预测模型中,输出手写数字图片中手写数字的识别结果。[0080]在上述S1~S5步骤中,具体实现方式如下:[0081]本发明中步骤S1具体如下:每一组手写数字图片及其对应的数字标签作为一组样本,构建成反事实预测数据集,表示为其中vi,xi,yi分别为第i个样本的可观测变量、干预和结果,N为样本总数。其中对于手写数字图片而言,由于其本身难以提取可观测变量,因此本发明中实际将样本的可观测变量vi也直接以该样本对应的手写数字图片本身代替,即vi=xi。[0082]参见图2所示,在S1中,假设干预变量X(手写数字图片)和结果变量Y(手写数字图片对应的标签)之间的数据关系为:[0083]Y=g(X)+e10CN112633503A说 明 书6/12页[0084]其中g(·)是一个未知的因果反馈函数(结构函数),它可能是非线性的连续函数。e是一个误差项,它包含了同时和X、Y都有关的不可观测的混淆。其中e满足零期望和有限方差的要求,即且这里允许e和X相关,即使得X成为了一个内生性变量同时[0085]工具变量Z用于解决内生性干预变量问题,它需要满足干预相关和结果排除两个条件。干预相关指X直接和Z相关,即使得结果排除指Z对Y仅仅只能通过X施加影响,即使得除此之外,Z应该是无混淆的,即需要使得基于工具变量的反事实预测的目的就是对真实的反馈函数进行预测。[0086]如果存在其他外生性的变量C,可以直接将其合并入工具变量和干预变量,即X=(X′,C)和Z=(Z′,C),其中X′和Z′是真实的干预变量和工具变量。由于C是严格外生的,即它和无关观测的误差e无关,因此这样的操作并不会对结果产生影响。[0087]假设可获取的可观测变量是V、干预变量是X、结果变量是Y,可获取N个样本,即本发明的目标就是使用这N个样本,获取工具变量Z的解耦表征。[0088]本发明中步骤S2具体包括以下子步骤:[0089]S201:基于神经网络构建第一表征模型φZ(·),其中第一表征模型φZ(·)以可观测变量V为输入,以工具变量Z为输出。同样的,基于神经网络构建第二表征模型φC(·),其中第二表征模型φC(·)以可观测变量V为输入,以其他协变量C为输出。[0090]本步骤中,使用神经网络构建工具变量Z和其他协变量C的表征,即φZ(·)和φC(·),使得φZ(·)和X相关、和Y关于X条件独立,也使得φC(·)同时和X、Y相关。同时可以通过使得φZ(·)和φC(·)尽可能独立,来对进入Z和C的信息进行正则约束。[0091]S202:基于神经网络构建第一约束网络fZX(·),其中第一约束网络fZX(·)以工具变量Z为输入,以干预变量X为输出,同时设定第一约束网络的损失函数为:[0092][0093]ZXZ其中:为第一约束网络f(·)中以φ(vi)为输入去预测xi时得到ZZ的变分分布;φ(vi)为第一表征模型φ(·)中输入vi时得到的输出结果;log表示对数似然函数。[0094]设置本步骤是由于首先工具变量Z需要满足干预相关条件,即因此需要鼓励可观测变量V中和X相关的信息能够进入Z的表征中。由于互信息需要使用的是条件分布信息,而数据是基于样本的,因此首先使用变分分布近似真实的条件分布后续通过最小化损失函数就可以获得最优的变分近似。[0095]另外,为了增加Z和X的关联性,针对第一约束网络设定互信息最大化损失函数为:[0096]11CN112633503A说 明 书7/12页[0097]其中是正样本对(vi,xi)的条件似然,是负样本对(vi,xj)的条件似然。后续通过最小化即可增大正负样本对之间的差异,以此来优化工具变量的表征φZ(V)。[0098]S203:基于神经网络构建第二约束网络fZY(·),其中第二约束网络fZY(·)以工具变量Z为输入,以结果变量Y为输出。同时设定第二约束网络的损失函数为:[0099][0100]ZYZ其中:为第二约束网络f(·)中以φ(vi)为输入去预测yi时得到的变分分布。[0101]工具变量Z还需要满足结果排除条件,即因此需要对Z和Y的条件互信息进行最小化。由于X是连续的变量,因此此处通过让正样本和负样本的似然期望相近来使得Z和Y条件独立。[0102]针对第二约束网络设定互信息最大化损失函数为:[0103][0104]其中:其中是正样本对(vi,yi)的条件似然,是负样本对(vi,yj)的条件似然;ωij为由第i个样本的干预xi和第j个样本的干预xj之间距离决定的权重。此处权重ωij通过RBF核函数计算,公式如下:[0105][0106]其中σ是一个用于调节的超参数。如果正负样本的xi和xj相接近,则它们的权重增大,也就是本发明着重于解决具有相近X的样本对。[0107]S204:协变量C的表征φC(V)需要首先和X相关,因此基于神经网络构建第三约束网络fCX(·),其中第三约束网络fCX(·)以其他协变量C为输入,以干预变量X为输出。同时设定第三约束网络的损失函数为:[0108][0109]CXC其中:为第三约束网络f(·)中以φ(vi)为输入去预测xi时得到CC的变分分布;φ(vi)表示第二表征模型φ(·)中输入vi时得到的输出结果;[0110]另外,针对第三约束网络设定互信息最大化损失函数为:[0111][0112]S205:同时需要使得协变量C的表征φC(V)需要和Y相关,因此基于神经网络构建12CN112633503A说 明 书8/12页第四约束网络fCY(·),其中第四约束网络fCY(·)以其他协变量C为输入,以结果变量Y为输出。同时设定第四约束网络的损失函数为:[0113][0114]CYC其中:为第四约束网络f(·)中以φ(vi)为输入去预测yi时得到的变分分布;[0115]另外,针对第四约束网络设定互信息最大化损失函数为:[0116][0117]S206:基于神经网络构建第五约束网络fZC(·),其中第五约束网络fZC(·)以工具变量Z为输入,以其他协变量C为输出。同时设定第五约束网络的损失函数为:[0118][0119]ZCZC其中:为第五约束网络f(·)中以φ(vi)为输入去预测φ(vi)时得到的变分分布。[0120]本步骤中,如果协变量C的信息进入工具变量Z中,会破坏Z的结果排除条件。同时如果Z的信息进入C中,则会对反事实预测带来一定的偏差。因此通过最小化Z和C的互信息来对它们进行约束,针对第五约束网络设定互信息最大化损失函数为:[0121][0122]在本发明中,步骤S3具体包括以下子步骤:[0123]S301:第一阶段(干预)首先使用工具变量Z和其他协变量C的表征去回归干预变量X。具体而言,基于神经网络构建第一阶段回归网络fX(·),其中第一阶段回归网络fX(·)以ZC工具变量Z的表征φ(vi)和其他协变量C的表征φ(vi)为输入,以干预变量X为输出。同时,设定第一阶段回归网络的损失函数为:[0124][0125]其中l(·)表示计算平方误差;[0126]S302:第二阶段(结果)进一步使用预测出来的来回归Y。具体而言,基于神经网络构建第二阶段回归网络fY(·),其中第二阶段回归网络fY(·)以和其他协变量CC的表征φ(vi)为输入,以结果变量Y为输出。同时,设定第二阶段回归网络的损失函数为:[0127]13CN112633503A说 明 书9/12页[0128]其中:femb(·)为用于扩充干预变量维度的映射网络,表示第一阶段回归网络fX(·)输出的干预变量X估计值,[0129]在本发明中,步骤S4具体包括以下步骤:[0130]S401:将所有五个约束网络的损失函数进行整合得到综合损失函数:[0131][0132]利用S1中的反事实预测数据集对五个约束网络进行训练,通过最小化综合损失函数分别优化各约束网络中的网络参数。该损失函数的各个部分会优化各自的参数,互相之间不会干扰,因此不需要超参数。[0133]S402:将所有五个约束网络的互信息最大化损失函数进行整合得到综合互信息损失函数:[0134][0135]其中:α、β、∈、η是权重超参数。[0136]利用S1中的反事实预测数据集对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化综合互信息损失函数分别优化第一表征模型φZ(·)和第二表征模型φC(·)中的网络参数。[0137]S403:利用S1中的反事实预测数据集对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化第一阶段回归网络的损失函数优化第一阶段回归网络以及两个表征模型中的网络参数;[0138]S404:利用S1中的反事实预测数据集继续对第一表征模型φZ(·)和第二表征模型φC(·)进行训练,通过最小化第二阶段回归网络的损失函数优化第二阶段回归网络、映射网络以及两个表征模型中的网络参数;[0139]S405:不断迭代重复S401~S405,使被用于交替训练对应的网络参数,直至迭代终止,得到参数优化后的第一表征模型φ′Z(·)和第二表征模型φ′C(·)。[0140]当上述解耦模型完成了优化,可以直接将其作为基于工具变量的方法的输入,将其用于反事实预测,获得更准确的反事实预测精度。[0141]在本发明中,步骤S5具体包括以下步骤:[0142]S51:针对待识别的目标手写数字图片,将目标手写数字图片作为可观测变量,输入参数优化后的第一表征模型φ′Z(·)和第二表征模型φ′C(·)中,得到工具变量Z的表征和其他协变量C的表征;[0143]S52:将目标手写数字图片以及S51中得到的工具变量Z的表征和其他协变量C的表征,一并输入经过训练的基于工具变量的反事实预测模型中,输出目标手写数字图片中手写数字的识别结果。[0144]在本发明中,基于工具变量的反事实预测模型可以是任何能够通过工具变量实现预测的模型,例如可选的为2SLS、Deep IV、Kernel IV或DeepGMM模型。[0145]上述方法的各步骤中的具体参数可以根据实际进行调整。[0146]本发明的关键技术在于基于表征学习进行自动的工具变量解耦,获得有效的工具变量的表征,并将其应用于基于工具变量的反事实预测方法,使得这些方法可以在无法获14CN112633503A说 明 书10/12页得工具变量的场景中得以较好的应用,达到相当甚至更好的反事实预测精度。[0147]另外,在另一实施例中,本发明提供了一种基于神经网络的工具变量生成与反事实推理方法及装置,它包括存储器和处理器;[0148]其中存储器,用于存储计算机程序;[0149]处理器,用于当执行所述计算机程序时,实现前述实施例中的基于神经网络的工具变量生成与反事实推理方法及装置。[0150]上述S1~S5的方法具体可以通过计算机程序来实现,举例而言,计算机程序中的模块可以按照功能划分如下:[0151]采样模块,对干干预变量、结果变量、可观测变量进行采样,约束可观测变量严格外生;[0152]互信息约束模块,对工具变量和协变量的表征通过互信息约束它们与干预变量和结果变量之间的关系;[0153]两阶段反事实预测模块,分别对干预变量和结果变量进行预测,两次预测的偏差用于进一步优化初步学习到的表征;[0154]反事实预测模块,交替优化表征,应用学习到的表征到现有的反事实预测方法进行反事实预测,提升反事实预测的精度。[0155]其中,采样模块包括:[0156]干预变量采样模块,用于从原始数据中采样干预变量,对其进行控制来进行反事实推断;[0157]结果变量采样模块,用于从原始数据中采样结果变量,结果变量是对干预变量变化的反映;[0158]可观测变量采样模块,可观测变量反映每个样本的特征,我们使得它严格外生,用于工具变量的解耦。[0159]其中,互信息约束模块包括:[0160]工具变量约束模块,对工具变量的表征进行互信息约束,使得它与干预变量相关,同时和结果变量条件独立;[0161]协变量约束模块,对协变量的表征进行互信息约束,使得它和干预变量、结果变量都相关;[0162]表征正交模块,对工具变量和协变量的表征进行正交约束,使得工具变量和协变量的表征尽可能独立。[0163]其中,两阶段反事实预测模块包括:[0164]干预变量预测模块,将初步解耦到的工具变量、协变量表征用于干预变量的预测,获取干预变量回归值;[0165]结果变量预测模块,将干预变量回归值和协变量用于结果变量的预测,得到反事实结果预测值。[0166]其中,反事实预测模块包括:[0167]表征优化模块,综合以上的互信息约束模块和两阶段反事实预测模块,通过交替优化的方式获取最优的表征;[0168]反事实预测模块,将得到的最优表征用于现有的方法中进行反事实预测,提升反15CN112633503A说 明 书11/12页事实预测的精度。[0169]当然,以上具体的功能模块的 设计 领导形象设计圆作业设计ao工艺污水处理厂设计附属工程施工组织设计清扫机器人结构设计 可以根据实际需要调整,以满足功能实现为准。[0170]需要注意的是,存储器可以包括随机存取存储器(Random Access Memory,RAM),也可以包括非易失性存储器(Non‑Volatile Memory,NVM),例如至少一个磁盘存储器。上述的处理器可以是通用处理器,包括中央处理器(Central Processing Unit,CPU)、网络处理器(Network Processor,NP)等;还可以是数字信号处理器(Digital Signal Processing,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field‑Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。当然,还装置中还应当具有实现程序运行的必要组件,例如电源、通信总线等等。[0171]在另一实施例中,本发明提供了一种计算机可读存储介质,该存储介质上存储有计算机程序,当所述计算机程序被处理器执行时,实现前述实施例中的基于神经网络的工具变量生成与反事实推理方法及装置。[0172]下面利用前述的基于神经网络的工具变量生成与反事实推理方法及装置,通过一个具体的应用实例来展示本发明分类方法的具体效果。具体的方法步骤如前所述,不再赘述,下面仅展示其具体效果。[0173]实施例[0174]本实施例在手写数字图片和仿真数据集上进行测试。该方法主要针对手写数字图片和对应的标签之间的关系,通过自动的工具变量解耦,获取仅仅和标签条件相关的工具变量信息,从而辅助手写数字识别以达到最大的精度。[0175]我们给定手写数字图片X,手写数字图片对应的标签Y之间的关系为:[0176]Y=g(X)+e+σ[0177]其中为不可观测的混淆变量,为误差项,g是手写数字图片X和手写数字图片对应的标签Y之间真实的潜在关系(非线性映射函数),此处我们假设他们之间的关系为g(X)=‑X。同时手写数字图片受到潜在的工具变量Z~Unif([‑3,3]2)、不可观测的混淆变量e和误差项的影响:[0178]X=Z1+e+γ[0179]算法训练和测试中,分别采样500个样本用于训练、验证、测试。每个样本都包含了手写数字图片、对应的标签和其他相关的混合数据。为了展示该方法解耦出的工具变量的性能,使用了辅助的基于工具变量的反事实预测模型来进行手写数字图片预测。本实施例中所采用的基于工具变量的反事实预测模型包括五种,分别为2SLS(van)、2SLS(poly)、2SLS(NN)、DeepIV、KernelIV、DeepGMM。这些模型算法均属于现有技术,不再赘述。若需了解其具体的实现算法,可实现参见以下现有技术文献:[0180]2SLS(van):Angrist J D,Pischke J S.Mostly harmless econometrics:An empiricist's companion[M].Princeton university press,2008.[0181]2SLS(poly):Darolles S,Fan Y,Florens J P,et al.Nonparametric instrumental regression[J].Econometrica,2011,79(5):1541‑1565.[0182]2SLS(NN):Darolles S,Fan Y,Florens J P,et al.Nonparametric 16CN112633503A说 明 书12/12页instrumental regression[J].Econometrica,2011,79(5):1541‑1565.[0183]DeepIV:Hartford J,Lewis G,Leyton‑Brown K,et al.Deep IV:A flexible approach for counterfactual prediction[C]//International Conference on Machine Learning.2017:1414‑1423.[0184]KernelIV:Singh R,Sahani M,Gretton A.Kernel instrumental variable regression[C]//Advances in Neural Information Processing Systems.2019:4593‑4605.[0185]DeepGMM:Bennett A,Kallus N,Schnabel T.Deep generalized method of moments for instrumental variable analysis[C]//Advances in Neural Information Processing Systems.2019:3564‑3574.[0186]为了客观评估本算法的性能,使用手写数字图片的预测结果与真实的结果的均方误差(MSE)对该方法进行评价。[0187]所得实验结果如表1所示,结果表明,本发明的方法具有极高的手写数字图片识别精度,从而能够显著提升手写数字识别的效率和准确性。[0188]表1不同辅助方法下手写数字识别的均方误差及其标准差[0189]2SLS(van)2SLS(poly)2SLS(NN)DeepIVKernelIVDeepGMM0.00(0.00)0.00(0.00)0.14(0.03)0.09(0.03)0.11(0.04)0.01(0.01)[0190]以上所述的实施例只是本发明的一种较佳的方案,然其并非用以限制本发明。有关技术领域的普通技术人员,在不脱离本发明的精神和范围的情况下,还可以做出各种变化和变型。因此凡采取等同替换或等效变换的方式所获得的技术方案,均落在本发明的保护范围内。17CN112633503A说 明 书 附 图1/2页图118CN112633503A说 明 书 附 图2/2页图219
本文档为【基于神经网络的工具变量生成与反事实推理方法及装置】,请使用软件OFFICE或WPS软件打开。作品中的文字与图均可以修改和编辑, 图片更改请在作品中右键图片并更换,文字修改请直接点击文字进行修改,也可以新增和删除文档中的内容。
该文档来自用户分享,如有侵权行为请发邮件ishare@vip.sina.com联系网站客服,我们会及时删除。
[版权声明] 本站所有资料为用户分享产生,若发现您的权利被侵害,请联系客服邮件isharekefu@iask.cn,我们尽快处理。
本作品所展示的图片、画像、字体、音乐的版权可能需版权方额外授权,请谨慎使用。
网站提供的党政主题相关内容(国旗、国徽、党徽..)目的在于配合国家政策宣传,仅限个人学习分享使用,禁止用于任何广告和商用目的。
下载需要: ¥10.0 已有0 人下载
最新资料
资料动态
专题动态
机构认证用户
掌桥科研
掌桥科研向科研人提供中文文献、外文文献、中文专利、外文专利、政府科技报告、OA文献、外军国防科技文献等多种科研资源的推广、发现、揭示和辅助获取服务,以及自动文档翻译、人工翻译、文档格式转换、收录引证等科研服务,涵盖了理、工、医、农、社科、军事、法律、经济、哲学等诸多学科和行业的中外文献资源。
格式:pdf
大小:1MB
软件:PDF阅读器
页数:19
分类:
上传时间:2022-01-25
浏览量:10