| | import gradio as gr |
| | from PIL import Image |
| | from joblib import load |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torchvision.models import efficientnet_b0 |
| | import torchvision.transforms as transforms |
| |
|
| |
|
| | class MultiModalClassifier(nn.Module): |
| | def __init__(self, num_classes, num_features): |
| | super(MultiModalClassifier, self).__init__() |
| | |
| | efficientnet = efficientnet_b0(pretrained=True) |
| | |
| | |
| | self.efficientnet_features = nn.Sequential(*list(efficientnet.children())[:-1]) |
| | |
| | |
| | self.age_dim = 1 |
| | self.anatom_site_dim = 1 |
| | self.sex_dim = 1 |
| | |
| | |
| | self.fc1 = nn.Linear(num_features + self.age_dim + self.anatom_site_dim + self.sex_dim, 256) |
| | self.fc2 = nn.Linear(256, num_classes) |
| | |
| | |
| | self.dropout = nn.Dropout(p=0.5) |
| | |
| | def forward(self, image, age, anatom_site, sex): |
| | |
| | image_features = self.efficientnet_features(image) |
| | image_features = F.avg_pool2d(image_features, image_features.size()[2:]).view(image.size(0), -1) |
| | |
| | |
| | age = age.view(-1, 1) |
| | anatom_site = anatom_site.view(-1, 1) |
| | sex = sex.view(-1, 1) |
| | |
| | additional_features = torch.cat((age, anatom_site, sex), dim=1) |
| | combined_features = torch.cat((image_features, additional_features), dim=1) |
| | |
| | |
| | combined_features = F.relu(self.fc1(combined_features)) |
| | combined_features = self.dropout(combined_features) |
| | output = self.fc2(combined_features) |
| | |
| | return output |
| |
|
| | |
| | num_classes = 1 |
| | num_features = 1280 |
| | model = MultiModalClassifier(num_classes, num_features) |
| |
|
| | |
| | model.load_state_dict(torch.load(r'best_epoch_weights.pth',map_location=torch.device('cpu'))) |
| |
|
| | |
| | model.eval() |
| |
|
| | |
| | age_scaler = load(r'age_approx_scaler.joblib') |
| |
|
| | |
| | test_transform = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | ]) |
| |
|
| | diagnosis_map = {0: 'benign', 1: 'malignant'} |
| |
|
| | |
| | sexes_mapping = {'male': 0, 'female': 1} |
| |
|
| | |
| | anatom_site_mapping = { |
| | 'torso': 0, |
| | 'lower extremity': 1, |
| | 'head/neck': 2, |
| | 'upper extremity': 3, |
| | 'palms/soles': 4, |
| | 'oral/genital': 5, |
| | } |
| |
|
| | def predict(image, age, gender, anatom_site): |
| |
|
| | image = Image.fromarray(image) |
| | |
| | image = test_transform(image) |
| | image = image.float() |
| | image = image.unsqueeze(0) |
| |
|
| | sex = torch.tensor([[sexes_mapping[gender.lower()]]], dtype=torch.float32) |
| | anatom_site = torch.tensor([[anatom_site_mapping[anatom_site]]], dtype=torch.float32) |
| |
|
| | |
| | scaled_age = age_scaler.transform([[age]]) |
| | |
| | age_tensor = torch.tensor(np.array(scaled_age), dtype=torch.float32) |
| |
|
| | |
| | output = model(image, age_tensor, anatom_site, sex) |
| |
|
| | |
| | output_sigmoid = torch.sigmoid(output) |
| | |
| | predicted_class = (output_sigmoid > 0.5).float() |
| |
|
| | |
| | return f"The predicted_class is a {diagnosis_map[int(predicted_class)]}." |
| |
|
| |
|
| | description_html = """ |
| | Fill in the required parameters and click 'classify'. |
| | """ |
| |
|
| | example_data = [ |
| | ["ISIC_0000060_downsampled.jpg", 35, "Female", "torso"], |
| | ["ISIC_0068279.jpg", 45.0, "Female", "head/neck"] |
| | ] |
| |
|
| | inputs = [ |
| | "image", |
| | gr.Number(label="Age", minimum=0, maximum=120), |
| | gr.Dropdown(['Male', 'Female'], label="Gender"), |
| | gr.Dropdown(['torso', 'lower extremity', 'head/neck', 'upper extremity', 'palms/soles', 'oral/genital'], label="Anatomical Site") |
| | ] |
| |
|
| | gr.Interface( |
| | predict, |
| | inputs, |
| | outputs = gr.Textbox(label="Output", type="text"), |
| | title="Skin Cancer Diagnosis", |
| | description=description_html, |
| | allow_flagging='never', |
| | examples=example_data |
| | ).launch() |