| """Tests for ARMS-HAT Python bindings.""" |
|
|
| import pytest |
| import tempfile |
| import os |
|
|
|
|
| def test_import(): |
| """Test that the module can be imported.""" |
| from arms_hat import HatIndex, HatConfig, SearchResult |
|
|
|
|
| def test_create_index(): |
| """Test index creation.""" |
| from arms_hat import HatIndex |
|
|
| index = HatIndex.cosine(128) |
| assert len(index) == 0 |
| assert index.is_empty() |
|
|
|
|
| def test_add_and_query(): |
| """Test adding points and querying.""" |
| from arms_hat import HatIndex |
|
|
| dims = 64 |
| index = HatIndex.cosine(dims) |
|
|
| |
| ids = [] |
| for i in range(10): |
| embedding = [0.0] * dims |
| embedding[i % dims] = 1.0 |
| embedding[(i + 1) % dims] = 0.5 |
| id_ = index.add(embedding) |
| ids.append(id_) |
| assert len(id_) == 32 |
|
|
| assert len(index) == 10 |
| assert not index.is_empty() |
|
|
| |
| query = [0.0] * dims |
| query[0] = 1.0 |
| query[1] = 0.5 |
|
|
| results = index.near(query, k=5) |
| assert len(results) == 5 |
|
|
| |
| assert results[0].id == ids[0] |
| assert results[0].score > 0.9 |
|
|
|
|
| def test_sessions(): |
| """Test session management.""" |
| from arms_hat import HatIndex |
|
|
| index = HatIndex.cosine(32) |
|
|
| |
| for i in range(5): |
| index.add([float(i % 32 == j) for j in range(32)]) |
|
|
| |
| index.new_session() |
|
|
| |
| for i in range(5): |
| index.add([float((i + 10) % 32 == j) for j in range(32)]) |
|
|
| stats = index.stats() |
| assert stats.session_count >= 1 |
| assert stats.chunk_count == 10 |
|
|
|
|
| def test_documents(): |
| """Test document management within sessions.""" |
| from arms_hat import HatIndex |
|
|
| index = HatIndex.cosine(32) |
|
|
| |
| for i in range(3): |
| index.add([1.0 if j == i else 0.0 for j in range(32)]) |
|
|
| |
| index.new_document() |
|
|
| |
| for i in range(3): |
| index.add([1.0 if j == i + 10 else 0.0 for j in range(32)]) |
|
|
| stats = index.stats() |
| assert stats.document_count >= 1 |
| assert stats.chunk_count == 6 |
|
|
|
|
| def test_persistence_bytes(): |
| """Test serialization to/from bytes.""" |
| from arms_hat import HatIndex |
|
|
| dims = 64 |
| index = HatIndex.cosine(dims) |
|
|
| |
| ids = [] |
| for i in range(20): |
| embedding = [0.1] * dims |
| embedding[i % dims] = 1.0 |
| ids.append(index.add(embedding)) |
|
|
| |
| data = index.to_bytes() |
| assert len(data) > 0 |
|
|
| |
| loaded = HatIndex.from_bytes(data) |
| assert len(loaded) == len(index) |
|
|
| |
| query = [0.1] * dims |
| query[0] = 1.0 |
|
|
| original_results = index.near(query, k=5) |
| loaded_results = loaded.near(query, k=5) |
|
|
| assert len(original_results) == len(loaded_results) |
| assert original_results[0].id == loaded_results[0].id |
|
|
|
|
| def test_persistence_file(): |
| """Test save/load to file.""" |
| from arms_hat import HatIndex |
|
|
| dims = 64 |
| index = HatIndex.cosine(dims) |
|
|
| |
| for i in range(10): |
| embedding = [0.1] * dims |
| embedding[i % dims] = 1.0 |
| index.add(embedding) |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix=".hat", delete=False) as f: |
| path = f.name |
|
|
| try: |
| index.save(path) |
| assert os.path.exists(path) |
| assert os.path.getsize(path) > 0 |
|
|
| |
| loaded = HatIndex.load(path) |
| assert len(loaded) == len(index) |
|
|
| finally: |
| os.unlink(path) |
|
|
|
|
| def test_config(): |
| """Test custom configuration.""" |
| from arms_hat import HatIndex, HatConfig |
|
|
| config = HatConfig() |
| |
| config = config.with_beam_width(5) |
| config = config.with_temporal_weight(0.1) |
|
|
| index = HatIndex.with_config(128, config) |
| assert len(index) == 0 |
|
|
|
|
| def test_remove(): |
| """Test point removal.""" |
| from arms_hat import HatIndex |
|
|
| index = HatIndex.cosine(32) |
|
|
| id1 = index.add([1.0] + [0.0] * 31) |
| id2 = index.add([0.0, 1.0] + [0.0] * 30) |
|
|
| assert len(index) == 2 |
|
|
| index.remove(id1) |
| assert len(index) == 1 |
|
|
| |
| results = index.near([0.0, 1.0] + [0.0] * 30, k=5) |
| assert len(results) == 1 |
| assert results[0].id == id2 |
|
|
|
|
| def test_consolidate(): |
| """Test consolidation.""" |
| from arms_hat import HatIndex |
|
|
| index = HatIndex.cosine(32) |
|
|
| |
| for i in range(100): |
| embedding = [0.0] * 32 |
| embedding[i % 32] = 1.0 |
| index.add(embedding) |
|
|
| |
| index.consolidate() |
| index.consolidate_full() |
|
|
| assert len(index) == 100 |
|
|
|
|
| def test_stats(): |
| """Test stats retrieval.""" |
| from arms_hat import HatIndex |
|
|
| index = HatIndex.cosine(64) |
|
|
| for i in range(10): |
| index.add([float(i % 64 == j) for j in range(64)]) |
|
|
| stats = index.stats() |
| assert stats.chunk_count == 10 |
| assert stats.total_points == 10 |
|
|
|
|
| def test_repr(): |
| """Test string representations.""" |
| from arms_hat import HatIndex, HatConfig, SearchResult |
|
|
| index = HatIndex.cosine(64) |
| repr_str = repr(index) |
| assert "HatIndex" in repr_str |
|
|
| config = HatConfig() |
| repr_str = repr(config) |
| assert "HatConfig" in repr_str |
|
|
|
|
| def test_near_sessions(): |
| """Test coarse-grained session search.""" |
| from arms_hat import HatIndex |
|
|
| index = HatIndex.cosine(32) |
|
|
| |
| for i in range(5): |
| embedding = [0.0] * 32 |
| embedding[0] = 1.0 |
| embedding[i + 1] = 0.3 |
| index.add(embedding) |
|
|
| index.new_session() |
|
|
| |
| for i in range(5): |
| embedding = [0.0] * 32 |
| embedding[10] = 1.0 |
| embedding[i + 11] = 0.3 |
| index.add(embedding) |
|
|
| |
| query = [0.0] * 32 |
| query[0] = 1.0 |
|
|
| sessions = index.near_sessions(query, k=2) |
| assert len(sessions) >= 1 |
|
|
| |
| if len(sessions) > 1: |
| assert sessions[0].score >= sessions[1].score |
|
|
|
|
| def test_high_dimensions(): |
| """Test with OpenAI embedding dimensions.""" |
| from arms_hat import HatIndex |
|
|
| dims = 1536 |
| index = HatIndex.cosine(dims) |
|
|
| |
| for i in range(10): |
| embedding = [(j * i * 0.01) % 1.0 for j in range(dims)] |
| index.add(embedding) |
|
|
| assert len(index) == 10 |
|
|
| |
| query = [0.5] * dims |
| results = index.near(query, k=5) |
| assert len(results) == 5 |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__, "-v"]) |
|
|