| from fastapi import FastAPI |
| from pydantic import BaseModel |
| import faiss |
| import pickle |
| from sentence_transformers import SentenceTransformer |
| import numpy as np |
| from collections import Counter |
| import gzip |
| import uvicorn |
|
|
| |
| INDEX_PATH = "faiss.index" |
| META_PATH = "metadata.pkl.gz" |
| MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" |
| CHUNK_SIZE = 2000 |
|
|
| |
| index = faiss.read_index(INDEX_PATH) |
|
|
| with gzip.open(META_PATH, "rb") as f: |
| meta = pickle.load(f) |
|
|
| texts = meta["texts"] |
| statuses = meta["statuses"] |
|
|
| |
| model = SentenceTransformer(MODEL_NAME) |
|
|
| |
| app = FastAPI(title="Text Embedding Predictor") |
|
|
| |
| class Query(BaseModel): |
| text: str |
| k: int = 5 |
|
|
| |
| def split_text(text, chunk_size=CHUNK_SIZE): |
| chunks = [] |
| for i in range(0, len(text), chunk_size): |
| chunks.append(text[i:i+chunk_size]) |
| return chunks |
|
|
| |
| @app.post("/predict") |
| def predict(query: Query): |
| text_chunks = split_text(query.text) |
| all_top_statuses = [] |
| all_results = [] |
|
|
| for chunk in text_chunks: |
| |
| chunk = chunk.replace("\\", "\\\\") |
| |
| q_emb = model.encode([chunk]).astype("float32") |
| distances, indices = index.search(q_emb, query.k) |
|
|
| top_statuses = [] |
| results = [] |
|
|
| for rank, idx in enumerate(indices[0]): |
| status = statuses[idx] |
| top_statuses.append(status) |
| results.append({ |
| "rank": rank + 1, |
| "text": texts[idx], |
| "status": status, |
| "distance": float(distances[0][rank]) |
| }) |
|
|
| all_top_statuses.extend(top_statuses) |
| all_results.extend(results) |
|
|
| |
| vote = Counter(all_top_statuses).most_common(1)[0] |
|
|
| return { |
| "prediction": vote[0], |
| "votes": dict(Counter(all_top_statuses)), |
| "top_k": all_results[:query.k] |
| } |
|
|
| |
| if __name__ == "__main__": |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |
|
|