⬅ robustML/advertrain/training/classical_training.py source

1 import os
2 from typing import Any, Dict, Tuple
3  
4 import numpy as np
5 import torch
6 from torch.nn import Module
7 from torch.optim import Optimizer
8 from torch.utils.data import DataLoader
9 from tqdm import tqdm
10  
11 from robustML.advertrain.metrics import Metrics
12  
13  
14 class ClassicalTraining:
15 """
16 A class representing the classical training process for a PyTorch model.
17  
18 Attributes:
19 model (Module): The PyTorch model to be trained.
20 optimizer (Optimizer): The optimizer used for training.
21 loss_func: The loss function used for training.
22 device (torch.device): The device on which to train the model.
23 metrics (Metrics): An instance of Metrics to track training performance.
24 """
25  
26 def __init__(
27 self, model: Module, optimizer: Optimizer, loss_func, device: torch.device
28 ) -> None:
29 self.model = model
30 self.loss_func = loss_func
31 self.device = device
32 self.optimizer = optimizer
33 self.metrics = Metrics()
34 self.metrics.reset_metrics()
35  
36 def preprocess_batch(
37 self, x: torch.Tensor, y: torch.Tensor, epoch: int
38 ) -> Tuple[torch.Tensor, torch.Tensor]:
39 """
40 Preprocess a batch of data and labels before training or validation.
41  
42 Args:
43 x (torch.Tensor): Input data batch.
44 y (torch.Tensor): Corresponding labels batch.
45 epoch (int): The current epoch number.
46  
47 Returns:
48 Tuple[torch.Tensor, torch.Tensor]: The preprocessed data and labels.
49 """
50 return x, y
51  
52 def train_batch(
53 self, x: torch.Tensor, y: torch.Tensor, epoch: int
54 ) -> Tuple[float, int]:
55 """
56 Process and train a single batch of data.
57  
58 Args:
59 x (torch.Tensor): Input data batch.
60 y (torch.Tensor): Corresponding labels batch.
61 epoch (int): The current epoch number.
62  
63 Returns:
64 Tuple[float, int]: The training loss for the batch and the batch size.
65 """
66 x, y = self._to_device(x, y)
67 x.clamp_(0, 1)
68  
69 self.optimizer.zero_grad()
70 output = self.model(x)
71 loss = self.loss_func(output, y)
72 loss.backward()
73 self.optimizer.step()
74  
75 pred = torch.argmax(output, dim=1)
76 self.metrics.update(x, y, pred, loss)
77  
78 return loss.item(), len(x)
79 return loss.item(), len(x)
80  
81 def val_batch(
82 self, x: torch.Tensor, y: torch.Tensor, epoch: int
83 ) -> Tuple[float, int]:
84 """
85 Validate a single batch of data.
86  
87 Args:
88 x (torch.Tensor): Input data batch.
89 y (torch.Tensor): Corresponding labels batch.
90 epoch (int): The current epoch number.
91  
92 Returns:
93 Tuple[float, int]: The validation loss for the batch and the batch size.
94 """
95 x, y = self._to_device(x, y)
96  
97 with torch.no_grad():
98 output = self.model(x)
99 loss = self.loss_func(output, y)
100  
101 pred = torch.argmax(output, dim=1)
102 self.metrics.update(x, y, pred, loss)
103  
104 return loss.item(), len(x)
105  
106 def fit(
107 self,
108 epochs: int,
109 train_dataloader: DataLoader,
110 val_dataloader: DataLoader,
111 patience: int,
112 checkpoint: str
113 ) -> Dict[str, Any]:
114 """
115 Train and validate the model over a given number of epochs, implementing early stopping.
116  
117 Args:
118 epochs (int): The total number of epochs to train.
119 train_dataloader (DataLoader): The DataLoader for the training data.
120 val_dataloader (DataLoader): The DataLoader for the validation data.
121 patience (int): The number of epochs to wait for improvement before stopping early.
122 checkpoint (str): Path to save the model checkpoints.
123  
124 Returns:
125 Dict[str, Any]: A dictionary containing training and validation metrics.
126 """
127 wait = 0
128 val_loss_min = np.Inf
129  
130 metrics = {
131 "loss": [], "acc": [],
132 "val_loss": [], "val_acc": [],
133 "optimizer": str(self.optimizer),
134 "loss_func": str(self.loss_func),
135 "epochs": epochs, "patience": patience
136 }
137  
138 for epoch in range(epochs):
139 (
140 train_acc,
141 train_loss,
142 train_precision_defect,
143 train_recall_defect,
144 train_f1_score_defect,
145 ) = self._process_epoch(train_dataloader, epochs, epoch, train=True)
146 (
147 val_acc,
148 val_loss,
149 val_precision_defect,
150 val_recall_defect,
151 val_f1_score_defect,
152 ) = self._process_epoch(val_dataloader, epochs, epoch, train=False)
153  
154 # update metrics
  • E501 Line too long (121 > 120 characters)
