pytorch 如何通过 argmax 反向传播? [英] How does pytorch backprop through argmax?

查看:55
本文介绍了pytorch 如何通过 argmax 反向传播?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在 pytorch 中使用质心位置的梯度下降而不是期望最大化来构建 Kmeans.损失是每个点到其最近质心的平方距离之和.为了确定哪个质心离每个点最近,我使用 argmin,它在任何地方都不可微.然而,pytorch 仍然能够反向传播和更新权重(质心位置),在数据上提供与 sklearn kmeans 相似的性能.

I'm building Kmeans in pytorch using gradient descent on centroid locations, instead of expectation-maximisation. Loss is the sum of square distances of each point to its nearest centroid. To identify which centroid is nearest to each point, I use argmin, which is not differentiable everywhere. However, pytorch is still able to backprop and update weights (centroid locations), giving similar performance to sklearn kmeans on the data.

任何想法这是如何工作的,或者我如何在pytorch中解决这个问题?pytorch github 上的讨论表明 argmax 不可微:https://github.com/pytorch/pytorch/issues/1339.

Any ideas how this is working, or how I can figure this out within pytorch? Discussion on pytorch github suggests argmax is not differentiable: https://github.com/pytorch/pytorch/issues/1339.

以下示例代码(随机点):

Example code below (on random pts):

import numpy as np
import torch

num_pts, batch_size, n_dims, num_clusters, lr = 1000, 100, 200, 20, 1e-5

# generate random points
vector = torch.from_numpy(np.random.rand(num_pts, n_dims)).float()

# randomly pick starting centroids
idx = np.random.choice(num_pts, size=num_clusters)
kmean_centroids = vector[idx][:,None,:] # [num_clusters,1,n_dims]
kmean_centroids = torch.tensor(kmean_centroids, requires_grad=True)

for t in range(4001):
    # get batch
    idx = np.random.choice(num_pts, size=batch_size)
    vector_batch = vector[idx]

    distances = vector_batch - kmean_centroids # [num_clusters, #pts, #dims]
    distances = torch.sum(distances**2, dim=2) # [num_clusters, #pts]

    # argmin
    membership = torch.min(distances, 0)[1] # [#pts]

    # cluster distances
    cluster_loss = 0
    for i in range(num_clusters):
        subset = torch.transpose(distances,0,1)[membership==i]
        if len(subset)!=0: # to prevent NaN
            cluster_loss += torch.sum(subset[:,i])

    cluster_loss.backward()
    print(cluster_loss.item())

    with torch.no_grad():
        kmean_centroids -= lr * kmean_centroids.grad
        kmean_centroids.grad.zero_()

推荐答案

正如 alvas 在评论中所指出的,argmax 是不可微的.但是,一旦您计算它并将每个数据点分配给一个集群,损失相对于这些集群位置的导数是明确定义的.这就是你的算法所做的.

As alvas noted in the comments, argmax is not differentiable. However, once you compute it and assign each datapoint to a cluster, the derivative of loss with respect to the location of these clusters is well-defined. This is what your algorithm does.

为什么有效?如果您只有一个集群(因此 argmax 操作无关紧要),您的损失函数将是二次的,数据点的平均值为最小值.现在有了多个集群,你可以看到你的损失函数是分段的(在更高维度上认为是体积) - 对于任何一组质心 [C1, C2, C3, ...] 每个数据点是分配给某个质心 CN 并且损失是局部二次的.该局部性的范围由所有替代质心 [C1', C2', C3', ...] 给出,其中来自 argmax 的赋值保持不变;在这个区域内,argmax 可以被视为一个常数,而不是一个函数,因此 loss 的导数是明确定义的.

Why does it work? If you had only one cluster (so that the argmax operation didn't matter), your loss function would be quadratic, with minimum at the mean of the data points. Now with multiple clusters, you can see that your loss function is piecewise (in higher dimensions think volumewise) quadratic - for any set of centroids [C1, C2, C3, ...] each data point is assigned to some centroid CN and the loss is locally quadratic. The extent of this locality is given by all alternative centroids [C1', C2', C3', ...] for which the assignment coming from argmax remains the same; within this region the argmax can be treated as a constant, rather than a function and thus the derivative of loss is well-defined.

现在,实际上,您不太可能将 argmax 视为常数,但您仍然可以将朴素的argmax-is-a-constant"导数视为近似指向最小值,因为大多数数据点可能确实属于迭代之间的同一集群.一旦接近局部最小值,点不再改变它们的分配,过程就会收敛到最小值.

Now, in reality, it's unlikely you can treat argmax as constant, but you can still treat the naive "argmax-is-a-constant" derivative as pointing approximately towards a minimum, because the majority of data points are likely to indeed belong to the same cluster between iterations. And once you get close enough to a local minimum such that the points no longer change their assignments, the process can converge to a minimum.

另一种更理论化的方式来看待它是你正在做一个期望最大化的近似.通常,您会有计算分配"步骤,它由 argmax 镜像,以及最小化"步骤,归结为在给定当前分配的情况下找到最小化聚类中心.最小值由 d(loss)/d([C1, C2, ...]) == 0 给出,对于二次损失,通过每个簇内的数据点分析给出.在您的实现中,您正在求解相同的方程,但使用梯度下降步骤.事实上,如果您使用二阶 (Newton) 更新方案而不是一阶梯度下降,您将隐式地精确复制基线 EM 方案.

Another, more theoretical way to look at it is that you're doing an approximation of expectation maximization. Normally, you would have the "compute assignments" step, which is mirrored by argmax, and the "minimize" step which boils down to finding the minimizing cluster centers given the current assignments. The minimum is given by d(loss)/d([C1, C2, ...]) == 0, which for a quadratic loss is given analytically by the means of data points within each cluster. In your implementation, you're solving the same equation but with a gradient descent step. In fact, if you used a 2nd order (Newton) update scheme instead of 1st order gradient descent, you would be implicitly reproducing exactly the baseline EM scheme.

这篇关于pytorch 如何通过 argmax 反向传播?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