| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| use pyo3::prelude::*; |
| use pyo3::exceptions::{PyValueError, PyIOError}; |
|
|
| use crate::core::{Id, Point}; |
| use crate::adapters::index::{HatIndex as RustHatIndex, HatConfig, ConsolidationConfig, Consolidate}; |
| use crate::ports::Near; |
|
|
| |
| #[pyclass(name = "SearchResult")] |
| #[derive(Clone)] |
| pub struct PySearchResult { |
| |
| #[pyo3(get)] |
| pub id: String, |
|
|
| |
| #[pyo3(get)] |
| pub score: f32, |
| } |
|
|
| #[pymethods] |
| impl PySearchResult { |
| fn __repr__(&self) -> String { |
| format!("SearchResult(id='{}', score={:.4})", self.id, self.score) |
| } |
|
|
| fn __str__(&self) -> String { |
| format!("{}: {:.4}", self.id, self.score) |
| } |
| } |
|
|
| |
| #[pyclass(name = "HatConfig")] |
| #[derive(Clone)] |
| pub struct PyHatConfig { |
| inner: HatConfig, |
| } |
|
|
| #[pymethods] |
| impl PyHatConfig { |
| #[new] |
| fn new() -> Self { |
| Self { inner: HatConfig::default() } |
| } |
|
|
| |
| fn with_beam_width(mut slf: PyRefMut<'_, Self>, width: usize) -> PyRefMut<'_, Self> { |
| slf.inner.beam_width = width; |
| slf |
| } |
|
|
| |
| fn with_temporal_weight(mut slf: PyRefMut<'_, Self>, weight: f32) -> PyRefMut<'_, Self> { |
| slf.inner.temporal_weight = weight; |
| slf |
| } |
|
|
| |
| fn with_propagation_threshold(mut slf: PyRefMut<'_, Self>, threshold: f32) -> PyRefMut<'_, Self> { |
| slf.inner.propagation_threshold = threshold; |
| slf |
| } |
|
|
| fn __repr__(&self) -> String { |
| format!( |
| "HatConfig(beam_width={}, temporal_weight={:.2}, propagation_threshold={:.3})", |
| self.inner.beam_width, self.inner.temporal_weight, self.inner.propagation_threshold |
| ) |
| } |
| } |
|
|
| |
| #[pyclass(name = "SessionSummary")] |
| #[derive(Clone)] |
| pub struct PySessionSummary { |
| #[pyo3(get)] |
| pub id: String, |
|
|
| #[pyo3(get)] |
| pub score: f32, |
|
|
| #[pyo3(get)] |
| pub chunk_count: usize, |
|
|
| #[pyo3(get)] |
| pub timestamp_ms: u64, |
| } |
|
|
| #[pymethods] |
| impl PySessionSummary { |
| fn __repr__(&self) -> String { |
| format!( |
| "SessionSummary(id='{}', score={:.4}, chunks={})", |
| self.id, self.score, self.chunk_count |
| ) |
| } |
| } |
|
|
| |
| #[pyclass(name = "DocumentSummary")] |
| #[derive(Clone)] |
| pub struct PyDocumentSummary { |
| #[pyo3(get)] |
| pub id: String, |
|
|
| #[pyo3(get)] |
| pub score: f32, |
|
|
| #[pyo3(get)] |
| pub chunk_count: usize, |
| } |
|
|
| #[pymethods] |
| impl PyDocumentSummary { |
| fn __repr__(&self) -> String { |
| format!( |
| "DocumentSummary(id='{}', score={:.4}, chunks={})", |
| self.id, self.score, self.chunk_count |
| ) |
| } |
| } |
|
|
| |
| #[pyclass(name = "HatStats")] |
| #[derive(Clone)] |
| pub struct PyHatStats { |
| #[pyo3(get)] |
| pub global_count: usize, |
|
|
| #[pyo3(get)] |
| pub session_count: usize, |
|
|
| #[pyo3(get)] |
| pub document_count: usize, |
|
|
| #[pyo3(get)] |
| pub chunk_count: usize, |
| } |
|
|
| #[pymethods] |
| impl PyHatStats { |
| |
| #[getter] |
| fn total_points(&self) -> usize { |
| self.chunk_count |
| } |
|
|
| fn __repr__(&self) -> String { |
| format!( |
| "HatStats(points={}, sessions={}, documents={}, chunks={})", |
| self.chunk_count, self.session_count, self.document_count, self.chunk_count |
| ) |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| #[pyclass(name = "HatIndex")] |
| pub struct PyHatIndex { |
| inner: RustHatIndex, |
| } |
|
|
| #[pymethods] |
| impl PyHatIndex { |
| |
| |
| |
| |
| #[staticmethod] |
| fn cosine(dimensionality: usize) -> Self { |
| Self { |
| inner: RustHatIndex::cosine(dimensionality), |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| #[staticmethod] |
| fn with_config(dimensionality: usize, config: &PyHatConfig) -> Self { |
| Self { |
| inner: RustHatIndex::cosine(dimensionality).with_config(config.inner.clone()), |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| fn add(&mut self, embedding: Vec<f32>) -> PyResult<String> { |
| let point = Point::new(embedding); |
| let id = Id::now(); |
|
|
| self.inner.add(id, &point) |
| .map_err(|e| PyValueError::new_err(format!("{}", e)))?; |
|
|
| Ok(format!("{}", id)) |
| } |
|
|
| |
| |
| |
| |
| |
| fn add_with_id(&mut self, id_hex: &str, embedding: Vec<f32>) -> PyResult<()> { |
| let id = parse_id_hex(id_hex)?; |
| let point = Point::new(embedding); |
|
|
| self.inner.add(id, &point) |
| .map_err(|e| PyValueError::new_err(format!("{}", e)))?; |
|
|
| Ok(()) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| fn near(&self, query: Vec<f32>, k: usize) -> PyResult<Vec<PySearchResult>> { |
| let point = Point::new(query); |
|
|
| let results = self.inner.near(&point, k) |
| .map_err(|e| PyValueError::new_err(format!("{}", e)))?; |
|
|
| Ok(results.into_iter().map(|r| PySearchResult { |
| id: format!("{}", r.id), |
| score: r.score, |
| }).collect()) |
| } |
|
|
| |
| |
| |
| fn new_session(&mut self) { |
| self.inner.new_session(); |
| } |
|
|
| |
| |
| |
| |
| fn new_document(&mut self) { |
| self.inner.new_document(); |
| } |
|
|
| |
| fn stats(&self) -> PyHatStats { |
| let s = self.inner.stats(); |
| PyHatStats { |
| global_count: s.global_count, |
| session_count: s.session_count, |
| document_count: s.document_count, |
| chunk_count: s.chunk_count, |
| } |
| } |
|
|
| |
| fn __len__(&self) -> usize { |
| self.inner.len() |
| } |
|
|
| |
| fn is_empty(&self) -> bool { |
| self.inner.is_empty() |
| } |
|
|
| |
| |
| |
| |
| fn remove(&mut self, id_hex: &str) -> PyResult<()> { |
| let id = parse_id_hex(id_hex)?; |
|
|
| self.inner.remove(id) |
| .map_err(|e| PyValueError::new_err(format!("{}", e)))?; |
|
|
| Ok(()) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| fn near_sessions(&self, query: Vec<f32>, k: usize) -> PyResult<Vec<PySessionSummary>> { |
| let point = Point::new(query); |
|
|
| let results = self.inner.near_sessions(&point, k) |
| .map_err(|e| PyValueError::new_err(format!("{}", e)))?; |
|
|
| Ok(results.into_iter().map(|s| PySessionSummary { |
| id: format!("{}", s.id), |
| score: s.score, |
| chunk_count: s.chunk_count, |
| timestamp_ms: s.timestamp, |
| }).collect()) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| fn near_documents(&self, session_id: &str, query: Vec<f32>, k: usize) -> PyResult<Vec<PyDocumentSummary>> { |
| let sid = parse_id_hex(session_id)?; |
| let point = Point::new(query); |
|
|
| let results = self.inner.near_documents(sid, &point, k) |
| .map_err(|e| PyValueError::new_err(format!("{}", e)))?; |
|
|
| Ok(results.into_iter().map(|d| PyDocumentSummary { |
| id: format!("{}", d.id), |
| score: d.score, |
| chunk_count: d.chunk_count, |
| }).collect()) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| fn near_in_document(&self, doc_id: &str, query: Vec<f32>, k: usize) -> PyResult<Vec<PySearchResult>> { |
| let did = parse_id_hex(doc_id)?; |
| let point = Point::new(query); |
|
|
| let results = self.inner.near_in_document(did, &point, k) |
| .map_err(|e| PyValueError::new_err(format!("{}", e)))?; |
|
|
| Ok(results.into_iter().map(|r| PySearchResult { |
| id: format!("{}", r.id), |
| score: r.score, |
| }).collect()) |
| } |
|
|
| |
| |
| |
| |
| fn consolidate(&mut self) { |
| self.inner.consolidate(ConsolidationConfig::light()); |
| } |
|
|
| |
| fn consolidate_full(&mut self) { |
| self.inner.consolidate(ConsolidationConfig::full()); |
| } |
|
|
| |
| |
| |
| |
| fn save(&self, path: &str) -> PyResult<()> { |
| self.inner.save_to_file(std::path::Path::new(path)) |
| .map_err(|e| PyIOError::new_err(format!("{}", e))) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| #[staticmethod] |
| fn load(path: &str) -> PyResult<Self> { |
| let inner = RustHatIndex::load_from_file(std::path::Path::new(path)) |
| .map_err(|e| PyIOError::new_err(format!("{}", e)))?; |
|
|
| Ok(Self { inner }) |
| } |
|
|
| |
| |
| |
| |
| fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, pyo3::types::PyBytes>> { |
| let data = self.inner.to_bytes() |
| .map_err(|e| PyIOError::new_err(format!("{}", e)))?; |
| Ok(pyo3::types::PyBytes::new_bound(py, &data)) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| #[staticmethod] |
| fn from_bytes(data: &[u8]) -> PyResult<Self> { |
| let inner = RustHatIndex::from_bytes(data) |
| .map_err(|e| PyIOError::new_err(format!("{}", e)))?; |
|
|
| Ok(Self { inner }) |
| } |
|
|
| fn __repr__(&self) -> String { |
| let stats = self.inner.stats(); |
| format!( |
| "HatIndex(points={}, sessions={})", |
| stats.chunk_count, stats.session_count |
| ) |
| } |
| } |
|
|
| |
| fn parse_id_hex(hex: &str) -> PyResult<Id> { |
| if hex.len() != 32 { |
| return Err(PyValueError::new_err( |
| format!("ID must be 32 hex characters, got {}", hex.len()) |
| )); |
| } |
|
|
| let mut bytes = [0u8; 16]; |
| for (i, chunk) in hex.as_bytes().chunks(2).enumerate() { |
| let high = hex_char_to_nibble(chunk[0])?; |
| let low = hex_char_to_nibble(chunk[1])?; |
| bytes[i] = (high << 4) | low; |
| } |
|
|
| Ok(Id::from_bytes(bytes)) |
| } |
|
|
| fn hex_char_to_nibble(c: u8) -> PyResult<u8> { |
| match c { |
| b'0'..=b'9' => Ok(c - b'0'), |
| b'a'..=b'f' => Ok(c - b'a' + 10), |
| b'A'..=b'F' => Ok(c - b'A' + 10), |
| _ => Err(PyValueError::new_err(format!("Invalid hex character: {}", c as char))), |
| } |
| } |
|
|
| |
| #[pymodule] |
| fn arms_hat(m: &Bound<'_, PyModule>) -> PyResult<()> { |
| m.add_class::<PyHatIndex>()?; |
| m.add_class::<PyHatConfig>()?; |
| m.add_class::<PySearchResult>()?; |
| m.add_class::<PySessionSummary>()?; |
| m.add_class::<PyDocumentSummary>()?; |
| m.add_class::<PyHatStats>()?; |
|
|
| |
| m.add("__doc__", "ARMS-HAT: Hierarchical Attention Tree for AI memory retrieval")?; |
| m.add("__version__", env!("CARGO_PKG_VERSION"))?; |
|
|
| Ok(()) |
| } |
|
|