import torch import torch.nn as nn from transformers import PreTrainedModel from simple_mlp_configuration import SimpleMLPConfig from transformers.modeling_outputs import SequenceClassifierOutput class SimpleMLPForClassification(PreTrainedModel): config_class = SimpleMLPConfig def __init__(self, config): super().__init__(config) self.config = config self.num_labels = config.num_classes self.fc1 = nn.Linear(config.input_dim, config.hidden_dim) self.activation = nn.ReLU() self.dropout = nn.Dropout(config.dropout_rate) self.fc2 = nn.Linear(config.hidden_dim, config.num_classes) self.post_init() def forward(self, inputs_embeds, labels=None, return_dict=None): return_dict = return_dict if return_dict is not None else self.config.use_return_dict x = self.fc1(inputs_embeds) x = self.activation(x) x = self.dropout(x) logits = self.fc2(x) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, )