python
14 hours, 17 minutes ago
# test_clustering_service.py
import unittest
from unittest.mock import Mock, patch
from clustering_service import ClusteringService
import pandas as pd
import numpy as np
class TestClusteringService(unittest.TestCase):
def setUp(self):
# Mock embeddings utility to simulate external service
self.embeddings_util = Mock()
self.clustering_service = ClusteringService(self.embeddings_util)
def test_get_embeddings_for_strings_success(self):
# Test successful clustering using DBSCAN method
original_strings = ["Test1", "Test2", "Test3", "Test4"]
number_clusters = 2
# Mocking the embeddings
embeddings_df = pd.DataFrame({
'Data': original_strings,
'Embedding': [[0.1, 0.2], [0.1, 0.3], [0.4, 0.5], [0.5, 0.6]]
})
self.embeddings_util.get_embeddings.return_value = embeddings_df
# Test the DBSCAN clustering method
with patch.object(self.clustering_service, 'get_dbscan_clusters', return_value=[0, 0, 1, 1]):
result = self.clustering_service.get_embeddings_for_strings(original_strings, number_clusters, cluster_method='db_scan', eps=0.5)
# Expected result dictionary
expected_result = {
0: ["Test1", "Test2"],
1: ["Test3", "Test4"]
}
self.assertEqual(result, expected_result)
def test_get_embeddings_for_strings_no_embeddings(self):
# Test when no embeddings are retrieved (should return None)
original_strings = ["Test1", "Test2"]
number_clusters = 2
self.embeddings_util.get_embeddings.return_value = None
result = self.clustering_service.get_embeddings_for_strings(original_strings, number_clusters)
self.assertIsNone(result)
def test_get_spectral_clusters(self):
# Test spectral clustering with mock embeddings
embeddings = [[0.1, 0.2], [0.2, 0.3], [0.8, 0.9], [0.9, 1.0]]
num_clusters = 2
labels = self.clustering_service.get_spectral_clusters(embeddings, num_clusters)
# Validate the number of unique labels should be equal to num_clusters
self.assertEqual(len(set(labels)), num_clusters)
def test_get_dbscan_clusters(self):
# Test DBSCAN clustering with mock embeddings
embeddings = [[0.1, 0.2], [0.2, 0.3], [0.8, 0.9], [0.9, 1.0]]
eps = 0.5
labels = self.clustering_service.get_dbscan_clusters(embeddings, eps)
# Check that all embeddings have been assigned a label (-1 indicates noise)
self.assertTrue(all(isinstance(label, int) for label in labels))
def test_get_hierarchy_clusters(self):
# Test Agglomerative Clustering with mock embeddings
embeddings = [[0.1, 0.2], [0.2, 0.3], [0.8, 0.9], [0.9, 1.0]]
num_clusters = 2
labels = self.clustering_service.get_hierarchy_clusters(embeddings, num_clusters)
# Validate the number of unique labels should be equal to num_clusters
self.assertEqual(len(set(labels)), num_clusters)
def test_ensure_data_type_conversion(self):
# Test ensure_data_type for proper data type conversion
embedding_list = ["[0.1, 0.2]", "[0.3, 0.4]"]
expected_result = [[0.1, 0.2], [0.3, 0.4]]
result = self.clustering_service.ensure_data_type(embedding_list)
# Check if result matches the expected conversion
self.assertEqual(result, expected_result)
def test_ensure_data_type_no_conversion_needed(self):
# Test ensure_data_type when no conversion is needed
embedding_list = [[0.1, 0.2], [0.3, 0.4]]
result = self.clustering_service.ensure_data_type(embedding_list)
# Should return the list as-is
self.assertEqual(result, embedding_list)
def test_get_embeddings_for_strings_with_failed_values(self):
# Test clustering with embeddings that contain failed (-1) values
original_strings = ["Test1", "Test2", "Test3"]
number_clusters = 2
# Embedding DataFrame with -1 as failure
embeddings_df = pd.DataFrame({
'Data': ["Test1", "Test2", "Test3"],
'Embedding': [[0.1, 0.2], [0.3, 0.4], -1]
})
self.embeddings_util.get_embeddings.return_value = embeddings_df
# Expected clustering result
with patch.object(self.clustering_service, 'get_dbscan_clusters', return_value=[0, 1]):
result = self.clustering_service.get_embeddings_for_strings(original_strings, number_clusters, cluster_method='db_scan', eps=0.5)
# Expected result includes PROCESS_AI_FAILED key for failed values
expected_result = {
0: ["Test1"],
1: ["Test2"],
"PROCESS_AI_FAILED": ["Test3"]
}
self.assertEqual(result, expected_result)
0 Comments
Please Login to Comment Here