Coverage for dqm/domain_gap/utils.py: 49%

88 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-05 14:00 +0000

1""" 

2Domain Gap Metric Calculation Script 

3 

4This script defines a PyTorch module, DomainMeter, for domain adaptation using 

5Central Moment Discrepancy (CMD) and Kullback-Leibler (KL) divergence. 

6The script also includes custom dataset classes (RMSELoss and PandasDatasets) 

7for loading images from a Pandas DataFrame and implementing a custom root mean  

8square error (RMSE) loss. 

9 

10Authors: 

11 Sabrina CHAOUCHE 

12 Yoann RANDON 

13 Faouzi ADJED 

14 

15Classes: 

16 ModelConfiguration 

17 RMSELoss 

18 PandasDatasets 

19 DomainMeter 

20 

21 

22Dependencies: 

23 torch 

24 torchvision 

25 mlflow 

26 cmd (DomainMeter class from cmd module) 

27 twe_logger (Custom logging module) 

28 

29Usage: 

30Run the script with optional arguments '--cfg' 

31for the JSON config file path and '--tsk' 

32for the JSON task config file path. 

33""" 

34 

35from typing import Tuple, List 

36import json 

37import os 

38import torch 

39from dqm.domain_gap import custom_datasets 

40from torch.utils.data import DataLoader 

41import torchvision 

42from torchvision import transforms 

43from torchvision.models.feature_extraction import create_feature_extractor 

44 

45 

46def load_config(config_file): 

47 """Load configuration from a JSON file.""" 

48 try: 

49 with open(config_file, "r") as file: 

50 config = json.load(file) 

51 return config 

52 except FileNotFoundError: 

53 print(f"Error: The file '{config_file}' does not exist.") 

54 exit(1) 

55 except json.JSONDecodeError: 

56 print(f"Error: Could not parse JSON in '{config_file}'.") 

57 exit(1) 

58 

59 

60def display_resume(cfg, dist, time_lapse): 

61 # Display a summary of the computation 

62 print("-" * 80) 

63 print("Summary") 

64 print(f"source: {cfg['DATA']['source']}") 

65 print(f"target: {cfg['DATA']['target']}") 

66 if "batch_size" in cfg["DATA"]: 

67 print(f"batch size: {cfg['DATA']['batch_size']}") 

68 if "device" in cfg["MODEL"]: 

69 print(f"device: {cfg['MODEL']['device']}") 

70 if "arch" in cfg["MODEL"]: 

71 print(f"model: {cfg['MODEL']['arch']}") 

72 if "archs" in cfg["MODEL"]: 

73 print(f"models: {cfg['MODEL']['archs']}") 

74 # Check if 'dist' is a tensor and convert to float if necessary 

75 # distance = dist.item() if isinstance(dist, torch.Tensor) else dist 

76 # ======================================================================== 

77 if dist is not None: 

78 distance = dist.item() if isinstance(dist, torch.Tensor) else dist 

79 else: 

80 distance = None 

81 # ======================================================================== 

82 print(f"distance: {distance}") 

83 print(f"method : {cfg['METHOD']['name']}") 

84 if "evaluator" in cfg["METHOD"]: 

85 print(f"evaluator : {cfg['METHOD']['evaluator']}") 

86 print(f"compute time: {round(time_lapse, 2)} seconds") 

87 print("-" * 80) 

88 

89 

90# Function to generate transform 

91def generate_transform( 

92 img_size: Tuple[int, int], norm_mean: List[float], norm_std: List[float] 

93): 

94 """ 

95 Generate transform to change data input into compatible model inputs 

96 

97 Args: 

98 image_size (Tuple[int, int]): value to resize image 

99 norm_mean (List[float]): normalization mean 

100 norm_std (List[float]): normalization standard deviation 

101 

102 Returns: 

103 transform: a function which apply multiple changes to data 

104 """ 

105 transform = transforms.Compose( 

106 [ 

107 transforms.Resize(img_size), # Resize image 

108 transforms.ToTensor(), # Convert to Tensor 

109 transforms.Normalize(mean=norm_mean, std=norm_std), # Normalize 

110 ] 

111 ) 

112 return transform 

113 

114 

115def extract_nth_layer_feature(model, n): 

