File size: 3,864 Bytes
f440f7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import numpy as np
import torch

from typing import Any
from transformers import AutoTokenizer


def splade_max(features, attention_mask):
    """
    SPLADE pooling operation
    """
    relu = torch.nn.ReLU(inplace=False)
    values, ids_ = torch.max(
        torch.log(1 + relu(features)) * attention_mask.unsqueeze(-1), dim=1
    )
    return values, ids_


def encode(
    self,
    sentences: list[str],
    max_length: int = 1024,
    prompt_type: str = "document",
    return_dict: bool = False,
    print_dict: bool = False,
    batch_size: int = 8,
    top_k_q: int = -1,
    top_k_d: int = -1,
    **kwargs: Any,
) -> np.ndarray:
    all_embeddings = []
    for i in range(0, len(sentences), batch_size):
        batch_texts = sentences[i : i + batch_size]
        batch_dict = self.create_batch_dict(batch_texts, max_length)
        batch_dict = {
            key: value.to(self.model.device) for key, value in batch_dict.items()
        }
        with torch.no_grad():
            splare_reps = self(**batch_dict)[0]
            if prompt_type == "query" and top_k_q > 0:
                splare_reps = top_k(splare_reps, top_k_q)
            if prompt_type == "document" and top_k_d > 0:
                splare_reps = top_k(splare_reps, top_k_d)
            all_embeddings.append(splare_reps.cpu().float().numpy())
    if return_dict:
        d = bow_dict(self, np.concatenate(all_embeddings, axis=0))
        if print_dict:
            print_bow_bars(sentences, d)
        return d
    else:
        return np.concatenate(all_embeddings, axis=0)


def bow_dict(self, embeddings):
    out = []
    for vector in embeddings:
        idx = np.nonzero(vector)[0]
        weights = vector[idx]
        d = {k: v for k, v in zip(idx.tolist(), weights.tolist())}
        sorted_d = {
            self.reverse_voc[k]: float(v)
            for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)
        }
        out.append(sorted_d)
    return out


def print_bow_bars(sentences, bow_list, width=20):
    ascii_header("TOP ACTIVATED WORDS")
    for sent, bow in zip(sentences, bow_list):
        print(f"* INPUT: {sent}\n")
        max_w = max(bow.values())
        for k, v in sorted(bow.items(), key=lambda x: x[1], reverse=True):
            bar = "█" * int(v / max_w * width)
            print(f"{k[:25]:25} | {bar} {v:.2f}")
        print("\n")


def ascii_header(title, width=70):
    title = f" {title} "
    print("+" + "-" * (width - 2) + "+")
    print("|" + title.center(width - 2) + "|")
    print("+" + "-" * (width - 2) + "+")
    print("\n")


def similarity(self, a, b) -> torch.Tensor:
    """
    MTEB eval requires this
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)
    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    def _dot_score_core(a_tensor, b_tensor):
        if len(a_tensor.shape) == 1:
            a_tensor = a_tensor.unsqueeze(0)
        if len(b_tensor.shape) == 1:
            b_tensor = b_tensor.unsqueeze(0)
        return a_tensor @ b_tensor.transpose(0, 1)

    return _dot_score_core(a, b)


def prepare_tokenizer(tokenizer_name: str, padding_side="right"):
    """
    loads and prepares tokenizer
    """
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    tokenizer.pad_token = (
        tokenizer.bos_token or tokenizer.pad_token or tokenizer.eos_token
    )
    tokenizer.padding_side = padding_side
    return tokenizer


def top_k(x: torch.Tensor, k: int) -> torch.Tensor:
    """
    zeroes out all but the top-k values in the last dimension of x
    """
    _, topk_indices = x.topk(k, dim=-1)
    # create a zero tensor of the same shape as x
    mask = torch.zeros_like(x, dtype=torch.bool)
    # use scatter along the last dimension
    mask.scatter_(-1, topk_indices, True)
    # zero out all but the top-k
    return x * mask