如何在 pytorch-lightning 中使用 TensorBoard 记录器转储混淆矩阵? [英] How to dump confusion matrix using TensorBoard logger in pytorch-lightning?

查看:126
本文介绍了如何在 pytorch-lightning 中使用 TensorBoard 记录器转储混淆矩阵?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

官方文档仅说明

<预><代码>>>>从 pytorch_lightning.metrics 导入 ConfusionMatrix>>>target = torch.tensor([1, 1, 0, 0])>>>preds = torch.tensor([0, 1, 0, 0])>>>confmat = ConfusionMatrix(num_classes=2)>>>confmat(预测,目标)

这并未展示如何在框架中使用指标.

我的尝试(方法不完整,只显示相关部分):

def __init__(...):self.val_confusion = pl.metrics.classification.ConfusionMatrix(num_classes=self._config.n_clusters)defvalidation_step(self,batch,batch_index):...log_probs = self.forward(orig_batch)损失 = self._criterion(log_probs,label_batch)self.val_confusion.update(log_probs, label_batch)self.log('validation_confusion_step', self.val_confusion, on_step=True, on_epoch=False)defvalidation_step_end(自我,输出):返回输出def validation_epoch_end(self, outs):self.log('validation_confusion_epoch', self.val_confusion.compute())

在第 0 个 epoch 之后,这给出了

 回溯(最近一次调用最后一次):文件C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainer	rainer.py",第 521 行,在火车中self.train_loop.run_training_epoch()文件C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainer	raining_loop.py",第 588 行,在 run_training_epochself.trainer.run_evaluation(test_mode=False)文件C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainer	rainer.py",第 613 行,在 run_evaluation 中self.evaluation_loop.log_evaluation_step_metrics(输出,batch_idx)文件C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerevaluation_loop.py",第 346 行,在 log_evaluation_step_metricsself.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx)文件C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerevaluation_loop.py",第 350 行,在 __log_result_step_metricscached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector()文件C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py",第 378 行,在 update_logger_connector 中batch_log_metrics = self.get_latest_batch_log_metrics()文件C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py",第 418 行,在 get_latest_batch_log_metricsbatch_log_metrics = self.run_batch_from_func_name(get_batch_log_metrics")文件C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py",第 414 行,run_batch_from_func_name结果 = [func(include_forked_originals=False) for func in results]文件C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py",第 414 行,在  中结果 = [func(include_forked_originals=False) for func in results]文件C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py",第 122 行,在 get_batch_log_metricsreturn self.run_latest_batch_metrics_with_func_name("get_batch_log_metrics",*args, **kwargs)文件C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py",第 115 行,run_latest_batch_metrics_with_func_name对于范围内的 dl_idx(self.num_dataloaders)文件C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py",第115行,在<listcomp>对于范围内的 dl_idx(self.num_dataloaders)文件C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py",第 100 行,在 get_latest_from_func_name结果.更新(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))文件C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightningcorestep_result.py",第 298 行,在 get_batch_log_metrics结果[dl_key] = self[k]._forward_cache.detach()AttributeError: 'NoneType' 对象没有属性 'detach'

它确实在训练前通过了健全性验证检查.

失败发生在 validation_step_end 中的返回.对我来说毫无意义.

完全相同的使用 mertics 的方法准确无误.

如何得到正确的混淆矩阵?

解决方案

您可以使用 self.logger.experiment.add_figure(*tag*, *figure*) 报告图形.

变量 self.logger.experiment 实际上是一个 SummaryWriter(来自 PyTorch,而不是 Lightning).此类具有方法 add_figure (文档).

您可以按如下方式使用它:(MNIST 示例)

 def validation_step(self, batch, batch_idx):x, y = 批次预测 = 自我(x)损失 = F.nll_loss(preds, y)返回{'损失':损失,'preds':preds,'目标':y}def validation_epoch_end(自我,输出):preds = torch.cat([tmp['preds'] for tmp in output])目标 = torch.cat([tmp['target'] for tmp in output])混淆矩阵= pl.metrics.functional.confusion_matrix(预测,目标,num_classes = 10)df_cm = pd.DataFrame(confusion_matrix.numpy(), index = range(10), columns=range(10))plt.figure(figsize = (10,7))fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()plt.close(fig_)self.logger.experiment.add_figure(混淆矩阵", fig_, self.current_epoch)

The official doc only states

>>> from pytorch_lightning.metrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)

