yield mock_instance
-@pytest.mark.asyncio
-async def test_get_client():
- WFTemporal.clients = {}
- with generate_mock_client() as client:
- temporal = WFTemporal(temporal_api="localhost")
- result = await temporal.get_client()
- assert result == client
-
-
-@pytest.mark.asyncio
-async def test_get_cached_client():
- WFTemporal.clients = {}
- client_1 = None
- client_2 = None
- with generate_mock_client() as client:
- temporal = WFTemporal(temporal_api="localhost")
- client_1 = await temporal.get_client()
- assert client_1 == client
-
- with generate_mock_client() as client:
- temporal = WFTemporal(temporal_api="localhost")
- client_2 = await temporal.get_client()
- # We should get the same client as before, not a new one
- assert client_2 == client_1
- assert client_2 != client
-
-
@pytest.mark.asyncio
@mock.patch("osm_common.wftemporal.uuid")
async def test_start_workflow_no_id(mock_uuid):
- WFTemporal.clients = {}
+ WFTemporal._client = None
mock_uuid.uuid4.return_value = "01234567-89abc-def-0123-456789abcdef"
with generate_mock_client() as client:
client.start_workflow.return_value = "handle"
- temporal = WFTemporal(temporal_api="localhost")
+ temporal = WFTemporal()
result = await temporal.start_workflow(
task_queue="q", workflow_name="workflow", workflow_data="data"
)
@pytest.mark.asyncio
async def test_start_workflow_with_id():
- WFTemporal.clients = {}
+ WFTemporal._client = None
with generate_mock_client() as client:
client.start_workflow.return_value = "handle"
- temporal = WFTemporal(temporal_api="localhost")
+ temporal = WFTemporal()
result = await temporal.start_workflow(
task_queue="q", id="id", workflow_name="workflow", workflow_data="data"
)
@pytest.mark.asyncio
async def test_execute_workflow_with_id():
- WFTemporal.clients = {}
+ WFTemporal._client = None
with generate_mock_client() as client:
handle = Mock(WorkflowHandle)
handle.result = AsyncMock(return_value="success")
client.start_workflow.return_value = handle
- temporal = WFTemporal(temporal_api="localhost")
+ temporal = WFTemporal()
result = await temporal.execute_workflow(
task_queue="q",
id="another id",
client.start_workflow.assert_awaited_once_with(
workflow="workflow-with-result", arg="data", id="another id", task_queue="q"
)
+
+
+@pytest.mark.asyncio
+async def test_client_cache():
+ WFTemporal._client = None
+
+ with generate_mock_client() as client:
+ handle = Mock(WorkflowHandle)
+ handle.result = AsyncMock(return_value="success")
+ client.start_workflow.return_value = handle
+
+ temporal = WFTemporal()
+ result = await temporal.execute_workflow(
+ task_queue="q",
+ id="another id",
+ workflow_name="workflow-with-result",
+ workflow_data="data",
+ )
+ assert result == "success"
+ client.start_workflow.assert_awaited_once_with(
+ workflow="workflow-with-result", arg="data", id="another id", task_queue="q"
+ )
+
+ temporal = WFTemporal()
+ result = await temporal.execute_workflow(
+ task_queue="q",
+ id="yet another id",
+ workflow_name="workflow-with-result",
+ workflow_data="more data",
+ )
+ assert result == "success"
+ client.start_workflow.assert_awaited_with(
+ workflow="workflow-with-result",
+ arg="more data",
+ id="yet another id",
+ task_queue="q",
+ )
class WFTemporal(object):
- clients = {}
+ _client = None
+ temporal_api = None
- def __init__(self, temporal_api=None, logger_name="temporal.client"):
+ def __init__(self, logger_name="temporal.client"):
self.logger = logging.getLogger(logger_name)
- self.temporal_api = temporal_api
async def execute_workflow(
self, task_queue: str, workflow_name: str, workflow_data: any, id: str = None
async def start_workflow(
self, task_queue: str, workflow_name: str, workflow_data: any, id: str = None
):
- client = await self.get_client()
+ if WFTemporal._client is None:
+ self.logger.debug(
+ f"No cached client found, connecting to {WFTemporal.temporal_api}"
+ )
+ WFTemporal._client = await Client.connect(WFTemporal.temporal_api)
+
if id is None:
id = str(uuid.uuid4())
self.logger.info(f"Starting workflow {workflow_name}, id {id}")
- handle = await client.start_workflow(
+ handle = await WFTemporal._client.start_workflow(
workflow=workflow_name, arg=workflow_data, id=id, task_queue=task_queue
)
return handle
-
- async def get_client(self):
- if self.temporal_api in WFTemporal.clients:
- client = WFTemporal.clients[self.temporal_api]
- else:
- self.logger.debug(
- f"No cached client found, connecting to {self.temporal_api}"
- )
- client = await Client.connect(self.temporal_api)
- WFTemporal.clients[self.temporal_api] = client
-
- self.logger.debug(f"Using client {client} for {self.temporal_api}")
- return client