1 #######################################################################################
2 # Copyright ETSI Contributors and Others.
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16 #######################################################################################
18 from contextlib
import contextmanager
19 from typing
import Generator
20 from unittest
import mock
21 from unittest
.mock
import AsyncMock
, Mock
23 from osm_common
.wftemporal
import WFTemporal
25 from temporalio
.client
import Client
, WorkflowHandle
29 def generate_mock_client() -> Generator
[AsyncMock
, None, None]:
30 # Note the mock instance has to be created before patching for autospec to work
31 mock_instance
= mock
.create_autospec(Client
)
32 mock_instance
.connect
.return_value
= mock_instance
34 with mock
.patch("osm_common.wftemporal.Client.connect") as client
:
35 client
.return_value
= mock_instance
40 async def test_get_client():
41 WFTemporal
.clients
= {}
42 with
generate_mock_client() as client
:
43 temporal
= WFTemporal(temporal_api
="localhost")
44 result
= await temporal
.get_client()
45 assert result
== client
49 async def test_get_cached_client():
50 WFTemporal
.clients
= {}
53 with
generate_mock_client() as client
:
54 temporal
= WFTemporal(temporal_api
="localhost")
55 client_1
= await temporal
.get_client()
56 assert client_1
== client
58 with
generate_mock_client() as client
:
59 temporal
= WFTemporal(temporal_api
="localhost")
60 client_2
= await temporal
.get_client()
61 # We should get the same client as before, not a new one
62 assert client_2
== client_1
63 assert client_2
!= client
67 @mock.patch("osm_common.wftemporal.uuid")
68 async def test_start_workflow_no_id(mock_uuid
):
69 WFTemporal
.clients
= {}
71 mock_uuid
.uuid4
.return_value
= "01234567-89abc-def-0123-456789abcdef"
73 with
generate_mock_client() as client
:
74 client
.start_workflow
.return_value
= "handle"
75 temporal
= WFTemporal(temporal_api
="localhost")
76 result
= await temporal
.start_workflow(
77 task_queue
="q", workflow_name
="workflow", workflow_data
="data"
79 assert result
== "handle"
80 client
.start_workflow
.assert_awaited_once_with(
83 id="01234567-89abc-def-0123-456789abcdef",
89 async def test_start_workflow_with_id():
90 WFTemporal
.clients
= {}
92 with
generate_mock_client() as client
:
93 client
.start_workflow
.return_value
= "handle"
94 temporal
= WFTemporal(temporal_api
="localhost")
95 result
= await temporal
.start_workflow(
96 task_queue
="q", id="id", workflow_name
="workflow", workflow_data
="data"
98 assert result
== "handle"
99 client
.start_workflow
.assert_awaited_once_with(
100 workflow
="workflow", arg
="data", id="id", task_queue
="q"
105 async def test_execute_workflow_with_id():
106 WFTemporal
.clients
= {}
108 with
generate_mock_client() as client
:
109 handle
= Mock(WorkflowHandle
)
110 handle
.result
= AsyncMock(return_value
="success")
111 client
.start_workflow
.return_value
= handle
113 temporal
= WFTemporal(temporal_api
="localhost")
114 result
= await temporal
.execute_workflow(
117 workflow_name
="workflow-with-result",
118 workflow_data
="data",
120 assert result
== "success"
121 client
.start_workflow
.assert_awaited_once_with(
122 workflow
="workflow-with-result", arg
="data", id="another id", task_queue
="q"