python
5 hours, 5 minutes ago
import unittest
from unittest.mock import MagicMock, patch
from base_intent_extraction_service import BaseIntentExtractionService # Replace with the correct module name
class TestBaseIntentExtractionService(unittest.TestCase):
def setUp(self):
# Mock services
self.mongo_service = MagicMock()
self.oracle_service = MagicMock()
self.service = BaseIntentExtractionService(self.mongo_service, self.oracle_service)
@patch('emona.util.common_util.chunker')
def test_write_data_to_mongo_in_chunks_success(self, mock_chunker):
mock_chunker.return_value = [[{"record": 1}], [{"record": 2}]]
self.mongo_service.bulk_write.return_value = None
mongo_bulk_write_list = [{"record": 1}, {"record": 2}]
mongo_collection_name = "test_collection"
chunk_size = 1
self.service.write_data_to_mongo_in_chunks(mongo_bulk_write_list, mongo_collection_name, chunk_size)
# Assertions
self.mongo_service.bulk_write.assert_any_call(mongo_collection_name, [{"record": 1}])
self.mongo_service.bulk_write.assert_any_call(mongo_collection_name, [{"record": 2}])
self.assertEqual(self.mongo_service.bulk_write.call_count, 2)
@patch('emona.util.common_util.chunker')
def test_write_data_to_mongo_in_chunks_exception(self, mock_chunker):
mock_chunker.return_value = [[{"record": 1}], [{"record": 2}]]
self.mongo_service.bulk_write.side_effect = Exception("Mocked Exception")
mongo_bulk_write_list = [{"record": 1}, {"record": 2}]
mongo_collection_name = "test_collection"
chunk_size = 1
self.service.write_data_to_mongo_in_chunks(mongo_bulk_write_list, mongo_collection_name, chunk_size)
# Assertions
self.assertEqual(self.mongo_service.bulk_write.call_count, 2)
def test_get_prompt_info_success(self):
mock_response = '[["prompt1", "This is a prompt", "Example1, Example2"]]'
self.oracle_service.execute_select_query.return_value = mock_response
result = self.service.get_prompt_info(1, "Azure")
# Assertions
self.assertEqual(result, {
"prompt_name": "prompt1",
"prompt": "This is a prompt",
"examples": "Example1, Example2"
})
self.oracle_service.execute_select_query.assert_called_once_with(
"SELECT PROMPT_NAME, PROMPT, EXAMPLES FROM AI_PROMPT_CONFIG WHERE USE_CASE_ID = 1 AND IS_ACTIVE = 'Y' AND PROVIDER = 'Azure'"
)
def test_get_prompt_info_exception(self):
self.oracle_service.execute_select_query.side_effect = Exception("Mocked Exception")
result = self.service.get_prompt_info(1, "Azure")
# Assertions
self.assertIsNone(result)
def test_get_model_name_google_text(self):
model_name = self.service.get_model_name("Google", "text", "text_model", "chat_model", "azure_model")
self.assertEqual(model_name, "text_model")
def test_get_model_name_google_chat(self):
model_name = self.service.get_model_name("Google", "chat", "text_model", "chat_model", "azure_model")
self.assertEqual(model_name, "chat_model")
def test_get_model_name_azure(self):
model_name = self.service.get_model_name("Azure", "text", "text_model", "chat_model", "azure_model")
self.assertEqual(model_name, "azure_model")
def test_get_model_name_invalid_provider(self):
model_name = self.service.get_model_name("Invalid", "text", "text_model", "chat_model", "azure_model")
self.assertIsNone(model_name)
def test_get_sample_input_output_list(self):
prompt_examples = """
input: Input1
output: Output1
input: Input2
output: Output2
"""
inputs, outputs = self.service.get_sample_input_output_list(prompt_examples)
# Assertions
self.assertEqual(inputs, ["Input1", "Input2"])
self.assertEqual(outputs, ["Output1", "Output2"])
def test_get_sample_input_output_list_empty(self):
inputs, outputs = self.service.get_sample_input_output_list(None)
# Assertions
self.assertEqual(inputs, [])
self.assertEqual(outputs, [])
if __name__ == '__main__':
unittest.main()
0 Comments
Please Login to Comment Here