blob: b8ea2570baa758abfccb139d56a205ead9828f81 [file] [log] [blame]
#######################################################################################
# Copyright ETSI Contributors and Others.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#######################################################################################
from contextlib import contextmanager
from typing import Generator
from unittest import mock
from unittest.mock import AsyncMock, Mock
from osm_common.wftemporal import WFTemporal
import pytest
from temporalio.client import Client, WorkflowHandle
@contextmanager
def generate_mock_client() -> Generator[AsyncMock, None, None]:
# Note the mock instance has to be created before patching for autospec to work
mock_instance = mock.create_autospec(Client)
mock_instance.connect.return_value = mock_instance
with mock.patch("osm_common.wftemporal.Client.connect") as client:
client.return_value = mock_instance
yield mock_instance
@pytest.mark.asyncio
async def test_get_client():
WFTemporal.clients = {}
with generate_mock_client() as client:
temporal = WFTemporal(temporal_api="localhost")
result = await temporal.get_client()
assert result == client
@pytest.mark.asyncio
async def test_get_cached_client():
WFTemporal.clients = {}
client_1 = None
client_2 = None
with generate_mock_client() as client:
temporal = WFTemporal(temporal_api="localhost")
client_1 = await temporal.get_client()
assert client_1 == client
with generate_mock_client() as client:
temporal = WFTemporal(temporal_api="localhost")
client_2 = await temporal.get_client()
# We should get the same client as before, not a new one
assert client_2 == client_1
assert client_2 != client
@pytest.mark.asyncio
@mock.patch("osm_common.wftemporal.uuid")
async def test_start_workflow_no_id(mock_uuid):
WFTemporal.clients = {}
mock_uuid.uuid4.return_value = "01234567-89abc-def-0123-456789abcdef"
with generate_mock_client() as client:
client.start_workflow.return_value = "handle"
temporal = WFTemporal(temporal_api="localhost")
result = await temporal.start_workflow(
task_queue="q", workflow_name="workflow", workflow_data="data"
)
assert result == "handle"
client.start_workflow.assert_awaited_once_with(
workflow="workflow",
arg="data",
id="01234567-89abc-def-0123-456789abcdef",
task_queue="q",
)
@pytest.mark.asyncio
async def test_start_workflow_with_id():
WFTemporal.clients = {}
with generate_mock_client() as client:
client.start_workflow.return_value = "handle"
temporal = WFTemporal(temporal_api="localhost")
result = await temporal.start_workflow(
task_queue="q", id="id", workflow_name="workflow", workflow_data="data"
)
assert result == "handle"
client.start_workflow.assert_awaited_once_with(
workflow="workflow", arg="data", id="id", task_queue="q"
)
@pytest.mark.asyncio
async def test_execute_workflow_with_id():
WFTemporal.clients = {}
with generate_mock_client() as client:
handle = Mock(WorkflowHandle)
handle.result = AsyncMock(return_value="success")
client.start_workflow.return_value = handle
temporal = WFTemporal(temporal_api="localhost")
result = await temporal.execute_workflow(
task_queue="q",
id="another id",
workflow_name="workflow-with-result",
workflow_data="data",
)
assert result == "success"
client.start_workflow.assert_awaited_once_with(
workflow="workflow-with-result", arg="data", id="another id", task_queue="q"
)