b8ea2570baa758abfccb139d56a205ead9828f81
[osm/common.git] / osm_common / tests / test_wftemporal.py
1 #######################################################################################
2 # Copyright ETSI Contributors and Others.
3 #
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
7 #
8 # http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
13 # implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16 #######################################################################################
17
18 from contextlib import contextmanager
19 from typing import Generator
20 from unittest import mock
21 from unittest.mock import AsyncMock, Mock
22
23 from osm_common.wftemporal import WFTemporal
24 import pytest
25 from temporalio.client import Client, WorkflowHandle
26
27
28 @contextmanager
29 def generate_mock_client() -> Generator[AsyncMock, None, None]:
30 # Note the mock instance has to be created before patching for autospec to work
31 mock_instance = mock.create_autospec(Client)
32 mock_instance.connect.return_value = mock_instance
33
34 with mock.patch("osm_common.wftemporal.Client.connect") as client:
35 client.return_value = mock_instance
36 yield mock_instance
37
38
39 @pytest.mark.asyncio
40 async def test_get_client():
41 WFTemporal.clients = {}
42 with generate_mock_client() as client:
43 temporal = WFTemporal(temporal_api="localhost")
44 result = await temporal.get_client()
45 assert result == client
46
47
48 @pytest.mark.asyncio
49 async def test_get_cached_client():
50 WFTemporal.clients = {}
51 client_1 = None
52 client_2 = None
53 with generate_mock_client() as client:
54 temporal = WFTemporal(temporal_api="localhost")
55 client_1 = await temporal.get_client()
56 assert client_1 == client
57
58 with generate_mock_client() as client:
59 temporal = WFTemporal(temporal_api="localhost")
60 client_2 = await temporal.get_client()
61 # We should get the same client as before, not a new one
62 assert client_2 == client_1
63 assert client_2 != client
64
65
66 @pytest.mark.asyncio
67 @mock.patch("osm_common.wftemporal.uuid")
68 async def test_start_workflow_no_id(mock_uuid):
69 WFTemporal.clients = {}
70
71 mock_uuid.uuid4.return_value = "01234567-89abc-def-0123-456789abcdef"
72
73 with generate_mock_client() as client:
74 client.start_workflow.return_value = "handle"
75 temporal = WFTemporal(temporal_api="localhost")
76 result = await temporal.start_workflow(
77 task_queue="q", workflow_name="workflow", workflow_data="data"
78 )
79 assert result == "handle"
80 client.start_workflow.assert_awaited_once_with(
81 workflow="workflow",
82 arg="data",
83 id="01234567-89abc-def-0123-456789abcdef",
84 task_queue="q",
85 )
86
87
88 @pytest.mark.asyncio
89 async def test_start_workflow_with_id():
90 WFTemporal.clients = {}
91
92 with generate_mock_client() as client:
93 client.start_workflow.return_value = "handle"
94 temporal = WFTemporal(temporal_api="localhost")
95 result = await temporal.start_workflow(
96 task_queue="q", id="id", workflow_name="workflow", workflow_data="data"
97 )
98 assert result == "handle"
99 client.start_workflow.assert_awaited_once_with(
100 workflow="workflow", arg="data", id="id", task_queue="q"
101 )
102
103
104 @pytest.mark.asyncio
105 async def test_execute_workflow_with_id():
106 WFTemporal.clients = {}
107
108 with generate_mock_client() as client:
109 handle = Mock(WorkflowHandle)
110 handle.result = AsyncMock(return_value="success")
111 client.start_workflow.return_value = handle
112
113 temporal = WFTemporal(temporal_api="localhost")
114 result = await temporal.execute_workflow(
115 task_queue="q",
116 id="another id",
117 workflow_name="workflow-with-result",
118 workflow_data="data",
119 )
120 assert result == "success"
121 client.start_workflow.assert_awaited_once_with(
122 workflow="workflow-with-result", arg="data", id="another id", task_queue="q"
123 )