#
# Forked and modified from https://github.com/DariaKern/Chicks4FreeID/blob/main/run_baseline.py
#
import argparse
from itertools import chain, islice
import time
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from tqdm import tqdm
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
# General torch imports
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch import Tensor
from torch.nn import (
CrossEntropyLoss,
Identity,
Linear,
Module
)
from torch.optim import SGD, Optimizer, AdamW
from torch.utils.data import DataLoader
# For writing the result table to markdown
import pandas as pd
from PIL import Image
# For fully supervised baselines
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchvision.models.vision_transformer import VisionTransformer
# For calculating the metrics
from torchmetrics.classification import MulticlassAveragePrecision, MulticlassAccuracy
from datasets import Dataset, load_dataset
# For the training loop
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import DeviceStatsMonitor, Callback
from pytorch_lightning.loggers import TensorBoardLogger
from lightly.utils.dist import print_rank_zero
# Some fancy stuff for optimizing vision transformers
from lightly.utils.scheduler import CosineWarmupScheduler
from lightly.models.utils import get_weight_decay_parameters
import timm
@dataclass
[docs]
class Config:
[docs]
batch_size_per_device: int = 16
[docs]
log_dir: Path = Path("baseline_logs")
[docs]
checkpoint_path: Optional[Path] = None
[docs]
skip_embedding_training: bool = False
[docs]
skip_knn_eval: bool = False
[docs]
skip_linear_eval: bool = False
[docs]
methods: Optional[List[str]] = None
[docs]
accelerator: str = "auto"
[docs]
precision: str = "16-mixed"
[docs]
check_val_every_n_epoch: int = 5
[docs]
profile= None # "pytorch"
[docs]
experiment_result_metrics: Optional[List[str]] = field(default_factory=lambda: [])
[docs]
baseline_id: Optional[str] = None
[docs]
aggregate_metrics: bool = True
[docs]
def clear_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif hasattr(torch, 'mps') and torch.backends.mps.is_available():
# MPS doesn't have an explicit empty_cache function, but you can set up custom logic if needed
torch.mps.empty_cache() # For now, do nothing as MPS doesn't provide an empty_cache method
else:
# CPU - no need to empty cache
pass
[docs]
def timing_decorator(func):
def wrapper(*args, **kwargs):
start_time = time.perf_counter()
result = func(*args, **kwargs)
end_time = time.perf_counter()
duration_seconds = end_time - start_time
duration_timedelta = timedelta(seconds=duration_seconds)
print(f"Duration: {duration_timedelta}")
return result
return wrapper
[docs]
def knn_predict(
feature: Tensor,
feature_bank: Tensor,
feature_labels: Tensor,
num_classes: int,
knn_k: int = 200,
knn_t: float = 0.1,
) -> Tensor:
"""
[Modified version from lightly, which returns the scores instead of the predictions]
Run kNN predictions on features based on a feature bank
This method is commonly used to monitor performance of self-supervised
learning methods.
The default parameters are the ones
used in https://arxiv.org/pdf/1805.01978v1.pdf.
# code for kNN prediction from here:
# https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb
Args:
feature:
Tensor with shape (B, D) for which you want predictions.
feature_bank:
Tensor of shape (D, N) of a database of features used for kNN.
feature_labels:
Labels with shape (N,) for the features in the feature_bank.
num_classes:
Number of classes (e.g. `10` for CIFAR-10).
knn_k:
Number of k neighbors used for kNN.
knn_t:
Temperature parameter to reweights similarities for kNN.
Returns:
A tensor containing the kNN scores
Examples:
>>> images, targets, _ = batch
>>> feature = backbone(images).squeeze()
>>> # we recommend to normalize the features
>>> feature = F.normalize(feature, dim=1)
>>> pred_labels = knn_predict(
>>> feature,
>>> feature_bank,
>>> targets_bank,
>>> num_classes=10,
>>> )
"""
# compute cos similarity between each feature vector and feature bank ---> (B, N)
sim_matrix = torch.mm(feature, feature_bank)
# (B, K)
sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
# (B, K)
sim_labels = torch.gather(
feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices
)
# we do a reweighting of the similarities
sim_weight = (sim_weight / knn_t).exp()
# counts for each class
one_hot_label = torch.zeros(
feature.size(0) * knn_k, num_classes, device=sim_labels.device
)
# (B*K, C)
one_hot_label = one_hot_label.scatter(
dim=-1, index=sim_labels.view(-1, 1), value=1.0
)
# weighted score ---> (B, C)
pred_scores = torch.sum(
one_hot_label.view(feature.size(0), -1, num_classes)
* sim_weight.unsqueeze(dim=-1),
dim=1,
)
# pred_labels = pred_scores.argsort(dim=-1, descending=True)
return pred_scores
[docs]
class MetricCallback(Callback):
"""A [Lightly] Callback that collects log metrics from the LightningModule and stores them after
every epoch.
Attributes:
train_metrics:
Dictionary that stores the last logged metrics after every train epoch.
val_metrics:
Dictionary that stores the last logged metrics after every validation epoch.
"""
def __init__(self) -> None:
super().__init__()
[docs]
self.train_metrics: Dict[str, List[float]] = {}
[docs]
self.val_metrics: Dict[str, List[float]] = {}
[docs]
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
if not trainer.sanity_checking:
self._append_metrics(metrics_dict=self.train_metrics, trainer=trainer)
[docs]
def on_validation_end(
self, trainer: Trainer, pl_module: LightningModule
) -> None:
if not trainer.sanity_checking:
self._append_metrics(metrics_dict=self.val_metrics, trainer=trainer)
def _append_metrics(
self, metrics_dict: Dict[str, List[float]], trainer: Trainer
) -> None:
for name, value in trainer.callback_metrics.items():
if isinstance(value, Tensor) and value.numel() != 1:
# Skip non-scalar tensors.
print("skipping metric", name, value)
continue
metrics_dict.setdefault(name, []).append(float(value))
[docs]
class MetricModule(LightningModule):
def __init__(self, num_classes: int):
super().__init__()
[docs]
self.num_classes = num_classes
if self.enable_logging:
self.train_map = MulticlassAveragePrecision(num_classes=num_classes)
self.val_map = MulticlassAveragePrecision(num_classes=num_classes)
self.train_top1 = MulticlassAccuracy(num_classes=num_classes, top_k=1)
self.train_top5 = MulticlassAccuracy(num_classes=num_classes, top_k=5)
self.val_top1 = MulticlassAccuracy(num_classes=num_classes, top_k=1)
self.val_top5 = MulticlassAccuracy(num_classes=num_classes, top_k=5)
[docs]
def update_train_metrics(self, pred_scores: Tensor, targets: Tensor):
if self.enable_logging:
self.train_map(pred_scores, targets)
self.train_top1(pred_scores, targets)
self.train_top5(pred_scores, targets)
[docs]
def update_val_metrics(self, pred_scores: Tensor, targets: Tensor):
if self.enable_logging:
self.val_map(pred_scores, targets)
self.val_top1(pred_scores, targets)
self.val_top5(pred_scores, targets)
[docs]
def on_train_epoch_end(self):
super().on_train_epoch_end()
if self.enable_logging and self.train_map.update_called:
self.log("train_mAP", self.train_map, prog_bar=True)
self.log("train_top1", self.train_top1, prog_bar=True)
self.log("train_top5", self.train_top5, prog_bar=True)
[docs]
def on_validation_epoch_end(self):
super().on_validation_epoch_end()
if self.enable_logging and self.val_map.update_called:
self.log("val_mAP", self.val_map, prog_bar=True)
self.log("val_top1", self.val_top1, prog_bar=True)
self.log("val_top5", self.val_top5, prog_bar=True)
[docs]
class KNNClassifier(MetricModule):
"""
A lightly KNN Classifier modified to log mean average precision metric.
Also it now inherits from MetricModule and the logging logic has changed.
"""
def __init__(
self,
model: Module,
num_classes: int,
knn_k: int = 200,
knn_t: float = 0.1,
feature_dtype: torch.dtype = torch.float32,
normalize: bool = True,
):
"""KNN classifier to compute baseline performance of embedding models.
Settings based on InstDisc [0]. Code adapted from MoCo [1].
- [0]: InstDisc, 2018, https://arxiv.org/pdf/1805.01978v1.pdf
- [1]: MoCo, 2019, https://github.com/facebookresearch/moco
Args:
model:
Model used for feature extraction. Must define a forward(images) method
that returns a feature tensor.
num_classes:
Number of classes in the dataset.
knn_k:
Number of neighbors used for KNN search.
knn_t:
Temperature parameter to reweights similarities.
feature_dtype:
Torch data type of the features used for KNN search. Reduce to float16
for memory-efficient KNN search.
normalize:
Whether to normalize the features for KNN search.
Examples:
>>> from pytorch_lightning import Trainer
>>> from torch import nn
>>> import torchvision
>>> from lightly.models import LinearClassifier
>>> from lightly.modles.modules import SimCLRProjectionHead
>>>
>>> class SimCLR(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.backbone = torchvision.models.resnet18()
>>> self.backbone.fc = nn.Identity() # Ignore classification layer
>>> self.projection_head = SimCLRProjectionHead(512, 512, 128)
>>>
>>> def forward(self, x):
>>> # Forward must return image features.
>>> features = self.backbone(x).flatten(start_dim=1)
>>> return features
>>>
>>> # Initialize a model.
>>> model = SimCLR()
>>>
>>>
>>> # Wrap it with a KNNClassifier.
>>> knn_classifier = KNNClassifier(resnet, num_classes=10)
>>>
>>> # Extract features and evaluate.
>>> trainer = Trainer(max_epochs=1)
>>> trainer.fit(knn_classifier, train_dataloder, val_dataloader)
"""
super().__init__(num_classes=num_classes)
self.save_hyperparameters(
{
"num_classes": num_classes,
"knn_k": knn_k,
"knn_t": knn_t,
"feature_dtype": str(feature_dtype),
}
)
self.model.eval()
[docs]
self.num_classes = num_classes
[docs]
self.feature_dtype = feature_dtype
[docs]
self.normalize = normalize
self._train_features = []
self._train_targets = []
self._train_features_tensor: Optional[Tensor] = None
self._train_targets_tensor: Optional[Tensor] = None
@torch.no_grad()
[docs]
def training_step(self, batch, batch_idx) -> None:
images, targets = batch[0], batch[1]
features = self.model.forward(images).flatten(start_dim=1)
if self.normalize:
features = F.normalize(features, dim=1)
features = features.to(self.feature_dtype)
self._train_features.append(features.detach().cpu())
self._train_targets.append(targets.detach().cpu())
[docs]
def validation_step(self, batch, batch_idx) -> None:
if self._train_features_tensor is None or self._train_targets_tensor is None:
return
images, targets = batch[0], batch[1]
with torch.no_grad():
features = self.model.forward(images).flatten(start_dim=1)
if self.normalize:
features = F.normalize(features, dim=1)
features = features.to(self.feature_dtype)
pred_scores = knn_predict(
feature=features,
feature_bank=self._train_features_tensor,
feature_labels=self._train_targets_tensor,
num_classes=self.num_classes,
knn_k=self.knn_k,
knn_t=self.knn_t,
)
self.update_val_metrics(pred_scores, targets)
del images, targets, features, pred_scores
[docs]
def on_validation_epoch_start(self) -> None:
if self._train_features and self._train_targets:
# Features and targets have size (world_size, batch_size, dim) and
# (world_size, batch_size) after gather. For non-distributed training,
# features and targets have size (batch_size, dim) and (batch_size,).
features = self.all_gather(torch.cat(self._train_features, dim=0))
self._train_features = []
targets = self.all_gather(torch.cat(self._train_targets, dim=0))
self._train_targets = []
# Reshape to (dim, world_size * batch_size)
features = features.flatten(end_dim=-2).t().contiguous()
self._train_features_tensor = features.to(self.device)
# Reshape to (world_size * batch_size,)
targets = targets.flatten().t().contiguous()
self._train_targets_tensor = targets.to(self.device)
[docs]
def on_train_epoch_start(self) -> None:
# Set model to eval mode to disable norm layer updates.
self.model.eval()
# Reset features and targets.
self._train_features = []
self._train_targets = []
self._train_features_tensor = None
self._train_targets_tensor = None
[docs]
def on_validation_end(self) -> None:
super().on_validation_end()
# Clear the cache after each validation epoch to prevent memory leaks.
del self._train_features_tensor
del self._train_targets_tensor
del self._train_features
del self._train_targets
[docs]
class LinearClassifier(MetricModule):
"""
A lightly Linear Classifier, modified to log the mean average precision
Also, the logging logic has changed + it now inherits from MetricModule
Further, the LinearClassifier now also allows the instantiation of fully supervised models.
"""
def __init__(
self,
model: Module,
batch_size_per_device: int,
feature_dim: int,
num_classes: int,
freeze_model: bool = False,
enable_logging: bool = True,
) -> None:
"""Linear classifier for computing baseline performance.
Settings based on SimCLR [0].
- [0]: https://arxiv.org/abs/2002.05709
Args:
model:
Model used for feature extraction. Must define a forward(images) method
that returns a feature tensor.
batch_size_per_device:
Batch size per device.
feature_dim:
Dimension of features returned by forward method of model.
num_classes:
Number of classes in the dataset.
freeze_model:
If True, the model is frozen and only the classification head is
trained. This corresponds to the linear eval setting. Set to False for
finetuning.
Examples:
>>> from pytorch_lightning import Trainer
>>> from torch import nn
>>> import torchvision
>>> from lightly.models import LinearClassifier
>>> from lightly.modles.modules import SimCLRProjectionHead
>>>
>>> class SimCLR(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.backbone = torchvision.models.resnet18()
>>> self.backbone.fc = nn.Identity() # Ignore classification layer
>>> self.projection_head = SimCLRProjectionHead(512, 512, 128)
>>>
>>> def forward(self, x):
>>> # Forward must return image features.
>>> features = self.backbone(x).flatten(start_dim=1)
>>> return features
>>>
>>> # Initialize a model.
>>> model = SimCLR()
>>>
>>> # Wrap it with a LinearClassifier.
>>> linear_classifier = LinearClassifier(
>>> model,
>>> batch_size=256,
>>> num_classes=10,
>>> freeze_model=True, # linear evaluation, set to False for finetune
>>> )
>>>
>>> # Train the linear classifier.
>>> trainer = Trainer(max_epochs=90)
>>> trainer.fit(linear_classifier, train_dataloader, val_dataloader)
"""
super().__init__(num_classes=num_classes)
self.save_hyperparameters(ignore="model")
[docs]
self.batch_size_per_device = batch_size_per_device
[docs]
self.feature_dim = feature_dim
[docs]
self.num_classes = num_classes
[docs]
self.freeze_model = freeze_model
[docs]
self.enable_logging = enable_logging
[docs]
self.classification_head = self.build_classification_head(
feature_dim=feature_dim, num_classes=num_classes
)
[docs]
self.criterion = self.build_critierion()
[docs]
def build_classification_head(self, feature_dim: int, num_classes: int):
return Linear(feature_dim, num_classes)
[docs]
def build_critierion(self):
return CrossEntropyLoss()
[docs]
def forward(self, images: Tensor) -> Tensor:
with torch.set_grad_enabled(not self.freeze_model):
features = self.model(images).flatten(start_dim=1)
output = self.classification_head(features)
del images, features
return output
[docs]
def training_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> Tensor:
images, targets = batch[0], batch[1]
predictions = self.forward(images)
loss = self.criterion(predictions, targets)
#if self.enable_logging:
self.log("train_loss", loss, prog_bar=True, sync_dist=True, batch_size=images.size(0))
self.update_train_metrics(predictions, targets)
# Clear unnecessary variables
del batch, images, targets, predictions
return loss # Return the loss
@torch.no_grad()
[docs]
def validation_step(self, batch: Tuple[Tensor, ...], batch_idx: int) -> Tensor:
images, targets = batch[0], batch[1]
predictions = self.forward(images)
loss = self.criterion(predictions, targets)
#if self.enable_logging:
self.log("val_loss", loss, prog_bar=True, sync_dist=True, batch_size=images.size(0))
self.update_val_metrics(predictions, targets)
# Clear unnecessary variables
del batch, images, targets, predictions, loss
[docs]
def on_train_epoch_start(self) -> None:
if self.freeze_model:
# Set model to eval mode to disable norm layer updates.
self.model.eval()
[docs]
class ViT_B_16Classifier(LinearClassifier):
"""
A fully supervised model that uses the Vision Transformer model from the torchvision library
The model uses the standard ViT_B_16 model and cross entropy (as in inherited from LinearClassifier) for training
"""
[docs]
model: VisionTransformer
def __init__(
self,
batch_size_per_device,
feature_dim,
num_classes,
) -> None:
super().__init__(
model=None,
feature_dim=feature_dim,
num_classes=num_classes,
batch_size_per_device=batch_size_per_device,
freeze_model=False,
)
self.model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)
# Use the Identity head to get to the features
self.model.heads = Identity()
[docs]
class ViTEmbedding(LightningModule):
"""
This module is used to extract features from the Vision Transformer Classifier in eval mode
"""
def __init__(self, model: ViT_B_16Classifier) -> None:
super().__init__()
self.save_hyperparameters(ignore="model")
[docs]
self.model = model.model
self.model.eval()
[docs]
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
return self.model(x)
[docs]
class SwinL384(LinearClassifier):
# Disable logging during embedding training because the ArcFaceLoss takes an embedding isntead of class scores.
# Without class scores available during training, the logging would fail.
[docs]
enable_logging: bool = False
def __init__(
self,
batch_size_per_device,
feature_dim,
num_classes,
) -> None:
super().__init__(
model=timm.create_model('swin_large_patch4_window12_384', num_classes=0, pretrained=True),
feature_dim=feature_dim,
num_classes=num_classes,
batch_size_per_device=batch_size_per_device,
freeze_model=False,
enable_logging=self.enable_logging
)
[docs]
def build_critierion(self):
return ArcFaceLoss(num_classes=self.num_classes, embedding_size=self.feature_dim, margin=0.5, scale=64)
[docs]
def forward(self, x: Tensor) -> Tensor:
return self.model(x)
[docs]
class BaselineMethod():
"""
An abstract class that holds common code of our baseline methods.
The class runs:
- embedding training
- kNN evaluation
- linear evaluation
Reported metrics are:
- Top-1 accuracy
- Top-5 accuracy
- Mean Average Precision (mAP)
The baseline method can be configured by inheriting from this class and overriding specific
attributes or functions as well as passing a config object.
"""
[docs]
embedding_train_dataset: Iterable[Tuple[Tensor, Tensor]]
[docs]
embedding_val_dataset: Iterable[Tuple[Tensor, Tensor]]
[docs]
linear_train_dataset: Iterable[Tuple[Tensor, Tensor]]
[docs]
linear_val_dataset: Iterable[Tuple[Tensor, Tensor]]
[docs]
knn_val_dataset: Iterable[Tuple[Tensor, Tensor]]
[docs]
knn_train_dataset: Iterable[Tuple[Tensor, Tensor]]
[docs]
method_specific_augmentation = T.Compose([])
[docs]
cfg: Config # A config class specifying the hyperparameters
[docs]
model: Module # The model used for embedding training
[docs]
feature_dim: int = 2048 # Important for the linear evaluation
[docs]
skip_embedding_training: bool = False # Overwrites self.cfg.skip_embedding_training
_name: str = "" # Name property. Will return the class name if not set
def __init__(self, cfg: Config):
self.cfg = cfg
[docs]
self.method_dir = self.cfg.log_dir / self.cfg.baseline_id / self.name
self.method_dir = self.method_dir.resolve()
# Transform for pretaining of the embedding
# Transform for linear eval training and kkn training
# Transform for all validation datasets
self.embedding_train_dataset = VisionDataset(
train=True,
transform=self.embedding_train_transform,
test_run=self.cfg.test_run,
)
self.knn_train_dataset = self.linear_train_dataset = VisionDataset(
train=True,
transform=self.eval_train_transform,
)
self.knn_val_dataset = self.linear_val_dataset = self.embedding_val_dataset = VisionDataset(
train=False,
transform=self.val_transform,
)
@property
[docs]
def name(self) -> str:
return self._name or self.__class__.__name__
@timing_decorator
[docs]
def run_baseline_method(self):
print_rank_zero(f"## Starting {self.name}... ")
loaded_checkpoint = False
if self.cfg.checkpoint_path:
if self.name not in str(self.cfg.checkpoint_path):
print_rank_zero(f"Not loading checkpoint for {self.name} because checkpoint path does not contain '{self.name}'.")
else:
self.model.load_state_dict(torch.load(self.cfg.checkpoint_path)["state_dict"])
loaded_checkpoint = True
print_rank_zero(f"Loaded checkpoint for {self.name} from {self.cfg.checkpoint_path}")
skip_embedding_training = (
self.cfg.skip_embedding_training
or self.cfg.epochs == 0
or loaded_checkpoint
or self.skip_embedding_training
)
if skip_embedding_training:
print_rank_zero("Skipping embedding training")
else:
self.embedding_training()
if self.cfg.skip_knn_eval:
print_rank_zero("Skipping KNN evaluation")
else:
self.knn_eval()
if self.cfg.skip_linear_eval:
print_rank_zero("Skipping linear evaluation")
else:
self.linear_eval()
del self.model
clear_cache()
print_rank_zero(f"## Finished {self.name}")
[docs]
def get_embedding_model(self) -> Module:
"Must return a model that returns features on forward pass."
return self.model
[docs]
def knn_eval(self,) -> None:
"""Runs KNN evaluation on the given model.
Parameters follow InstDisc [0] settings.
The most important settings are:
- Num nearest neighbors: 200
- Temperature: 0.1
References:
- [0]: InstDict, 2018, https://arxiv.org/abs/1805.01978
"""
print_rank_zero(f"### Running {self.name} KNN evaluation...")
self.train(
classifier = KNNClassifier(
model=self.get_embedding_model(),
num_classes=self.cfg.num_classes,
feature_dtype=torch.float16,
),
epochs = 1,
train_dataset = self.knn_train_dataset,
val_dataset = self.knn_val_dataset,
log_name="knn_eval",
)
[docs]
def linear_eval(self,) -> None:
"""Runs a linear evaluation on the given model.
Parameters follow SimCLR [0] settings.
The most important settings are:
- Backbone: Frozen
- Epochs: 90
- Optimizer: SGD
- Base Learning Rate: 0.1
- Momentum: 0.9
- Weight Decay: 0.0
- LR Schedule: Cosine without warmup
References:
- [0]: SimCLR, 2020, https://arxiv.org/abs/2002.05709
"""
print_rank_zero(f"### Running {self.name} linear evaluation... ")
self.train(
classifier = LinearClassifier(
model=self.get_embedding_model(),
batch_size_per_device=self.cfg.batch_size_per_device,
feature_dim=self.feature_dim,
num_classes=self.cfg.num_classes,
freeze_model=True,
),
epochs = 90,
train_dataset = self.linear_train_dataset,
val_dataset = self.linear_val_dataset,
log_name="linear_eval",
)
[docs]
def embedding_training(self):
print_rank_zero(f"### Training {self.name} embedding model... ")
self.train(
classifier = self.model,
epochs = self.cfg.epochs,
train_dataset = self.embedding_train_dataset,
val_dataset = self.embedding_val_dataset,
log_name="embedding_training",
)
@timing_decorator
[docs]
def train(self, classifier, epochs, train_dataset, val_dataset, log_name):
train_dataloader = DataLoader(
train_dataset,
batch_size=self.cfg.batch_size_per_device,
shuffle=True,
num_workers=self.cfg.num_workers,
drop_last=True,
persistent_workers=True
)
val_dataloader = DataLoader(
val_dataset,
batch_size=self.cfg.batch_size_per_device,
shuffle=False,
num_workers=self.cfg.num_workers,
persistent_workers=True,
)
metric_callback = MetricCallback()
trainer = Trainer(
max_epochs=epochs if not self.cfg.test_run else 1,
accelerator=self.cfg.accelerator,
devices=self.cfg.devices,
logger=TensorBoardLogger(save_dir=str(self.method_dir), name=log_name),
callbacks=[
DeviceStatsMonitor(),
metric_callback,
],
num_sanity_val_steps=0,
log_every_n_steps=1,
precision=self.cfg.precision,
check_val_every_n_epoch=min(
epochs,
(self.cfg.check_val_every_n_epoch if not self.cfg.test_run else 1)
),
strategy="auto",
profiler=self.cfg.profile,
)
trainer.fit(
model=classifier,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
# Print the current run results
for metric in metric_callback.val_metrics.keys():
max_value = max(metric_callback.val_metrics[metric])
print_rank_zero(f"{self.name} {log_name} {metric}: {max_value}")
# Update the metric values in a markdown and csv file
metrics = {
metric: max(value)
for metric, value in metric_callback.val_metrics.items()
if "train" not in metric and "loss" not in metric
}
self.cfg.experiment_result_metrics.append({
"Setting": self.name,
"Evaluation": log_name,
**metrics
})
result_metrics_dir = self.cfg.log_dir / self.cfg.baseline_id
result_metrics = pd.DataFrame(self.cfg.experiment_result_metrics)
result_metrics.to_csv(result_metrics_dir / "metrics.csv", index=False)
result_metrics.to_markdown(result_metrics_dir / "metrics.md", index=False, floatfmt=".4f")
[docs]
class SwinL384Baseline(BaselineMethod):
[docs]
method_specific_augmentation = T.Compose([
#T.Resize(size=(384, 384)),
T.RandAugment(num_ops=2, magnitude=20),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
#normalize_transform = T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
[docs]
skip_embedding_training = False
def __init__(self, args):
super().__init__(args)
[docs]
self.model = SwinL384(
batch_size_per_device=self.cfg.batch_size_per_device,
num_classes=self.cfg.num_classes,
feature_dim=self.feature_dim,
)
[docs]
class ViT_B_16Baseline(BaselineMethod):
[docs]
method_specific_augmentation = T.Compose([
T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __init__(self, cfg):
super().__init__(cfg)
[docs]
self.model = ViT_B_16Classifier(
batch_size_per_device=self.cfg.batch_size_per_device,
num_classes=self.cfg.num_classes,
feature_dim=self.feature_dim,
)
[docs]
def get_embedding_model(self):
return ViTEmbedding(model=self.model)
[docs]
class Baseline:
"""
The main class that runs the baseline methods.
"""
[docs]
methods: Dict[str, Type[BaselineMethod]] = {
#"swav": SwAVBaseline, # Removed
#"aim": AIMBaseline, # Removed
#"resnet50": ResNet50Baseline, # Removed, Resnet worked around 90% top1, it is kinda old tho tbh so it's not further pursued
"vit_b_16": ViT_B_16Baseline,
"mega_descriptor_finetune": SwinL384Baseline,
}
@timing_decorator
[docs]
def run(self, args):
"""
Run the class as specified in the config.
args: argparse.Namespace - The CLI arguments, used as kwargs to instantiate a Config object.
"""
cfg = Config(**vars(args))
if cfg.aggregate_metrics:
self.aggregate_metrics(cfg)
return
cfg.baseline_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
methods = cfg.methods or list(self.methods.keys())
print_rank_zero(f"# Running: {methods}...")
for method in methods:
self.methods[method](cfg).run_baseline_method()
print_rank_zero(f"# All baselines metrics computed!")
print_rank_zero(f"# Results saved in {cfg.log_dir / cfg.baseline_id}")
[docs]
def calculate_mean_std(self, df: pd.DataFrame, groupby: List[str]):
"""
Calculate the mean and standard deviation for all numerical columns grouped by two identifiers,
and return a DataFrame with the results in the ± notation.
Parameters:
df (pd.DataFrame): Input DataFrame.
groupby (List[str]): List of which columns to group by.
Returns:
pd.DataFrame: DataFrame with mean and standard deviation in ± notation.
"""
# Determine the value columns
value_columns = [col for col in df.columns if col not in groupby]
# Group by the identifiers
grouped = df.groupby(groupby)
# Calculate mean and standard deviation
mean_df = grouped.mean().reset_index()
std_df = grouped.std().reset_index()
# Merge the mean and standard deviation DataFrames
merged_df = mean_df.copy()
for col in value_columns:
merged_df[f"{col}_std"] = std_df[col]
# Function to combine mean and standard deviation with ± notation
def combine_mean_std(row, mean_col, std_col):
return f"{row[mean_col]:.3f} ± {(row[std_col] if row[std_col] is not None else 0):.3f}"
# Create a DataFrame to store results
result_df = mean_df.copy()
for col in value_columns:
result_df[col] = merged_df.apply(lambda row: combine_mean_std(row, col, f"{col}_std"), axis=1)
#result_df = result_df.pivot(index='Setting', columns='Evaluation')
return result_df.reset_index()
[docs]
def aggregate_metrics(self, cfg: Config):
# Agglomerate the all metric files
metrics = list(cfg.log_dir.glob("**/metrics*.csv"))
result_metrics = pd.concat([pd.read_csv(metric) for metric in metrics], ignore_index=True)
result_metrics.dropna(inplace=True)
result_metrics.to_csv(cfg.log_dir / "agglomerated_metrics.csv", index=False)
result_metrics.to_markdown(cfg.log_dir / "agglomerated_metrics.md", index=False, floatfmt=".4f")
# Aggregate the agglomerated metrics with error bars
result_metrics = self.calculate_mean_std(result_metrics, ["Setting", "Evaluation"])
result_metrics.to_csv(cfg.log_dir / "aggregated_metrics.csv", index=False)
result_metrics.to_markdown(cfg.log_dir / "aggregated_metrics.md", index=False, floatfmt=".4f")
print_rank_zero(f"Aggregated metrics saved in {cfg.log_dir}")
[docs]
parser = argparse.ArgumentParser(description='Baseline metrics')
parser.add_argument("--log-dir", type=Path, default=str(Config.log_dir))
parser.add_argument("--batch-size-per-device", type=int, default=Config.batch_size_per_device) #default=32) #default=128)
parser.add_argument("--epochs", type=int, default=Config.epochs)
parser.add_argument("--num-workers", type=int, default=Config.num_workers)
parser.add_argument("--checkpoint-path", type=Path, default=Config.checkpoint_path)
parser.add_argument("--methods", type=str, nargs="+", default=Config.methods, choices=Baseline.methods.keys(), required=False)
#parser.add_argument("--num-classes", type=int, default=Config.num_classes)
parser.add_argument("--skip-embedding-training", action="store_true", default=Config.skip_embedding_training)
parser.add_argument("--skip-knn-eval", action="store_true", default=Config.skip_knn_eval)
parser.add_argument("--skip-linear-eval", action="store_true", default=Config.skip_linear_eval)
#parser.add_argument("--test-run", action="store_true", default=Config.test_run) # For debugging only
parser.add_argument("--aggregate-metrics", action="store_true", default=Config.aggregate_metrics)
if __name__ == "__main__":
[docs]
args = parser.parse_args()
Baseline().run(args)