[docs]classRecall(_BasePrecisionRecall):r"""Calculates recall for binary and multiclass data. .. math:: \text{Recall} = \frac{ TP }{ TP + FN } where :math:`\text{TP}` is true positives and :math:`\text{FN}` is false negatives. - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. - `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...). - `y` must be in the following shape (batch_size, ...). Args: output_transform: a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. average: if True, precision is computed as the unweighted average (across all classes in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case). is_multilabel: flag to use in multilabel case. By default, value is False. If True, average parameter should be True and the average is computed across samples, instead of classes. device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. Examples: In binary and multilabel cases, the elements of `y` and `y_pred` should have 0 or 1 values. Thresholding of predictions can be done as below: .. code-block:: python def thresholded_output_transform(output): y_pred, y = output y_pred = torch.round(y_pred) return y_pred, y recall = Recall(output_transform=thresholded_output_transform) In multilabel cases, average parameter should be True. However, if user would like to compute F1 metric, for example, average parameter should be False. This can be done as shown below: .. code-block:: python precision = Precision(average=False) recall = Recall(average=False) F1 = precision * recall * 2 / (precision + recall + 1e-20) F1 = MetricsLambda(lambda t: torch.mean(t).item(), F1) .. warning:: In multilabel cases, if average is False, current implementation stores all input data (output and target) in as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM. """def__init__(self,output_transform:Callable=lambdax:x,average:bool=False,is_multilabel:bool=False,device:Union[str,torch.device]=torch.device("cpu"),):super(Recall,self).__init__(output_transform=output_transform,average=average,is_multilabel=is_multilabel,device=device)
[docs]@reinit__is_reduceddefupdate(self,output:Sequence[torch.Tensor])->None:self._check_shape(output)self._check_type(output)y_pred,y=output[0].detach(),output[1].detach()ifself._type=="binary":y_pred=y_pred.view(-1)y=y.view(-1)elifself._type=="multiclass":num_classes=y_pred.size(1)ify.max()+1>num_classes:raiseValueError(f"y_pred contains less classes than y. Number of predicted classes is {num_classes}"f" and element in y has invalid class = {y.max().item()+1}.")y=to_onehot(y.view(-1),num_classes=num_classes)indices=torch.argmax(y_pred,dim=1).view(-1)y_pred=to_onehot(indices,num_classes=num_classes)elifself._type=="multilabel":# if y, y_pred shape is (N, C, ...) -> (C, N x ...)num_classes=y_pred.size(1)y_pred=torch.transpose(y_pred,1,0).reshape(num_classes,-1)y=torch.transpose(y,1,0).reshape(num_classes,-1)# Convert from int cuda/cpu to double on self._devicey_pred=y_pred.to(dtype=torch.float64,device=self._device)y=y.to(dtype=torch.float64,device=self._device)correct=y*y_predactual_positives=y.sum(dim=0)ifcorrect.sum()==0:true_positives=torch.zeros_like(actual_positives)else:true_positives=correct.sum(dim=0)ifself._type=="multilabel":ifnotself._average:self._true_positives=torch.cat([self._true_positives,true_positives],dim=0)# type: torch.Tensorself._positives=torch.cat([self._positives,actual_positives],dim=0)# type: torch.Tensorelse:self._true_positives+=torch.sum(true_positives/(actual_positives+self.eps))self._positives+=len(actual_positives)else:self._true_positives+=true_positivesself._positives+=actual_positivesself._updated=True