1 """
2 This module defines a GapMetric class responsible for calculating
3 the Domain Gap distance between source and target data using various methods and models.
4 It utilizes various methods and models for this purpose.
5
6 Authors:
7 Yoann RANDON
8 Sabrina CHAOUCHE
9 Faouzi ADJED
10
11 Dependencies:
12 time
13 json
14 argparse
15 typing
16 mlflow
17 torchvision.models (resnet50, ResNet50_Weights, resnet18, ResNet18_Weights)
18 utils (DomainMeter)
19 twe_logger
20
21 Classes:
22 DomainGapMetrics: Class for calculating Central Moment Discrepancy (CMD)
23 distance between source and target data.
24
25 Functions: None
26
27 Usage:
28 1. Create an instance of GapMetric.
29 2. Parse the configuration using parse_config().
30 3. Load the CNN model using load_model(cfg).
31 4. Compute the CMD distance using compute_distance(cfg).
32 5. Log MLflow parameters using set_mlflow_params(cfg).
33 6. Process multiple tasks using process_tasks(cfg, tsk).
34 """
35
36 from typing import Dict
37
38 import torch
39 import torch.nn as nn
40 import torch.nn.functional as F
41
42 from dqm.domain_gap.utils import (
43 extract_nth_layer_feature,
44 generate_transform,
45 load_model,
46 compute_features,
47 construct_dataloader,
48 )
49
50 from scipy.stats import wasserstein_distance
51 from scipy.linalg import sqrtm
52 from scipy.linalg import eigh
53
54
55 import ot
56 import numpy as np
57
58 from sklearn import svm
59
60
61 class Metric:
62 """Base class for defining a metric."""
63
64 def __init__(self) -> None:
65 """Initialize the Metric instance."""
66 pass
67
68 def compute(self) -> float:
69 """Compute the value of the metric."""
70 pass
71
72
73 # ==========================================================================#
74 # MMD - Maximum Mean Discrepancy #
75 # ==========================================================================#
76 class MMD(Metric):
77 """Maximum Mean Discrepancy metric class defintion"""
78
79 def __init__(self) -> None:
80 super().__init__()
81
82 def __rbf_kernel(self, x, y, gamma: float) -> float:
83 """
84 Computes the Radial Basis Function (RBF) kernel between two sets of vectors.
85
86 Args:
87 x (torch.Tensor): Tensor of shape (N, D), where N is the number of samples.
88 y (torch.Tensor): Tensor of shape (M, D), where M is the number of samples.
89 gamma (float): Kernel coefficient, typically 1 / (2 * sigma^2).
90
91 Returns:
92 torch.Tensor: Kernel matrix of shape (N, M) with RBF similarities.
93 """
94 k = torch.cdist(x, y, p=2.0)
95 k = -gamma * k
96 return torch.exp(k)
97
98 def __polynomial_kernel(
99 self, x, y, degree: float, gamma: float, coefficient0: float
100 ) -> torch.Tensor:
101 """
102 Computes the Polynomial Kernel between two tensors.
103
104 The polynomial kernel is defined as:
105 K(x, y) = (γ * ⟨x, y⟩ + c) ^ d
106
107 where:
108 - ⟨x, y⟩ is the dot product of `x` and `y`
109 - γ (gamma) is a scaling factor
110 - c (coefficient0) is a bias term
111 - d (degree) is the polynomial degree
112
113 Args:
114 x (torch.Tensor): A tensor of shape (N, D), where N is the number of samples.
115 y (torch.Tensor): A tensor of shape (M, D), where M is the number of samples.
116 degree (float): The degree of the polynomial.
117 gamma (float): The scaling factor for the dot product.
118 coefficient0 (float): The bias term.
119
120 Returns:
121 torch.Tensor: A kernel matrix of shape (N, M) containing polynomial similarities.
122 """
123 k = torch.matmul(x, y) * gamma + coefficient0
124 return torch.pow(k, degree)
125
126 @torch.no_grad()
127 def compute(self, cfg) -> float:
128 """
129 Computes a domain gap metric between two datasets using a specified kernel method.
130
131 This function extracts features from source and target datasets using a deep learning model,
132 applies a specified kernel function (linear, RBF, or polynomial), and computes a similarity
133 measure between the datasets.
134
135 Args:
136 cfg (dict): Configuration dictionary containing:
137 - `DATA`:
138 - `source` (str): Path to the source dataset.
139 - `target` (str): Path to the target dataset.
140 - `batch_size` (int): Batch size for dataloaders.
141 - `width` (int): Width of input images.
142 - `height` (int): Height of input images.
143 - `norm_mean` (tuple): Mean for normalization.
144 - `norm_std` (tuple): Standard deviation for normalization.
145 - `MODEL`:
146 - `arch` (str): Model architecture.
147 - `n_layer_feature` (int): Layer from which features are extracted.
148 - `device` (str): Device to run computations ('cpu' or 'cuda').
149 - `METHOD`:
150 - `kernel` (str): Kernel type ('linear', 'rbf', 'poly').
151 - `kernel_params` (dict): Parameters for the chosen kernel.
152
153 Returns:
154 float: Computed domain gap value based on the selected kernel.
155
156 Raises:
157 AssertionError: If source and target datasets have different sizes.
158 """
159 source_folder_path = cfg["DATA"]["source"]
160 target_folder_path = cfg["DATA"]["target"]
161 batch_size = cfg["DATA"]["batch_size"]
162 image_size = (cfg["DATA"]["width"], cfg["DATA"]["height"])
163 norm_mean = cfg["DATA"]["norm_mean"]
164 norm_std = cfg["DATA"]["norm_std"]
165 model = cfg["MODEL"]["arch"]
166 n_layer_feature = cfg["MODEL"]["n_layer_feature"]
167 device = cfg["MODEL"]["device"]
168 kernel = cfg["METHOD"]["kernel"]
169 kernel_params = cfg["METHOD"]["kernel_params"]
170 device = device
171
172 transform = generate_transform(image_size, norm_mean, norm_std)
173 source_loader = construct_dataloader(source_folder_path, transform, batch_size)
174 target_loader = construct_dataloader(target_folder_path, transform, batch_size)
175
176 loaded_model = load_model(model, device)
177 feature_extractor = extract_nth_layer_feature(loaded_model, n_layer_feature)
178
179 source_features_t = compute_features(source_loader, feature_extractor, device)
180 target_features_t = compute_features(target_loader, feature_extractor, device)
181
182 # flatten features to compute on matricial features
183 source_features = source_features_t.view(source_features_t.size(0), -1)
184 target_features = target_features_t.view(target_features_t.size(0), -1)
185
186 # Both datasets (source and target) have to have the same size
187 assert len(source_features) == len(target_features)
188
189 feature_extractor.eval()
190
191 # Get the features of the source and target datasets using the model
192 if kernel == "linear":
193 xx = torch.matmul(source_features, source_features.t())
194 yy = torch.matmul(target_features, target_features.t())
195 xy = torch.matmul(source_features, target_features.t())
196
197 return torch.mean(xx + yy - 2.0 * xy).item()
198
199 if kernel == "rbf":
200 gamma = kernel_params.get("gamma", 1.0)
201 if source_features.dim() == 1:
202 source_features = torch.unsqueeze(source_features, 0)
203 if target_features.dim() == 1:
204 target_features = torch.unsqueeze(target_features, 0)
205 xx = self.__rbf_kernel(source_features, source_features, gamma)
206 yy = self.__rbf_kernel(target_features, target_features, gamma)
207 xy = self.__rbf_kernel(source_features, target_features, gamma)
208
209 return torch.mean(xx + yy - 2.0 * xy).item()
210
211 if kernel == "poly":
212 degree = kernel_params.get("degree", 3.0)
213 gamma = kernel_params.get("gamma", 1.0)
214 coefficient0 = kernel_params.get("coefficient0", 1.0)
215 xx = self.__polynomial_kernel(
216 source_features, source_features.t(), degree, gamma, coefficient0
217 )
218 yy = self.__polynomial_kernel(
219 target_features, target_features.t(), degree, gamma, coefficient0
220 )
221 xy = self.__polynomial_kernel(
222 source_features, target_features.t(), degree, gamma, coefficient0
223 )
224
225 return torch.mean(xx + yy - 2.0 * xy).item()
226
227
228 # ==========================================================================#
229 # CMD - Central Moments Discrepancy v2 #
230 # ==========================================================================#
231
232
233 class RMSELoss(nn.Module):
234 """
235 Compute the Root Mean Squared Error (RMSE) loss between the predicted values and the target values.
236
237 This class provides a PyTorch module for calculating the RMSE loss, which is a common metric for
238 evaluating the accuracy of regression models. The RMSE is the square root of the average of squared
239 differences between predicted values and target values.
240
241 Attributes:
242 mse (nn.MSELoss): Mean Squared Error loss module with reduction set to "sum".
243 eps (float): A small value added to the loss to prevent division by zero and ensure numerical stability.
244
245 Methods:
246 forward(yhat, y): Compute the RMSE loss between the predicted values `yhat` and the target values `y`.
247 """
248
249 def __init__(self, eps=0):
250 super().__init__()
251 self.mse = nn.MSELoss(reduction="sum")
252 self.eps = eps
253
254 def forward(self, yhat, y):
255 loss = torch.sqrt(self.mse(yhat, y) + self.eps)
256 return loss
257
258
259 class CMD(Metric):
260
261 def __init__(self) -> None:
262 super().__init__()
263
264 def __get_unbiased(self, n: int, k: int) -> int:
265 """
266 Computes an unbiased normalization factor for higher-order statistical moments.
267
268 This function calculates the product of `(n-1) * (n-2) * ... * (n-k+1)`,
269 which is used to adjust higher-order moment estimations to be unbiased.
270
271 Args:
272 n (int): Total number of samples.
273 k (int): Order of the moment being computed.
274
275 Returns:
276 int: The unbiased normalization factor.
277
278 Raises:
279 AssertionError: If `n <= 0`, `k <= 0`, or `n <= k`.
280 """
281 assert n > 0 and k > 0 and n > k
282 output = 1
283 for i in range(n - 1, n - k, -1):
284 output *= i
285 return output
286
287 def __compute_moments(
288 self,
289 dataloader,
290 feature_extractor,
291 k,
292 device,
293 shapes: dict,
294 axis_config: dict[str, tuple] = None,
295 apply_sigmoid: bool = True,
296 unbiased: bool = False,
297 ) -> dict:
298 """
299 Computes the first `k` statistical moments of feature maps extracted from a dataset.
300
301 Args:
302 dataloader (torch.utils.data.DataLoader): DataLoader providing batches of input data.
303 feature_extractor (callable): Function or model that extracts features from input data.
304 k (int): Number of moments to compute (e.g., mean, variance, skewness, etc.).
305 device (torch.device): Device on which to perform computations (e.g., "cuda" or "cpu").
306 shapes (dict): Dictionary mapping layer names to their corresponding tensor shapes.
307 axis_config (dict[str, tuple], optional): Dictionary specifying summation and viewing axes.
308 Defaults to `{"sum_axis": (0, 2, 3), "view_axis": (1, -1, 1, 1)}`.
309 apply_sigmoid (bool, optional): Whether to apply a sigmoid function to extracted features.
310 Defaults to True.
311 unbiased (bool, optional): Whether to apply unbiased estimation for higher-order moments.
312 Defaults to False.
313
314 Returns:
315 dict: A dictionary containing computed moments for each layer. The structure is:
316 {
317 "layer_name": {
318 0: mean tensor,
319 1: second moment tensor,
320 ...
321 k-1: kth moment tensor
322 },
323 ...
324 }
325 """
326 # Initialize axis_config if None
327 if axis_config is None:
328 axis_config = {"sum_axis": (0, 2, 3), "view_axis": (1, -1, 1, 1)}
329
330 # Initialize statistics dictionary
331 moments = {layer_name: dict() for layer_name in shapes.keys()}
332 for layer_name, shape in shapes.items():
333 channels = shape[1]
334 for j in range(k):
335 moments[layer_name][j] = torch.zeros(channels).to(device)
336
337 # Initialize normalization factors for each layer
338 nb_samples = {layer_name: 0 for layer_name in shapes.keys()} # TOTOTOTO
339
340 # Iterate through the DataLoader
341 for batch in dataloader:
342 batch = batch.to(device)
343 batch_size = batch.size(0)
344
345 # Update the sample count for normalization
346 for layer_name, shape in shapes.items():
347 nb_samples[layer_name] += batch_size * shape[2] * shape[3]
348
349 # Compute features for the current batch
350 features = feature_extractor(batch)
351
352 # Compute mean (1st moment)
353 for layer_name, feature in features.items():
354 if apply_sigmoid:
355 feature = torch.sigmoid(feature)
356 moments[layer_name][0] += feature.sum(axis_config.get("sum_axis"))
357
358 # Normalize the first moment (mean)
359 for layer_name, n in nb_samples.items():
360 moments[layer_name][0] /= n
361
362 # Compute higher-order moments (k >= 2)
363 for batch in dataloader:
364 batch = batch.to(device)
365 features = feature_extractor(batch)
366
367 for layer_name, feature in features.items():
368 if apply_sigmoid:
369 feature = torch.sigmoid(feature)
370
371 # Calculate differences from the mean
372 difference = feature - moments[layer_name][0].view(
373 axis_config.get("view_axis")
374 )
375
376 # Accumulate moments for k >= 2
377 for j in range(1, k):
378 moments[layer_name][j] += (difference ** (j + 1)).sum(
379 axis_config.get("sum_axis")
380 )
381
382 # Normalize higher-order moments
383 for layer_name, n in nb_samples.items():
384 for j in range(1, k):
385 moments[layer_name][j] /= n
386 if unbiased:
387 nb_samples_unbiased = self.__get_unbiased(n, j)
388 moments[layer_name][j] *= n**j / nb_samples_unbiased
389
390 return moments
391
392 @torch.no_grad()
393 def compute(self, cfg) -> float:
394 """
395 Compute the Central Moment Discrepancy (CMD) loss between source and target datasets using a pre-trained model.
396
397 This method calculates the CMD loss, which measures the discrepancy between the distributions of features
398 extracted from source and target datasets. The features are extracted from specified layers of the model,
399 and the loss is computed as a weighted sum of the differences in moments of the feature distributions.
400
401 Args:
402 cfg (Dict): A configuration dictionary containing the following keys:
403 - "DATA": Dictionary with data-related configurations:
404 - "source" (str): Path to the source folder containing images.
405 - "target" (str): Path to the target folder containing images.
406 - "batch_size" (int): The batch size for data loading.
407 - "width" (int): The width of the images.
408 - "height" (int): The height of the images.
409 - "norm_mean" (list of float): Mean values for image normalization.
410 - "norm_std" (list of float): Standard deviation values for image normalization.
411 - "MODEL": Dictionary with model-related configurations:
412 - "arch" (str): The architecture of the model to use.
413 - "n_layer_feature" (list of int): List of layer numbers from which to extract features.
414 - "feature_extractors_layers_weights" (list of float): Weights for each feature layer.
415 - "device" (str): The device to run the model on (e.g., "cpu" or "cuda").
416 - "METHOD": Dictionary with method-related configurations:
417 - "k" (int): The number of moments to consider in the CMD calculation.
418
419 Returns:
420 float: The computed CMD loss between the source and target datasets.
421
422 The method performs the following steps:
423 1. Constructs data loaders for the source and target datasets with specified transformations.
424 2. Loads the model and sets it up on the specified device.
425 3. Extracts features from the specified layers of the model for both datasets.
426 4. Computes the moments of the feature distributions for both datasets.
427 5. Calculates the CMD loss as a weighted sum of the differences in moments.
428 6. Returns the total CMD loss.
429
430 Raises:
431 AssertionError: If the source and target datasets do not have the same number of samples.
432 AssertionError: If the keys of the feature weights dictionary do not match the specified feature layers.
433 """
434 source_folder_path = cfg["DATA"]["source"]
435 target_folder_path = cfg["DATA"]["target"]
436 batch_size = cfg["DATA"]["batch_size"]
437 image_size = (cfg["DATA"]["width"], cfg["DATA"]["height"])
438 norm_mean = cfg["DATA"]["norm_mean"]
439 norm_std = cfg["DATA"]["norm_std"]
440 model = cfg["MODEL"]["arch"]
441 feature_extractors_layers = cfg["MODEL"]["n_layer_feature"]
442 k = cfg["METHOD"]["k"]
443 feature_extractors_layers_weights = cfg["MODEL"][
444 "feature_extractors_layers_weights"
445 ]
446 device = cfg["MODEL"]["device"]
447
448 transform = generate_transform(image_size, norm_mean, norm_std)
449 source_loader = construct_dataloader(source_folder_path, transform, batch_size)
450 target_loader = construct_dataloader(target_folder_path, transform, batch_size)
451
452 # Both datasets (source and target) have to have the same dimension (number of samples)
453 assert (
454 source_loader.dataset[0].size() == target_loader.dataset[0].size()
455 ), "dataset must have the same size"
456
457 loaded_model = load_model(model, device)
458 loaded_model.eval()
459 feature_extractor = extract_nth_layer_feature(
460 loaded_model, feature_extractors_layers
461 )
462
463 # Initialize RMSE Loss
464 rmse = RMSELoss()
465
466 # Initialize feature weights dictionary => TO DO:
467 # Add the features wights dict (layers wights dict) as an input of the function
468 feature_weights = {
469 node: weight
470 for node in feature_extractors_layers
471 for weight in feature_extractors_layers_weights
472 }
473 assert set(feature_weights.keys()) == set(feature_extractors_layers)
474 # The keys of the feature weights dict
475 # have to be the same as the return nodes specified in the cfg file
476
477 # Get channel info for each layer
478 sample = torch.randn(1, 3, image_size[1], image_size[0]) # (N,C,H,W)
479 with torch.no_grad():
480 output = feature_extractor(sample.to(device))
481 shapes = {k: v.size() for k, v in output.items()}
482
483 # Compute source moments
484 source_moments = self.__compute_moments(
485 source_loader, feature_extractor, k, device, shapes
486 )
487 target_moments = self.__compute_moments(
488 target_loader, feature_extractor, k, device, shapes
489 )
490
491 # Compute CMD Loss
492 total_loss = 0.0
493 for layer_name, weight in feature_weights.items():
494 layer_loss = 0.0
495 for statistic_order, statistic_weight in enumerate(
496 feature_extractors_layers_weights
497 ):
498 source_moment = source_moments[layer_name][statistic_order]
499 taregt_moment = target_moments[layer_name][statistic_order]
500 layer_loss += statistic_weight * rmse(source_moment, taregt_moment) / k
501 total_loss += weight * layer_loss / len(feature_weights)
502
503 return total_loss.item()
504
505
506 # ========================================================================== #
507 # PROXY-A-DISTANCE #
508 # ========================================================================== #
509
510
511 class ProxyADistance(Metric):
512 def __init__(self):
513 super().__init__()
514
515 def adapt_format_like_pred(self, y, pred):
516 """
517 Convert a list of class indices into a one-hot encoded tensor matching the format of the predictions.
518
519 This method takes a list of class indices and converts it into a one-hot encoded tensor that matches the
520 shape and format of the provided predictions tensor. This is useful for comparing ground truth labels
521 with model predictions in a consistent format.
522
523 Args:
524 y (torch.Tensor or list): A 1D tensor or list containing class indices. Each element should be an
525 integer representing the class index.
526 pred (torch.Tensor): A 2D tensor containing predicted probabilities or scores for each class.
527 The shape should be (N, C), where N is the number of samples and C is the
528 number of classes.
529
530 Returns:
531 torch.Tensor: A one-hot encoded tensor of the same shape as `pred`, where each row has a 1 at the
532 index of the true class and 0 elsewhere.
533
534 The method performs the following steps:
535 1. Initializes a zero tensor with the same shape as `pred`.
536 2. Iterates over each class index in `y` and sets the corresponding position in the new tensor to 1.
537 """
538 # iterate over pred
539 new_y_test = torch.zeros_like(pred)
540 for i in range(len(y)):
541 new_y_test[i][int(y[i])] = 1
542 return new_y_test
543
544 def function_pad(self, x, y, error_metric) -> float:
545 """
546 Computes the PAD (Presentation Attack Detection) value using SVM classifier.
547
548 Args:
549 x (np.ndarray): Training features.
550 y (np.ndarray): Training labels.
551 x_test (np.ndarray): Test features.
552 y_test (np.ndarray): Test labels.
553
554 Returns:
555 dict: A dictionary containing PAD either using MSE or MAE metric.
556 """
557 c = 1
558 kernel = "linear"
559 pad_model = svm.SVC(C=c, kernel=kernel, probability=True, verbose=0)
560 pad_model.fit(x, y)
561 pred = torch.from_numpy(pad_model.predict_proba(x))
562 adapt_y_test = self.adapt_format_like_pred(y, pred)
563
564 # Calculate the MSE
565 if error_metric == "mse":
566 error = F.mse_loss(adapt_y_test, pred)
567
568 # Calculate the MAE
569 if error_metric == "mae":
570 error = torch.mean(torch.abs(adapt_y_test - pred))
571 pad_value = 2.0 * (1 - 2.0 * error)
572
573 return pad_value
574
575 def compute_image_distance(self, cfg: Dict) -> float:
576 """
577 Compute the average image distance between source and target datasets using multiple models.
578
579 This method calculates the average image distance between features extracted from source and target
580 image datasets using multiple pre-trained models. The distance is computed using a specified evaluation
581 function for each model, and the average distance across all models is returned.
582
583 Args:
584 cfg (Dict): A configuration dictionary containing the following keys:
585 - "DATA": Dictionary with data-related configurations:
586 - "source" (str): Path to the source folder containing images.
587 - "target" (str): Path to the target folder containing images.
588 - "batch_size" (int): The batch size for data loading.
589 - "width" (int): The width of the images.
590 - "height" (int): The height of the images.
591 - "norm_mean" (list of float): Mean values for image normalization.
592 - "norm_std" (list of float): Standard deviation values for image normalization.
593 - "MODEL": Dictionary with model-related configurations:
594 - "arch" (list of str): List of model architectures to use.
595 - "n_layer_feature" (int): The layer number from which to extract features.
596 - "device" (str): The device to run the models on (e.g., "cpu" or "cuda").
597 - "METHOD": Dictionary with method-related configurations:
598 - "evaluator" (str): The evaluation function to use for computing the distance.
599
600 Returns:
601 float: The computed average image distance between the source and target datasets across all models.
602
603 The method performs the following steps:
604 1. Constructs data loaders for the source and target datasets with specified transformations.
605 2. Iterates over each model specified in the configuration.
606 3. Loads each model and sets it up on the specified device.
607 4. Extracts features from the specified layer of the model for both datasets.
608 5. Computes the combined features and labels for the source and target datasets.
609 6. Calculates the distance using the specified evaluation function.
610 7. Returns the average distance across all models.
611 """
612 source_folder_path = cfg["DATA"]["source"]
613 target_folder_path = cfg["DATA"]["target"]
614 batch_size = cfg["DATA"]["batch_size"]
615 image_size = (cfg["DATA"]["width"], cfg["DATA"]["height"])
616 norm_mean = cfg["DATA"]["norm_mean"]
617 norm_std = cfg["DATA"]["norm_std"]
618 models = cfg["MODEL"]["arch"]
619 n_layer_feature = cfg["MODEL"]["n_layer_feature"]
620 device = cfg["MODEL"]["device"]
621 evaluator = cfg["METHOD"]["evaluator"]
622
623 transform = generate_transform(image_size, norm_mean, norm_std)
624 source_loader = construct_dataloader(source_folder_path, transform, batch_size)
625 target_loader = construct_dataloader(target_folder_path, transform, batch_size)
626
627 sum_pad = 0
628 for model in models:
629 loaded_model = load_model(model, device)
630
631 feature_extractor = extract_nth_layer_feature(loaded_model, n_layer_feature)
632
633 source_features = compute_features(source_loader, feature_extractor, device)
634 target_features = compute_features(target_loader, feature_extractor, device)
635
636 combined_features = torch.cat((source_features, target_features), dim=0)
637 combined_labels = torch.cat(
638 (
639 torch.zeros(source_features.size(0)),
640 torch.ones(target_features.size(0)),
641 ),
642 dim=0,
643 )
644
645 # Compute pad
646 pad_value = self.function_pad(combined_features, combined_labels, evaluator)
647
648 sum_pad += pad_value
649
650 return sum_pad / len(models)
651
652
653 # ========================================================================== #
654 # Wasserstein_Distance #
655 # ========================================================================== #
656
657
658 class Wasserstein:
659 def __init__(self):
660 super().__init__()
661
662 def compute_cov_matrix(self, tensor):
663 """
664 Compute the covariance matrix of a given tensor.
665
666 This method calculates the covariance matrix for a given tensor, which represents a set of feature vectors.
667 The covariance matrix provides a measure of how much the dimensions of the feature vectors vary from the mean
668 with respect to each other.
669
670 Args:
671 tensor (torch.Tensor): A 2D tensor where each row represents a feature vector.
672 The tensor should have shape (N, D), where N is the number of samples
673 and D is the dimensionality of the features.
674
675 Returns:
676 torch.Tensor: The computed covariance matrix of the feature vectors, with shape (D, D).
677
678 The method performs the following steps:
679 1. Computes the mean vector of the feature vectors.
680 2. Centers the feature vectors by subtracting the mean vector.
681 3. Computes the covariance matrix using the centered feature vectors.
682 """
683 mean = torch.mean(tensor, dim=0)
684 centered_tensor = tensor - mean
685 return torch.mm(centered_tensor.t(), centered_tensor) / (tensor.shape[0] - 1)
686
687 def compute_1D_distance(self, cfg):
688 """
689 Compute the average 1D Wasserstein Distance between corresponding features from source and target datasets.
690
691 This method calculates the average 1D Wasserstein Distance between features extracted from source and target
692 image datasets using a pre-trained model. The features are extracted from a specified layer of the model,
693 and the distance is computed for each corresponding feature dimension.
694
695 Args:
696 cfg (dict): A configuration dictionary containing the following keys:
697 - "MODEL": Dictionary with model-related configurations:
698 - "arch" (str): The architecture of the model to use.
699 - "device" (str): The device to run the model on (e.g., "cpu" or "cuda").
700 - "n_layer_feature" (int): The layer number from which to extract features.
701 - "DATA": Dictionary with data-related configurations:
702 - "width" (int): The width of the images.
703 - "height" (int): The height of the images.
704 - "norm_mean" (list of float): Mean values for image normalization.
705 - "norm_std" (list of float): Standard deviation values for image normalization.
706 - "batch_size" (int): The batch size for data loading.
707 - "source" (str): Path to the source folder containing images.
708 - "target" (str): Path to the target folder containing images.
709
710 Returns:
711 float: The computed average 1D Wasserstein Distance between the source and target image features.
712
713 The method performs the following steps:
714 1. Loads the model and sets it up on the specified device.
715 2. Constructs data loaders for the source and target datasets with specified transformations.
716 3. Extracts features from the specified layer of the model for both datasets.
717 4. Computes the 1D Wasserstein Distance for each corresponding feature dimension.
718 5. Returns the average distance across all feature dimensions.
719 """
720 model = cfg["MODEL"]["arch"]
721 device = cfg["MODEL"]["device"]
722 loaded_model = load_model(model, device)
723 n_layer_feature = cfg["MODEL"]["n_layer_feature"]
724 image_size = (cfg["DATA"]["width"], cfg["DATA"]["height"])
725 norm_mean = cfg["DATA"]["norm_mean"]
726 norm_std = cfg["DATA"]["norm_std"]
727 device = cfg["MODEL"]["device"]
728 batch_size = cfg["DATA"]["batch_size"]
729 source_folder_path = cfg["DATA"]["source"]
730 target_folder_path = cfg["DATA"]["target"]
731
732 transform = generate_transform(image_size, norm_mean, norm_std)
733 source_loader = construct_dataloader(source_folder_path, transform, batch_size)
734 target_loader = construct_dataloader(target_folder_path, transform, batch_size)
735
736 loaded_model = load_model(model, device)
737 feature_extractor = extract_nth_layer_feature(loaded_model, n_layer_feature)
738
739 source_features = compute_features(source_loader, feature_extractor, device)
740 target_features = compute_features(target_loader, feature_extractor, device)
741
742 sum_wass_distance = 0
743 for n in range(min(len(source_features), len(target_features))):
744 source_feature_n = source_features[:, n]
745 target_feature_n = target_features[:, n]
746 sum_wass_distance += wasserstein_distance(
747 source_feature_n, target_feature_n
748 )
749 return sum_wass_distance / len(source_features)
750
751 def compute_slice_wasserstein_distance(self, cfg):
752 """
753 Compute the Sliced Wasserstein Distance between two sets of image features.
754
755 This method calculates the Sliced Wasserstein Distance between features extracted from source and target
756 image datasets using a pre-trained model. The features are projected onto a lower-dimensional space using
757 the eigenvectors corresponding to the largest eigenvalues of the covariance matrix. The distance is then
758 computed between these projections.
759
760 Args:
761 cfg (dict): A configuration dictionary containing the following keys:
762 - "MODEL": Dictionary with model-related configurations:
763 - "arch" (str): The architecture of the model to use.
764 - "device" (str): The device to run the model on (e.g., "cpu" or "cuda").
765 - "n_layer_feature" (int): The layer number from which to extract features.
766 - "DATA": Dictionary with data-related configurations:
767 - "width" (int): The width of the images.
768 - "height" (int): The height of the images.
769 - "norm_mean" (list of float): Mean values for image normalization.
770 - "norm_std" (list of float): Standard deviation values for image normalization.
771 - "batch_size" (int): The batch size for data loading.
772 - "source" (str): Path to the source folder containing images.
773 - "target" (str): Path to the target folder containing images.
774
775 Returns:
776 float: The computed Sliced Wasserstein Distance between the source and target image features.
777
778 The method performs the following steps:
779 1. Loads the model and sets it up on the specified device.
780 2. Constructs data loaders for the source and target datasets with specified transformations.
781 3. Extracts features from the specified layer of the model for both datasets.
782 4. Concatenates the features and computes the covariance matrix.
783 5. Computes the eigenvalues and eigenvectors of the covariance matrix.
784 6. Projects the features onto a lower-dimensional space using the eigenvectors.
785 7. Computes the Sliced Wasserstein Distance between the projected features.
786 """
787 model = cfg["MODEL"]["arch"]
788 device = cfg["MODEL"]["device"]
789
790 n_layer_feature = cfg["MODEL"]["n_layer_feature"]
791 image_size = (cfg["DATA"]["width"], cfg["DATA"]["height"])
792 norm_mean = cfg["DATA"]["norm_mean"]
793 norm_std = cfg["DATA"]["norm_std"]
794 batch_size = cfg["DATA"]["batch_size"]
795 source_folder_path = cfg["DATA"]["source"]
796 target_folder_path = cfg["DATA"]["target"]
797
798 transform = generate_transform(image_size, norm_mean, norm_std)
799 source_loader = construct_dataloader(source_folder_path, transform, batch_size)
800 target_loader = construct_dataloader(target_folder_path, transform, batch_size)
801
802 loaded_model = load_model(model, device)
803 feature_extractor = extract_nth_layer_feature(loaded_model, n_layer_feature)
804
805 source_features = compute_features(source_loader, feature_extractor, device)
806 target_features = compute_features(target_loader, feature_extractor, device)
807
808 all_features = torch.concat((source_features, target_features))
809 labels = torch.concat(
810 (torch.zeros(len(source_features)), torch.ones(len(target_features)))
811 )
812 cov_matrix = self.compute_cov_matrix(all_features)
813
814 values, vectors = eigh(cov_matrix.detach().numpy())
815
816 # Select the last two eigenvalues and corresponding eigenvectors
817 values = values[-2:] # Get the last two eigenvalues
818 vectors = vectors[:, -2:] # Get the last two eigenvectors
819 values, vectors = torch.from_numpy(values), torch.from_numpy(vectors)
820 vectors = vectors.T
821
822 new_coordinates = torch.mm(vectors, all_features.T).T
823 mask_source = labels == 0
824 mask_target = labels == 1
825
826 x0 = new_coordinates[mask_source]
827 x1 = new_coordinates[mask_target]
828
829 return ot.sliced_wasserstein_distance(x0, x1)
830
831
832 # ========================================================================== #
833 # Frechet Inception Distance #
834 # ========================================================================== #
835
836
837 class FID(Metric):
838 def __init__(self):
839 super().__init__()
840 self.model = "inception_v3"
841
842 def calculate_statistics(self, features: torch.Tensor):
843 """
844 Calculate the mean and covariance matrix of a set of features.
845
846 This method computes the mean vector and the covariance matrix for a given set of features.
847 It converts the features from a PyTorch tensor to a NumPy array for easier manipulation and
848 statistical calculations.
849
850 Args:
851 features (torch.Tensor): A 2D tensor where each row represents a feature vector.
852 The tensor should have shape (N, D), where N is the number of
853 samples and D is the dimensionality of the features.
854
855 Returns:
856 tuple: A tuple containing:
857 - mu (numpy.ndarray): The mean vector of the features, with shape (D,).
858 - sigma (numpy.ndarray): The covariance matrix of the features, with shape (D, D).
859
860 The function performs the following steps:
861 1. Converts the features tensor to a NumPy array for easier manipulation.
862 2. Computes the mean vector of the features.
863 3. Computes the covariance matrix of the features.
864 """
865 # Convert features to numpy for easier manipulation
866 features_np = features.detach().numpy()
867
868 # Compute the mean and covariance
869 mu = np.mean(features_np, axis=0)
870 sigma = np.cov(features_np, rowvar=False)
871
872 return mu, sigma
873
874 def compute_image_distance(self, cfg: dict):
875 """
876 Compute the Frechet Inception Distance (FID) between two sets of images.
877
878 This method calculates the FID between images from a source and target dataset using a pre-trained
879 InceptionV3 model to extract features. The FID is a measure of the similarity between two distributions
880 of images, commonly used to evaluate the quality of generated images.
881
882 Args:
883 cfg (dict): A configuration dictionary containing the following keys:
884 - "MODEL": Dictionary with model-related configurations:
885 - "device" (str): The device to run the model on (e.g., "cpu" or "cuda").
886 - "n_layer_feature" (int): The layer number from which to extract features.
887 - "DATA": Dictionary with data-related configurations:
888 - "width" (int): The width of the images.
889 - "height" (int): The height of the images.
890 - "norm_mean" (list of float): Mean values for image normalization.
891 - "norm_std" (list of float): Standard deviation values for image normalization.
892 - "batch_size" (int): The batch size for data loading.
893 - "source" (str): Path to the source folder containing images.
894 - "target" (str): Path to the target folder containing images.
895
896 Returns:
897 torch.Tensor: The computed FID score, representing the distance between the source and target image
898 distributions.
899
900 The method performs the following steps:
901 1. Loads the InceptionV3 model and sets it up on the specified device.
902 2. Constructs data loaders for the source and target datasets with specified transformations.
903 3. Extracts features from the specified layer of the model for both datasets.
904 4. Calculates the mean and covariance of the features for both datasets.
905 5. Computes the FID score using the means and covariances of the features.
906 6. Ensures the FID score is positive by taking the absolute value.
907 """
908 device = cfg["MODEL"]["device"]
909 n_layer_feature = cfg["MODEL"]["n_layer_feature"]
910 img_size = (cfg["DATA"]["width"], cfg["DATA"]["height"])
911 norm_mean = cfg["DATA"]["norm_mean"]
912 norm_std = cfg["DATA"]["norm_std"]
913 batch_size = cfg["DATA"]["batch_size"]
914 source_folder_path = cfg["DATA"]["source"]
915 target_folder_path = cfg["DATA"]["target"]
916
917 transform = generate_transform(img_size, norm_mean, norm_std)
918 source_loader = construct_dataloader(source_folder_path, transform, batch_size)
919 target_loader = construct_dataloader(target_folder_path, transform, batch_size)
920
921 inception_v3 = load_model(self.model, device)
922 feature_extractor = extract_nth_layer_feature(inception_v3, n_layer_feature)
923
924 # compute features as tensor
925 source_features = compute_features(source_loader, feature_extractor, device)
926 target_features = compute_features(target_loader, feature_extractor, device)
927
928 # Calculate statistics for source features
929 mu1, sigma1 = self.calculate_statistics(source_features)
930
931 # Calculate statistics for target features
932 mu2, sigma2 = self.calculate_statistics(target_features)
933
934 diff = mu1 - mu2
935
936 # Compute the square root of the product of the covariance matrices
937 covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)
938
939 if np.iscomplexobj(covmean):
940 covmean = covmean.real
941
942 fid = (
943 diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
944 )
945
946 positive_fid = torch.abs(torch.tensor(fid))
947 return positive_fid
948
949
950 # ========================================================================== #
951 # Kullback-Leibler divergence for MultiVariate Normal distribution #
952 # ========================================================================== #
953
954
955 class KLMVN(Metric):
956 """Instanciate KLMVN class to compute KLMVN metrics"""
957
958 def __init__(self):
959 super().__init__()
960
961 def calculate_statistics(self, features: torch.Tensor):
962 """
963 Calculate the mean and covariance matrix of a set of features.
964
965 This function computes the mean vector and the covariance matrix for a given set of features.
966 It ensures that the feature matrix has full rank, which is necessary for certain statistical
967 operations.
968
969 Args:
970 features (torch.Tensor): A 2D tensor where each row represents a feature vector.
971 The tensor should have shape (N, D), where N is the number of
972 samples and D is the dimensionality of the features.
973
974 Returns:
975 tuple: A tuple containing:
976 - mu (torch.Tensor): The mean vector of the features, with shape (D,).
977 - sigma (torch.Tensor): The covariance matrix of the features, with shape (D, D).
978
979 Raises:
980 AssertionError: If the feature matrix does not have full rank.
981
982 The function performs the following steps:
983 1. Computes the mean vector of the features.
984 2. Centers the features by subtracting the mean vector.
985 3. Computes the covariance matrix using the centered features.
986 4. Checks the rank of the feature matrix to ensure it has full rank.
987 """
988 # Compute the mean of the features
989 mu = torch.mean(features, dim=0)
990
991 # Center the features by subtracting the mean
992 centered_features = features - mu
993
994 # Compute the covariance matrix (similar to np.cov with rowvar=False)
995 # (N - 1) is used for unbiased estimation
996 sigma = torch.mm(centered_features.T, centered_features) / (
997 features.size(0) - 1
998 )
999
1000 # Compute the rank of the feature matrix
1001 rank_feature = torch.linalg.matrix_rank(features)
1002
1003 # Ensure the feature matrix has full rank
1004 assert rank_feature == features.size(0), "The feature matrix is not full rank."
1005
1006 return mu, sigma
1007
1008 def regularize_covariance(self, cov_matrix, epsilon=1e-6):
1009 """
1010 Regularize a covariance matrix by adding a small value to its diagonal elements.
1011
1012 This function enhances the numerical stability of a covariance matrix by adding a small constant
1013 to its diagonal. This is particularly useful when the covariance matrix is nearly singular or
1014 when performing operations that require the matrix to be positive definite.
1015
1016 Args:
1017 cov_matrix (numpy.ndarray): The covariance matrix to be regularized. It should be a square matrix.
1018 epsilon (float, optional): A small value to add to the diagonal elements of the covariance matrix.
1019 Default is 1e-6.
1020
1021 Returns:
1022 numpy.ndarray: The regularized covariance matrix with the small value added to its diagonal.
1023
1024 The function performs the following steps:
1025 1. Adds the specified `epsilon` value to the diagonal elements of the input covariance matrix.
1026 2. Returns the modified covariance matrix.
1027 """
1028 # Add a small value to the diagonal for numerical stability
1029 return cov_matrix + epsilon * np.eye(cov_matrix.shape[0])
1030
1031 def klmvn(self, mu1, cov1, mu2, cov2, device):
1032 """
1033 Compute the Kullback-Leibler (KL) divergence between two multivariate normal distributions.
1034
1035 This method calculates the KL divergence between two multivariate normal distributions defined by
1036 their mean vectors and covariance matrices. It assumes that the covariance matrices are diagonal.
1037
1038 Args:
1039 mu1 (torch.Tensor): Mean vector of the first multivariate normal distribution.
1040 cov1 (torch.Tensor): Diagonal elements of the covariance matrix of the first distribution.
1041 mu2 (torch.Tensor): Mean vector of the second multivariate normal distribution.
1042 cov2 (torch.Tensor): Diagonal elements of the covariance matrix of the second distribution.
1043 device (torch.device): The device (CPU or GPU) on which to perform the computation.
1044
1045 Returns:
1046 torch.Tensor: The computed KL divergence between the two distributions.
1047
1048 The method performs the following steps:
1049 1. Constructs diagonal covariance matrices from the provided diagonal elements.
1050 2. Creates multivariate normal distributions using the mean vectors and covariance matrices.
1051 3. Computes the KL divergence between the two distributions.
1052 """
1053 # assume diagonal matrix
1054 p_cov = torch.eye(len(cov1), device=device) * cov1
1055 q_cov = torch.eye(len(cov2), device=device) * cov2
1056
1057 # build pdf
1058 p = torch.distributions.multivariate_normal.MultivariateNormal(mu1, p_cov)
1059 q = torch.distributions.multivariate_normal.MultivariateNormal(mu2, q_cov)
1060
1061 # compute KL Divergence
1062 kld = torch.distributions.kl_divergence(p, q)
1063 return kld
1064
1065 def compute_image_distance(self, cfg: dict) -> float:
1066 """
1067 Compute the distance between image features from source and target datasets using a pre-trained model.
1068
1069 This method calculates the distance between the statistical representations of image features extracted
1070 from two datasets. It uses a pre-trained model to extract features from specified layers and computes
1071 the Kullback-Leibler divergence between the distributions of these features.
1072
1073 Args:
1074 cfg (dict): A configuration dictionary containing the following keys:
1075 - "MODEL": Dictionary with model-related configurations:
1076 - "device" (str): The device to run the model on (e.g., "cpu" or "cuda").
1077 - "arch" (str): The architecture of the model to use.
1078 - "n_layer_feature" (int): The layer number from which to extract features.
1079 - "DATA": Dictionary with data-related configurations:
1080 - "width" (int): The width of the images.
1081 - "height" (int): The height of the images.
1082 - "norm_mean" (list of float): Mean values for image normalization.
1083 - "norm_std" (list of float): Standard deviation values for image normalization.
1084 - "batch_size" (int): The batch size for data loading.
1085 - "source" (str): Path to the source folder containing images.
1086 - "target" (str): Path to the target folder containing images.
1087
1088 Returns:
1089 float: The computed distance between the source and target image features.
1090
1091 The method performs the following steps:
1092 1. Loads the model and sets it up on the specified device.
1093 2. Constructs data loaders for the source and target datasets with specified transformations.
1094 3. Extracts features from the specified layer of the model for both datasets.
1095 4. Calculates the mean and covariance of the features for both datasets.
1096 5. Regularizes the covariance matrices to ensure numerical stability.
1097 6. Computes the Kullback-Leibler divergence between the feature distributions.
1098 """
1099
1100 device = cfg["MODEL"]["device"]
1101 model = cfg["MODEL"]["arch"]
1102 n_layer_feature = cfg["MODEL"]["n_layer_feature"]
1103 img_size = (cfg["DATA"]["width"], cfg["DATA"]["height"])
1104 norm_mean = cfg["DATA"]["norm_mean"]
1105 norm_std = cfg["DATA"]["norm_std"]
1106 batch_size = cfg["DATA"]["batch_size"]
1107 source_folder_path = cfg["DATA"]["source"]
1108 target_folder_path = cfg["DATA"]["target"]
1109
1110 transform = generate_transform(img_size, norm_mean, norm_std)
1111 source_loader = construct_dataloader(source_folder_path, transform, batch_size)
1112 target_loader = construct_dataloader(target_folder_path, transform, batch_size)
1113
1114 loaded_model = load_model(model, device)
1115 feature_extractor = extract_nth_layer_feature(loaded_model, n_layer_feature)
1116
1117 # compute features as tensor
1118 source_features = compute_features(source_loader, feature_extractor, device)
1119 target_features = compute_features(target_loader, feature_extractor, device)
1120
1121 # Calculate statistics for source features
1122 mu1, cov1 = self.calculate_statistics(source_features)
1123 cov1 = self.regularize_covariance(cov1)
1124
1125 # Calculate statistics for target features
1126 mu2, cov2 = self.calculate_statistics(target_features)
1127 cov2 = self.regularize_covariance(cov2)
1128
1129 dist = self.klmvn(mu1, cov1, mu2, cov2, device)
1130 return dist