Coverage for dqm/domain_gap/metrics.py: 81%
351 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-05 14:00 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-05 14:00 +0000
1"""
2This module defines a GapMetric class responsible for calculating
3the Domain Gap distance between source and target data using various methods and models.
4It utilizes various methods and models for this purpose.
6Authors:
7 Yoann RANDON
8 Sabrina CHAOUCHE
9 Faouzi ADJED
11Dependencies:
12 time
13 json
14 argparse
15 typing
16 mlflow
17 torchvision.models (resnet50, ResNet50_Weights, resnet18, ResNet18_Weights)
18 utils (DomainMeter)
19 twe_logger
21Classes:
22 DomainGapMetrics: Class for calculating Central Moment Discrepancy (CMD)
23 distance between source and target data.
25Functions: None
27Usage:
281. Create an instance of GapMetric.
292. Parse the configuration using parse_config().
303. Load the CNN model using load_model(cfg).
314. Compute the CMD distance using compute_distance(cfg).
325. Log MLflow parameters using set_mlflow_params(cfg).
336. Process multiple tasks using process_tasks(cfg, tsk).
34"""
36from typing import Dict
38import torch
39import torch.nn as nn
40import torch.nn.functional as F
42from dqm.domain_gap.utils import (
43 extract_nth_layer_feature,
44 generate_transform,
45 load_model,
46 compute_features,
47 construct_dataloader,
48)
50from scipy.stats import wasserstein_distance
51from scipy.linalg import sqrtm
52from scipy.linalg import eigh
55import ot
56import numpy as np
58from sklearn import svm
61class Metric:
62 """Base class for defining a metric."""
64 def __init__(self) -> None:
65 """Initialize the Metric instance."""
66 pass
68 def compute(self) -> float:
69 """Compute the value of the metric."""
70 pass
73# ==========================================================================#
74# MMD - Maximum Mean Discrepancy #
75# ==========================================================================#
76class MMD(Metric):
77 """Maximum Mean Discrepancy metric class defintion"""
79 def __init__(self) -> None:
80 super().__init__()
82 def __rbf_kernel(self, x, y, gamma: float) -> float:
83 """
84 Computes the Radial Basis Function (RBF) kernel between two sets of vectors.
86 Args:
87 x (torch.Tensor): Tensor of shape (N, D), where N is the number of samples.
88 y (torch.Tensor): Tensor of shape (M, D), where M is the number of samples.
89 gamma (float): Kernel coefficient, typically 1 / (2 * sigma^2).
91 Returns:
92 torch.Tensor: Kernel matrix of shape (N, M) with RBF similarities.
93 """
94 k = torch.cdist(x, y, p=2.0)
95 k = -gamma * k
96 return torch.exp(k)
98 def __polynomial_kernel(
99 self, x, y, degree: float, gamma: float, coefficient0: float
100 ) -> torch.Tensor:
101 """
102 Computes the Polynomial Kernel between two tensors.
104 The polynomial kernel is defined as:
105 K(x, y) = (γ * ⟨x, y⟩ + c) ^ d
107 where:
108 - ⟨x, y⟩ is the dot product of `x` and `y`
109 - γ (gamma) is a scaling factor
110 - c (coefficient0) is a bias term
111 - d (degree) is the polynomial degree
113 Args:
114 x (torch.Tensor): A tensor of shape (N, D), where N is the number of samples.
115 y (torch.Tensor): A tensor of shape (M, D), where M is the number of samples.
116 degree (float): The degree of the polynomial.
117 gamma (float): The scaling factor for the dot product.
118 coefficient0 (float): The bias term.
120 Returns:
121 torch.Tensor: A kernel matrix of shape (N, M) containing polynomial similarities.
122 """
123 k = torch.matmul(x, y) * gamma + coefficient0
124 return torch.pow(k, degree)
126 @torch.no_grad()
127 def compute(self, cfg) -> float:
128 """
129 Computes a domain gap metric between two datasets using a specified kernel method.
131 This function extracts features from source and target datasets using a deep learning model,
132 applies a specified kernel function (linear, RBF, or polynomial), and computes a similarity
133 measure between the datasets.
135 Args:
136 cfg (dict): Configuration dictionary containing:
137 - `DATA`:
138 - `source` (str): Path to the source dataset.
139 - `target` (str): Path to the target dataset.
140 - `batch_size` (int): Batch size for dataloaders.
141 - `width` (int): Width of input images.
142 - `height` (int): Height of input images.
143 - `norm_mean` (tuple): Mean for normalization.
144 - `norm_std` (tuple): Standard deviation for normalization.
145 - `MODEL`:
146 - `arch` (str): Model architecture.
147 - `n_layer_feature` (int): Layer from which features are extracted.
148 - `device` (str): Device to run computations ('cpu' or 'cuda').
149 - `METHOD`:
150 - `kernel` (str): Kernel type ('linear', 'rbf', 'poly').
151 - `kernel_params` (dict): Parameters for the chosen kernel.
153 Returns:
154 float: Computed domain gap value based on the selected kernel.
156 Raises:
157 AssertionError: If source and target datasets have different sizes.
158 """
159 source_folder_path = cfg["DATA"]["source"]
160 target_folder_path = cfg["DATA"]["target"]
161 batch_size = cfg["DATA"]["batch_size"]
162 image_size = (cfg["DATA"]["width"], cfg["DATA"]["height"])
163 norm_mean = cfg["DATA"]["norm_mean"]
164 norm_std = cfg["DATA"]["norm_std"]
165 model = cfg["MODEL"]["arch"]
166 n_layer_feature = cfg["MODEL"]["n_layer_feature"]
167 device = cfg["MODEL"]["device"]
168 kernel = cfg["METHOD"]["kernel"]
169 kernel_params = cfg["METHOD"]["kernel_params"]
170 device = device
172 transform = generate_transform(image_size, norm_mean, norm_std)
173 source_loader = construct_dataloader(source_folder_path, transform, batch_size)
174 target_loader = construct_dataloader(target_folder_path, transform, batch_size)
176 loaded_model = load_model(model, device)
177 feature_extractor = extract_nth_layer_feature(loaded_model, n_layer_feature)
179 source_features_t = compute_features(source_loader, feature_extractor, device)
180 target_features_t = compute_features(target_loader, feature_extractor, device)
182 # flatten features to compute on matricial features
183 source_features = source_features_t.view(source_features_t.size(0), -1)
184 target_features = target_features_t.view(target_features_t.size(0), -1)
186 # Both datasets (source and target) have to have the same size
187 assert len(source_features) == len(target_features)
189 feature_extractor.eval()
191 # Get the features of the source and target datasets using the model
192 if kernel == "linear":
193 xx = torch.matmul(source_features, source_features.t())
194 yy = torch.matmul(target_features, target_features.t())
195 xy = torch.matmul(source_features, target_features.t())
197 return torch.mean(xx + yy - 2.0 * xy).item()
199 if kernel == "rbf":
200 gamma = kernel_params.get("gamma", 1.0)
201 if source_features.dim() == 1:
202 source_features = torch.unsqueeze(source_features, 0)
203 if target_features.dim() == 1:
204 target_features = torch.unsqueeze(target_features, 0)
205 xx = self.__rbf_kernel(source_features, source_features, gamma)
206 yy = self.__rbf_kernel(target_features, target_features, gamma)
207 xy = self.__rbf_kernel(source_features, target_features, gamma)
209 return torch.mean(xx + yy - 2.0 * xy).item()
211 if kernel == "poly":
212 degree = kernel_params.get("degree", 3.0)
213 gamma = kernel_params.get("gamma", 1.0)
214 coefficient0 = kernel_params.get("coefficient0", 1.0)
215 xx = self.__polynomial_kernel(
216 source_features, source_features.t(), degree, gamma, coefficient0
217 )
218 yy = self.__polynomial_kernel(
219 target_features, target_features.t(), degree, gamma, coefficient0
220 )
221 xy = self.__polynomial_kernel(
222 source_features, target_features.t(), degree, gamma, coefficient0
223 )
225 return torch.mean(xx + yy - 2.0 * xy).item()
228# ==========================================================================#
229# CMD - Central Moments Discrepancy v2 #
230# ==========================================================================#
233class RMSELoss(nn.Module):
234 """
235 Compute the Root Mean Squared Error (RMSE) loss between the predicted values and the target values.
237 This class provides a PyTorch module for calculating the RMSE loss, which is a common metric for
238 evaluating the accuracy of regression models. The RMSE is the square root of the average of squared
239 differences between predicted values and target values.
241 Attributes:
242 mse (nn.MSELoss): Mean Squared Error loss module with reduction set to "sum".
243 eps (float): A small value added to the loss to prevent division by zero and ensure numerical stability.
245 Methods:
246 forward(yhat, y): Compute the RMSE loss between the predicted values `yhat` and the target values `y`.
247 """
249 def __init__(self, eps=0):
250 super().__init__()
251 self.mse = nn.MSELoss(reduction="sum")
252 self.eps = eps
254 def forward(self, yhat, y):
255 loss = torch.sqrt(self.mse(yhat, y) + self.eps)
256 return loss
259class CMD(Metric):
261 def __init__(self) -> None:
262 super().__init__()
264 def __get_unbiased(self, n: int, k: int) -> int:
265 """
266 Computes an unbiased normalization factor for higher-order statistical moments.
268 This function calculates the product of `(n-1) * (n-2) * ... * (n-k+1)`,
269 which is used to adjust higher-order moment estimations to be unbiased.
271 Args:
272 n (int): Total number of samples.
273 k (int): Order of the moment being computed.
275 Returns:
276 int: The unbiased normalization factor.
278 Raises:
279 AssertionError: If `n <= 0`, `k <= 0`, or `n <= k`.
280 """
281 assert n > 0 and k > 0 and n > k
282 output = 1
283 for i in range(n - 1, n - k, -1):
284 output *= i
285 return output
287 def __compute_moments(
288 self,
289 dataloader,
290 feature_extractor,
291 k,
292 device,
293 shapes: dict,
294 axis_config: dict[str, tuple] = None,
295 apply_sigmoid: bool = True,
296 unbiased: bool = False,
297 ) -> dict:
298 """
299 Computes the first `k` statistical moments of feature maps extracted from a dataset.
301 Args:
302 dataloader (torch.utils.data.DataLoader): DataLoader providing batches of input data.
303 feature_extractor (callable): Function or model that extracts features from input data.
304 k (int): Number of moments to compute (e.g., mean, variance, skewness, etc.).
305 device (torch.device): Device on which to perform computations (e.g., "cuda" or "cpu").
306 shapes (dict): Dictionary mapping layer names to their corresponding tensor shapes.
307 axis_config (dict[str, tuple], optional): Dictionary specifying summation and viewing axes.
308 Defaults to `{"sum_axis": (0, 2, 3), "view_axis": (1, -1, 1, 1)}`.
309 apply_sigmoid (bool, optional): Whether to apply a sigmoid function to extracted features.
310 Defaults to True.
311 unbiased (bool, optional): Whether to apply unbiased estimation for higher-order moments.
312 Defaults to False.
314 Returns:
315 dict: A dictionary containing computed moments for each layer. The structure is:
316 {
317 "layer_name": {
318 0: mean tensor,
319 1: second moment tensor,
320 ...
321 k-1: kth moment tensor
322 },
323 ...
324 }
325 """
326 # Initialize axis_config if None
327 if axis_config is None:
328 axis_config = {"sum_axis": (0, 2, 3), "view_axis": (1, -1, 1, 1)}
330 # Initialize statistics dictionary
331 moments = {layer_name: dict() for layer_name in shapes.keys()}
332 for layer_name, shape in shapes.items():
333 channels = shape[1]
334 for j in range(k):
335 moments[layer_name][j] = torch.zeros(channels).to(device)
337 # Initialize normalization factors for each layer
338 nb_samples = {layer_name: 0 for layer_name in shapes.keys()} # TOTOTOTO
340 # Iterate through the DataLoader
341 for batch in dataloader:
342 batch = batch.to(device)
343 batch_size = batch.size(0)
345 # Update the sample count for normalization
346 for layer_name, shape in shapes.items():
347 nb_samples[layer_name] += batch_size * shape[2] * shape[3]
349 # Compute features for the current batch
350 features = feature_extractor(batch)
352 # Compute mean (1st moment)
353 for layer_name, feature in features.items():
354 if apply_sigmoid:
355 feature = torch.sigmoid(feature)
356 moments[layer_name][0] += feature.sum(axis_config.get("sum_axis"))
358 # Normalize the first moment (mean)
359 for layer_name, n in nb_samples.items():
360 moments[layer_name][0] /= n
362 # Compute higher-order moments (k >= 2)
363 for batch in dataloader:
364 batch = batch.to(device)
365 features = feature_extractor(batch)
367 for layer_name, feature in features.items():
368 if apply_sigmoid:
369 feature = torch.sigmoid(feature)
371 # Calculate differences from the mean
372 difference = feature - moments[layer_name][0].view(
373 axis_config.get("view_axis")
374 )
376 # Accumulate moments for k >= 2
377 for j in range(1, k):
378 moments[layer_name][j] += (difference ** (j + 1)).sum(
379 axis_config.get("sum_axis")
380 )
382 # Normalize higher-order moments
383 for layer_name, n in nb_samples.items():
384 for j in range(1, k):
385 moments[layer_name][j] /= n
386 if unbiased:
387 nb_samples_unbiased = self.__get_unbiased(n, j)
388 moments[layer_name][j] *= n**j / nb_samples_unbiased
390 return moments
392 @torch.no_grad()
393 def compute(self, cfg) -> float:
394 """
395 Compute the Central Moment Discrepancy (CMD) loss between source and target datasets using a pre-trained model.
397 This method calculates the CMD loss, which measures the discrepancy between the distributions of features
398 extracted from source and target datasets. The features are extracted from specified layers of the model,
399 and the loss is computed as a weighted sum of the differences in moments of the feature distributions.
401 Args:
402 cfg (Dict): A configuration dictionary containing the following keys:
403 - "DATA": Dictionary with data-related configurations:
404 - "source" (str): Path to the source folder containing images.
405 - "target" (str): Path to the target folder containing images.
406 - "batch_size" (int): The batch size for data loading.
407 - "width" (int): The width of the images.
408 - "height" (int): The height of the images.
409 - "norm_mean" (list of float): Mean values for image normalization.
410 - "norm_std" (list of float): Standard deviation values for image normalization.
411 - "MODEL": Dictionary with model-related configurations:
412 - "arch" (str): The architecture of the model to use.
413 - "n_layer_feature" (list of int): List of layer numbers from which to extract features.
414 - "feature_extractors_layers_weights" (list of float): Weights for each feature layer.
415 - "device" (str): The device to run the model on (e.g., "cpu" or "cuda").
416 - "METHOD": Dictionary with method-related configurations:
417 - "k" (int): The number of moments to consider in the CMD calculation.
419 Returns:
420 float: The computed CMD loss between the source and target datasets.
422 The method performs the following steps:
423 1. Constructs data loaders for the source and target datasets with specified transformations.
424 2. Loads the model and sets it up on the specified device.
425 3. Extracts features from the specified layers of the model for both datasets.
426 4. Computes the moments of the feature distributions for both datasets.
427 5. Calculates the CMD loss as a weighted sum of the differences in moments.
428 6. Returns the total CMD loss.
430 Raises:
431 AssertionError: If the source and target datasets do not have the same number of samples.
432 AssertionError: If the keys of the feature weights dictionary do not match the specified feature layers.
433 """
434 source_folder_path = cfg["DATA"]["source"]
435 target_folder_path = cfg["DATA"]["target"]
436 batch_size = cfg["DATA"]["batch_size"]
437 image_size = (cfg["DATA"]["width"], cfg["DATA"]["height"])
438 norm_mean = cfg["DATA"]["norm_mean"]
439 norm_std = cfg["DATA"]["norm_std"]
440 model = cfg["MODEL"]["arch"]
441 feature_extractors_layers = cfg["MODEL"]["n_layer_feature"]
442 k = cfg["METHOD"]["k"]
443 feature_extractors_layers_weights = cfg["MODEL"][
444 "feature_extractors_layers_weights"
445 ]
446 device = cfg["MODEL"]["device"]
448 transform = generate_transform(image_size, norm_mean, norm_std)
449 source_loader = construct_dataloader(source_folder_path, transform, batch_size)
450 target_loader = construct_dataloader(target_folder_path, transform, batch_size)
452 # Both datasets (source and target) have to have the same dimension (number of samples)
453 assert (
454 source_loader.dataset[0].size() == target_loader.dataset[0].size()
455 ), "dataset must have the same size"
457 loaded_model = load_model(model, device)
458 loaded_model.eval()
459 feature_extractor = extract_nth_layer_feature(
460 loaded_model, feature_extractors_layers
461 )
463 # Initialize RMSE Loss
464 rmse = RMSELoss()
466 # Initialize feature weights dictionary => TO DO:
467 # Add the features wights dict (layers wights dict) as an input of the function
468 feature_weights = {
469 node: weight
470 for node in feature_extractors_layers
471 for weight in feature_extractors_layers_weights
472 }
473 assert set(feature_weights.keys()) == set(feature_extractors_layers)
474 # The keys of the feature weights dict
475 # have to be the same as the return nodes specified in the cfg file
477 # Get channel info for each layer
478 sample = torch.randn(1, 3, image_size[1], image_size[0]) # (N,C,H,W)
479 with torch.no_grad():
480 output = feature_extractor(sample.to(device))
481 shapes = {k: v.size() for k, v in output.items()}
483 # Compute source moments
484 source_moments = self.__compute_moments(
485 source_loader, feature_extractor, k, device, shapes
486 )
487 target_moments = self.__compute_moments(
488 target_loader, feature_extractor, k, device, shapes
489 )
491 # Compute CMD Loss
492 total_loss = 0.0
493 for layer_name, weight in feature_weights.items():
494 layer_loss = 0.0
495 for statistic_order, statistic_weight in enumerate(
496 feature_extractors_layers_weights
497 ):
498 source_moment = source_moments[layer_name][statistic_order]
499 taregt_moment = target_moments[layer_name][statistic_order]
500 layer_loss += statistic_weight * rmse(source_moment, taregt_moment) / k
501 total_loss += weight * layer_loss / len(feature_weights)
503 return total_loss.item()
506# ========================================================================== #
507# PROXY-A-DISTANCE #
508# ========================================================================== #
511class ProxyADistance(Metric):
512 def __init__(self):
513 super().__init__()
515 def adapt_format_like_pred(self, y, pred):
516 """
517 Convert a list of class indices into a one-hot encoded tensor matching the format of the predictions.
519 This method takes a list of class indices and converts it into a one-hot encoded tensor that matches the
520 shape and format of the provided predictions tensor. This is useful for comparing ground truth labels
521 with model predictions in a consistent format.
523 Args:
524 y (torch.Tensor or list): A 1D tensor or list containing class indices. Each element should be an
525 integer representing the class index.
526 pred (torch.Tensor): A 2D tensor containing predicted probabilities or scores for each class.
527 The shape should be (N, C), where N is the number of samples and C is the
528 number of classes.
530 Returns:
531 torch.Tensor: A one-hot encoded tensor of the same shape as `pred`, where each row has a 1 at the
532 index of the true class and 0 elsewhere.
534 The method performs the following steps:
535 1. Initializes a zero tensor with the same shape as `pred`.
536 2. Iterates over each class index in `y` and sets the corresponding position in the new tensor to 1.
537 """
538 # iterate over pred
539 new_y_test = torch.zeros_like(pred)
540 for i in range(len(y)):
541 new_y_test[i][int(y[i])] = 1
542 return new_y_test
544 def function_pad(self, x, y, error_metric) -> float:
545 """
546 Computes the PAD (Presentation Attack Detection) value using SVM classifier.
548 Args:
549 x (np.ndarray): Training features.
550 y (np.ndarray): Training labels.
551 x_test (np.ndarray): Test features.
552 y_test (np.ndarray): Test labels.
554 Returns:
555 dict: A dictionary containing PAD either using MSE or MAE metric.
556 """
557 c = 1
558 kernel = "linear"
559 pad_model = svm.SVC(C=c, kernel=kernel, probability=True, verbose=0)
560 pad_model.fit(x, y)
561 pred = torch.from_numpy(pad_model.predict_proba(x))
562 adapt_y_test = self.adapt_format_like_pred(y, pred)
564 # Calculate the MSE
565 if error_metric == "mse":
566 error = F.mse_loss(adapt_y_test, pred)
568 # Calculate the MAE
569 if error_metric == "mae":
570 error = torch.mean(torch.abs(adapt_y_test - pred))
571 pad_value = 2.0 * (1 - 2.0 * error)
573 return pad_value
575 def compute_image_distance(self, cfg: Dict) -> float:
576 """
577 Compute the average image distance between source and target datasets using multiple models.
579 This method calculates the average image distance between features extracted from source and target
580 image datasets using multiple pre-trained models. The distance is computed using a specified evaluation
581 function for each model, and the average distance across all models is returned.
583 Args:
584 cfg (Dict): A configuration dictionary containing the following keys:
585 - "DATA": Dictionary with data-related configurations:
586 - "source" (str): Path to the source folder containing images.
587 - "target" (str): Path to the target folder containing images.
588 - "batch_size" (int): The batch size for data loading.
589 - "width" (int): The width of the images.
590 - "height" (int): The height of the images.
591 - "norm_mean" (list of float): Mean values for image normalization.
592 - "norm_std" (list of float): Standard deviation values for image normalization.
593 - "MODEL": Dictionary with model-related configurations:
594 - "arch" (list of str): List of model architectures to use.
595 - "n_layer_feature" (int): The layer number from which to extract features.
596 - "device" (str): The device to run the models on (e.g., "cpu" or "cuda").
597 - "METHOD": Dictionary with method-related configurations:
598 - "evaluator" (str): The evaluation function to use for computing the distance.
600 Returns:
601 float: The computed average image distance between the source and target datasets across all models.
603 The method performs the following steps:
604 1. Constructs data loaders for the source and target datasets with specified transformations.
605 2. Iterates over each model specified in the configuration.
606 3. Loads each model and sets it up on the specified device.
607 4. Extracts features from the specified layer of the model for both datasets.
608 5. Computes the combined features and labels for the source and target datasets.
609 6. Calculates the distance using the specified evaluation function.
610 7. Returns the average distance across all models.
611 """
612 source_folder_path = cfg["DATA"]["source"]
613 target_folder_path = cfg["DATA"]["target"]
614 batch_size = cfg["DATA"]["batch_size"]
615 image_size = (cfg["DATA"]["width"], cfg["DATA"]["height"])
616 norm_mean = cfg["DATA"]["norm_mean"]
617 norm_std = cfg["DATA"]["norm_std"]
618 models = cfg["MODEL"]["arch"]
619 n_layer_feature = cfg["MODEL"]["n_layer_feature"]
620 device = cfg["MODEL"]["device"]
621 evaluator = cfg["METHOD"]["evaluator"]
623 transform = generate_transform(image_size, norm_mean, norm_std)
624 source_loader = construct_dataloader(source_folder_path, transform, batch_size)
625 target_loader = construct_dataloader(target_folder_path, transform, batch_size)
627 sum_pad = 0
628 for model in models:
629 loaded_model = load_model(model, device)
631 feature_extractor = extract_nth_layer_feature(loaded_model, n_layer_feature)
633 source_features = compute_features(source_loader, feature_extractor, device)
634 target_features = compute_features(target_loader, feature_extractor, device)
636 combined_features = torch.cat((source_features, target_features), dim=0)
637 combined_labels = torch.cat(
638 (
639 torch.zeros(source_features.size(0)),
640 torch.ones(target_features.size(0)),
641 ),
642 dim=0,
643 )
645 # Compute pad
646 pad_value = self.function_pad(combined_features, combined_labels, evaluator)
648 sum_pad += pad_value
650 return sum_pad / len(models)
653# ========================================================================== #
654# Wasserstein_Distance #
655# ========================================================================== #
658class Wasserstein:
659 def __init__(self):
660 super().__init__()
662 def compute_cov_matrix(self, tensor):
663 """
664 Compute the covariance matrix of a given tensor.
666 This method calculates the covariance matrix for a given tensor, which represents a set of feature vectors.
667 The covariance matrix provides a measure of how much the dimensions of the feature vectors vary from the mean
668 with respect to each other.
670 Args:
671 tensor (torch.Tensor): A 2D tensor where each row represents a feature vector.
672 The tensor should have shape (N, D), where N is the number of samples
673 and D is the dimensionality of the features.
675 Returns:
676 torch.Tensor: The computed covariance matrix of the feature vectors, with shape (D, D).
678 The method performs the following steps:
679 1. Computes the mean vector of the feature vectors.
680 2. Centers the feature vectors by subtracting the mean vector.
681 3. Computes the covariance matrix using the centered feature vectors.
682 """
683 mean = torch.mean(tensor, dim=0)
684 centered_tensor = tensor - mean
685 return torch.mm(centered_tensor.t(), centered_tensor) / (tensor.shape[0] - 1)
687 def compute_1D_distance(self, cfg):
688 """
689 Compute the average 1D Wasserstein Distance between corresponding features from source and target datasets.
691 This method calculates the average 1D Wasserstein Distance between features extracted from source and target
692 image datasets using a pre-trained model. The features are extracted from a specified layer of the model,
693 and the distance is computed for each corresponding feature dimension.
695 Args:
696 cfg (dict): A configuration dictionary containing the following keys:
697 - "MODEL": Dictionary with model-related configurations:
698 - "arch" (str): The architecture of the model to use.
699 - "device" (str): The device to run the model on (e.g., "cpu" or "cuda").
700 - "n_layer_feature" (int): The layer number from which to extract features.
701 - "DATA": Dictionary with data-related configurations:
702 - "width" (int): The width of the images.
703 - "height" (int): The height of the images.
704 - "norm_mean" (list of float): Mean values for image normalization.
705 - "norm_std" (list of float): Standard deviation values for image normalization.
706 - "batch_size" (int): The batch size for data loading.
707 - "source" (str): Path to the source folder containing images.
708 - "target" (str): Path to the target folder containing images.
710 Returns:
711 float: The computed average 1D Wasserstein Distance between the source and target image features.
713 The method performs the following steps:
714 1. Loads the model and sets it up on the specified device.
715 2. Constructs data loaders for the source and target datasets with specified transformations.
716 3. Extracts features from the specified layer of the model for both datasets.
717 4. Computes the 1D Wasserstein Distance for each corresponding feature dimension.
718 5. Returns the average distance across all feature dimensions.
719 """
720 model = cfg["MODEL"]["arch"]
721 device = cfg["MODEL"]["device"]
722 loaded_model = load_model(model, device)
723 n_layer_feature = cfg["MODEL"]["n_layer_feature"]
724 image_size = (cfg["DATA"]["width"], cfg["DATA"]["height"])
725 norm_mean = cfg["DATA"]["norm_mean"]
726 norm_std = cfg["DATA"]["norm_std"]
727 device = cfg["MODEL"]["device"]
728 batch_size = cfg["DATA"]["batch_size"]
729 source_folder_path = cfg["DATA"]["source"]
730 target_folder_path = cfg["DATA"]["target"]
732 transform = generate_transform(image_size, norm_mean, norm_std)
733 source_loader = construct_dataloader(source_folder_path, transform, batch_size)
734 target_loader = construct_dataloader(target_folder_path, transform, batch_size)
736 loaded_model = load_model(model, device)
737 feature_extractor = extract_nth_layer_feature(loaded_model, n_layer_feature)
739 source_features = compute_features(source_loader, feature_extractor, device)
740 target_features = compute_features(target_loader, feature_extractor, device)
742 sum_wass_distance = 0
743 for n in range(min(len(source_features), len(target_features))):
744 source_feature_n = source_features[:, n]
745 target_feature_n = target_features[:, n]
746 sum_wass_distance += wasserstein_distance(
747 source_feature_n, target_feature_n
748 )
749 return sum_wass_distance / len(source_features)
751 def compute_slice_wasserstein_distance(self, cfg):
752 """
753 Compute the Sliced Wasserstein Distance between two sets of image features.
755 This method calculates the Sliced Wasserstein Distance between features extracted from source and target
756 image datasets using a pre-trained model. The features are projected onto a lower-dimensional space using
757 the eigenvectors corresponding to the largest eigenvalues of the covariance matrix. The distance is then
758 computed between these projections.
760 Args:
761 cfg (dict): A configuration dictionary containing the following keys:
762 - "MODEL": Dictionary with model-related configurations:
763 - "arch" (str): The architecture of the model to use.
764 - "device" (str): The device to run the model on (e.g., "cpu" or "cuda").
765 - "n_layer_feature" (int): The layer number from which to extract features.
766 - "DATA": Dictionary with data-related configurations:
767 - "width" (int): The width of the images.
768 - "height" (int): The height of the images.
769 - "norm_mean" (list of float): Mean values for image normalization.
770 - "norm_std" (list of float): Standard deviation values for image normalization.
771 - "batch_size" (int): The batch size for data loading.
772 - "source" (str): Path to the source folder containing images.
773 - "target" (str): Path to the target folder containing images.
775 Returns:
776 float: The computed Sliced Wasserstein Distance between the source and target image features.
778 The method performs the following steps:
779 1. Loads the model and sets it up on the specified device.
780 2. Constructs data loaders for the source and target datasets with specified transformations.
781 3. Extracts features from the specified layer of the model for both datasets.
782 4. Concatenates the features and computes the covariance matrix.
783 5. Computes the eigenvalues and eigenvectors of the covariance matrix.
784 6. Projects the features onto a lower-dimensional space using the eigenvectors.
785 7. Computes the Sliced Wasserstein Distance between the projected features.
786 """
787 model = cfg["MODEL"]["arch"]
788 device = cfg["MODEL"]["device"]
790 n_layer_feature = cfg["MODEL"]["n_layer_feature"]
791 image_size = (cfg["DATA"]["width"], cfg["DATA"]["height"])
792 norm_mean = cfg["DATA"]["norm_mean"]
793 norm_std = cfg["DATA"]["norm_std"]
794 batch_size = cfg["DATA"]["batch_size"]
795 source_folder_path = cfg["DATA"]["source"]
796 target_folder_path = cfg["DATA"]["target"]
798 transform = generate_transform(image_size, norm_mean, norm_std)
799 source_loader = construct_dataloader(source_folder_path, transform, batch_size)
800 target_loader = construct_dataloader(target_folder_path, transform, batch_size)
802 loaded_model = load_model(model, device)
803 feature_extractor = extract_nth_layer_feature(loaded_model, n_layer_feature)
805 source_features = compute_features(source_loader, feature_extractor, device)
806 target_features = compute_features(target_loader, feature_extractor, device)
808 all_features = torch.concat((source_features, target_features))
809 labels = torch.concat(
810 (torch.zeros(len(source_features)), torch.ones(len(target_features)))
811 )
812 cov_matrix = self.compute_cov_matrix(all_features)
814 values, vectors = eigh(cov_matrix.detach().numpy())
816 # Select the last two eigenvalues and corresponding eigenvectors
817 values = values[-2:] # Get the last two eigenvalues
818 vectors = vectors[:, -2:] # Get the last two eigenvectors
819 values, vectors = torch.from_numpy(values), torch.from_numpy(vectors)
820 vectors = vectors.T
822 new_coordinates = torch.mm(vectors, all_features.T).T
823 mask_source = labels == 0
824 mask_target = labels == 1
826 x0 = new_coordinates[mask_source]
827 x1 = new_coordinates[mask_target]
829 return ot.sliced_wasserstein_distance(x0, x1)
832# ========================================================================== #
833# Frechet Inception Distance #
834# ========================================================================== #
837class FID(Metric):
838 def __init__(self):
839 super().__init__()
840 self.model = "inception_v3"
842 def calculate_statistics(self, features: torch.Tensor):
843 """
844 Calculate the mean and covariance matrix of a set of features.
846 This method computes the mean vector and the covariance matrix for a given set of features.
847 It converts the features from a PyTorch tensor to a NumPy array for easier manipulation and
848 statistical calculations.
850 Args:
851 features (torch.Tensor): A 2D tensor where each row represents a feature vector.
852 The tensor should have shape (N, D), where N is the number of
853 samples and D is the dimensionality of the features.
855 Returns:
856 tuple: A tuple containing:
857 - mu (numpy.ndarray): The mean vector of the features, with shape (D,).
858 - sigma (numpy.ndarray): The covariance matrix of the features, with shape (D, D).
860 The function performs the following steps:
861 1. Converts the features tensor to a NumPy array for easier manipulation.
862 2. Computes the mean vector of the features.
863 3. Computes the covariance matrix of the features.
864 """
865 # Convert features to numpy for easier manipulation
866 features_np = features.detach().numpy()
868 # Compute the mean and covariance
869 mu = np.mean(features_np, axis=0)
870 sigma = np.cov(features_np, rowvar=False)
872 return mu, sigma
874 def compute_image_distance(self, cfg: dict):
875 """
876 Compute the Frechet Inception Distance (FID) between two sets of images.
878 This method calculates the FID between images from a source and target dataset using a pre-trained
879 InceptionV3 model to extract features. The FID is a measure of the similarity between two distributions
880 of images, commonly used to evaluate the quality of generated images.
882 Args:
883 cfg (dict): A configuration dictionary containing the following keys:
884 - "MODEL": Dictionary with model-related configurations:
885 - "device" (str): The device to run the model on (e.g., "cpu" or "cuda").
886 - "n_layer_feature" (int): The layer number from which to extract features.
887 - "DATA": Dictionary with data-related configurations:
888 - "width" (int): The width of the images.
889 - "height" (int): The height of the images.
890 - "norm_mean" (list of float): Mean values for image normalization.
891 - "norm_std" (list of float): Standard deviation values for image normalization.
892 - "batch_size" (int): The batch size for data loading.
893 - "source" (str): Path to the source folder containing images.
894 - "target" (str): Path to the target folder containing images.
896 Returns:
897 torch.Tensor: The computed FID score, representing the distance between the source and target image
898 distributions.
900 The method performs the following steps:
901 1. Loads the InceptionV3 model and sets it up on the specified device.
902 2. Constructs data loaders for the source and target datasets with specified transformations.
903 3. Extracts features from the specified layer of the model for both datasets.
904 4. Calculates the mean and covariance of the features for both datasets.
905 5. Computes the FID score using the means and covariances of the features.
906 6. Ensures the FID score is positive by taking the absolute value.
907 """
908 device = cfg["MODEL"]["device"]
909 n_layer_feature = cfg["MODEL"]["n_layer_feature"]
910 img_size = (cfg["DATA"]["width"], cfg["DATA"]["height"])
911 norm_mean = cfg["DATA"]["norm_mean"]
912 norm_std = cfg["DATA"]["norm_std"]
913 batch_size = cfg["DATA"]["batch_size"]
914 source_folder_path = cfg["DATA"]["source"]
915 target_folder_path = cfg["DATA"]["target"]
917 transform = generate_transform(img_size, norm_mean, norm_std)
918 source_loader = construct_dataloader(source_folder_path, transform, batch_size)
919 target_loader = construct_dataloader(target_folder_path, transform, batch_size)
921 inception_v3 = load_model(self.model, device)
922 feature_extractor = extract_nth_layer_feature(inception_v3, n_layer_feature)
924 # compute features as tensor
925 source_features = compute_features(source_loader, feature_extractor, device)
926 target_features = compute_features(target_loader, feature_extractor, device)
928 # Calculate statistics for source features
929 mu1, sigma1 = self.calculate_statistics(source_features)
931 # Calculate statistics for target features
932 mu2, sigma2 = self.calculate_statistics(target_features)
934 diff = mu1 - mu2
936 # Compute the square root of the product of the covariance matrices
937 covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)
939 if np.iscomplexobj(covmean):
940 covmean = covmean.real
942 fid = (
943 diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
944 )
946 positive_fid = torch.abs(torch.tensor(fid))
947 return positive_fid
950# ========================================================================== #
951# Kullback-Leibler divergence for MultiVariate Normal distribution #
952# ========================================================================== #
955class KLMVN(Metric):
956 """Instanciate KLMVN class to compute KLMVN metrics"""
958 def __init__(self):
959 super().__init__()
961 def calculate_statistics(self, features: torch.Tensor):
962 """
963 Calculate the mean and covariance matrix of a set of features.
965 This function computes the mean vector and the covariance matrix for a given set of features.
966 It ensures that the feature matrix has full rank, which is necessary for certain statistical
967 operations.
969 Args:
970 features (torch.Tensor): A 2D tensor where each row represents a feature vector.
971 The tensor should have shape (N, D), where N is the number of
972 samples and D is the dimensionality of the features.
974 Returns:
975 tuple: A tuple containing:
976 - mu (torch.Tensor): The mean vector of the features, with shape (D,).
977 - sigma (torch.Tensor): The covariance matrix of the features, with shape (D, D).
979 Raises:
980 AssertionError: If the feature matrix does not have full rank.
982 The function performs the following steps:
983 1. Computes the mean vector of the features.
984 2. Centers the features by subtracting the mean vector.
985 3. Computes the covariance matrix using the centered features.
986 4. Checks the rank of the feature matrix to ensure it has full rank.
987 """
988 # Compute the mean of the features
989 mu = torch.mean(features, dim=0)
991 # Center the features by subtracting the mean
992 centered_features = features - mu
994 # Compute the covariance matrix (similar to np.cov with rowvar=False)
995 # (N - 1) is used for unbiased estimation
996 sigma = torch.mm(centered_features.T, centered_features) / (
997 features.size(0) - 1
998 )
1000 # Compute the rank of the feature matrix
1001 rank_feature = torch.linalg.matrix_rank(features)
1003 # Ensure the feature matrix has full rank
1004 assert rank_feature == features.size(0), "The feature matrix is not full rank."
1006 return mu, sigma
1008 def regularize_covariance(self, cov_matrix, epsilon=1e-6):
1009 """
1010 Regularize a covariance matrix by adding a small value to its diagonal elements.
1012 This function enhances the numerical stability of a covariance matrix by adding a small constant
1013 to its diagonal. This is particularly useful when the covariance matrix is nearly singular or
1014 when performing operations that require the matrix to be positive definite.
1016 Args:
1017 cov_matrix (numpy.ndarray): The covariance matrix to be regularized. It should be a square matrix.
1018 epsilon (float, optional): A small value to add to the diagonal elements of the covariance matrix.
1019 Default is 1e-6.
1021 Returns:
1022 numpy.ndarray: The regularized covariance matrix with the small value added to its diagonal.
1024 The function performs the following steps:
1025 1. Adds the specified `epsilon` value to the diagonal elements of the input covariance matrix.
1026 2. Returns the modified covariance matrix.
1027 """
1028 # Add a small value to the diagonal for numerical stability
1029 return cov_matrix + epsilon * np.eye(cov_matrix.shape[0])
1031 def klmvn(self, mu1, cov1, mu2, cov2, device):
1032 """
1033 Compute the Kullback-Leibler (KL) divergence between two multivariate normal distributions.
1035 This method calculates the KL divergence between two multivariate normal distributions defined by
1036 their mean vectors and covariance matrices. It assumes that the covariance matrices are diagonal.
1038 Args:
1039 mu1 (torch.Tensor): Mean vector of the first multivariate normal distribution.
1040 cov1 (torch.Tensor): Diagonal elements of the covariance matrix of the first distribution.
1041 mu2 (torch.Tensor): Mean vector of the second multivariate normal distribution.
1042 cov2 (torch.Tensor): Diagonal elements of the covariance matrix of the second distribution.
1043 device (torch.device): The device (CPU or GPU) on which to perform the computation.
1045 Returns:
1046 torch.Tensor: The computed KL divergence between the two distributions.
1048 The method performs the following steps:
1049 1. Constructs diagonal covariance matrices from the provided diagonal elements.
1050 2. Creates multivariate normal distributions using the mean vectors and covariance matrices.
1051 3. Computes the KL divergence between the two distributions.
1052 """
1053 # assume diagonal matrix
1054 p_cov = torch.eye(len(cov1), device=device) * cov1
1055 q_cov = torch.eye(len(cov2), device=device) * cov2
1057 # build pdf
1058 p = torch.distributions.multivariate_normal.MultivariateNormal(mu1, p_cov)
1059 q = torch.distributions.multivariate_normal.MultivariateNormal(mu2, q_cov)
1061 # compute KL Divergence
1062 kld = torch.distributions.kl_divergence(p, q)
1063 return kld
1065 def compute_image_distance(self, cfg: dict) -> float:
1066 """
1067 Compute the distance between image features from source and target datasets using a pre-trained model.
1069 This method calculates the distance between the statistical representations of image features extracted
1070 from two datasets. It uses a pre-trained model to extract features from specified layers and computes
1071 the Kullback-Leibler divergence between the distributions of these features.
1073 Args:
1074 cfg (dict): A configuration dictionary containing the following keys:
1075 - "MODEL": Dictionary with model-related configurations:
1076 - "device" (str): The device to run the model on (e.g., "cpu" or "cuda").
1077 - "arch" (str): The architecture of the model to use.
1078 - "n_layer_feature" (int): The layer number from which to extract features.
1079 - "DATA": Dictionary with data-related configurations:
1080 - "width" (int): The width of the images.
1081 - "height" (int): The height of the images.
1082 - "norm_mean" (list of float): Mean values for image normalization.
1083 - "norm_std" (list of float): Standard deviation values for image normalization.
1084 - "batch_size" (int): The batch size for data loading.
1085 - "source" (str): Path to the source folder containing images.
1086 - "target" (str): Path to the target folder containing images.
1088 Returns:
1089 float: The computed distance between the source and target image features.
1091 The method performs the following steps:
1092 1. Loads the model and sets it up on the specified device.
1093 2. Constructs data loaders for the source and target datasets with specified transformations.
1094 3. Extracts features from the specified layer of the model for both datasets.
1095 4. Calculates the mean and covariance of the features for both datasets.
1096 5. Regularizes the covariance matrices to ensure numerical stability.
1097 6. Computes the Kullback-Leibler divergence between the feature distributions.
1098 """
1100 device = cfg["MODEL"]["device"]
1101 model = cfg["MODEL"]["arch"]
1102 n_layer_feature = cfg["MODEL"]["n_layer_feature"]
1103 img_size = (cfg["DATA"]["width"], cfg["DATA"]["height"])
1104 norm_mean = cfg["DATA"]["norm_mean"]
1105 norm_std = cfg["DATA"]["norm_std"]
1106 batch_size = cfg["DATA"]["batch_size"]
1107 source_folder_path = cfg["DATA"]["source"]
1108 target_folder_path = cfg["DATA"]["target"]
1110 transform = generate_transform(img_size, norm_mean, norm_std)
1111 source_loader = construct_dataloader(source_folder_path, transform, batch_size)
1112 target_loader = construct_dataloader(target_folder_path, transform, batch_size)
1114 loaded_model = load_model(model, device)
1115 feature_extractor = extract_nth_layer_feature(loaded_model, n_layer_feature)
1117 # compute features as tensor
1118 source_features = compute_features(source_loader, feature_extractor, device)
1119 target_features = compute_features(target_loader, feature_extractor, device)
1121 # Calculate statistics for source features
1122 mu1, cov1 = self.calculate_statistics(source_features)
1123 cov1 = self.regularize_covariance(cov1)
1125 # Calculate statistics for target features
1126 mu2, cov2 = self.calculate_statistics(target_features)
1127 cov2 = self.regularize_covariance(cov2)
1129 dist = self.klmvn(mu1, cov1, mu2, cov2, device)
1130 return dist