Coverage for robustAI/advertrain/training/classical_training.py: 32%

76 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-10-01 08:42 +0000

1""" 

2This module contains the class for classic training process 

3""" 

4 

5import os 

6from typing import Any, Dict, Tuple 

7 

8import numpy as np 

9import torch 

10from torch.nn import Module 

11from torch.optim import Optimizer 

12from torch.utils.data import DataLoader 

13from tqdm import tqdm 

14 

15from robustAI.advertrain.metrics import Metrics 

16 

17 

18class ClassicalTraining: 

19 """ 

20 A class representing the classical training process for a PyTorch model. 

21 

22 Attributes: 

23 model (Module): The PyTorch model to be trained. 

24 optimizer (Optimizer): The optimizer used for training. 

25 loss_func: The loss function used for training. 

26 device (torch.device): The device on which to train the model. 

27 metrics (Metrics): An instance of Metrics to track training performance. 

28 """ 

29 

30 def __init__( 

31 self, model: Module, optimizer: Optimizer, loss_func, device: torch.device 

32 ) -> None: 

33 self.model = model 

34 self.loss_func = loss_func 

35 self.device = device 

36 self.optimizer = optimizer 

37 self.metrics = Metrics() 

38 self.metrics.reset_metrics() 

39 

40 def preprocess_batch( 

41 self, x: torch.Tensor, y: torch.Tensor, epoch: int 

42 ) -> Tuple[torch.Tensor, torch.Tensor]: 

43 """ 

44 Preprocess a batch of data and labels before training or validation. 

45 

46 Args: 

47 x (torch.Tensor): Input data batch. 

48 y (torch.Tensor): Corresponding labels batch. 

49 epoch (int): The current epoch number. 

50 

51 Returns: 

52 Tuple[torch.Tensor, torch.Tensor]: The preprocessed data and labels. 

53 """ 

54 return x, y 

55 

56 def train_batch( 

57 self, x: torch.Tensor, y: torch.Tensor, epoch: int 

58 ) -> Tuple[float, int]: 

59 """ 

60 Process and train a single batch of data. 

61 

62 Args: 

63 x (torch.Tensor): Input data batch. 

64 y (torch.Tensor): Corresponding labels batch. 

65 epoch (int): The current epoch number. 

66 

67 Returns: 

68 Tuple[float, int]: The training loss for the batch and the batch size. 

69 """ 

70 x, y = self._to_device(x, y) 

71 x.clamp_(0, 1) 

72 

73 self.optimizer.zero_grad() 

74 output = self.model(x) 

75 loss = self.loss_func(output, y) 

76 loss.backward() 

77 self.optimizer.step() 

78 

79 pred = torch.argmax(output, dim=1) 

80 self.metrics.update(x, y, pred, loss) 

81 

82 return loss.item(), len(x) 

83 

84 def val_batch( 

85 self, x: torch.Tensor, y: torch.Tensor, epoch: int 

86 ) -> Tuple[float, int]: 

87 """ 

88 Validate a single batch of data. 

89 

90 Args: 

91 x (torch.Tensor): Input data batch. 

92 y (torch.Tensor): Corresponding labels batch. 

93 epoch (int): The current epoch number. 

94 

95 Returns: 

96 Tuple[float, int]: The validation loss for the batch and the batch size. 

97 """ 

98 x, y = self._to_device(x, y) 

99 

100 with torch.no_grad(): 

101 output = self.model(x) 

102 loss = self.loss_func(output, y) 

103 

104 pred = torch.argmax(output, dim=1) 

105 self.metrics.update(x, y, pred, loss) 

106 

107 return loss.item(), len(x) 

108 

109 def fit( 

110 self, 

111 epochs: int, 

112 train_dataloader: DataLoader, 

113 val_dataloader: DataLoader, 

114 patience: int, 

115 checkpoint: str 

116 ) -> Dict[str, Any]: 

117 """ 

118 Train and validate the model over a given number of epochs, implementing early stopping. 

119 

120 Args: 

121 epochs (int): The total number of epochs to train. 

122 train_dataloader (DataLoader): The DataLoader for the training data. 

123 val_dataloader (DataLoader): The DataLoader for the validation data. 

124 patience (int): The number of epochs to wait for improvement before stopping early. 

125 checkpoint (str): Path to save the model checkpoints. 

126 

127 Returns: 

128 Dict[str, Any]: A dictionary containing training and validation metrics. 

129 """ 

