blob: 1b104a3972f5e2a15be38e7e5aa0f4f9f2ced123 [file] [log] [blame]
#######################################################################################
# Copyright ETSI Contributors and Others.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#######################################################################################
from contextlib import contextmanager
from typing import Generator
from unittest import mock
from unittest.mock import AsyncMock, Mock
from osm_common.wftemporal import WFTemporal
import pytest
from temporalio.client import Client, WorkflowHandle
@contextmanager
def generate_mock_client() -> Generator[AsyncMock, None, None]:
# Note the mock instance has to be created before patching for autospec to work
mock_instance = mock.create_autospec(Client)
mock_instance.connect.return_value = mock_instance
with mock.patch("osm_common.wftemporal.Client.connect") as client:
client.return_value = mock_instance
yield mock_instance
@pytest.mark.asyncio
@mock.patch("osm_common.wftemporal.uuid")
async def test_start_workflow_no_id(mock_uuid):
WFTemporal._client = None
mock_uuid.uuid4.return_value = "01234567-89abc-def-0123-456789abcdef"
with generate_mock_client() as client:
client.start_workflow.return_value = "handle"
temporal = WFTemporal()
result = await temporal.start_workflow(
task_queue="q", workflow_name="workflow", workflow_data="data"
)
assert result == "handle"
client.start_workflow.assert_awaited_once_with(
workflow="workflow",
arg="data",
id="01234567-89abc-def-0123-456789abcdef",
task_queue="q",
)
@pytest.mark.asyncio
async def test_start_workflow_with_id():
WFTemporal._client = None
with generate_mock_client() as client:
client.start_workflow.return_value = "handle"
temporal = WFTemporal()
result = await temporal.start_workflow(
task_queue="q", id="id", workflow_name="workflow", workflow_data="data"
)
assert result == "handle"
client.start_workflow.assert_awaited_once_with(
workflow="workflow", arg="data", id="id", task_queue="q"
)
@pytest.mark.asyncio
async def test_execute_workflow_with_id():
WFTemporal._client = None
with generate_mock_client() as client:
handle = Mock(WorkflowHandle)
handle.result = AsyncMock(return_value="success")
client.start_workflow.return_value = handle
temporal = WFTemporal()
result = await temporal.execute_workflow(
task_queue="q",
id="another id",
workflow_name="workflow-with-result",
workflow_data="data",
)
assert result == "success"
client.start_workflow.assert_awaited_once_with(
workflow="workflow-with-result", arg="data", id="another id", task_queue="q"
)
@pytest.mark.asyncio
async def test_client_cache():
WFTemporal._client = None
with generate_mock_client() as client:
handle = Mock(WorkflowHandle)
handle.result = AsyncMock(return_value="success")
client.start_workflow.return_value = handle
temporal = WFTemporal()
result = await temporal.execute_workflow(
task_queue="q",
id="another id",
workflow_name="workflow-with-result",
workflow_data="data",
)
assert result == "success"
client.start_workflow.assert_awaited_once_with(
workflow="workflow-with-result", arg="data", id="another id", task_queue="q"
)
temporal = WFTemporal()
result = await temporal.execute_workflow(
task_queue="q",
id="yet another id",
workflow_name="workflow-with-result",
workflow_data="more data",
)
assert result == "success"
client.start_workflow.assert_awaited_with(
workflow="workflow-with-result",
arg="more data",
id="yet another id",
task_queue="q",
)