Coverage for robustML/advertrain/metrics.py: 81%

53 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-10 08:11 +0000

1import json 

2import os 

3from typing import Any, Dict 

4 

5import torch 

6from torch import Tensor 

7 

8from robustML.advertrain.constants import METRICS_FILE 

9 

10 

11class Metrics: 

12 """ 

13 Class to track performance metrics for binary classification tasks. 

14 

15 This class tracks true positives, true negatives, false positives, 

16 false negatives, and cumulative loss across batches. It calculates metrics 

17 like accuracy, precision, recall, and F1-score. 

18 """ 

19 

20 def __init__(self): 

21 self.reset_metrics() 

22 

23 def reset_metrics(self) -> None: 

24 """Reset all metrics to zero.""" 

25 self.loss = 0 

26 self.TP = self.TN = self.FP = self.FN = self.P = self.N = 0 

27 

28 def update(self, x: Tensor, y: Tensor, pred: Tensor, loss: Tensor) -> None: 

29 """ 

30 Update metrics based on inputs, ground truth, model predictions, and loss. 

31 

32 Args: 

33 x (Tensor): Input tensor 

34 y (Tensor): target labels 

35 pred (Tensor): Model predictions 

36 loss (Tensor): Batch loss 

37 """ 

38 TP = torch.logical_and(pred == 1, y == 1) 

39 TN = torch.logical_and(pred == 0, y == 0) 

40 FP = torch.logical_and(pred == 1, y == 0) 

41 FN = torch.logical_and(pred == 0, y == 1) 

42 

43 self.P += torch.sum(y == 1) 

44 self.N += torch.sum(y == 0) 

45 

46 self.TP += torch.sum(TP) 

47 self.TN += torch.sum(TN) 

48 self.FP += torch.sum(FP) 

49 self.FN += torch.sum(FN) 

50 

51 self.loss += loss.item() * len(x) 

52 

53 def _precision(self) -> float: 

54 return self.TP / (self.TP + self.FP + 1e-8) 

55 

56 def _recall(self) -> float: 

57 return self.TP / (self.P + 1e-8) 

58 

59 def _f1_score(self) -> float: 

60 precision = self._precision() 

61 recall = self._recall() 

62 return 2 * precision * recall / (precision + recall + 1e-8) 

63 

64 def get_metrics(self) -> tuple: 

65 """ 

66 Calculate and return key performance metrics. 

67 

68 Returns: 

69 tuple: Tuple containing accuracy, loss, precision, recall, and F1-score. 

70 """ 

71 acc = (self.TP + self.TN) / (self.P + self.N + 1e-8) 

72 loss = self.loss / (self.P + self.N + 1e-8) 

73 precision = self._precision() 

74 recall = self._recall() 

75 f1_score = self._f1_score() 

76 

77 return acc, loss, precision, recall, f1_score 

78 

79 def save_metrics(self, metrics: Dict[str, Any], checkpoint: str) -> None: 

80 """ 

81 Save metrics in a JSON file located at `<checkpoint>/metrics.json`. 

82 

83 This function serializes the provided metrics dictionary into JSON format and 

84 writes it to a file named 'metrics.json' in the specified checkpoint directory. 

85 

86 Args: 

87 metrics (Dict[str, Any]): A dictionary containing metric names as keys and their corresponding values. 

88 checkpoint (str): The directory path where the metrics.json file will be saved. 

89 """ 

90 data = json.dumps(metrics) 

91 with open(os.path.join(checkpoint, METRICS_FILE), "w") as f: 

92 f.write(data) 

93 

94 def load_metrics(self, checkpoint: str) -> Dict[str, Any]: 

95 """ 

96 Load metrics from a JSON file located at `<checkpoint>/metrics.json`. 

97 

98 This function reads the 'metrics.json' file from the specified checkpoint directory 

99 and returns the contents as a dictionary. 

100 

101 Args: 

102 checkpoint (str): The directory path from where the metrics.json file will be loaded. 

103 

104 Returns: 

105 Dict[str, Any]: A dictionary containing the loaded metrics. 

106 """ 

107 with open(os.path.join(checkpoint, METRICS_FILE), "r") as file: 

108 data = json.load(file) 

109 

110 return data 

111 

112 def display(self, title: str) -> None: 

113 """ 

114 Display the calculated metrics with a title. 

115 

116 Args: 

117 title (str): The title for the metrics display. 

118 """ 

119 acc, loss, precision, recall, f1_score = self.get_metrics() 

120 print(f"{title}\n" 

121 f"Loss: {loss:.3f}\t" 

122 f"Acc: {acc:.3f}\t" 

123 f"Recall: {recall:.3f}\t" 

124 f"Precision: {precision:.3f}\t" 

125 f"F1 Score: {f1_score:.3f}") 

126 

127 def display_table(self, title: str) -> None: 

128 """ 

129 Display the metrics in a tabular format with a title. 

130 

131 Args: 

132 title (str): The title for the table. 

133 """ 

134 acc, loss, precision, recall, f1_score = self.get_metrics() 

135 print(f"| {title} | {acc:.3f} | {recall:.3f} | {precision:.3f} | {f1_score:.3f} |")