python
4 days ago
import unittest
from unittest.mock import MagicMock, patch
from fastapi.responses import JSONResponse
from llm_model_response_service import LlmModelResponseService
class TestLlmModelResponseService(unittest.TestCase):
def setUp(self):
# Mock dependencies
self.mock_config = MagicMock()
self.mock_oracle_service = MagicMock()
self.mock_azure_service = MagicMock()
self.mock_gcp_service = MagicMock()
# Initialize the service
self.service = LlmModelResponseService(
self.mock_config,
self.mock_oracle_service,
self.mock_azure_service,
self.mock_gcp_service
)
@patch('llm_model_response_service.JSONResponse')
def test_map_phrase_to_category_success(self, mock_json_response):
# Mock database query and responses
self.mock_oracle_service.execute_select_query.return_value = "[['TestPrompt', 'Classify the following: {}']]"
mock_json_response.return_value = {"Status": "SUCCESS", "Data": ["Test Result"]}
# Mock GCP service call
self.mock_gcp_service.call_google_text_model.return_value = ["Test Result"]
# Call method
result = self.service.map_phrase_to_category(
"test phrases",
"test categories",
"text",
"model-name",
"application/json",
"trace-id"
)
# Assertions
self.assertIsInstance(result, JSONResponse)
self.assertIn("SUCCESS", result.body.decode())
@patch('llm_model_response_service.JSONResponse')
def test_map_phrase_to_category_exception(self, mock_json_response):
# Mock to raise an exception
self.mock_oracle_service.execute_select_query.side_effect = Exception("Database error")
mock_json_response.return_value = {"Status": "FAILED"}
# Call method
result = self.service.map_phrase_to_category(
"test phrases",
"test categories",
"text",
"model-name",
"application/json",
"trace-id"
)
# Assertions
self.assertIsInstance(result, JSONResponse)
self.assertIn("FAILED", result.body.decode())
def test_format_and_call_gemini_multimodal_success(self):
# Mock GCP service response
self.mock_gcp_service.call_gemini_multimodal.return_value = {"Status": "SUCCESS"}
# Call method
result = self.service.format_and_call_gemini_multimodal(
content="[{'text': 'test text'}]",
file_list=[],
model_name="gemini-model",
max_output_tokens=100,
temperature=0.5,
top_p=0.8,
response_mime_type="application/json",
trace_id="trace-id"
)
# Assertions
self.assertEqual(result["Status"], "SUCCESS")
def test_call_google_text_llm_success(self):
# Mock GCP service response
self.mock_gcp_service.call_google_text_model.return_value = "Test Output"
# Call method
result = self.service.call_google_text_llm(
model_name="text-model",
content="test content",
temperature=0.7,
max_output_tokens=150,
top_p=0.9,
top_k=5,
tuned_model_name="",
response_mime_type="application/json",
trace_id="trace-id"
)
# Assertions
self.assertEqual(result, "Test Output")
def test_get_label_google_provider_success(self):
# Mock database response and GCP call
self.mock_oracle_service.execute_select_query.return_value = "[['TestPrompt', 'Label this: {}']]"
self.mock_gcp_service.call_google_text_model.return_value = ["Labelled Output"]
# Call method
result = self.service.get_label(
phrases="test phrases",
provider="google",
model_name="test-model",
prompt="",
use_case_id=7,
temperature=0.7,
max_tokens=100,
top_p=0.8,
top_k=5,
frequency_penalty=0.1,
presence_penalty=0.1,
stop=None,
response_mime_type="application/json",
trace_id="trace-id"
)
# Assertions
self.assertEqual(result.strip(), "Labelled Output")
def test_get_label_invalid_provider(self):
# Call method with unsupported provider
result = self.service.get_label(
phrases="test phrases",
provider="unsupported",
model_name="test-model",
prompt="",
use_case_id=7,
temperature=0.7,
max_tokens=100,
top_p=0.8,
top_k=5,
frequency_penalty=0.1,
presence_penalty=0.1,
stop=None,
response_mime_type="application/json",
trace_id="trace-id"
)
# Assertions
self.assertEqual(result, "The provider specified: unsupported is not supported.")
def test_get_multiple_labels_success(self):
# Mock database and Azure response
self.mock_oracle_service.execute_select_query.return_value = "[['Prompt', 'Cluster the following: {}']]"
self.mock_azure_service.run_azure_llm.return_value = {
"choices": [{"message": {"content": '{"Cluster1": ["Phrase1", "Phrase2"]}'}}]
}
# Call method
result = self.service.get_multiple_labels(
phrases="test phrases",
azure_model_name="azure-model",
temperature=0.7,
max_tokens=100,
top_p=0.8,
frequency_penalty=0.1,
presence_penalty=0.1,
stop=None,
trace_id="trace-id"
)
# Assertions
self.assertEqual(result, {"Cluster1": ["Phrase1", "Phrase2"]})
def test_format_and_call_gemini_multimodal_base64_failure(self):
# Simulate exception during processing
self.mock_gcp_service.call_gemini_multimodal.side_effect = Exception("Processing error")
# Call method
result = self.service.format_and_call_gemini_multimodal_base64(
content="[{'text': 'test text'}]",
file_list=[],
model_name="gemini-model",
max_output_tokens=100,
temperature=0.5,
top_p=0.8,
response_mime_type="application/json",
trace_id="trace-id"
)
# Assertions
self.assertEqual(result["Status"], "FAILED")
self.assertIn("Processing error", result["Reason"])
if __name__ == "__main__":
unittest.main()
0 Comments
Please Login to Comment Here