python
5 hours, 19 minutes ago
import unittest
from unittest.mock import MagicMock, patch
from clustering_service import ClusteringService # Replace with the correct module name
import numpy as np
class TestClusteringService(unittest.TestCase):
def setUp(self):
self.mock_gcp_ai_service = MagicMock()
self.service = ClusteringService(self.mock_gcp_ai_service)
def test_get_spectral_clusters(self):
embeddings = np.random.rand(10, 5) # 10 samples with 5 dimensions
num_clusters = 3
with patch('clustering_service.SpectralClustering') as MockSpectralClustering:
mock_clustering = MagicMock()
mock_clustering.labels_ = [0, 1, 2, 0, 1, 2, 0, 1, 2, 0]
MockSpectralClustering.return_value = mock_clustering
labels = self.service.get_spectral_clusters(embeddings, num_clusters)
MockSpectralClustering.assert_called_once_with(n_clusters=num_clusters, assign_labels='discretize', random_state=0)
self.assertEqual(labels, mock_clustering.labels_)
def test_get_dbscan_clusters(self):
embeddings = np.random.rand(10, 5)
eps = 0.5
with patch('clustering_service.DBSCAN') as MockDBSCAN:
mock_clustering = MagicMock()
mock_clustering.labels_ = [0, 0, -1, 1, 1, -1, 2, 2, -1, 2]
MockDBSCAN.return_value = mock_clustering
labels = self.service.get_dbscan_clusters(embeddings, eps)
MockDBSCAN.assert_called_once_with(eps=eps, min_samples=2, metric='cosine')
self.assertEqual(labels, mock_clustering.labels_)
def test_get_hierarchy_clusters(self):
embeddings = np.random.rand(10, 5)
num_clusters = 4
with patch('clustering_service.AgglomerativeClustering') as MockAgglomerativeClustering:
mock_clustering = MagicMock()
mock_clustering.labels_ = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1]
MockAgglomerativeClustering.return_value = mock_clustering
labels = self.service.get_hierarchy_clusters(embeddings, num_clusters)
MockAgglomerativeClustering.assert_called_once_with(compute_full_tree=True, n_clusters=num_clusters)
self.assertEqual(labels, mock_clustering.labels_)
def test_get_embeddings_with_spectral_clustering(self):
embeddings = np.random.rand(6, 5)
self.mock_gcp_ai_service.get_embeddings_single.return_value = embeddings
original_strings = ["text1", "text2", "text3", "text4", "text5", "text6"]
cluster_method = 'spectral'
num_clusters = 2
with patch.object(self.service, 'get_spectral_clusters', return_value=[0, 1, 0, 1, 0, 1]) as mock_method:
result = self.service.get_embeddings(original_strings, num_clusters, cluster_method)
mock_method.assert_called_once_with(embeddings, num_clusters)
self.assertEqual(result[0], ["text1", "text3", "text5"])
self.assertEqual(result[1], ["text2", "text4", "text6"])
def test_get_embeddings_with_dbscan_clustering(self):
embeddings = np.random.rand(6, 5)
self.mock_gcp_ai_service.get_embeddings_single.return_value = embeddings
original_strings = ["text1", "text2", "text3", "text4", "text5", "text6"]
cluster_method = 'db_scan'
eps = 0.5
with patch.object(self.service, 'get_dbscan_clusters', return_value=[0, 0, -1, 1, 1, -1]) as mock_method:
result = self.service.get_embeddings(original_strings, None, cluster_method, eps)
mock_method.assert_called_once_with(embeddings, eps)
self.assertEqual(result[0], ["text1", "text2"])
self.assertEqual(result[1], ["text4", "text5"])
self.assertEqual(result[-1], ["text3", "text6"])
def test_get_embeddings_with_hierarchy_clustering(self):
embeddings = np.random.rand(6, 5)
self.mock_gcp_ai_service.get_embeddings_single.return_value = embeddings
original_strings = ["text1", "text2", "text3", "text4", "text5", "text6"]
cluster_method = 'hierarchy'
num_clusters = 3
with patch.object(self.service, 'get_hierarchy_clusters', return_value=[0, 1, 2, 0, 1, 2]) as mock_method:
result = self.service.get_embeddings(original_strings, num_clusters, cluster_method)
mock_method.assert_called_once_with(embeddings, num_clusters)
self.assertEqual(result[0], ["text1", "text4"])
self.assertEqual(result[1], ["text2", "text5"])
self.assertEqual(result[2], ["text3", "text6"])
def test_get_embeddings_with_invalid_cluster_method(self):
embeddings = np.random.rand(6, 5)
self.mock_gcp_ai_service.get_embeddings_single.return_value = embeddings
original_strings = ["text1", "text2", "text3", "text4", "text5", "text6"]
cluster_method = 'invalid_method'
with self.assertRaises(ValueError):
self.service.get_embeddings(original_strings, 2, cluster_method)
if __name__ == '__main__':
unittest.main()
0 Comments
Please Login to Comment Here