| """ |
| Thyroid Ultrasound Evaluation + Grad-CAM (Pure PyTorch, no Trainer) |
| """ |
| import os, sys, math, json, random, warnings, traceback |
| warnings.filterwarnings("ignore") |
|
|
| import numpy as np |
| from PIL import Image |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| import torch |
| import torch.nn.functional as F |
| from datasets import load_dataset |
| from transformers import AutoImageProcessor, AutoModelForImageClassification |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix |
|
|
| HF_USERNAME = "Johnyquest7" |
| DATASET_NAME = "BTX24/thyroid-cancer-classification-ultrasound-dataset" |
| MODEL_NAME = f"{HF_USERNAME}/ML-Inter_thyroid" |
| OUTPUT_DIR = "./eval_outputs" |
| SEED = 42 |
| BATCH_SIZE = 16 |
|
|
| random.seed(SEED) |
| np.random.seed(SEED) |
| torch.manual_seed(SEED) |
|
|
| def main(): |
| print("=" * 60) |
| print("Thyroid Ultrasound Evaluation + Grad-CAM") |
| print("=" * 60) |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"\nDevice: {device}") |
| print(f"Loading model: {MODEL_NAME}") |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) |
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME).to(device).eval() |
| print(f"Model loaded: {sum(p.numel() for p in model.parameters())/1e6:.1f}M params") |
|
|
| print(f"\nLoading dataset: {DATASET_NAME}") |
| ds = load_dataset(DATASET_NAME, split="train") |
| ds = ds.shuffle(seed=SEED) |
| train_test = ds.train_test_split(test_size=0.2, stratify_by_column="label", seed=SEED) |
| test_ds = train_test["test"] |
| print(f"Test samples: {len(test_ds)} (Benign: {sum(1 for x in test_ds if x['label']==0)}, Malignant: {sum(1 for x in test_ds if x['label']==1)})") |
|
|
| id2label = model.config.id2label |
|
|
| |
| all_logits, all_labels = [], [] |
| print("\nRunning inference...") |
| for i in range(0, len(test_ds), BATCH_SIZE): |
| batch_items = [test_ds[j] for j in range(i, min(i+BATCH_SIZE, len(test_ds)))] |
| images = [item["image"].convert("RGB") if item["image"].mode != "RGB" else item["image"] for item in batch_items] |
| inputs = processor(images, return_tensors="pt") |
| pixel_values = inputs["pixel_values"].to(device) |
| with torch.no_grad(): |
| outputs = model(pixel_values=pixel_values) |
| all_logits.extend(outputs.logits.cpu().numpy()) |
| all_labels.extend([item["label"] for item in batch_items]) |
| if (i // BATCH_SIZE) % 5 == 0: |
| print(f" Batch {i//BATCH_SIZE + 1}/{(len(test_ds)+BATCH_SIZE-1)//BATCH_SIZE}") |
|
|
| y_true = np.array(all_labels) |
| y_logits = np.array(all_logits) |
| y_pred = np.argmax(y_logits, axis=1) |
| probs = F.softmax(torch.from_numpy(y_logits), dim=1).numpy() |
|
|
| acc = accuracy_score(y_true, y_pred) |
| prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted", zero_division=0) |
| try: |
| auc = roc_auc_score(y_true, probs[:, 1]) |
| except Exception: |
| auc = roc_auc_score(y_true, probs[:, 0]) |
| cm = confusion_matrix(y_true, y_pred) |
| sens = cm[1,1] / (cm[1,1] + cm[1,0]) if (cm[1,1] + cm[1,0]) > 0 else 0 |
| spec = cm[0,0] / (cm[0,0] + cm[0,1]) if (cm[0,0] + cm[0,1]) > 0 else 0 |
|
|
| final = { |
| "test_accuracy": float(acc), |
| "test_weighted_f1": float(f1), |
| "test_weighted_precision": float(prec), |
| "test_weighted_recall": float(rec), |
| "test_roc_auc": float(auc), |
| "test_sensitivity": float(sens), |
| "test_specificity": float(spec), |
| "test_confusion_matrix": cm.tolist(), |
| } |
| print(f"\n{'='*60}") |
| print("FINAL TEST METRICS") |
| print(f"{'='*60}") |
| for k, v in final.items(): |
| print(f" {k}: {v}") |
| with open(f"{OUTPUT_DIR}/test_metrics.json", "w") as f: |
| json.dump(final, f, indent=2) |
| print(f"\nSaved to {OUTPUT_DIR}/test_metrics.json") |
|
|
| |
| correct_idx = [i for i in range(len(y_true)) if y_true[i] == y_pred[i]] |
| incorrect_idx = [i for i in range(len(y_true)) if y_true[i] != y_pred[i]] |
| random.shuffle(correct_idx) |
| random.shuffle(incorrect_idx) |
| selected = correct_idx[:5] + incorrect_idx[:5] |
| print(f"\nGenerating Grad-CAM for {len(selected)} samples ({len(correct_idx[:5])} correct, {len(incorrect_idx[:5])} incorrect)...") |
|
|
| gradcam_data = {} |
| def fwd_hook(module, input, output): |
| gradcam_data["feat"] = output.detach() |
| def bwd_hook(module, grad_input, grad_output): |
| gradcam_data["grad"] = grad_output[0].detach() |
|
|
| target_layer = model.swinv2.encoder.layers[-1].blocks[-1].layernorm_after |
| fwd_handle = target_layer.register_forward_hook(fwd_hook) |
| bwd_handle = target_layer.register_full_backward_hook(bwd_hook) |
|
|
| for idx in selected: |
| try: |
| item = test_ds[idx] |
| img = item["image"].convert("RGB") |
| label = item["label"] |
| inputs = processor(img, return_tensors="pt") |
| img_tensor = inputs["pixel_values"].to(device).requires_grad_(True) |
| model.zero_grad() |
| outputs = model(pixel_values=img_tensor) |
| target_class = int(y_pred[idx]) |
| score = outputs.logits[0, target_class] |
| score.backward() |
|
|
| feat = gradcam_data["feat"][0] |
| grads = gradcam_data["grad"][0] |
| if feat.dim() == 3: |
| weights = grads.mean(dim=0, keepdim=True) |
| cam = torch.matmul(feat, weights.t()).squeeze() |
| H = W = int(math.sqrt(cam.shape[0])) |
| cam = cam.reshape(H, W) |
| else: |
| weights = grads.mean(dim=(0,1), keepdim=True) |
| cam = (feat * weights).sum(dim=-1).squeeze() |
|
|
| cam = F.relu(cam) |
| cam = cam - cam.min() |
| cam = cam / (cam.max() + 1e-8) |
| cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(256,256), mode="bilinear", align_corners=False) |
| cam = cam.squeeze().cpu().numpy() |
|
|
| img_np = img_tensor.squeeze().detach().cpu().permute(1,2,0).numpy() |
| img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8) |
| plt.figure(figsize=(6,6)) |
| plt.imshow(img_np) |
| plt.imshow(cam, cmap="jet", alpha=0.5) |
| pred_name = id2label.get(target_class, str(target_class)) |
| true_name = id2label.get(label, str(label)) |
| plt.title(f"Pred: {pred_name} | True: {true_name}") |
| plt.axis("off") |
| fname = f"{OUTPUT_DIR}/gradcam_sample_{idx}_pred{pred_name}_true{true_name}.png" |
| plt.savefig(fname, bbox_inches="tight", dpi=150) |
| plt.close() |
| print(f" Saved {fname}") |
| except Exception as e: |
| print(f" Skipped sample {idx}: {e}") |
| traceback.print_exc() |
|
|
| fwd_handle.remove() |
| bwd_handle.remove() |
| print("\nEvaluation complete.") |
|
|
| if __name__ == "__main__": |
| main() |
|
|