Coverage for robustML/advertrain/training/classical_training.py: 32%

76 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-10 08:11 +0000

1import os 

2from typing import Any, Dict, Tuple 

3 

4import numpy as np 

5import torch 

6from torch.nn import Module 

7from torch.optim import Optimizer 

8from torch.utils.data import DataLoader 

9from tqdm import tqdm 

10 

11from robustML.advertrain.metrics import Metrics 

12 

13 

14class ClassicalTraining: 

15 """ 

16 A class representing the classical training process for a PyTorch model. 

17 

18 Attributes: 

19 model (Module): The PyTorch model to be trained. 

20 optimizer (Optimizer): The optimizer used for training. 

21 loss_func: The loss function used for training. 

22 device (torch.device): The device on which to train the model. 

23 metrics (Metrics): An instance of Metrics to track training performance. 

24 """ 

25 

26 def __init__( 

27 self, model: Module, optimizer: Optimizer, loss_func, device: torch.device 

28 ) -> None: 

29 self.model = model 

30 self.loss_func = loss_func 

31 self.device = device 

32 self.optimizer = optimizer 

33 self.metrics = Metrics() 

34 self.metrics.reset_metrics() 

35 

36 def preprocess_batch( 

37 self, x: torch.Tensor, y: torch.Tensor, epoch: int 

38 ) -> Tuple[torch.Tensor, torch.Tensor]: 

39 """ 

40 Preprocess a batch of data and labels before training or validation. 

41 

42 Args: 

43 x (torch.Tensor): Input data batch. 

44 y (torch.Tensor): Corresponding labels batch. 

45 epoch (int): The current epoch number. 

46 

47 Returns: 

48 Tuple[torch.Tensor, torch.Tensor]: The preprocessed data and labels. 

49 """ 

50 return x, y 

51 

52 def train_batch( 

53 self, x: torch.Tensor, y: torch.Tensor, epoch: int 

54 ) -> Tuple[float, int]: 

55 """ 

56 Process and train a single batch of data. 

57 

58 Args: 

59 x (torch.Tensor): Input data batch. 

60 y (torch.Tensor): Corresponding labels batch. 

61 epoch (int): The current epoch number. 

62 

63 Returns: 

64 Tuple[float, int]: The training loss for the batch and the batch size. 

65 """ 

66 x, y = self._to_device(x, y) 

67 x.clamp_(0, 1) 

68 

69 self.optimizer.zero_grad() 

70 output = self.model(x) 

71 loss = self.loss_func(output, y) 

72 loss.backward() 

73 self.optimizer.step() 

74 

75 pred = torch.argmax(output, dim=1) 

76 self.metrics.update(x, y, pred, loss) 

77 

78 return loss.item(), len(x) 

79 return loss.item(), len(x) 

80 

81 def val_batch( 

82 self, x: torch.Tensor, y: torch.Tensor, epoch: int 

83 ) -> Tuple[float, int]: 

84 """ 

85 Validate a single batch of data. 

86 

87 Args: 

88 x (torch.Tensor): Input data batch. 

89 y (torch.Tensor): Corresponding labels batch. 

90 epoch (int): The current epoch number. 

91 

92 Returns: 

93 Tuple[float, int]: The validation loss for the batch and the batch size. 

94 """ 

95 x, y = self._to_device(x, y) 

96 

97 with torch.no_grad(): 

98 output = self.model(x) 

99 loss = self.loss_func(output, y) 

100 

101 pred = torch.argmax(output, dim=1) 

102 self.metrics.update(x, y, pred, loss) 

103 

104 return loss.item(), len(x) 

105 

106 def fit( 

107 self, 

108 epochs: int, 

109 train_dataloader: DataLoader, 

110 val_dataloader: DataLoader, 

111 patience: int, 

112 checkpoint: str 

113 ) -> Dict[str, Any]: 

114 """ 

115 Train and validate the model over a given number of epochs, implementing early stopping. 

116 

117 Args: 

118 epochs (int): The total number of epochs to train. 

119 train_dataloader (DataLoader): The DataLoader for the training data. 

120 val_dataloader (DataLoader): The DataLoader for the validation data. 

121 patience (int): The number of epochs to wait for improvement before stopping early. 

122 checkpoint (str): Path to save the model checkpoints. 

123 

124 Returns: 

125 Dict[str, Any]: A dictionary containing training and validation metrics. 

126 """ 

