Instructions to use zeroentropy/zerank-2-reranker with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use zeroentropy/zerank-2-reranker with sentence-transformers:
from sentence_transformers import CrossEncoder model = CrossEncoder("zeroentropy/zerank-2-reranker") query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) - Notebooks
- Google Colab
- Kaggle
Use torch.inference_mode() and disable gradient checkpointing
#4
by prathamj31 - opened
- config.json +4 -1
- modeling_zeranker.py +33 -16
config.json
CHANGED
|
@@ -64,5 +64,8 @@
|
|
| 64 |
"transformers_version": "4.57.1",
|
| 65 |
"use_cache": true,
|
| 66 |
"use_sliding_window": false,
|
| 67 |
-
"vocab_size": 151936
|
|
|
|
|
|
|
|
|
|
| 68 |
}
|
|
|
|
| 64 |
"transformers_version": "4.57.1",
|
| 65 |
"use_cache": true,
|
| 66 |
"use_sliding_window": false,
|
| 67 |
+
"vocab_size": 151936,
|
| 68 |
+
"auto_map": {
|
| 69 |
+
"AutoConfig": "modeling_zeranker.ZEConfig"
|
| 70 |
+
}
|
| 71 |
}
|
modeling_zeranker.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
from sentence_transformers import CrossEncoder as _CE
|
| 2 |
|
| 3 |
import math
|
|
|
|
| 4 |
from typing import cast, Any
|
| 5 |
-
import types
|
| 6 |
-
|
| 7 |
|
| 8 |
import torch
|
| 9 |
from transformers.configuration_utils import PretrainedConfig
|
|
@@ -23,8 +22,10 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
|
| 23 |
# pyright: reportUnknownMemberType=false
|
| 24 |
# pyright: reportUnknownVariableType=false
|
| 25 |
|
|
|
|
|
|
|
| 26 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 27 |
-
PER_DEVICE_BATCH_SIZE_TOKENS =
|
| 28 |
global_device = (
|
| 29 |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 30 |
)
|
|
@@ -74,9 +75,12 @@ def load_model(
|
|
| 74 |
if device is None:
|
| 75 |
device = global_device
|
| 76 |
|
|
|
|
|
|
|
| 77 |
config = AutoConfig.from_pretrained(MODEL_PATH)
|
| 78 |
assert isinstance(config, PretrainedConfig)
|
| 79 |
|
|
|
|
| 80 |
model = AutoModelForCausalLM.from_pretrained(
|
| 81 |
MODEL_PATH,
|
| 82 |
torch_dtype="auto",
|
|
@@ -93,6 +97,7 @@ def load_model(
|
|
| 93 |
| Qwen3ForCausalLM,
|
| 94 |
)
|
| 95 |
|
|
|
|
| 96 |
tokenizer = cast(
|
| 97 |
AutoTokenizer,
|
| 98 |
AutoTokenizer.from_pretrained(
|
|
@@ -105,6 +110,7 @@ def load_model(
|
|
| 105 |
if tokenizer.pad_token is None:
|
| 106 |
tokenizer.pad_token = tokenizer.eos_token
|
| 107 |
|
|
|
|
| 108 |
return tokenizer, model
|
| 109 |
|
| 110 |
|
|
@@ -125,13 +131,7 @@ def predict(
|
|
| 125 |
raise ValueError("query_documents or sentences must be provided")
|
| 126 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 127 |
|
| 128 |
-
|
| 129 |
-
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 130 |
-
self.inner_model.gradient_checkpointing_enable()
|
| 131 |
-
self.inner_model.eval()
|
| 132 |
-
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 133 |
-
"Yes", add_special_tokens=False
|
| 134 |
-
)[0]
|
| 135 |
|
| 136 |
model = self.inner_model
|
| 137 |
tokenizer = self.inner_tokenizer
|
|
@@ -161,9 +161,12 @@ def predict(
|
|
| 161 |
batches[-1].append((query, document))
|
| 162 |
max_length = max(max_length, 20 + len(query) + len(document))
|
| 163 |
|
|
|
|
|
|
|
| 164 |
# Inference all of the document batches
|
| 165 |
all_logits: list[float] = []
|
| 166 |
-
for batch in batches:
|
|
|
|
| 167 |
batch_inputs = format_pointwise_datapoints(
|
| 168 |
tokenizer,
|
| 169 |
batch,
|
|
@@ -172,11 +175,12 @@ def predict(
|
|
| 172 |
batch_inputs = batch_inputs.to(global_device)
|
| 173 |
|
| 174 |
try:
|
| 175 |
-
|
|
|
|
| 176 |
except torch.OutOfMemoryError:
|
| 177 |
-
|
| 178 |
torch.cuda.empty_cache()
|
| 179 |
-
|
| 180 |
outputs = model(**batch_inputs, use_cache=False)
|
| 181 |
|
| 182 |
# Extract the logits
|
|
@@ -199,18 +203,31 @@ def predict(
|
|
| 199 |
# Unsort by indices
|
| 200 |
scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
|
| 201 |
|
|
|
|
| 202 |
return scores
|
| 203 |
|
| 204 |
|
| 205 |
def to_device(self: _CE, new_device: torch.device) -> None:
|
| 206 |
global global_device
|
|
|
|
| 207 |
global_device = new_device
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
_CE.predict = predict
|
|
|
|
| 211 |
|
| 212 |
from transformers import Qwen3Config
|
| 213 |
|
| 214 |
ZEConfig = Qwen3Config
|
| 215 |
-
|
| 216 |
-
_CE.to = to_device
|
|
|
|
| 1 |
from sentence_transformers import CrossEncoder as _CE
|
| 2 |
|
| 3 |
import math
|
| 4 |
+
import logging
|
| 5 |
from typing import cast, Any
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
from transformers.configuration_utils import PretrainedConfig
|
|
|
|
| 22 |
# pyright: reportUnknownMemberType=false
|
| 23 |
# pyright: reportUnknownVariableType=false
|
| 24 |
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 28 |
+
PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
|
| 29 |
global_device = (
|
| 30 |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 31 |
)
|
|
|
|
| 75 |
if device is None:
|
| 76 |
device = global_device
|
| 77 |
|
| 78 |
+
logger.info(f"Loading model from {MODEL_PATH} on device: {device}")
|
| 79 |
+
|
| 80 |
config = AutoConfig.from_pretrained(MODEL_PATH)
|
| 81 |
assert isinstance(config, PretrainedConfig)
|
| 82 |
|
| 83 |
+
logger.info(f"Loading model with config type: {config.model_type}")
|
| 84 |
model = AutoModelForCausalLM.from_pretrained(
|
| 85 |
MODEL_PATH,
|
| 86 |
torch_dtype="auto",
|
|
|
|
| 97 |
| Qwen3ForCausalLM,
|
| 98 |
)
|
| 99 |
|
| 100 |
+
logger.info("Loading tokenizer")
|
| 101 |
tokenizer = cast(
|
| 102 |
AutoTokenizer,
|
| 103 |
AutoTokenizer.from_pretrained(
|
|
|
|
| 110 |
if tokenizer.pad_token is None:
|
| 111 |
tokenizer.pad_token = tokenizer.eos_token
|
| 112 |
|
| 113 |
+
logger.info("Model and tokenizer loaded successfully")
|
| 114 |
return tokenizer, model
|
| 115 |
|
| 116 |
|
|
|
|
| 131 |
raise ValueError("query_documents or sentences must be provided")
|
| 132 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 133 |
|
| 134 |
+
logger.info(f"Starting prediction for {len(query_documents)} query-document pairs")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
model = self.inner_model
|
| 137 |
tokenizer = self.inner_tokenizer
|
|
|
|
| 161 |
batches[-1].append((query, document))
|
| 162 |
max_length = max(max_length, 20 + len(query) + len(document))
|
| 163 |
|
| 164 |
+
logger.info(f"Created {len(batches)} batches for inference")
|
| 165 |
+
|
| 166 |
# Inference all of the document batches
|
| 167 |
all_logits: list[float] = []
|
| 168 |
+
for batch_idx, batch in enumerate(batches):
|
| 169 |
+
logger.debug(f"Processing batch {batch_idx + 1}/{len(batches)} with {len(batch)} pairs")
|
| 170 |
batch_inputs = format_pointwise_datapoints(
|
| 171 |
tokenizer,
|
| 172 |
batch,
|
|
|
|
| 175 |
batch_inputs = batch_inputs.to(global_device)
|
| 176 |
|
| 177 |
try:
|
| 178 |
+
with torch.inference_mode():
|
| 179 |
+
outputs = model(**batch_inputs, use_cache=False)
|
| 180 |
except torch.OutOfMemoryError:
|
| 181 |
+
logger.warning(f"GPU OOM! Memory reserved: {torch.cuda.memory_reserved()}")
|
| 182 |
torch.cuda.empty_cache()
|
| 183 |
+
logger.info(f"GPU cache cleared. Memory reserved: {torch.cuda.memory_reserved()}")
|
| 184 |
outputs = model(**batch_inputs, use_cache=False)
|
| 185 |
|
| 186 |
# Extract the logits
|
|
|
|
| 203 |
# Unsort by indices
|
| 204 |
scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
|
| 205 |
|
| 206 |
+
logger.info(f"Prediction complete. Generated {len(scores)} scores")
|
| 207 |
return scores
|
| 208 |
|
| 209 |
|
| 210 |
def to_device(self: _CE, new_device: torch.device) -> None:
|
| 211 |
global global_device
|
| 212 |
+
logger.info(f"Changing device from {global_device} to {new_device}")
|
| 213 |
global_device = new_device
|
| 214 |
|
| 215 |
+
# Load the model now since __init__ patching doesn't work due to timing
|
| 216 |
+
# (CrossEncoder instance is created before this module is loaded)
|
| 217 |
+
if not hasattr(self, "inner_model"):
|
| 218 |
+
logger.info("Loading model during device setup (eager loading)")
|
| 219 |
+
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 220 |
+
self.inner_model.eval()
|
| 221 |
+
self.inner_model.gradient_checkpointing_disable()
|
| 222 |
+
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 223 |
+
"Yes", add_special_tokens=False
|
| 224 |
+
)[0]
|
| 225 |
+
logger.info(f"Model loaded successfully. Yes token ID: {self.inner_yes_token_id}")
|
| 226 |
+
|
| 227 |
|
| 228 |
_CE.predict = predict
|
| 229 |
+
_CE.to = to_device
|
| 230 |
|
| 231 |
from transformers import Qwen3Config
|
| 232 |
|
| 233 |
ZEConfig = Qwen3Config
|
|
|
|
|
|