python
5 days, 22 hours ago
import unittest
from unittest.mock import patch, MagicMock, mock_open
import jwt
import time
from pathlib import Path
from azure_auth_util import (
Client,
is_azure_token_valid,
azure_auth,
client_certificate_method,
client_thumbprint_method,
_make_client_assertion,
_use_client_assertion,
fetch
)
class TestAzureAuthUtil(unittest.TestCase):
def setUp(self):
self.client = Client(
tenant_id="tenant123",
client_id="client123",
scopes=["scope1", "scope2"]
)
self.certificate_key_path = "mock_key_path.pem"
self.certificate_path = "mock_cert_path.pem"
@patch("jwt.decode")
def test_is_azure_token_valid_valid_token(self, mock_decode):
mock_decode.return_value = {"exp": time.time() + 1000}
self.assertTrue(is_azure_token_valid("mock_token"))
@patch("jwt.decode")
def test_is_azure_token_valid_expired_token(self, mock_decode):
mock_decode.return_value = {"exp": time.time() - 1000}
self.assertFalse(is_azure_token_valid("mock_token"))
@patch("jwt.decode", side_effect=Exception("Invalid token"))
def test_is_azure_token_valid_invalid_token(self, mock_decode):
self.assertFalse(is_azure_token_valid("mock_token"))
@patch("azure_auth_util.client_certificate_method")
def test_azure_auth(self, mock_client_cert_method):
mock_client_cert_method.return_value = "mock_token"
result = azure_auth(self.client, self.certificate_key_path, self.certificate_path)
self.assertEqual(result, "mock_token")
@patch("pathlib.Path.read_bytes", return_value=b"mock_cert_data")
@patch("cryptography.x509.load_pem_x509_certificate")
@patch("hashlib.sha1")
@patch("azure_auth_util.client_thumbprint_method")
def test_client_certificate_method(self, mock_thumbprint_method, mock_sha1, mock_load_cert, mock_read_bytes):
mock_sha1.return_value = "mock_sha1"
mock_thumbprint_method.return_value = "mock_token"
result = client_certificate_method(self.client, Path(self.certificate_key_path), Path(self.certificate_path))
self.assertEqual(result, "mock_token")
@patch("azure_auth_util._make_client_assertion")
@patch("azure_auth_util._use_client_assertion")
def test_client_thumbprint_method(self, mock_use_assertion, mock_make_assertion):
mock_make_assertion.return_value = "mock_assertion"
mock_use_assertion.return_value = "mock_token"
result = client_thumbprint_method(self.client, Path(self.certificate_key_path), "mock_thumbprint")
self.assertEqual(result, "mock_token")
@patch("pathlib.Path.read_bytes", return_value=b"mock_key_data")
@patch("uuid.uuid4", return_value="mock_uuid")
@patch("datetime.datetime.now")
def test_make_client_assertion(self, mock_now, mock_uuid, mock_read_bytes):
mock_now.return_value.timestamp.return_value = 1234567890
mock_thumb_print = "mock_thumb_print"
with patch("jwt.jwk.jwk_from_pem", return_value="mock_key"):
with patch("jwt.jwt.JWT.encode", return_value="mock_jwt_token"):
result = _make_client_assertion(Path(self.certificate_key_path), self.client, mock_thumb_print)
self.assertEqual(result, "mock_jwt_token")
@patch("urllib.request.urlopen")
def test_use_client_assertion(self, mock_urlopen):
mock_response = MagicMock()
mock_response.read.return_value = json.dumps({"access_token": "mock_token"}).encode("utf-8")
mock_urlopen.return_value = mock_response
result = _use_client_assertion(self.client, "mock_assertion")
self.assertEqual(result, "mock_token")
@patch("urllib.request.urlopen")
def test_fetch_success(self, mock_urlopen):
mock_response = MagicMock()
mock_response.read.return_value = "mock_response".encode("utf-8")
mock_urlopen.return_value = mock_response
req = MagicMock()
result = fetch(req)
self.assertEqual(result, "mock_response")
@patch("urllib.request.urlopen", side_effect=urllib.error.HTTPError(
url=None, code=400, msg="Bad Request", hdrs=None, fp=None
))
def test_fetch_http_error(self, mock_urlopen):
req = MagicMock()
with self.assertRaises(Exception) as context:
fetch(req)
self.assertIn("http error: 400", str(context.exception))
if __name__ == "__main__":
unittest.main()
0 Comments
Please Login to Comment Here