127 wait = 0 

128 val_loss_min = np.Inf 

129 

130 metrics = { 

131 "loss": [], "acc": [], 

132 "val_loss": [], "val_acc": [], 

133 "optimizer": str(self.optimizer), 

134 "loss_func": str(self.loss_func), 

135 "epochs": epochs, "patience": patience 

136 } 

137 

138 for epoch in range(epochs): 

139 ( 

140 train_acc, 

141 train_loss, 

142 train_precision_defect, 

143 train_recall_defect, 

144 train_f1_score_defect, 

145 ) = self._process_epoch(train_dataloader, epochs, epoch, train=True) 

146 ( 

147 val_acc, 

148 val_loss, 

149 val_precision_defect, 

150 val_recall_defect, 

151 val_f1_score_defect, 

152 ) = self._process_epoch(val_dataloader, epochs, epoch, train=False) 

153 

154 # update metrics 

155 metrics = self._update_metrics(metrics, train_loss.item(), train_acc.item(), val_loss.item(), val_acc.item()) 

156 

157 print(f"Epoch {epoch + 1}/{epochs}\n" 

158 f"Train Loss: {train_loss : .3f}, Acc: {train_acc : .3f}, Recall: {train_recall_defect : 1.3f}, Precision: {train_precision_defect : 1.3f},F1 Score: {train_f1_score_defect : .3f}\n" 

159 f"Validation Loss: {val_loss : .3f}, Acc: {val_acc : .3f}, Recall: {val_recall_defect : 1.3f}, Precision: {val_precision_defect : 1.3f},F1 Score: {val_f1_score_defect : .3f}") 

160 

161 # Checkpoint 

162 self.metrics.save_metrics(metrics, checkpoint) 

163 

164 if val_loss < val_loss_min: 

165 print(f"Validation loss decreased ({val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...") 

166 torch.save(self.model.state_dict(), os.path.join(checkpoint, "model.pth")) 

167 val_loss_min = val_loss 

168 wait = 0 

169 

170 # Early stopping 

171 else: 

172 wait += 1 

173 if wait > patience: 

174 print( 

175 f"Terminated training for early stopping at epoch {epoch + 1}" 

176 ) 

177 break 

178 

179 return metrics 

180 

181 def _to_device(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 

182 """Moves tensors to the specified device.""" 

183 return x.to(self.device), y.to(self.device) 

184 

185 def _process_epoch(self, dataloader: DataLoader, epochs: int, epoch: int, train: bool) -> Dict[str, float]: 

186 """ 

187 Process a single epoch of training or validation. 

188 

189 Args: 

190 dataloader (DataLoader): The DataLoader for the epoch. 

191 epoch (int): The current epoch number. 

192 train (bool): Flag indicating whether it's a training epoch. 

193 

194 Returns: 

195 Dict[str, float]: Metrics for the processed epoch. 

196 """ 

197 self.metrics.reset_metrics() 

198 if train: 

199 self.model.train() 

200 for x, y in tqdm( 

201 dataloader, 

202 desc=f"Epochs {epoch + 1}/{epochs} : Training", 

203 position=0, 

204 leave=True, 

205 ): 

206 self.train_batch(x, y, epoch) 

207 else: 

208 self.model.eval() 

209 for x, y in tqdm( 

210 dataloader, 

211 desc="Validation", 

212 position=0, 

213 leave=True, 

214 ): 

215 self.val_batch(x, y, epoch) 

216 

217 return self.metrics.get_metrics() 

218 

219 def _update_metrics(self, metrics: Dict[str, Any], train_loss: float, train_acc: float, val_loss: float, val_acc: float) -> None: 

220 """ 

221 Update the overall metrics dictionary with the metrics from the current epoch. 

222 

223 Args: 

224 metrics (Dict[str, Any]): The overall metrics dictionary. 

225 train_metrics (Dict[str, float]): Metrics from the training phase. 

226 val_metrics (Dict[str, float]): Metrics from the validation phase. 

227 """ 

228 metrics['loss'].append(train_loss) 

229 metrics['acc'].append(train_acc) 

230 metrics['val_loss'].append(val_loss) 

231 metrics['val_acc'].append(val_acc) 

232 

233 return metrics