From: Daniel Arndt Date: Tue, 4 Jul 2023 21:54:42 +0000 (-0300) Subject: Test that the exception thrown is as expected X-Git-Url: https://osm.etsi.org/gitweb/?a=commitdiff_plain;h=refs%2Fchanges%2F27%2F13627%2F7;p=osm%2FLCM.git Test that the exception thrown is as expected Change-Id: Ie434adb94eebb01b8fd910c2205701c31c0cfa06 Signed-off-by: Daniel Arndt --- diff --git a/osm_lcm/temporal/utils.py b/osm_lcm/temporal/utils.py new file mode 100644 index 00000000..cbec8da4 --- /dev/null +++ b/osm_lcm/temporal/utils.py @@ -0,0 +1,31 @@ +####################################################################################### +# Copyright ETSI Contributors and Others. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from temporalio.client import WorkflowFailureError + + +def get_root_cause(exception: WorkflowFailureError) -> BaseException: + """Get the root cause of a WorkflowFailureError + + Temporal nests the cause of an exception at each layer as it bubbles up in + the framework from activities and child workflows. This function will + return the root cause of the exception. + """ + cause = getattr(exception, "cause", None) + while cause: + exception = cause + cause = getattr(exception, "cause", None) + return exception diff --git a/osm_lcm/tests/test_vim_workflows.py b/osm_lcm/tests/test_vim_workflows.py index 500ce820..b8e5cfa3 100644 --- a/osm_lcm/tests/test_vim_workflows.py +++ b/osm_lcm/tests/test_vim_workflows.py @@ -14,41 +14,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asynctest from datetime import timedelta from unittest.mock import Mock -from osm_common.temporal_task_queues.task_queues_mappings import LCM_TASK_QUEUE -from osm_common.temporal.activities.paas import ( - TestVimConnectivity, -) +import asynctest +from osm_common.temporal.activities.paas import TestVimConnectivity from osm_common.temporal.activities.vim import ( - UpdateVimState, - UpdateVimOperationState, DeleteVimRecord, + UpdateVimOperationState, + UpdateVimState, ) -from osm_common.temporal.states import ( - VimState, - VimOperationState, -) +from osm_common.temporal.states import VimOperationState, VimState +from osm_common.temporal_task_queues.task_queues_mappings import LCM_TASK_QUEUE +from parameterized import parameterized_class +from temporalio import activity +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker + from osm_lcm.temporal.vim_workflows import ( VimCreateWorkflow, - VimUpdateWorkflow, - VimDeleteWorkflow, VimCreateWorkflowImpl, - VimUpdateWorkflowImpl, + VimDeleteWorkflow, VimDeleteWorkflowImpl, + VimUpdateWorkflow, + VimUpdateWorkflowImpl, ) -from parameterized import parameterized_class -from temporalio import activity -from temporalio.client import WorkflowFailureError -from temporalio.testing import WorkflowEnvironment -from temporalio.worker import Worker +from osm_lcm.tests.utils import validate_workflow_failure_error_type # Prevent the tasks from running indefinitely -TASK_TIMEOUT = timedelta(seconds=5) +TASK_TIMEOUT = timedelta(seconds=0.1) # Prevent the workflow from running indefinitely -EXECUTION_TIMEOUT = timedelta(seconds=10) +EXECUTION_TIMEOUT = timedelta(seconds=5) class TestException(Exception): @@ -186,7 +182,7 @@ class TestVimWorkflow(TestVimWorkflowsBase): expected_vim_state = [VimState.ENABLED] * retry_policy expected_vim_op_state = [VimOperationState.COMPLETED] async with self.env, self.get_worker(activities): - with self.assertRaises(WorkflowFailureError): + with validate_workflow_failure_error_type(self, TestException): await self.client.execute_workflow( self.workflow_name, arg=self.vim_operation_input, @@ -208,7 +204,7 @@ class TestVimWorkflow(TestVimWorkflowsBase): retry_policy = 3 expected_vim_op_state = [VimOperationState.COMPLETED] * retry_policy async with self.env, self.get_worker(activities): - with self.assertRaises(WorkflowFailureError): + with validate_workflow_failure_error_type(self, TestException): await self.client.execute_workflow( self.workflow_name, arg=self.vim_operation_input, @@ -231,7 +227,7 @@ class TestVimWorkflow(TestVimWorkflowsBase): expected_vim_state = [VimState.ERROR] expected_vim_op_state = [VimOperationState.FAILED] async with self.env, self.get_worker(activities): - with self.assertRaises(WorkflowFailureError): + with validate_workflow_failure_error_type(self, TestException): await self.client.execute_workflow( self.workflow_name, arg=self.vim_operation_input, @@ -253,7 +249,7 @@ class TestVimWorkflow(TestVimWorkflowsBase): expected_vim_state = [VimState.ERROR] * retry_policy expected_vim_op_state = [VimOperationState.FAILED] async with self.env, self.get_worker(activities): - with self.assertRaises(WorkflowFailureError): + with validate_workflow_failure_error_type(self, TestException): await self.client.execute_workflow( self.workflow_name, arg=self.vim_operation_input, @@ -273,7 +269,7 @@ class TestVimWorkflow(TestVimWorkflowsBase): expected_vim_state = [VimState.ERROR] expected_vim_op_state = [VimOperationState.FAILED] * retry_policy async with self.env, self.get_worker(activities): - with self.assertRaises(WorkflowFailureError): + with validate_workflow_failure_error_type(self, TestException): await self.client.execute_workflow( self.workflow_name, arg=self.vim_operation_input, @@ -293,7 +289,7 @@ class TestVimWorkflow(TestVimWorkflowsBase): expected_vim_state = [VimState.ERROR] * retry_policy expected_vim_op_state = [VimOperationState.FAILED] * retry_policy async with self.env, self.get_worker(activities): - with self.assertRaises(WorkflowFailureError): + with validate_workflow_failure_error_type(self, TestException): await self.client.execute_workflow( self.workflow_name, arg=self.vim_operation_input, @@ -322,7 +318,7 @@ class TestVimDeleteWorkflow(TestVimWorkflowsBase): async def test_vim_delete_exception(self): activities = [mock_delete_vim_record_raises] async with self.env, self.get_worker(activities): - with self.assertRaises(WorkflowFailureError): + with validate_workflow_failure_error_type(self, TestException): result = await self.client.execute_workflow( VimDeleteWorkflow.__name__, arg=self.vim_operation_input, diff --git a/osm_lcm/tests/utils.py b/osm_lcm/tests/utils.py new file mode 100644 index 00000000..8081fa21 --- /dev/null +++ b/osm_lcm/tests/utils.py @@ -0,0 +1,44 @@ +####################################################################################### +# Copyright ETSI Contributors and Others. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager +from typing import Type +import unittest +from temporalio.client import WorkflowFailureError +from temporalio.exceptions import RetryState + +from osm_lcm.temporal.utils import get_root_cause + + +@contextmanager +def validate_workflow_failure_error_type( + test_case: unittest.TestCase, cause_type: Type[Exception] +): + """Validates that the workflow failed with the given cause type. + + args: + cause_type: The type of the exception that caused the workflow to fail. + """ + with test_case.assertRaises(WorkflowFailureError) as e_info: + yield e_info + exception = e_info.exception + test_case.assertNotEquals( + exception.cause.retry_state, # type: ignore + RetryState.TIMEOUT, + "Workflow timed out. You may need to increase the execution timeout", + ) + exception = get_root_cause(exception) + test_case.assertEqual(exception.type, cause_type.__name__) # type: ignore