155 metrics = self._update_metrics(metrics, train_loss.item(), train_acc.item(), val_loss.item(), val_acc.item())
156  
157 print(f"Epoch {epoch + 1}/{epochs}\n"
  • E501 Line too long (199 > 120 characters)
158 f"Train Loss: {train_loss : .3f}, Acc: {train_acc : .3f}, Recall: {train_recall_defect : 1.3f}, Precision: {train_precision_defect : 1.3f},F1 Score: {train_f1_score_defect : .3f}\n"
  • E501 Line too long (193 > 120 characters)
159 f"Validation Loss: {val_loss : .3f}, Acc: {val_acc : .3f}, Recall: {val_recall_defect : 1.3f}, Precision: {val_precision_defect : 1.3f},F1 Score: {val_f1_score_defect : .3f}")
160  
161 # Checkpoint
162 self.metrics.save_metrics(metrics, checkpoint)
163  
164 if val_loss < val_loss_min:
165 print(f"Validation loss decreased ({val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...")
166 torch.save(self.model.state_dict(), os.path.join(checkpoint, "model.pth"))
167 val_loss_min = val_loss
168 wait = 0
169  
170 # Early stopping
171 else:
172 wait += 1
173 if wait > patience:
174 print(
175 f"Terminated training for early stopping at epoch {epoch + 1}"
176 )
177 break
178  
179 return metrics
180  
181 def _to_device(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
182 """Moves tensors to the specified device."""
183 return x.to(self.device), y.to(self.device)
184  
185 def _process_epoch(self, dataloader: DataLoader, epochs: int, epoch: int, train: bool) -> Dict[str, float]:
186 """
187 Process a single epoch of training or validation.
188  
189 Args:
190 dataloader (DataLoader): The DataLoader for the epoch.
191 epoch (int): The current epoch number.
192 train (bool): Flag indicating whether it's a training epoch.
193  
194 Returns:
195 Dict[str, float]: Metrics for the processed epoch.
196 """
197 self.metrics.reset_metrics()
198 if train:
199 self.model.train()
200 for x, y in tqdm(
201 dataloader,
202 desc=f"Epochs {epoch + 1}/{epochs} : Training",
203 position=0,
204 leave=True,
205 ):
206 self.train_batch(x, y, epoch)
207 else:
208 self.model.eval()
209 for x, y in tqdm(
210 dataloader,
211 desc="Validation",
212 position=0,
213 leave=True,
214 ):
215 self.val_batch(x, y, epoch)
216  
217 return self.metrics.get_metrics()
218  
  • E501 Line too long (133 > 120 characters)
219 def _update_metrics(self, metrics: Dict[str, Any], train_loss: float, train_acc: float, val_loss: float, val_acc: float) -> None:
220 """
221 Update the overall metrics dictionary with the metrics from the current epoch.
222  
223 Args:
224 metrics (Dict[str, Any]): The overall metrics dictionary.
225 train_metrics (Dict[str, float]): Metrics from the training phase.
226 val_metrics (Dict[str, float]): Metrics from the validation phase.
227 """
228 metrics['loss'].append(train_loss)
229 metrics['acc'].append(train_acc)
230 metrics['val_loss'].append(val_loss)
231 metrics['val_acc'].append(val_acc)
232  
233 return metrics