Test that the exception thrown is as expected
Change-Id: Ie434adb94eebb01b8fd910c2205701c31c0cfa06
Signed-off-by: Daniel Arndt <daniel.arndt@canonical.com>
diff --git a/osm_lcm/temporal/utils.py b/osm_lcm/temporal/utils.py
new file mode 100644
index 0000000..cbec8da
--- /dev/null
+++ b/osm_lcm/temporal/utils.py
@@ -0,0 +1,31 @@
+#######################################################################################
+# Copyright ETSI Contributors and Others.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from temporalio.client import WorkflowFailureError
+
+
+def get_root_cause(exception: WorkflowFailureError) -> BaseException:
+ """Get the root cause of a WorkflowFailureError
+
+ Temporal nests the cause of an exception at each layer as it bubbles up in
+ the framework from activities and child workflows. This function will
+ return the root cause of the exception.
+ """
+ cause = getattr(exception, "cause", None)
+ while cause:
+ exception = cause
+ cause = getattr(exception, "cause", None)
+ return exception
diff --git a/osm_lcm/tests/test_vim_workflows.py b/osm_lcm/tests/test_vim_workflows.py
index 500ce82..b8e5cfa 100644
--- a/osm_lcm/tests/test_vim_workflows.py
+++ b/osm_lcm/tests/test_vim_workflows.py
@@ -14,41 +14,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import asynctest
from datetime import timedelta
from unittest.mock import Mock
-from osm_common.temporal_task_queues.task_queues_mappings import LCM_TASK_QUEUE
-from osm_common.temporal.activities.paas import (
- TestVimConnectivity,
-)
+import asynctest
+from osm_common.temporal.activities.paas import TestVimConnectivity
from osm_common.temporal.activities.vim import (
- UpdateVimState,
- UpdateVimOperationState,
DeleteVimRecord,
+ UpdateVimOperationState,
+ UpdateVimState,
)
-from osm_common.temporal.states import (
- VimState,
- VimOperationState,
-)
-from osm_lcm.temporal.vim_workflows import (
- VimCreateWorkflow,
- VimUpdateWorkflow,
- VimDeleteWorkflow,
- VimCreateWorkflowImpl,
- VimUpdateWorkflowImpl,
- VimDeleteWorkflowImpl,
-)
+from osm_common.temporal.states import VimOperationState, VimState
+from osm_common.temporal_task_queues.task_queues_mappings import LCM_TASK_QUEUE
from parameterized import parameterized_class
from temporalio import activity
-from temporalio.client import WorkflowFailureError
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker
+from osm_lcm.temporal.vim_workflows import (
+ VimCreateWorkflow,
+ VimCreateWorkflowImpl,
+ VimDeleteWorkflow,
+ VimDeleteWorkflowImpl,
+ VimUpdateWorkflow,
+ VimUpdateWorkflowImpl,
+)
+from osm_lcm.tests.utils import validate_workflow_failure_error_type
+
# Prevent the tasks from running indefinitely
-TASK_TIMEOUT = timedelta(seconds=5)
+TASK_TIMEOUT = timedelta(seconds=0.1)
# Prevent the workflow from running indefinitely
-EXECUTION_TIMEOUT = timedelta(seconds=10)
+EXECUTION_TIMEOUT = timedelta(seconds=5)
class TestException(Exception):
@@ -186,7 +182,7 @@
expected_vim_state = [VimState.ENABLED] * retry_policy
expected_vim_op_state = [VimOperationState.COMPLETED]
async with self.env, self.get_worker(activities):
- with self.assertRaises(WorkflowFailureError):
+ with validate_workflow_failure_error_type(self, TestException):
await self.client.execute_workflow(
self.workflow_name,
arg=self.vim_operation_input,
@@ -208,7 +204,7 @@
retry_policy = 3
expected_vim_op_state = [VimOperationState.COMPLETED] * retry_policy
async with self.env, self.get_worker(activities):
- with self.assertRaises(WorkflowFailureError):
+ with validate_workflow_failure_error_type(self, TestException):
await self.client.execute_workflow(
self.workflow_name,
arg=self.vim_operation_input,
@@ -231,7 +227,7 @@
expected_vim_state = [VimState.ERROR]
expected_vim_op_state = [VimOperationState.FAILED]
async with self.env, self.get_worker(activities):
- with self.assertRaises(WorkflowFailureError):
+ with validate_workflow_failure_error_type(self, TestException):
await self.client.execute_workflow(
self.workflow_name,
arg=self.vim_operation_input,
@@ -253,7 +249,7 @@
expected_vim_state = [VimState.ERROR] * retry_policy
expected_vim_op_state = [VimOperationState.FAILED]
async with self.env, self.get_worker(activities):
- with self.assertRaises(WorkflowFailureError):
+ with validate_workflow_failure_error_type(self, TestException):
await self.client.execute_workflow(
self.workflow_name,
arg=self.vim_operation_input,
@@ -273,7 +269,7 @@
expected_vim_state = [VimState.ERROR]
expected_vim_op_state = [VimOperationState.FAILED] * retry_policy
async with self.env, self.get_worker(activities):
- with self.assertRaises(WorkflowFailureError):
+ with validate_workflow_failure_error_type(self, TestException):
await self.client.execute_workflow(
self.workflow_name,
arg=self.vim_operation_input,
@@ -293,7 +289,7 @@
expected_vim_state = [VimState.ERROR] * retry_policy
expected_vim_op_state = [VimOperationState.FAILED] * retry_policy
async with self.env, self.get_worker(activities):
- with self.assertRaises(WorkflowFailureError):
+ with validate_workflow_failure_error_type(self, TestException):
await self.client.execute_workflow(
self.workflow_name,
arg=self.vim_operation_input,
@@ -322,7 +318,7 @@
async def test_vim_delete_exception(self):
activities = [mock_delete_vim_record_raises]
async with self.env, self.get_worker(activities):
- with self.assertRaises(WorkflowFailureError):
+ with validate_workflow_failure_error_type(self, TestException):
result = await self.client.execute_workflow(
VimDeleteWorkflow.__name__,
arg=self.vim_operation_input,
diff --git a/osm_lcm/tests/utils.py b/osm_lcm/tests/utils.py
new file mode 100644
index 0000000..8081fa2
--- /dev/null
+++ b/osm_lcm/tests/utils.py
@@ -0,0 +1,44 @@
+#######################################################################################
+# Copyright ETSI Contributors and Others.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+# implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from contextlib import contextmanager
+from typing import Type
+import unittest
+from temporalio.client import WorkflowFailureError
+from temporalio.exceptions import RetryState
+
+from osm_lcm.temporal.utils import get_root_cause
+
+
+@contextmanager
+def validate_workflow_failure_error_type(
+ test_case: unittest.TestCase, cause_type: Type[Exception]
+):
+ """Validates that the workflow failed with the given cause type.
+
+ args:
+ cause_type: The type of the exception that caused the workflow to fail.
+ """
+ with test_case.assertRaises(WorkflowFailureError) as e_info:
+ yield e_info
+ exception = e_info.exception
+ test_case.assertNotEquals(
+ exception.cause.retry_state, # type: ignore
+ RetryState.TIMEOUT,
+ "Workflow timed out. You may need to increase the execution timeout",
+ )
+ exception = get_root_cause(exception)
+ test_case.assertEqual(exception.type, cause_type.__name__) # type: ignore