From: Mark Beierl Date: Thu, 23 Mar 2023 19:19:08 +0000 (+0000) Subject: Changing singleton usage X-Git-Url: https://osm.etsi.org/gitweb/?a=commitdiff_plain;h=6f14650730751aaaf2a0b28768057e55ff0bb9f3;p=osm%2Fcommon.git Changing singleton usage We won't be talking to any more than one temporal cluster at a time, so it does not make sense to have clients cached by API endpoint. Instead the main() of any program wanting to use temporal can just set the class level variable and then simply instantiate the class anywhere it is needed and it will manage the cached client without needing the temporal API URL Change-Id: Ia22635dc454e8df14ca22bc1e095f625d7e7337b Signed-off-by: Mark Beierl --- diff --git a/osm_common/tests/test_wftemporal.py b/osm_common/tests/test_wftemporal.py index b8ea257..1b104a3 100644 --- a/osm_common/tests/test_wftemporal.py +++ b/osm_common/tests/test_wftemporal.py @@ -36,43 +36,16 @@ def generate_mock_client() -> Generator[AsyncMock, None, None]: yield mock_instance -@pytest.mark.asyncio -async def test_get_client(): - WFTemporal.clients = {} - with generate_mock_client() as client: - temporal = WFTemporal(temporal_api="localhost") - result = await temporal.get_client() - assert result == client - - -@pytest.mark.asyncio -async def test_get_cached_client(): - WFTemporal.clients = {} - client_1 = None - client_2 = None - with generate_mock_client() as client: - temporal = WFTemporal(temporal_api="localhost") - client_1 = await temporal.get_client() - assert client_1 == client - - with generate_mock_client() as client: - temporal = WFTemporal(temporal_api="localhost") - client_2 = await temporal.get_client() - # We should get the same client as before, not a new one - assert client_2 == client_1 - assert client_2 != client - - @pytest.mark.asyncio @mock.patch("osm_common.wftemporal.uuid") async def test_start_workflow_no_id(mock_uuid): - WFTemporal.clients = {} + WFTemporal._client = None mock_uuid.uuid4.return_value = "01234567-89abc-def-0123-456789abcdef" with generate_mock_client() as client: client.start_workflow.return_value = "handle" - temporal = WFTemporal(temporal_api="localhost") + temporal = WFTemporal() result = await temporal.start_workflow( task_queue="q", workflow_name="workflow", workflow_data="data" ) @@ -87,11 +60,11 @@ async def test_start_workflow_no_id(mock_uuid): @pytest.mark.asyncio async def test_start_workflow_with_id(): - WFTemporal.clients = {} + WFTemporal._client = None with generate_mock_client() as client: client.start_workflow.return_value = "handle" - temporal = WFTemporal(temporal_api="localhost") + temporal = WFTemporal() result = await temporal.start_workflow( task_queue="q", id="id", workflow_name="workflow", workflow_data="data" ) @@ -103,14 +76,14 @@ async def test_start_workflow_with_id(): @pytest.mark.asyncio async def test_execute_workflow_with_id(): - WFTemporal.clients = {} + WFTemporal._client = None with generate_mock_client() as client: handle = Mock(WorkflowHandle) handle.result = AsyncMock(return_value="success") client.start_workflow.return_value = handle - temporal = WFTemporal(temporal_api="localhost") + temporal = WFTemporal() result = await temporal.execute_workflow( task_queue="q", id="another id", @@ -121,3 +94,40 @@ async def test_execute_workflow_with_id(): client.start_workflow.assert_awaited_once_with( workflow="workflow-with-result", arg="data", id="another id", task_queue="q" ) + + +@pytest.mark.asyncio +async def test_client_cache(): + WFTemporal._client = None + + with generate_mock_client() as client: + handle = Mock(WorkflowHandle) + handle.result = AsyncMock(return_value="success") + client.start_workflow.return_value = handle + + temporal = WFTemporal() + result = await temporal.execute_workflow( + task_queue="q", + id="another id", + workflow_name="workflow-with-result", + workflow_data="data", + ) + assert result == "success" + client.start_workflow.assert_awaited_once_with( + workflow="workflow-with-result", arg="data", id="another id", task_queue="q" + ) + + temporal = WFTemporal() + result = await temporal.execute_workflow( + task_queue="q", + id="yet another id", + workflow_name="workflow-with-result", + workflow_data="more data", + ) + assert result == "success" + client.start_workflow.assert_awaited_with( + workflow="workflow-with-result", + arg="more data", + id="yet another id", + task_queue="q", + ) diff --git a/osm_common/wftemporal.py b/osm_common/wftemporal.py index 0f7d421..90b6497 100644 --- a/osm_common/wftemporal.py +++ b/osm_common/wftemporal.py @@ -22,11 +22,11 @@ from temporalio.client import Client class WFTemporal(object): - clients = {} + _client = None + temporal_api = None - def __init__(self, temporal_api=None, logger_name="temporal.client"): + def __init__(self, logger_name="temporal.client"): self.logger = logging.getLogger(logger_name) - self.temporal_api = temporal_api async def execute_workflow( self, task_queue: str, workflow_name: str, workflow_data: any, id: str = None @@ -44,24 +44,16 @@ class WFTemporal(object): async def start_workflow( self, task_queue: str, workflow_name: str, workflow_data: any, id: str = None ): - client = await self.get_client() + if WFTemporal._client is None: + self.logger.debug( + f"No cached client found, connecting to {WFTemporal.temporal_api}" + ) + WFTemporal._client = await Client.connect(WFTemporal.temporal_api) + if id is None: id = str(uuid.uuid4()) self.logger.info(f"Starting workflow {workflow_name}, id {id}") - handle = await client.start_workflow( + handle = await WFTemporal._client.start_workflow( workflow=workflow_name, arg=workflow_data, id=id, task_queue=task_queue ) return handle - - async def get_client(self): - if self.temporal_api in WFTemporal.clients: - client = WFTemporal.clients[self.temporal_api] - else: - self.logger.debug( - f"No cached client found, connecting to {self.temporal_api}" - ) - client = await Client.connect(self.temporal_api) - WFTemporal.clients[self.temporal_api] = client - - self.logger.debug(f"Using client {client} for {self.temporal_api}") - return client