python
5 hours, 27 minutes ago
import unittest
from unittest.mock import patch, MagicMock
from azure_ai_service import AzureAiService # Replace with the actual module name
class TestAzureAiService(unittest.TestCase):
def setUp(self):
# Mock configuration
self.config = {
"bnym-eliza": {
"URL": "https://mocked-eliza-url.com",
"CLIENT_ID": "mock_client_id",
"CLIENT_SECRET": "mock_client_secret"
}
}
self.oracle_service = MagicMock()
self.moog_service = MagicMock()
self.service = AzureAiService(self.config, self.oracle_service, self.moog_service)
@patch('bnym_eliza.Session.connect')
@patch('bnym_eliza.Embedding.create')
def test_get_embeddings_success(self, mock_embedding_create, mock_session_connect):
# Mock embedding response
mock_embedding_create.return_value = {
'data': [{'embedding': [0.1, 0.2, 0.3]}]
}
mock_session_connect.return_value = MagicMock()
original_strings = ["test string"]
model_name = "mock_model"
trace_id = "trace123"
result = self.service.get_embeddings(original_strings, model_name, trace_id)
# Assertions
self.assertIn("test string", result)
self.assertEqual(result["test string"], [0.1, 0.2, 0.3])
mock_embedding_create.assert_called_once_with(model=model_name, input="test string")
@patch('bnym_eliza.Embedding.create')
def test_get_embeddings_exception(self, mock_embedding_create):
# Mock an exception
mock_embedding_create.side_effect = Exception("Mocked Exception")
original_strings = ["test string"]
model_name = "mock_model"
trace_id = "trace123"
result = self.service.get_embeddings(original_strings, model_name, trace_id)
# Assertions
self.assertEqual(result, {})
mock_embedding_create.assert_called_once()
@patch('bnym_eliza.ChatCompletion.create')
def test_run_azure_llm_success(self, mock_chat_create):
# Mock chat response
mock_chat_create.return_value = {"choices": [{"message": "response"}]}
messages = [{"role": "user", "content": "Hello"}]
model_name = "mock_model"
trace_id = "trace123"
response = self.service.run_azure_llm(
messages=messages,
model_name=model_name,
temperature=0.7,
max_tokens=100,
top_p=0.9,
frequency_penalty=0.1,
presence_penalty=0.2,
stop=None,
trace_id=trace_id
)
# Assertions
self.assertEqual(response, {"choices": [{"message": "response"}]})
mock_chat_create.assert_called_once_with(
model=model_name,
messages=messages,
temperature=0.7,
max_tokens=100,
top_p=0.9,
frequency_penalty=0.1,
presence_penalty=0.2,
stop=None
)
@patch('bnym_eliza.ChatCompletion.create')
def test_run_azure_llm_exception(self, mock_chat_create):
# Mock an exception
mock_chat_create.side_effect = Exception("Mocked Exception")
messages = [{"role": "user", "content": "Hello"}]
model_name = "mock_model"
trace_id = "trace123"
response = self.service.run_azure_llm(
messages=messages,
model_name=model_name,
temperature=0.7,
max_tokens=100,
top_p=0.9,
frequency_penalty=0.1,
presence_penalty=0.2,
stop=None,
trace_id=trace_id
)
# Assertions
self.assertIsNone(response)
mock_chat_create.assert_called_once()
if __name__ == '__main__':
unittest.main()
0 Comments
Please Login to Comment Here