This doesn't show how to use the metric with the framework.

My attempt (methods are not complete and only show relevant parts):

def __init__(...):
    self.val_confusion = pl.metrics.classification.ConfusionMatrix(num_classes=self._config.n_clusters)

def validation_step(self, batch, batch_index):
    ...
    log_probs = self.forward(orig_batch)
    loss = self._criterion(log_probs, label_batch)
   
    self.val_confusion.update(log_probs, label_batch)
    self.log('validation_confusion_step', self.val_confusion, on_step=True, on_epoch=False)

def validation_step_end(self, outputs):
    return outputs

def validation_epoch_end(self, outs):
    self.log('validation_confusion_epoch', self.val_confusion.compute())

After the 0th epoch, this gives

    Traceback (most recent call last):
      File "C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainer	rainer.py", line 521, in train
        self.train_loop.run_training_epoch()
      File "C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainer	raining_loop.py", line 588, in run_training_epoch
        self.trainer.run_evaluation(test_mode=False)
      File "C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainer	rainer.py", line 613, in run_evaluation
        self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx)
      File "C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerevaluation_loop.py", line 346, in log_evaluation_step_metrics
        self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx)
      File "C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerevaluation_loop.py", line 350, in __log_result_step_metrics
        cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector()
      File "C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py", line 378, in update_logger_connector
        batch_log_metrics = self.get_latest_batch_log_metrics()
      File "C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py", line 418, in get_latest_batch_log_metrics
        batch_log_metrics = self.run_batch_from_func_name("get_batch_log_metrics")
      File "C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py", line 414, in run_batch_from_func_name
        results = [func(include_forked_originals=False) for func in results]
      File "C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py", line 414, in <listcomp>
        results = [func(include_forked_originals=False) for func in results]
      File "C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py", line 122, in get_batch_log_metrics
        return self.run_latest_batch_metrics_with_func_name("get_batch_log_metrics",
*args, **kwargs)
      File "C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py", line 115, in run_latest_batch_metrics_with_func_name
        for dl_idx in range(self.num_dataloaders)
      File "C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py", line 115, in <listcomp>
        for dl_idx in range(self.num_dataloaders)
      File "C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightning	rainerconnectorslogger_connectorepoch_result_store.py", line 100, in get_latest_from_func_name
        results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
      File "C:codeEPMDKodexTemplatesTestingvenvlibsite-packagespytorch_lightningcorestep_result.py", line 298, in get_batch_log_metrics
        result[dl_key] = self[k]._forward_cache.detach()
    AttributeError: 'NoneType' object has no attribute 'detach'

                                                      

It does pass the sanity validation check before training.

The failure happens on the return in validation_step_end. Makes little sense to me.

The exact same method of using mertics works fine with accuracy.

How to get a correct confusion matrix?

解决方案

You can report the figure using self.logger.experiment.add_figure(*tag*, *figure*).

The variable self.logger.experiment is actually a SummaryWriter (from PyTorch, not Lightning). This class has the method add_figure (documentation).

You can use it as follows: (MNIST example)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = F.nll_loss(preds, y)
        return { 'loss': loss, 'preds': preds, 'target': y}

    def validation_epoch_end(self, outputs):
        preds = torch.cat([tmp['preds'] for tmp in outputs])
        targets = torch.cat([tmp['target'] for tmp in outputs])
        confusion_matrix = pl.metrics.functional.confusion_matrix(preds, targets, num_classes=10)

        df_cm = pd.DataFrame(confusion_matrix.numpy(), index = range(10), columns=range(10))
        plt.figure(figsize = (10,7))
        fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
        plt.close(fig_)
        
        self.logger.experiment.add_figure("Confusion matrix", fig_, self.current_epoch)

这篇关于如何在 pytorch-lightning 中使用 TensorBoard 记录器转储混淆矩阵?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