| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from simple_mlp_configuration import SimpleMLPConfig |
| from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
| class SimpleMLPForClassification(PreTrainedModel): |
| config_class = SimpleMLPConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| self.num_labels = config.num_classes |
|
|
| self.fc1 = nn.Linear(config.input_dim, config.hidden_dim) |
| self.activation = nn.ReLU() |
| self.dropout = nn.Dropout(config.dropout_rate) |
| self.fc2 = nn.Linear(config.hidden_dim, config.num_classes) |
|
|
| self.post_init() |
| |
| def forward(self, inputs_embeds, labels=None, return_dict=None): |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| x = self.fc1(inputs_embeds) |
| x = self.activation(x) |
| x = self.dropout(x) |
| logits = self.fc2(x) |
|
|
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| |
| if not return_dict: |
| output = (logits,) |
| return ((loss,) + output) if loss is not None else output |
| |
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| ) |
|
|
|
|