Coverage for robustML/advertrain/training/classical_training.py: 32%
76 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
1import os
2from typing import Any, Dict, Tuple
4import numpy as np
5import torch
6from torch.nn import Module
7from torch.optim import Optimizer
8from torch.utils.data import DataLoader
9from tqdm import tqdm
11from robustML.advertrain.metrics import Metrics
14class ClassicalTraining:
15 """
16 A class representing the classical training process for a PyTorch model.
18 Attributes:
19 model (Module): The PyTorch model to be trained.
20 optimizer (Optimizer): The optimizer used for training.
21 loss_func: The loss function used for training.
22 device (torch.device): The device on which to train the model.
23 metrics (Metrics): An instance of Metrics to track training performance.
24 """
26 def __init__(
27 self, model: Module, optimizer: Optimizer, loss_func, device: torch.device
28 ) -> None:
29 self.model = model
30 self.loss_func = loss_func
31 self.device = device
32 self.optimizer = optimizer
33 self.metrics = Metrics()
34 self.metrics.reset_metrics()
36 def preprocess_batch(
37 self, x: torch.Tensor, y: torch.Tensor, epoch: int
38 ) -> Tuple[torch.Tensor, torch.Tensor]:
39 """
40 Preprocess a batch of data and labels before training or validation.
42 Args:
43 x (torch.Tensor): Input data batch.
44 y (torch.Tensor): Corresponding labels batch.
45 epoch (int): The current epoch number.
47 Returns:
48 Tuple[torch.Tensor, torch.Tensor]: The preprocessed data and labels.
49 """
50 return x, y
52 def train_batch(
53 self, x: torch.Tensor, y: torch.Tensor, epoch: int
54 ) -> Tuple[float, int]:
55 """
56 Process and train a single batch of data.
58 Args:
59 x (torch.Tensor): Input data batch.
60 y (torch.Tensor): Corresponding labels batch.
61 epoch (int): The current epoch number.
63 Returns:
64 Tuple[float, int]: The training loss for the batch and the batch size.
65 """
66 x, y = self._to_device(x, y)
67 x.clamp_(0, 1)
69 self.optimizer.zero_grad()
70 output = self.model(x)
71 loss = self.loss_func(output, y)
72 loss.backward()
73 self.optimizer.step()
75 pred = torch.argmax(output, dim=1)
76 self.metrics.update(x, y, pred, loss)
78 return loss.item(), len(x)
79 return loss.item(), len(x)
81 def val_batch(
82 self, x: torch.Tensor, y: torch.Tensor, epoch: int
83 ) -> Tuple[float, int]:
84 """
85 Validate a single batch of data.
87 Args:
88 x (torch.Tensor): Input data batch.
89 y (torch.Tensor): Corresponding labels batch.
90 epoch (int): The current epoch number.
92 Returns:
93 Tuple[float, int]: The validation loss for the batch and the batch size.
94 """
95 x, y = self._to_device(x, y)
97 with torch.no_grad():
98 output = self.model(x)
99 loss = self.loss_func(output, y)
101 pred = torch.argmax(output, dim=1)
102 self.metrics.update(x, y, pred, loss)
104 return loss.item(), len(x)
106 def fit(
107 self,
108 epochs: int,
109 train_dataloader: DataLoader,
110 val_dataloader: DataLoader,
111 patience: int,
112 checkpoint: str
113 ) -> Dict[str, Any]:
114 """
115 Train and validate the model over a given number of epochs, implementing early stopping.
117 Args:
118 epochs (int): The total number of epochs to train.
119 train_dataloader (DataLoader): The DataLoader for the training data.
120 val_dataloader (DataLoader): The DataLoader for the validation data.
121 patience (int): The number of epochs to wait for improvement before stopping early.
122 checkpoint (str): Path to save the model checkpoints.
124 Returns:
125 Dict[str, Any]: A dictionary containing training and validation metrics.
126 """
127 wait = 0
128 val_loss_min = np.Inf
130 metrics = {
131 "loss": [], "acc": [],
132 "val_loss": [], "val_acc": [],
133 "optimizer": str(self.optimizer),
134 "loss_func": str(self.loss_func),
135 "epochs": epochs, "patience": patience
136 }
138 for epoch in range(epochs):
139 (
140 train_acc,
141 train_loss,
142 train_precision_defect,
143 train_recall_defect,
144 train_f1_score_defect,
145 ) = self._process_epoch(train_dataloader, epochs, epoch, train=True)
146 (
147 val_acc,
148 val_loss,
149 val_precision_defect,
150 val_recall_defect,
151 val_f1_score_defect,
152 ) = self._process_epoch(val_dataloader, epochs, epoch, train=False)
154 # update metrics
155 metrics = self._update_metrics(metrics, train_loss.item(), train_acc.item(), val_loss.item(), val_acc.item())
157 print(f"Epoch {epoch + 1}/{epochs}\n"
158 f"Train Loss: {train_loss : .3f}, Acc: {train_acc : .3f}, Recall: {train_recall_defect : 1.3f}, Precision: {train_precision_defect : 1.3f},F1 Score: {train_f1_score_defect : .3f}\n"
159 f"Validation Loss: {val_loss : .3f}, Acc: {val_acc : .3f}, Recall: {val_recall_defect : 1.3f}, Precision: {val_precision_defect : 1.3f},F1 Score: {val_f1_score_defect : .3f}")
161 # Checkpoint
162 self.metrics.save_metrics(metrics, checkpoint)
164 if val_loss < val_loss_min:
165 print(f"Validation loss decreased ({val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...")
166 torch.save(self.model.state_dict(), os.path.join(checkpoint, "model.pth"))
167 val_loss_min = val_loss
168 wait = 0
170 # Early stopping
171 else:
172 wait += 1
173 if wait > patience:
174 print(
175 f"Terminated training for early stopping at epoch {epoch + 1}"
176 )
177 break
179 return metrics
181 def _to_device(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
182 """Moves tensors to the specified device."""
183 return x.to(self.device), y.to(self.device)
185 def _process_epoch(self, dataloader: DataLoader, epochs: int, epoch: int, train: bool) -> Dict[str, float]:
186 """
187 Process a single epoch of training or validation.
189 Args:
190 dataloader (DataLoader): The DataLoader for the epoch.
191 epoch (int): The current epoch number.
192 train (bool): Flag indicating whether it's a training epoch.
194 Returns:
195 Dict[str, float]: Metrics for the processed epoch.
196 """
197 self.metrics.reset_metrics()
198 if train:
199 self.model.train()
200 for x, y in tqdm(
201 dataloader,
202 desc=f"Epochs {epoch + 1}/{epochs} : Training",
203 position=0,
204 leave=True,
205 ):
206 self.train_batch(x, y, epoch)
207 else:
208 self.model.eval()
209 for x, y in tqdm(
210 dataloader,
211 desc="Validation",
212 position=0,
213 leave=True,
214 ):
215 self.val_batch(x, y, epoch)
217 return self.metrics.get_metrics()
219 def _update_metrics(self, metrics: Dict[str, Any], train_loss: float, train_acc: float, val_loss: float, val_acc: float) -> None:
220 """
221 Update the overall metrics dictionary with the metrics from the current epoch.
223 Args:
224 metrics (Dict[str, Any]): The overall metrics dictionary.
225 train_metrics (Dict[str, float]): Metrics from the training phase.
226 val_metrics (Dict[str, float]): Metrics from the validation phase.
227 """
228 metrics['loss'].append(train_loss)
229 metrics['acc'].append(train_acc)
230 metrics['val_loss'].append(val_loss)
231 metrics['val_acc'].append(val_acc)
233 return metrics