--- /dev/null
+#######################################################################################
+# Copyright ETSI Contributors and Others.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#######################################################################################
+
+from contextlib import contextmanager
+from typing import Generator
+from unittest import mock
+from unittest.mock import AsyncMock, Mock
+
+from osm_common.wftemporal import WFTemporal
+import pytest
+from temporalio.client import Client, WorkflowHandle
+
+
+@contextmanager
+def generate_mock_client() -> Generator[AsyncMock, None, None]:
+ # Note the mock instance has to be created before patching for autospec to work
+ mock_instance = mock.create_autospec(Client)
+ mock_instance.connect.return_value = mock_instance
+
+ with mock.patch("osm_common.wftemporal.Client.connect") as client:
+ client.return_value = mock_instance
+ yield mock_instance
+
+
+@pytest.mark.asyncio
+async def test_get_client():
+ WFTemporal.clients = {}
+ with generate_mock_client() as client:
+ temporal = WFTemporal(temporal_api="localhost")
+ result = await temporal.get_client()
+ assert result == client
+
+
+@pytest.mark.asyncio
+async def test_get_cached_client():
+ WFTemporal.clients = {}
+ client_1 = None
+ client_2 = None
+ with generate_mock_client() as client:
+ temporal = WFTemporal(temporal_api="localhost")
+ client_1 = await temporal.get_client()
+ assert client_1 == client
+
+ with generate_mock_client() as client:
+ temporal = WFTemporal(temporal_api="localhost")
+ client_2 = await temporal.get_client()
+ # We should get the same client as before, not a new one
+ assert client_2 == client_1
+ assert client_2 != client
+
+
+@pytest.mark.asyncio
+@mock.patch("osm_common.wftemporal.uuid")
+async def test_start_workflow_no_id(mock_uuid):
+ WFTemporal.clients = {}
+
+ mock_uuid.uuid4.return_value = "01234567-89abc-def-0123-456789abcdef"
+
+ with generate_mock_client() as client:
+ client.start_workflow.return_value = "handle"
+ temporal = WFTemporal(temporal_api="localhost")
+ result = await temporal.start_workflow(
+ task_queue="q", workflow_name="workflow", workflow_data="data"
+ )
+ assert result == "handle"
+ client.start_workflow.assert_awaited_once_with(
+ workflow="workflow",
+ arg="data",
+ id="01234567-89abc-def-0123-456789abcdef",
+ task_queue="q",
+ )
+
+
+@pytest.mark.asyncio
+async def test_start_workflow_with_id():
+ WFTemporal.clients = {}
+
+ with generate_mock_client() as client:
+ client.start_workflow.return_value = "handle"
+ temporal = WFTemporal(temporal_api="localhost")
+ result = await temporal.start_workflow(
+ task_queue="q", id="id", workflow_name="workflow", workflow_data="data"
+ )
+ assert result == "handle"
+ client.start_workflow.assert_awaited_once_with(
+ workflow="workflow", arg="data", id="id", task_queue="q"
+ )
+
+
+@pytest.mark.asyncio
+async def test_execute_workflow_with_id():
+ WFTemporal.clients = {}
+
+ with generate_mock_client() as client:
+ handle = Mock(WorkflowHandle)
+ handle.result = AsyncMock(return_value="success")
+ client.start_workflow.return_value = handle
+
+ temporal = WFTemporal(temporal_api="localhost")
+ result = await temporal.execute_workflow(
+ task_queue="q",
+ id="another id",
+ workflow_name="workflow-with-result",
+ workflow_data="data",
+ )
+ assert result == "success"
+ client.start_workflow.assert_awaited_once_with(
+ workflow="workflow-with-result", arg="data", id="another id", task_queue="q"
+ )
--- /dev/null
+#######################################################################################
+# Copyright ETSI Contributors and Others.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#######################################################################################
+
+import logging
+import uuid
+
+from temporalio.client import Client
+
+
+class WFTemporal(object):
+ clients = {}
+
+ def __init__(self, temporal_api=None, logger_name="temporal.client"):
+ self.logger = logging.getLogger(logger_name)
+ self.temporal_api = temporal_api
+
+ async def execute_workflow(
+ self, task_queue: str, workflow_name: str, workflow_data: any, id: str = None
+ ):
+ handle = await self.start_workflow(
+ task_queue=task_queue,
+ workflow_name=workflow_name,
+ workflow_data=workflow_data,
+ id=id,
+ )
+ result = await handle.result()
+ self.logger.info(f"Completed workflow {workflow_name}, id {id}")
+ return result
+
+ async def start_workflow(
+ self, task_queue: str, workflow_name: str, workflow_data: any, id: str = None
+ ):
+ client = await self.get_client()
+ if id is None:
+ id = str(uuid.uuid4())
+ self.logger.info(f"Starting workflow {workflow_name}, id {id}")
+ handle = await client.start_workflow(
+ workflow=workflow_name, arg=workflow_data, id=id, task_queue=task_queue
+ )
+ return handle
+
+ async def get_client(self):
+ if self.temporal_api in WFTemporal.clients:
+ client = WFTemporal.clients[self.temporal_api]
+ else:
+ self.logger.debug(
+ f"No cached client found, connecting to {self.temporal_api}"
+ )
+ client = await Client.connect(self.temporal_api)
+ WFTemporal.clients[self.temporal_api] = client
+
+ self.logger.debug(f"Using client {client} for {self.temporal_api}")
+ return client
--- /dev/null
+#######################################################################################
+# Copyright ETSI Contributors and Others.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#######################################################################################
+---
+features:
+ - |
+ Temporal client: a library to manage the execution of workflows in a
+ particular temporal cluster.