Changing singleton usage 91/13091/1
authorMark Beierl <mark.beierl@canonical.com>
Thu, 23 Mar 2023 19:19:08 +0000 (19:19 +0000)
committerMark Beierl <mark.beierl@canonical.com>
Thu, 23 Mar 2023 19:19:08 +0000 (19:19 +0000)
We won't be talking to any more than one temporal cluster at a
time, so it does not make sense to have clients cached by API
endpoint.  Instead the main() of any program wanting to use
temporal can just set the class level variable and then simply
instantiate the class anywhere it is needed and it will
manage the cached client without needing the temporal API URL

Change-Id: Ia22635dc454e8df14ca22bc1e095f625d7e7337b
Signed-off-by: Mark Beierl <mark.beierl@canonical.com>
osm_common/tests/test_wftemporal.py
osm_common/wftemporal.py

index b8ea257..1b104a3 100644 (file)
@@ -36,43 +36,16 @@ def generate_mock_client() -> Generator[AsyncMock, None, None]:
         yield mock_instance
 
 
-@pytest.mark.asyncio
-async def test_get_client():
-    WFTemporal.clients = {}
-    with generate_mock_client() as client:
-        temporal = WFTemporal(temporal_api="localhost")
-        result = await temporal.get_client()
-        assert result == client
-
-
-@pytest.mark.asyncio
-async def test_get_cached_client():
-    WFTemporal.clients = {}
-    client_1 = None
-    client_2 = None
-    with generate_mock_client() as client:
-        temporal = WFTemporal(temporal_api="localhost")
-        client_1 = await temporal.get_client()
-        assert client_1 == client
-
-    with generate_mock_client() as client:
-        temporal = WFTemporal(temporal_api="localhost")
-        client_2 = await temporal.get_client()
-        # We should get the same client as before, not a new one
-        assert client_2 == client_1
-        assert client_2 != client
-
-
 @pytest.mark.asyncio
 @mock.patch("osm_common.wftemporal.uuid")
 async def test_start_workflow_no_id(mock_uuid):
-    WFTemporal.clients = {}
+    WFTemporal._client = None
 
     mock_uuid.uuid4.return_value = "01234567-89abc-def-0123-456789abcdef"
 
     with generate_mock_client() as client:
         client.start_workflow.return_value = "handle"
-        temporal = WFTemporal(temporal_api="localhost")
+        temporal = WFTemporal()
         result = await temporal.start_workflow(
             task_queue="q", workflow_name="workflow", workflow_data="data"
         )
@@ -87,11 +60,11 @@ async def test_start_workflow_no_id(mock_uuid):
 
 @pytest.mark.asyncio
 async def test_start_workflow_with_id():
-    WFTemporal.clients = {}
+    WFTemporal._client = None
 
     with generate_mock_client() as client:
         client.start_workflow.return_value = "handle"
-        temporal = WFTemporal(temporal_api="localhost")
+        temporal = WFTemporal()
         result = await temporal.start_workflow(
             task_queue="q", id="id", workflow_name="workflow", workflow_data="data"
         )
@@ -103,14 +76,14 @@ async def test_start_workflow_with_id():
 
 @pytest.mark.asyncio
 async def test_execute_workflow_with_id():
-    WFTemporal.clients = {}
+    WFTemporal._client = None
 
     with generate_mock_client() as client:
         handle = Mock(WorkflowHandle)
         handle.result = AsyncMock(return_value="success")
         client.start_workflow.return_value = handle
 
-        temporal = WFTemporal(temporal_api="localhost")
+        temporal = WFTemporal()
         result = await temporal.execute_workflow(
             task_queue="q",
             id="another id",
@@ -121,3 +94,40 @@ async def test_execute_workflow_with_id():
         client.start_workflow.assert_awaited_once_with(
             workflow="workflow-with-result", arg="data", id="another id", task_queue="q"
         )
+
+
+@pytest.mark.asyncio
+async def test_client_cache():
+    WFTemporal._client = None
+
+    with generate_mock_client() as client:
+        handle = Mock(WorkflowHandle)
+        handle.result = AsyncMock(return_value="success")
+        client.start_workflow.return_value = handle
+
+        temporal = WFTemporal()
+        result = await temporal.execute_workflow(
+            task_queue="q",
+            id="another id",
+            workflow_name="workflow-with-result",
+            workflow_data="data",
+        )
+        assert result == "success"
+        client.start_workflow.assert_awaited_once_with(
+            workflow="workflow-with-result", arg="data", id="another id", task_queue="q"
+        )
+
+        temporal = WFTemporal()
+        result = await temporal.execute_workflow(
+            task_queue="q",
+            id="yet another id",
+            workflow_name="workflow-with-result",
+            workflow_data="more data",
+        )
+        assert result == "success"
+        client.start_workflow.assert_awaited_with(
+            workflow="workflow-with-result",
+            arg="more data",
+            id="yet another id",
+            task_queue="q",
+        )
index 0f7d421..90b6497 100644 (file)
@@ -22,11 +22,11 @@ from temporalio.client import Client
 
 
 class WFTemporal(object):
-    clients = {}
+    _client = None
+    temporal_api = None
 
-    def __init__(self, temporal_api=None, logger_name="temporal.client"):
+    def __init__(self, logger_name="temporal.client"):
         self.logger = logging.getLogger(logger_name)
-        self.temporal_api = temporal_api
 
     async def execute_workflow(
         self, task_queue: str, workflow_name: str, workflow_data: any, id: str = None
@@ -44,24 +44,16 @@ class WFTemporal(object):
     async def start_workflow(
         self, task_queue: str, workflow_name: str, workflow_data: any, id: str = None
     ):
-        client = await self.get_client()
+        if WFTemporal._client is None:
+            self.logger.debug(
+                f"No cached client found, connecting to {WFTemporal.temporal_api}"
+            )
+            WFTemporal._client = await Client.connect(WFTemporal.temporal_api)
+
         if id is None:
             id = str(uuid.uuid4())
         self.logger.info(f"Starting workflow {workflow_name}, id {id}")
-        handle = await client.start_workflow(
+        handle = await WFTemporal._client.start_workflow(
             workflow=workflow_name, arg=workflow_data, id=id, task_queue=task_queue
         )
         return handle
-
-    async def get_client(self):
-        if self.temporal_api in WFTemporal.clients:
-            client = WFTemporal.clients[self.temporal_api]
-        else:
-            self.logger.debug(
-                f"No cached client found, connecting to {self.temporal_api}"
-            )
-            client = await Client.connect(self.temporal_api)
-            WFTemporal.clients[self.temporal_api] = client
-
-        self.logger.debug(f"Using client {client} for {self.temporal_api}")
-        return client