Update from master
[osm/common.git] / osm_common / tests / test_wftemporal.py
index b8ea257..1b104a3 100644 (file)
@@ -36,43 +36,16 @@ def generate_mock_client() -> Generator[AsyncMock, None, None]:
         yield mock_instance
 
 
-@pytest.mark.asyncio
-async def test_get_client():
-    WFTemporal.clients = {}
-    with generate_mock_client() as client:
-        temporal = WFTemporal(temporal_api="localhost")
-        result = await temporal.get_client()
-        assert result == client
-
-
-@pytest.mark.asyncio
-async def test_get_cached_client():
-    WFTemporal.clients = {}
-    client_1 = None
-    client_2 = None
-    with generate_mock_client() as client:
-        temporal = WFTemporal(temporal_api="localhost")
-        client_1 = await temporal.get_client()
-        assert client_1 == client
-
-    with generate_mock_client() as client:
-        temporal = WFTemporal(temporal_api="localhost")
-        client_2 = await temporal.get_client()
-        # We should get the same client as before, not a new one
-        assert client_2 == client_1
-        assert client_2 != client
-
-
 @pytest.mark.asyncio
 @mock.patch("osm_common.wftemporal.uuid")
 async def test_start_workflow_no_id(mock_uuid):
-    WFTemporal.clients = {}
+    WFTemporal._client = None
 
     mock_uuid.uuid4.return_value = "01234567-89abc-def-0123-456789abcdef"
 
     with generate_mock_client() as client:
         client.start_workflow.return_value = "handle"
-        temporal = WFTemporal(temporal_api="localhost")
+        temporal = WFTemporal()
         result = await temporal.start_workflow(
             task_queue="q", workflow_name="workflow", workflow_data="data"
         )
@@ -87,11 +60,11 @@ async def test_start_workflow_no_id(mock_uuid):
 
 @pytest.mark.asyncio
 async def test_start_workflow_with_id():
-    WFTemporal.clients = {}
+    WFTemporal._client = None
 
     with generate_mock_client() as client:
         client.start_workflow.return_value = "handle"
-        temporal = WFTemporal(temporal_api="localhost")
+        temporal = WFTemporal()
         result = await temporal.start_workflow(
             task_queue="q", id="id", workflow_name="workflow", workflow_data="data"
         )
@@ -103,14 +76,14 @@ async def test_start_workflow_with_id():
 
 @pytest.mark.asyncio
 async def test_execute_workflow_with_id():
-    WFTemporal.clients = {}
+    WFTemporal._client = None
 
     with generate_mock_client() as client:
         handle = Mock(WorkflowHandle)
         handle.result = AsyncMock(return_value="success")
         client.start_workflow.return_value = handle
 
-        temporal = WFTemporal(temporal_api="localhost")
+        temporal = WFTemporal()
         result = await temporal.execute_workflow(
             task_queue="q",
             id="another id",
@@ -121,3 +94,40 @@ async def test_execute_workflow_with_id():
         client.start_workflow.assert_awaited_once_with(
             workflow="workflow-with-result", arg="data", id="another id", task_queue="q"
         )
+
+
+@pytest.mark.asyncio
+async def test_client_cache():
+    WFTemporal._client = None
+
+    with generate_mock_client() as client:
+        handle = Mock(WorkflowHandle)
+        handle.result = AsyncMock(return_value="success")
+        client.start_workflow.return_value = handle
+
+        temporal = WFTemporal()
+        result = await temporal.execute_workflow(
+            task_queue="q",
+            id="another id",
+            workflow_name="workflow-with-result",
+            workflow_data="data",
+        )
+        assert result == "success"
+        client.start_workflow.assert_awaited_once_with(
+            workflow="workflow-with-result", arg="data", id="another id", task_queue="q"
+        )
+
+        temporal = WFTemporal()
+        result = await temporal.execute_workflow(
+            task_queue="q",
+            id="yet another id",
+            workflow_name="workflow-with-result",
+            workflow_data="more data",
+        )
+        assert result == "success"
+        client.start_workflow.assert_awaited_with(
+            workflow="workflow-with-result",
+            arg="more data",
+            id="yet another id",
+            task_queue="q",
+        )