130 wait = 0 

131 val_loss_min = np.Inf 

132 

133 metrics = { 

134 "loss": [], "acc": [], 

135 "val_loss": [], "val_acc": [], 

136 "optimizer": str(self.optimizer), 

137 "loss_func": str(self.loss_func), 

138 "epochs": epochs, "patience": patience 

139 } 

140 

141 for epoch in range(epochs): 

142 ( 

143 train_acc, 

144 train_loss, 

145 train_precision_defect, 

146 train_recall_defect, 

147 train_f1_score_defect, 

148 ) = self._process_epoch(train_dataloader, epochs, epoch, train=True) 

149 ( 

150 val_acc, 

151 val_loss, 

152 val_precision_defect, 

153 val_recall_defect, 

154 val_f1_score_defect, 

155 ) = self._process_epoch(val_dataloader, epochs, epoch, train=False) 

156 

157 # update metrics 

158 metrics = self._update_metrics(metrics, train_loss.item(), train_acc.item(), 

159 val_loss.item(), val_acc.item() 

160 ) 

161 

162 print( 

163 f"Epoch {epoch + 1}/{epochs}\n" 

164 f"Train Loss: {train_loss:.3f}, Acc: {train_acc:.3f}, " 

165 f"Recall: {train_recall_defect:1.3f}, Precision: {train_precision_defect:1.3f}, " 

166 f"F1 Score: {train_f1_score_defect:.3f}\n" 

167 f"Validation Loss: {val_loss:.3f}, Acc: {val_acc:.3f}, " 

168 f"Recall: {val_recall_defect:1.3f}, Precision: {val_precision_defect:1.3f}, " 

169 f"F1 Score: {val_f1_score_defect:.3f}" 

170 ) 

171 

172 # Checkpoint 

173 self.metrics.save_metrics(metrics, checkpoint) 

174 

175 if val_loss < val_loss_min: 

176 print(f"Validation loss decreased ({val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...") 

177 torch.save(self.model.state_dict(), os.path.join(checkpoint, "model.pth")) 

178 val_loss_min = val_loss 

179 wait = 0 

180 

181 # Early stopping 

182 else: 

183 wait += 1 

184 if wait > patience: 

185 print( 

186 f"Terminated training for early stopping at epoch {epoch + 1}" 

187 ) 

188 break 

189 

190 return metrics 

191 

192 def _to_device(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 

193 """Moves tensors to the specified device.""" 

194 return x.to(self.device), y.to(self.device) 

195 

196 def _process_epoch(self, dataloader: DataLoader, epochs: int, epoch: int, train: bool) -> Dict[str, float]: 

197 """ 

198 Process a single epoch of training or validation. 

199 

200 Args: 

201 dataloader (DataLoader): The DataLoader for the epoch. 

202 epoch (int): The current epoch number. 

203 train (bool): Flag indicating whether it's a training epoch. 

204 

205 Returns: 

206 Dict[str, float]: Metrics for the processed epoch. 

207 """ 

208 self.metrics.reset_metrics() 

209 if train: 

210 self.model.train() 

211 for x, y in tqdm( 

212 dataloader, 

213 desc=f"Epochs {epoch + 1}/{epochs} : Training", 

214 position=0, 

215 leave=True, 

216 ): 

217 self.train_batch(x, y, epoch) 

218 else: 

219 self.model.eval() 

220 for x, y in tqdm( 

221 dataloader, 

222 desc="Validation", 

223 position=0, 

224 leave=True, 

225 ): 

226 self.val_batch(x, y, epoch) 

227 

228 return self.metrics.get_metrics() 

229 

230 def _update_metrics(self, metrics: Dict[str, Any], train_loss: float, train_acc: float, 

231 val_loss: float, val_acc: float) -> None: 

232 """ 

233 Update the overall metrics dictionary with the metrics from the current epoch. 

234 

235 Args: 

236 metrics (Dict[str, Any]): The overall metrics dictionary. 

237 train_metrics (Dict[str, float]): Metrics from the training phase. 

238 val_metrics (Dict[str, float]): Metrics from the validation phase. 

239 """ 

240 metrics['loss'].append(train_loss) 

241 metrics['acc'].append(train_acc) 

242 metrics['val_loss'].append(val_loss) 

243 metrics['val_acc'].append(val_acc) 

244 

245 return metrics