Coverage for robustML/advertrain/metrics.py: 81%
53 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
1import json
2import os
3from typing import Any, Dict
5import torch
6from torch import Tensor
8from robustML.advertrain.constants import METRICS_FILE
11class Metrics:
12 """
13 Class to track performance metrics for binary classification tasks.
15 This class tracks true positives, true negatives, false positives,
16 false negatives, and cumulative loss across batches. It calculates metrics
17 like accuracy, precision, recall, and F1-score.
18 """
20 def __init__(self):
21 self.reset_metrics()
23 def reset_metrics(self) -> None:
24 """Reset all metrics to zero."""
25 self.loss = 0
26 self.TP = self.TN = self.FP = self.FN = self.P = self.N = 0
28 def update(self, x: Tensor, y: Tensor, pred: Tensor, loss: Tensor) -> None:
29 """
30 Update metrics based on inputs, ground truth, model predictions, and loss.
32 Args:
33 x (Tensor): Input tensor
34 y (Tensor): target labels
35 pred (Tensor): Model predictions
36 loss (Tensor): Batch loss
37 """
38 TP = torch.logical_and(pred == 1, y == 1)
39 TN = torch.logical_and(pred == 0, y == 0)
40 FP = torch.logical_and(pred == 1, y == 0)
41 FN = torch.logical_and(pred == 0, y == 1)
43 self.P += torch.sum(y == 1)
44 self.N += torch.sum(y == 0)
46 self.TP += torch.sum(TP)
47 self.TN += torch.sum(TN)
48 self.FP += torch.sum(FP)
49 self.FN += torch.sum(FN)
51 self.loss += loss.item() * len(x)
53 def _precision(self) -> float:
54 return self.TP / (self.TP + self.FP + 1e-8)
56 def _recall(self) -> float:
57 return self.TP / (self.P + 1e-8)
59 def _f1_score(self) -> float:
60 precision = self._precision()
61 recall = self._recall()
62 return 2 * precision * recall / (precision + recall + 1e-8)
64 def get_metrics(self) -> tuple:
65 """
66 Calculate and return key performance metrics.
68 Returns:
69 tuple: Tuple containing accuracy, loss, precision, recall, and F1-score.
70 """
71 acc = (self.TP + self.TN) / (self.P + self.N + 1e-8)
72 loss = self.loss / (self.P + self.N + 1e-8)
73 precision = self._precision()
74 recall = self._recall()
75 f1_score = self._f1_score()
77 return acc, loss, precision, recall, f1_score
79 def save_metrics(self, metrics: Dict[str, Any], checkpoint: str) -> None:
80 """
81 Save metrics in a JSON file located at `<checkpoint>/metrics.json`.
83 This function serializes the provided metrics dictionary into JSON format and
84 writes it to a file named 'metrics.json' in the specified checkpoint directory.
86 Args:
87 metrics (Dict[str, Any]): A dictionary containing metric names as keys and their corresponding values.
88 checkpoint (str): The directory path where the metrics.json file will be saved.
89 """
90 data = json.dumps(metrics)
91 with open(os.path.join(checkpoint, METRICS_FILE), "w") as f:
92 f.write(data)
94 def load_metrics(self, checkpoint: str) -> Dict[str, Any]:
95 """
96 Load metrics from a JSON file located at `<checkpoint>/metrics.json`.
98 This function reads the 'metrics.json' file from the specified checkpoint directory
99 and returns the contents as a dictionary.
101 Args:
102 checkpoint (str): The directory path from where the metrics.json file will be loaded.
104 Returns:
105 Dict[str, Any]: A dictionary containing the loaded metrics.
106 """
107 with open(os.path.join(checkpoint, METRICS_FILE), "r") as file:
108 data = json.load(file)
110 return data
112 def display(self, title: str) -> None:
113 """
114 Display the calculated metrics with a title.
116 Args:
117 title (str): The title for the metrics display.
118 """
119 acc, loss, precision, recall, f1_score = self.get_metrics()
120 print(f"{title}\n"
121 f"Loss: {loss:.3f}\t"
122 f"Acc: {acc:.3f}\t"
123 f"Recall: {recall:.3f}\t"
124 f"Precision: {precision:.3f}\t"
125 f"F1 Score: {f1_score:.3f}")
127 def display_table(self, title: str) -> None:
128 """
129 Display the metrics in a tabular format with a title.
131 Args:
132 title (str): The title for the table.
133 """
134 acc, loss, precision, recall, f1_score = self.get_metrics()
135 print(f"| {title} | {acc:.3f} | {recall:.3f} | {precision:.3f} | {f1_score:.3f} |")