116 # Get the model's named layers 

117 layer_names = list(dict(model.named_modules()).keys()) 

118 

119 if isinstance(n, list): 

120 # Create a feature extractor with the nth layer 

121 feature_extractor = create_feature_extractor(model, return_nodes=n) 

122 

123 return feature_extractor 

124 

125 # Handle integer input (layer index) 

126 if isinstance(n, int): 

127 # Convert negative index to positive (e.g., -1 means last layer) 

128 if n < 0: 

129 n = len(layer_names) + n 

130 

131 # Ensure the layer index is valid 

132 if n >= len(layer_names) or n < 0: 

133 raise ValueError( 

134 f"Layer index {n} is out of range for the model with {len(layer_names)} layers." 

135 ) 

136 

137 # Extract the n-th layer's name 

138 nth_layer_name = layer_names[n] 

139 

140 # Handle string input (layer name) 

141 elif isinstance(n, str): 

142 if n not in layer_names: 

143 raise ValueError( 

144 f"Layer name '{n}' not found in the model. Available layers are: {layer_names}" 

145 ) 

146 nth_layer_name = n 

147 

148 else: 

149 raise TypeError( 

150 "The argument 'n' must be either an integer (layer index) or a string (layer name) or a list of string (layer names)." 

151 ) 

152 

153 # Create a feature extractor with the nth layer 

154 feature_extractor = create_feature_extractor( 

155 model, return_nodes={nth_layer_name: "features"} 

156 ) 

157 

158 return feature_extractor 

159 

160 

161def load_model(model_str: str, device: str) -> torch.nn.Module: 

162 """ 

163 Loads a model based on the input string. 

164 

165 If the string contains '.pt' or '.pth', tries to load a saved PyTorch model from a file. 

166 If the string matches a known torchvision model (e.g., 'resnet18'), it loads the corresponding model. 

167 

168 Parameters: 

169 model_str (str): The model string or file path. 

170 

171 Returns: 

172 model (torch.nn.Module): The loaded PyTorch model. 

173 """ 

174 

175 # Check if the string is a path to a saved model file 

176 if model_str.endswith((".pt", ".pth")): 

177 # Verify the file exists 

178 if os.path.exists(model_str): 

179 # Attempt to load the model directly 

180 try: 

181 model = torch.load(model_str) 

182 print(f"Loaded model from {model_str}") 

183 return model 

184 except Exception as e: 

185 raise ValueError(f"Error loading model from file: {e}") 

186 else: 

187 raise FileNotFoundError(f"Model file '{model_str}' not found.") 

188 

189 else: 

190 model = torchvision.models.get_model(model_str, pretrained=True).to(device) 

191 return model 

192 

193 

194def compute_features(dataloader, model, device): 

195 """ 

196 Compute features from a model for images in the DataLoader batch by batch. 

197 

198 Args: 

199 dataloader (DataLoader): DataLoader object to load images in batches. 

200 model (torch.nn.Module): Pre-trained model to extract features. 

201 device (torch.device): Device to run the model (e.g., CPU or GPU). 

202 

203 Returns: 

204 torch.Tensor: A concatenated tensor of features for all images. 

205 """ 

206 model.eval() # Set the model to evaluation mode 

207 all_features = [] 

208 

209 with torch.no_grad(): # Disable gradient calculation 

210 for batch in dataloader: 

211 batch = batch.to(device) # Move the batch to the target device (GPU/CPU) 

212 features = model(batch)["features"].squeeze() # Extract features 

213 all_features.append(features) 

214 

215 return torch.cat(all_features) # Concatenate features from all batches 

216 

217 

218def construct_dataloader(folder_path: str, transform, batch_size: int): 

219 """ 

220 Loads images from a folder and returns a DataLoader for batch-wise processing. 

221 

222 Args: 

223 folder_path (str): Path to the folder containing images. 

224 transform (transform): Transform object to fine-tune data for model input. 

225 batch_size (int): Number of images per batch. 

226 

227 Returns: 

228 DataLoader: A DataLoader object that yields batches of transformed images. 

229 """ 

230 dataset = custom_datasets.ImagesFromFolderDataset(folder_path, transform) 

231 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 

232 return dataloader