python
5 days, 21 hours ago
import unittest
from unittest.mock import patch, MagicMock
from embeddings_util import EmbeddingsUtil # Replace with the correct module name
class TestEmbeddingsUtil(unittest.TestCase):
def setUp(self):
self.config = {
'cloud-ai-urls': {'EMBEDDINGS_API': 'http://example.com/embeddings'},
'gcp-config': {'PROJECT_ID': 'test_project'},
'certs': {'CERT_PATH': '/path/to/cert', 'CERT_X5T': 'test_x5t'}
}
self.util = EmbeddingsUtil(self.config)
@patch('embeddings_util.Auth')
@patch('embeddings_util.requests.session')
def test_get_embeddings_success(self, mock_session, mock_auth):
mock_auth_instance = MagicMock()
mock_auth_instance.generate_assertion_token.return_value = "fake_token"
mock_auth.return_value = mock_auth_instance
mock_session_instance = MagicMock()
mock_response = MagicMock()
mock_response.json.return_value = {
'DATA': {'string1': [0.1, 0.2, 0.3], 'string2': [0.4, 0.5, 0.6]}
}
mock_response.text = 'API response'
mock_session_instance.post.return_value = mock_response
mock_session.return_value = mock_session_instance
list_of_strings = ['string1', 'string2']
result = self.util.get_embeddings(list_of_strings)
mock_session_instance.post.assert_called_once_with(
self.config['cloud-ai-urls']['EMBEDDINGS_API'],
data='{"original_strings": ["string1", "string2"], "projectId": "test_project"}'
)
self.assertEqual(result, [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
@patch('embeddings_util.Auth')
@patch('embeddings_util.requests.session')
def test_get_embeddings_retries(self, mock_session, mock_auth):
mock_auth_instance = MagicMock()
mock_auth_instance.generate_assertion_token.return_value = "fake_token"
mock_auth.return_value = mock_auth_instance
mock_session_instance = MagicMock()
mock_response = MagicMock()
mock_response.json.side_effect = Exception("API error")
mock_session_instance.post.side_effect = Exception("Connection error")
mock_session.return_value = mock_session_instance
list_of_strings = ['string1', 'string2']
with self.assertLogs(level='INFO') as log:
result = self.util.get_embeddings(list_of_strings)
self.assertFalse(result)
self.assertIn("More than 3 retries getting out of function", log.output[-1])
def test_split_list(self):
list_of_strings = ['a', 'b', 'c', 'd', 'e']
chunk_size = 2
result = self.util.split_list(list_of_strings, chunk_size)
expected = [['a', 'b'], ['c', 'd'], ['e']]
self.assertEqual(result, expected)
@patch('embeddings_util.Auth')
@patch('embeddings_util.requests.session')
def test_get_embeddings_handles_missing_data(self, mock_session, mock_auth):
mock_auth_instance = MagicMock()
mock_auth_instance.generate_assertion_token.return_value = "fake_token"
mock_auth.return_value = mock_auth_instance
mock_session_instance = MagicMock()
mock_response = MagicMock()
mock_response.json.return_value = {'DATA': {'string1': [0.1, 0.2, 0.3]}}
mock_session_instance.post.return_value = mock_response
mock_session.return_value = mock_session_instance
list_of_strings = ['string1', 'string2']
result = self.util.get_embeddings(list_of_strings)
self.assertEqual(result, [[0.1, 0.2, 0.3], None])
@patch('embeddings_util.Auth')
@patch('embeddings_util.requests.session')
def test_get_embeddings_empty_input(self, mock_session, mock_auth):
list_of_strings = []
result = self.util.get_embeddings(list_of_strings)
self.assertEqual(result, [])
if __name__ == '__main__':
unittest.main()
0 Comments
Please Login to Comment Here