Source code for advertrain.dependencies.fire

"""
Taken from https://github.com/MarinePICOT/Adversarial-Robustness-via-Fisher-Rao-Regularization

Robust training losses. Based on code from
https://github.com/MarinePICOT/Adversarial-Robustness-via-Fisher-Rao-Regularization/blob/main/src/losses.py
"""

import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable


[docs] def entropy_loss(unlabeled_logits: torch.Tensor) -> torch.Tensor: """ Calculate the entropy loss for a batch of unlabeled data. Args: unlabeled_logits (torch.Tensor): A tensor of logits from a model's output. It should have a shape of (batch_size, num_classes). Returns: torch.Tensor: The mean entropy loss across the batch. """ unlabeled_probs = torch.nn.functional.softmax(unlabeled_logits, dim=1) return ( -(unlabeled_probs * torch.nn.functional.log_softmax(unlabeled_logits, dim=1)) .sum(dim=1) .mean(dim=0) )
[docs] def fire_loss( model: torch.nn.Module, x_natural: torch.Tensor, y: torch.Tensor, optimizer: torch.optim.Optimizer, epoch: int, device: torch.device, step_size: float = 0.003, epsilon: float = 0.001, perturb_steps: int = 10, beta: float = 1.0, adversarial: bool = True, distance: str = "Linf", entropy_weight: float = 0, pretrain: int = 0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ This function calculates the FIRE (Fast and Improved Robustness Estimation) loss, which is a combination of natural loss, robust loss, and entropy loss for unlabeled data. It is used for adversarial training and stability training of neural networks. Args: model (torch.nn.Module): The neural network model to be trained. x_natural (torch.Tensor): Input tensor of natural (non-adversarial) images. y (torch.Tensor): Tensor of labels. Unlabeled data should have label -1. optimizer (torch.optim.Optimizer): The optimizer used for training. epoch (int): Current training epoch. device (torch.device): The device on which to perform calculations. step_size (float): Step size for adversarial example generation. epsilon (float): Perturbation size for adversarial example generation. perturb_steps (int): Number of steps for adversarial example generation. beta (float): Weight for the robust loss in the overall loss calculation. adversarial (bool): Flag to enable/disable adversarial training. distance (str): Type of distance metric for adversarial example generation ("Linf" or "L2"). entropy_weight (float): Weight for the entropy loss in the overall loss calculation. pretrain (int): Number of pretraining epochs. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the total loss, natural loss, robust loss, and entropy loss for unlabeled data. """ if beta == 0: logits = model(x_natural) loss = F.cross_entropy(logits, y) inf = torch.Tensor([np.inf]) zero = torch.Tensor([0.0]) return loss, loss, inf, zero is_unlabeled = y == -1 if epoch < pretrain: logits = model(x_natural) loss = F.cross_entropy(logits, y) loss_natural = loss loss_robust = torch.Tensor([0.0]) if torch.sum(is_unlabeled) > 0: logits_unlabeled = logits[is_unlabeled] loss_entropy_unlabeled = entropy_loss(logits_unlabeled) loss = loss + entropy_weight * loss_entropy_unlabeled else: loss_entropy_unlabeled = torch.tensor(0) else: model.eval() # moving to eval mode to freeze batchnorm stats batch_size = len(x_natural) # generate adversarial example x_adv = x_natural.detach() + 0.0 # the + 0. is for copying the tensor s_nat = model(x_natural).softmax(1).detach() if adversarial: if distance == "Linf": x_adv += 0.001 * torch.randn(x_natural.shape).to(device) for _ in range(perturb_steps): x_adv.requires_grad_() with torch.enable_grad(): s_adv = model(x_adv).softmax(1) sqroot_prod = ((s_nat * s_adv) ** 0.5).sum(1) loss_kl = (torch.acos(sqroot_prod - 1e-7) ** 2).mean(0) # Minus eps to prevent gradient explosion near 1 (https://github.com/pytorch/pytorch/issues/8069) grad = torch.autograd.grad(loss_kl, [x_adv])[0] x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) x_adv = torch.min( torch.max(x_adv, x_natural - epsilon), x_natural + epsilon ) x_adv = torch.clamp(x_adv, 0.0, 1.0) else: raise ValueError( "No support for distance %s in adversarial " "training" % distance ) else: if distance == "L2": x_adv = x_adv + epsilon * torch.randn_like(x_adv) else: raise ValueError( "No support for distance %s in stability " "training" % distance ) model.train() # moving to train mode to update batchnorm stats # zero gradient optimizer.zero_grad() x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) logits_nat = model(x_natural) logits_adv = model(x_adv) s_adv = logits_adv.softmax(1) s_nat = logits_nat.softmax(1) loss_natural = F.cross_entropy(logits_nat, y, ignore_index=-1) sqroot_prod = ((s_nat * s_adv) ** 0.5).sum(1) loss_robust = (torch.acos(sqroot_prod - 1e-7) ** 2).mean(0) # Minus eps to prevent gradient explosion near 1 (https://github.com/pytorch/pytorch/issues/8069) loss = loss_natural + beta * loss_robust if torch.sum(is_unlabeled) > 0: logits_unlabeled = logits[is_unlabeled] loss_entropy_unlabeled = entropy_loss(logits_unlabeled) loss = loss + entropy_weight * loss_entropy_unlabeled else: loss_entropy_unlabeled = torch.tensor(0) return loss, loss_natural, loss_robust, loss_entropy_unlabeled
[docs] def noise_loss( model: torch.nn.Module, x_natural: torch.Tensor, y: torch.Tensor, epsilon: float = 0.25, clamp_x: bool = True ) -> torch.Tensor: """ This function augments the input data with random noise and computes the loss based on the model's predictions for the noisy data. Args: model (torch.nn.Module): The neural network model. x_natural (torch.Tensor): The original (clean) input data. y (torch.Tensor): The labels corresponding to the input data. epsilon (float, optional): The magnitude of the noise to be added to the input data. Defaults to 0.25. clamp_x (bool, optional): If True, the noisy data is clamped to the range [0.0, 1.0]. Defaults to True. Returns: torch.Tensor: The computed loss based on the model's predictions for the noisy data. """ x_noise = x_natural + epsilon * torch.randn_like(x_natural) if clamp_x: x_noise = x_noise.clamp(0.0, 1.0) logits_noise = model(x_noise) loss = F.cross_entropy(logits_noise, y, ignore_index=-1) return loss