本文提出一种用于Few-Shot learning的标签传播算法,将标签从有标记的样本向无标记的测试点传播。传播的过程借助于一个graph construction module.
网络结构
作者提出的网络由两部分组成,feature embedding与graph construction. Feature embedding $f_\varphi$用于提取特征,而graph construction $g_\phi$给出计算给定的两个特征对应的图的结点之间的边的权重的参数。
具体来说,每一个episode中包含一个support set $\mathcal{S} = \{ (x_1, y_1), (x_2, y_2), \cdots, (x_{N\times K}, y_{N\times K}) \}$, 一个query set $\mathcal{Q} =\{ (x_1^, y_1^), (x_2^, y_2^), \cdots, (x_T^, y_T^) \}$.
图的构建
训练时,将$\mathcal{S}\cup\mathcal{Q}$中的所有数据全部通过$f_\varphi$得到特征, 并两两之间利用$g_\phi$计算距离
其中$W_{ij}\in \mathbb{R}^{(N\times K+T)\times(N\times K+T)}$, $\sigma_i, \sigma_j = g_\phi(f_\varphi(x_i), f_\varphi(x_j))$.
在$W_{ij}$中保留每一行最大的k个值,并对其做normalized graph laplacians
其中$D$为$W$的度矩阵。
标签传播
标签传播采用如下递推式
这里$F_t, Y\in \mathbb{R}^{(N\times K+T)\times N}$, $F_t$为t时刻的预测标签, $Y$为已知的标签
上述递推式存在闭式解
分类与损失函数
首先将$F^*$转化为对应的概率
损失函数采用cross-entropy loss
这样就训练出了端到端的网络。