Coverage for dqm/domain_gap/utils.py: 49%
88 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-05 14:00 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-05 14:00 +0000
1"""
2Domain Gap Metric Calculation Script
4This script defines a PyTorch module, DomainMeter, for domain adaptation using
5Central Moment Discrepancy (CMD) and Kullback-Leibler (KL) divergence.
6The script also includes custom dataset classes (RMSELoss and PandasDatasets)
7for loading images from a Pandas DataFrame and implementing a custom root mean
8square error (RMSE) loss.
10Authors:
11 Sabrina CHAOUCHE
12 Yoann RANDON
13 Faouzi ADJED
15Classes:
16 ModelConfiguration
17 RMSELoss
18 PandasDatasets
19 DomainMeter
22Dependencies:
23 torch
24 torchvision
25 mlflow
26 cmd (DomainMeter class from cmd module)
27 twe_logger (Custom logging module)
29Usage:
30Run the script with optional arguments '--cfg'
31for the JSON config file path and '--tsk'
32for the JSON task config file path.
33"""
35from typing import Tuple, List
36import json
37import os
38import torch
39from dqm.domain_gap import custom_datasets
40from torch.utils.data import DataLoader
41import torchvision
42from torchvision import transforms
43from torchvision.models.feature_extraction import create_feature_extractor
46def load_config(config_file):
47 """Load configuration from a JSON file."""
48 try:
49 with open(config_file, "r") as file:
50 config = json.load(file)
51 return config
52 except FileNotFoundError:
53 print(f"Error: The file '{config_file}' does not exist.")
54 exit(1)
55 except json.JSONDecodeError:
56 print(f"Error: Could not parse JSON in '{config_file}'.")
57 exit(1)
60def display_resume(cfg, dist, time_lapse):
61 # Display a summary of the computation
62 print("-" * 80)
63 print("Summary")
64 print(f"source: {cfg['DATA']['source']}")
65 print(f"target: {cfg['DATA']['target']}")
66 if "batch_size" in cfg["DATA"]:
67 print(f"batch size: {cfg['DATA']['batch_size']}")
68 if "device" in cfg["MODEL"]:
69 print(f"device: {cfg['MODEL']['device']}")
70 if "arch" in cfg["MODEL"]:
71 print(f"model: {cfg['MODEL']['arch']}")
72 if "archs" in cfg["MODEL"]:
73 print(f"models: {cfg['MODEL']['archs']}")
74 # Check if 'dist' is a tensor and convert to float if necessary
75 # distance = dist.item() if isinstance(dist, torch.Tensor) else dist
76 # ========================================================================
77 if dist is not None:
78 distance = dist.item() if isinstance(dist, torch.Tensor) else dist
79 else:
80 distance = None
81 # ========================================================================
82 print(f"distance: {distance}")
83 print(f"method : {cfg['METHOD']['name']}")
84 if "evaluator" in cfg["METHOD"]:
85 print(f"evaluator : {cfg['METHOD']['evaluator']}")
86 print(f"compute time: {round(time_lapse, 2)} seconds")
87 print("-" * 80)
90# Function to generate transform
91def generate_transform(
92 img_size: Tuple[int, int], norm_mean: List[float], norm_std: List[float]
93):
94 """
95 Generate transform to change data input into compatible model inputs
97 Args:
98 image_size (Tuple[int, int]): value to resize image
99 norm_mean (List[float]): normalization mean
100 norm_std (List[float]): normalization standard deviation
102 Returns:
103 transform: a function which apply multiple changes to data
104 """
105 transform = transforms.Compose(
106 [
107 transforms.Resize(img_size), # Resize image
108 transforms.ToTensor(), # Convert to Tensor
109 transforms.Normalize(mean=norm_mean, std=norm_std), # Normalize
110 ]
111 )
112 return transform
115def extract_nth_layer_feature(model, n):
116 # Get the model's named layers
117 layer_names = list(dict(model.named_modules()).keys())
119 if isinstance(n, list):
120 # Create a feature extractor with the nth layer
121 feature_extractor = create_feature_extractor(model, return_nodes=n)
123 return feature_extractor
125 # Handle integer input (layer index)
126 if isinstance(n, int):
127 # Convert negative index to positive (e.g., -1 means last layer)
128 if n < 0:
129 n = len(layer_names) + n
131 # Ensure the layer index is valid
132 if n >= len(layer_names) or n < 0:
133 raise ValueError(
134 f"Layer index {n} is out of range for the model with {len(layer_names)} layers."
135 )
137 # Extract the n-th layer's name
138 nth_layer_name = layer_names[n]
140 # Handle string input (layer name)
141 elif isinstance(n, str):
142 if n not in layer_names:
143 raise ValueError(
144 f"Layer name '{n}' not found in the model. Available layers are: {layer_names}"
145 )
146 nth_layer_name = n
148 else:
149 raise TypeError(
150 "The argument 'n' must be either an integer (layer index) or a string (layer name) or a list of string (layer names)."
151 )
153 # Create a feature extractor with the nth layer
154 feature_extractor = create_feature_extractor(
155 model, return_nodes={nth_layer_name: "features"}
156 )
158 return feature_extractor
161def load_model(model_str: str, device: str) -> torch.nn.Module:
162 """
163 Loads a model based on the input string.
165 If the string contains '.pt' or '.pth', tries to load a saved PyTorch model from a file.
166 If the string matches a known torchvision model (e.g., 'resnet18'), it loads the corresponding model.
168 Parameters:
169 model_str (str): The model string or file path.
171 Returns:
172 model (torch.nn.Module): The loaded PyTorch model.
173 """
175 # Check if the string is a path to a saved model file
176 if model_str.endswith((".pt", ".pth")):
177 # Verify the file exists
178 if os.path.exists(model_str):
179 # Attempt to load the model directly
180 try:
181 model = torch.load(model_str)
182 print(f"Loaded model from {model_str}")
183 return model
184 except Exception as e:
185 raise ValueError(f"Error loading model from file: {e}")
186 else:
187 raise FileNotFoundError(f"Model file '{model_str}' not found.")
189 else:
190 model = torchvision.models.get_model(model_str, pretrained=True).to(device)
191 return model
194def compute_features(dataloader, model, device):
195 """
196 Compute features from a model for images in the DataLoader batch by batch.
198 Args:
199 dataloader (DataLoader): DataLoader object to load images in batches.
200 model (torch.nn.Module): Pre-trained model to extract features.
201 device (torch.device): Device to run the model (e.g., CPU or GPU).
203 Returns:
204 torch.Tensor: A concatenated tensor of features for all images.
205 """
206 model.eval() # Set the model to evaluation mode
207 all_features = []
209 with torch.no_grad(): # Disable gradient calculation
210 for batch in dataloader:
211 batch = batch.to(device) # Move the batch to the target device (GPU/CPU)
212 features = model(batch)["features"].squeeze() # Extract features
213 all_features.append(features)
215 return torch.cat(all_features) # Concatenate features from all batches
218def construct_dataloader(folder_path: str, transform, batch_size: int):
219 """
220 Loads images from a folder and returns a DataLoader for batch-wise processing.
222 Args:
223 folder_path (str): Path to the folder containing images.
224 transform (transform): Transform object to fine-tune data for model input.
225 batch_size (int): Number of images per batch.
227 Returns:
228 DataLoader: A DataLoader object that yields batches of transformed images.
229 """
230 dataset = custom_datasets.ImagesFromFolderDataset(folder_path, transform)
231 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
232 return dataloader