blob: b8ea2570baa758abfccb139d56a205ead9828f81 [file] [log] [blame]
Mark Beierl9468ea32023-02-24 21:23:48 +00001#######################################################################################
2# Copyright ETSI Contributors and Others.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
13# implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#######################################################################################
17
18from contextlib import contextmanager
19from typing import Generator
20from unittest import mock
21from unittest.mock import AsyncMock, Mock
22
23from osm_common.wftemporal import WFTemporal
24import pytest
25from temporalio.client import Client, WorkflowHandle
26
27
28@contextmanager
29def generate_mock_client() -> Generator[AsyncMock, None, None]:
30 # Note the mock instance has to be created before patching for autospec to work
31 mock_instance = mock.create_autospec(Client)
32 mock_instance.connect.return_value = mock_instance
33
34 with mock.patch("osm_common.wftemporal.Client.connect") as client:
35 client.return_value = mock_instance
36 yield mock_instance
37
38
39@pytest.mark.asyncio
40async def test_get_client():
41 WFTemporal.clients = {}
42 with generate_mock_client() as client:
43 temporal = WFTemporal(temporal_api="localhost")
44 result = await temporal.get_client()
45 assert result == client
46
47
48@pytest.mark.asyncio
49async def test_get_cached_client():
50 WFTemporal.clients = {}
51 client_1 = None
52 client_2 = None
53 with generate_mock_client() as client:
54 temporal = WFTemporal(temporal_api="localhost")
55 client_1 = await temporal.get_client()
56 assert client_1 == client
57
58 with generate_mock_client() as client:
59 temporal = WFTemporal(temporal_api="localhost")
60 client_2 = await temporal.get_client()
61 # We should get the same client as before, not a new one
62 assert client_2 == client_1
63 assert client_2 != client
64
65
66@pytest.mark.asyncio
67@mock.patch("osm_common.wftemporal.uuid")
68async def test_start_workflow_no_id(mock_uuid):
69 WFTemporal.clients = {}
70
71 mock_uuid.uuid4.return_value = "01234567-89abc-def-0123-456789abcdef"
72
73 with generate_mock_client() as client:
74 client.start_workflow.return_value = "handle"
75 temporal = WFTemporal(temporal_api="localhost")
76 result = await temporal.start_workflow(
77 task_queue="q", workflow_name="workflow", workflow_data="data"
78 )
79 assert result == "handle"
80 client.start_workflow.assert_awaited_once_with(
81 workflow="workflow",
82 arg="data",
83 id="01234567-89abc-def-0123-456789abcdef",
84 task_queue="q",
85 )
86
87
88@pytest.mark.asyncio
89async def test_start_workflow_with_id():
90 WFTemporal.clients = {}
91
92 with generate_mock_client() as client:
93 client.start_workflow.return_value = "handle"
94 temporal = WFTemporal(temporal_api="localhost")
95 result = await temporal.start_workflow(
96 task_queue="q", id="id", workflow_name="workflow", workflow_data="data"
97 )
98 assert result == "handle"
99 client.start_workflow.assert_awaited_once_with(
100 workflow="workflow", arg="data", id="id", task_queue="q"
101 )
102
103
104@pytest.mark.asyncio
105async def test_execute_workflow_with_id():
106 WFTemporal.clients = {}
107
108 with generate_mock_client() as client:
109 handle = Mock(WorkflowHandle)
110 handle.result = AsyncMock(return_value="success")
111 client.start_workflow.return_value = handle
112
113 temporal = WFTemporal(temporal_api="localhost")
114 result = await temporal.execute_workflow(
115 task_queue="q",
116 id="another id",
117 workflow_name="workflow-with-result",
118 workflow_data="data",
119 )
120 assert result == "success"
121 client.start_workflow.assert_awaited_once_with(
122 workflow="workflow-with-result", arg="data", id="another id", task_queue="q"
123 )