ZIP-B / test.py
Yiming-M's picture
2025-07-31 15:53 🐣
0ecb9aa verified
import torch
from torch import nn
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from argparse import ArgumentParser
import os
current_dir = os.path.abspath(os.path.dirname(__file__))
from datasets import standardize_dataset_name
from models import get_model
from utils import get_config, get_dataloader, setup, cleanup
from evaluate import evaluate
parser = ArgumentParser(description="Test a trained model on a dataset.")
# Parameters for model
parser.add_argument("--weight_path", type=str, required=True, help="The name of the weight to use.")
parser.add_argument("--output_filename", type=str, default=None, help="The name of the result file.")
# Parameters for evaluation
parser.add_argument("--dataset", type=str, required=True, help="The dataset to evaluate on.")
parser.add_argument("--split", type=str, default="val", choices=["val", "test"], help="The split to evaluate on.")
parser.add_argument("--input_size", type=int, default=224, help="The size of the input image.")
parser.add_argument("--sliding_window", action="store_true", help="Use sliding window strategy for evaluation.")
parser.add_argument("--max_input_size", type=int, default=4096, help="The maximum size of the input image in evaluation. Images larger than this will be processed using sliding window by force to avoid OOM.")
parser.add_argument("--max_num_windows", type=int, default=8, help="The maximum number of windows to be simultaneously processed.")
parser.add_argument("--resize_to_multiple", action="store_true", help="Resize the image to the nearest multiple of the input size.")
parser.add_argument("--stride", type=int, default=None, help="The stride for sliding window strategy.")
parser.add_argument("--amp", action="store_true", help="Use automatic mixed precision for evaluation.")
parser.add_argument("--device", type=str, default="cuda", help="The device to use for evaluation.")
parser.add_argument("--num_workers", type=int, default=8, help="The number of workers for the data loader.")
parser.add_argument("--local_rank", type=int, default=-1, help="The local rank for distributed training.")
def run(local_rank: int, nprocs: int, args: ArgumentParser):
print(f"Rank {local_rank} process among {nprocs} processes.")
setup(local_rank, nprocs)
print(f"Initialized successfully. Training with {nprocs} GPUs.")
device = f"cuda:{local_rank}" if local_rank != -1 else "cuda:0"
print(f"Using device: {device}.")
ddp = nprocs > 1
_ = get_config(vars(args).copy(), mute=False)
model = get_model(model_info_path=args.weight_path).to(device)
model = DDP(nn.SyncBatchNorm.convert_sync_batchnorm(model), device_ids=[local_rank], output_device=local_rank) if ddp else model
model = model.to(device)
model.eval()
args.output_filename = f"{model.model_name}_{args.weight_path.split('/')[-1].split('.')[0]}" if args.output_filename is None else args.output_filename
dataloader = get_dataloader(args, split=args.split)
scores = evaluate(
model=model,
data_loader=dataloader,
sliding_window=args.sliding_window,
max_input_size=args.max_input_size,
window_size=args.input_size,
stride=args.stride,
max_num_windows=args.max_num_windows,
amp=args.amp,
local_rank=local_rank,
nprocs=nprocs,
)
if local_rank == 0:
for k, v in scores.items():
print(f"{k}: {v}")
result_dir = os.path.join(current_dir, "results", args.dataset, args.split)
os.makedirs(result_dir, exist_ok=True)
with open(os.path.join(result_dir, f"{args.output_filename}.txt"), "w") as f:
for k, v in scores.items():
f.write(f"{k}: {v}\n")
cleanup(ddp)
if __name__ == "__main__":
args = parser.parse_args()
args.dataset = standardize_dataset_name(args.dataset)
if args.dataset in ["sha", "shb", "qnrf", "nwpu"]:
assert args.split == "val", f"Split {args.split} is not available for dataset {args.dataset}."
# Sliding window prediction will be used if args.sliding_window is True, or when the image size is larger than args.max_input_size
args.stride = args.stride or args.input_size
assert os.path.exists(args.weight_path), f"Weight path {args.weight_path} does not exist."
args.in_memory_dataset = False
args.nprocs = torch.cuda.device_count()
print(f"Using {args.nprocs} GPUs.")
if args.nprocs > 1:
mp.spawn(run, nprocs=args.nprocs, args=(args.nprocs, args))
else:
run(0, 1, args)