Coverage for robustAI/advertrain/training/classical_training.py: 32%
76 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-10-01 08:42 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-10-01 08:42 +0000
1"""
2This module contains the class for classic training process
3"""
5import os
6from typing import Any, Dict, Tuple
8import numpy as np
9import torch
10from torch.nn import Module
11from torch.optim import Optimizer
12from torch.utils.data import DataLoader
13from tqdm import tqdm
15from robustAI.advertrain.metrics import Metrics
18class ClassicalTraining:
19 """
20 A class representing the classical training process for a PyTorch model.
22 Attributes:
23 model (Module): The PyTorch model to be trained.
24 optimizer (Optimizer): The optimizer used for training.
25 loss_func: The loss function used for training.
26 device (torch.device): The device on which to train the model.
27 metrics (Metrics): An instance of Metrics to track training performance.
28 """
30 def __init__(
31 self, model: Module, optimizer: Optimizer, loss_func, device: torch.device
32 ) -> None:
33 self.model = model
34 self.loss_func = loss_func
35 self.device = device
36 self.optimizer = optimizer
37 self.metrics = Metrics()
38 self.metrics.reset_metrics()
40 def preprocess_batch(
41 self, x: torch.Tensor, y: torch.Tensor, epoch: int
42 ) -> Tuple[torch.Tensor, torch.Tensor]:
43 """
44 Preprocess a batch of data and labels before training or validation.
46 Args:
47 x (torch.Tensor): Input data batch.
48 y (torch.Tensor): Corresponding labels batch.
49 epoch (int): The current epoch number.
51 Returns:
52 Tuple[torch.Tensor, torch.Tensor]: The preprocessed data and labels.
53 """
54 return x, y
56 def train_batch(
57 self, x: torch.Tensor, y: torch.Tensor, epoch: int
58 ) -> Tuple[float, int]:
59 """
60 Process and train a single batch of data.
62 Args:
63 x (torch.Tensor): Input data batch.
64 y (torch.Tensor): Corresponding labels batch.
65 epoch (int): The current epoch number.
67 Returns:
68 Tuple[float, int]: The training loss for the batch and the batch size.
69 """
70 x, y = self._to_device(x, y)
71 x.clamp_(0, 1)
73 self.optimizer.zero_grad()
74 output = self.model(x)
75 loss = self.loss_func(output, y)
76 loss.backward()
77 self.optimizer.step()
79 pred = torch.argmax(output, dim=1)
80 self.metrics.update(x, y, pred, loss)
82 return loss.item(), len(x)
84 def val_batch(
85 self, x: torch.Tensor, y: torch.Tensor, epoch: int
86 ) -> Tuple[float, int]:
87 """
88 Validate a single batch of data.
90 Args:
91 x (torch.Tensor): Input data batch.
92 y (torch.Tensor): Corresponding labels batch.
93 epoch (int): The current epoch number.
95 Returns:
96 Tuple[float, int]: The validation loss for the batch and the batch size.
97 """
98 x, y = self._to_device(x, y)
100 with torch.no_grad():
101 output = self.model(x)
102 loss = self.loss_func(output, y)
104 pred = torch.argmax(output, dim=1)
105 self.metrics.update(x, y, pred, loss)
107 return loss.item(), len(x)
109 def fit(
110 self,
111 epochs: int,
112 train_dataloader: DataLoader,
113 val_dataloader: DataLoader,
114 patience: int,
115 checkpoint: str
116 ) -> Dict[str, Any]:
117 """
118 Train and validate the model over a given number of epochs, implementing early stopping.
120 Args:
121 epochs (int): The total number of epochs to train.
122 train_dataloader (DataLoader): The DataLoader for the training data.
123 val_dataloader (DataLoader): The DataLoader for the validation data.
124 patience (int): The number of epochs to wait for improvement before stopping early.
125 checkpoint (str): Path to save the model checkpoints.
127 Returns:
128 Dict[str, Any]: A dictionary containing training and validation metrics.
129 """
130 wait = 0
131 val_loss_min = np.Inf
133 metrics = {
134 "loss": [], "acc": [],
135 "val_loss": [], "val_acc": [],
136 "optimizer": str(self.optimizer),
137 "loss_func": str(self.loss_func),
138 "epochs": epochs, "patience": patience
139 }
141 for epoch in range(epochs):
142 (
143 train_acc,
144 train_loss,
145 train_precision_defect,
146 train_recall_defect,
147 train_f1_score_defect,
148 ) = self._process_epoch(train_dataloader, epochs, epoch, train=True)
149 (
150 val_acc,
151 val_loss,
152 val_precision_defect,
153 val_recall_defect,
154 val_f1_score_defect,
155 ) = self._process_epoch(val_dataloader, epochs, epoch, train=False)
157 # update metrics
158 metrics = self._update_metrics(metrics, train_loss.item(), train_acc.item(),
159 val_loss.item(), val_acc.item()
160 )
162 print(
163 f"Epoch {epoch + 1}/{epochs}\n"
164 f"Train Loss: {train_loss:.3f}, Acc: {train_acc:.3f}, "
165 f"Recall: {train_recall_defect:1.3f}, Precision: {train_precision_defect:1.3f}, "
166 f"F1 Score: {train_f1_score_defect:.3f}\n"
167 f"Validation Loss: {val_loss:.3f}, Acc: {val_acc:.3f}, "
168 f"Recall: {val_recall_defect:1.3f}, Precision: {val_precision_defect:1.3f}, "
169 f"F1 Score: {val_f1_score_defect:.3f}"
170 )
172 # Checkpoint
173 self.metrics.save_metrics(metrics, checkpoint)
175 if val_loss < val_loss_min:
176 print(f"Validation loss decreased ({val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...")
177 torch.save(self.model.state_dict(), os.path.join(checkpoint, "model.pth"))
178 val_loss_min = val_loss
179 wait = 0
181 # Early stopping
182 else:
183 wait += 1
184 if wait > patience:
185 print(
186 f"Terminated training for early stopping at epoch {epoch + 1}"
187 )
188 break
190 return metrics
192 def _to_device(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
193 """Moves tensors to the specified device."""
194 return x.to(self.device), y.to(self.device)
196 def _process_epoch(self, dataloader: DataLoader, epochs: int, epoch: int, train: bool) -> Dict[str, float]:
197 """
198 Process a single epoch of training or validation.
200 Args:
201 dataloader (DataLoader): The DataLoader for the epoch.
202 epoch (int): The current epoch number.
203 train (bool): Flag indicating whether it's a training epoch.
205 Returns:
206 Dict[str, float]: Metrics for the processed epoch.
207 """
208 self.metrics.reset_metrics()
209 if train:
210 self.model.train()
211 for x, y in tqdm(
212 dataloader,
213 desc=f"Epochs {epoch + 1}/{epochs} : Training",
214 position=0,
215 leave=True,
216 ):
217 self.train_batch(x, y, epoch)
218 else:
219 self.model.eval()
220 for x, y in tqdm(
221 dataloader,
222 desc="Validation",
223 position=0,
224 leave=True,
225 ):
226 self.val_batch(x, y, epoch)
228 return self.metrics.get_metrics()
230 def _update_metrics(self, metrics: Dict[str, Any], train_loss: float, train_acc: float,
231 val_loss: float, val_acc: float) -> None:
232 """
233 Update the overall metrics dictionary with the metrics from the current epoch.
235 Args:
236 metrics (Dict[str, Any]): The overall metrics dictionary.
237 train_metrics (Dict[str, float]): Metrics from the training phase.
238 val_metrics (Dict[str, float]): Metrics from the validation phase.
239 """
240 metrics['loss'].append(train_loss)
241 metrics['acc'].append(train_acc)
242 metrics['val_loss'].append(val_loss)
243 metrics['val_acc'].append(val_acc)
245 return metrics