"""
Domain Gap Metric Calculation Script
This script defines a PyTorch module, DomainMeter, for domain adaptation using
Central Moment Discrepancy (CMD) and Kullback-Leibler (KL) divergence.
The script also includes custom dataset classes (RMSELoss and PandasDatasets)
for loading images from a Pandas DataFrame and implementing a custom root mean
square error (RMSE) loss.
Authors:
Sabrina CHAOUCHE
Yoann RANDON
Faouzi ADJED
Classes:
ModelConfiguration
RMSELoss
PandasDatasets
DomainMeter
Dependencies:
torch
torchvision
mlflow
cmd (DomainMeter class from cmd module)
twe_logger (Custom logging module)
Usage:
Run the script with optional arguments '--cfg'
for the JSON config file path and '--tsk'
for the JSON task config file path.
"""
from typing import Tuple, List
import json
import os
import torch
from dqm.domain_gap import custom_datasets
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.models.feature_extraction import create_feature_extractor
[docs]
def load_config(config_file):
"""Load configuration from a JSON file."""
try:
with open(config_file, "r") as file:
config = json.load(file)
return config
except FileNotFoundError:
print(f"Error: The file '{config_file}' does not exist.")
exit(1)
except json.JSONDecodeError:
print(f"Error: Could not parse JSON in '{config_file}'.")
exit(1)
[docs]
def display_resume(cfg, dist, time_lapse):
# Display a summary of the computation
print("-" * 80)
print("Summary")
print(f"source: {cfg['DATA']['source']}")
print(f"target: {cfg['DATA']['target']}")
if "batch_size" in cfg["DATA"]:
print(f"batch size: {cfg['DATA']['batch_size']}")
if "device" in cfg["MODEL"]:
print(f"device: {cfg['MODEL']['device']}")
if "arch" in cfg["MODEL"]:
print(f"model: {cfg['MODEL']['arch']}")
if "archs" in cfg["MODEL"]:
print(f"models: {cfg['MODEL']['archs']}")
# Check if 'dist' is a tensor and convert to float if necessary
# distance = dist.item() if isinstance(dist, torch.Tensor) else dist
# ========================================================================
if dist is not None:
distance = dist.item() if isinstance(dist, torch.Tensor) else dist
else:
distance = None
# ========================================================================
print(f"distance: {distance}")
print(f"method : {cfg['METHOD']['name']}")
if "evaluator" in cfg["METHOD"]:
print(f"evaluator : {cfg['METHOD']['evaluator']}")
print(f"compute time: {round(time_lapse, 2)} seconds")
print("-" * 80)
# Function to generate transform
[docs]
def load_model(model_str: str, device: str) -> torch.nn.Module:
"""
Loads a model based on the input string.
If the string contains '.pt' or '.pth', tries to load a saved PyTorch model from a file.
If the string matches a known torchvision model (e.g., 'resnet18'), it loads the corresponding model.
Parameters:
model_str (str): The model string or file path.
Returns:
model (torch.nn.Module): The loaded PyTorch model.
"""
# Check if the string is a path to a saved model file
if model_str.endswith((".pt", ".pth")):
# Verify the file exists
if os.path.exists(model_str):
# Attempt to load the model directly
try:
model = torch.load(model_str)
print(f"Loaded model from {model_str}")
return model
except Exception as e:
raise ValueError(f"Error loading model from file: {e}")
else:
raise FileNotFoundError(f"Model file '{model_str}' not found.")
else:
model = torchvision.models.get_model(model_str, pretrained=True).to(device)
return model
[docs]
def compute_features(dataloader, model, device):
"""
Compute features from a model for images in the DataLoader batch by batch.
Args:
dataloader (DataLoader): DataLoader object to load images in batches.
model (torch.nn.Module): Pre-trained model to extract features.
device (torch.device): Device to run the model (e.g., CPU or GPU).
Returns:
torch.Tensor: A concatenated tensor of features for all images.
"""
model.eval() # Set the model to evaluation mode
all_features = []
with torch.no_grad(): # Disable gradient calculation
for batch in dataloader:
batch = batch.to(device) # Move the batch to the target device (GPU/CPU)
features = model(batch)["features"].squeeze() # Extract features
all_features.append(features)
return torch.cat(all_features) # Concatenate features from all batches
[docs]
def construct_dataloader(folder_path: str, transform, batch_size: int):
"""
Loads images from a folder and returns a DataLoader for batch-wise processing.
Args:
folder_path (str): Path to the folder containing images.
transform (transform): Transform object to fine-tune data for model input.
batch_size (int): Number of images per batch.
Returns:
DataLoader: A DataLoader object that yields batches of transformed images.
"""
dataset = custom_datasets.ImagesFromFolderDataset(folder_path, transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return dataloader