| ####################################################################################### |
| # Copyright ETSI Contributors and Others. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or |
| # implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| from contextlib import contextmanager |
| from typing import Type |
| import unittest |
| from temporalio.client import WorkflowFailureError |
| from temporalio.exceptions import RetryState |
| |
| from osm_lcm.temporal.utils import get_root_cause |
| |
| |
| @contextmanager |
| def validate_workflow_failure_error_type( |
| test_case: unittest.TestCase, cause_type: Type[Exception] |
| ): |
| """Validates that the workflow failed with the given cause type. |
| |
| args: |
| cause_type: The type of the exception that caused the workflow to fail. |
| """ |
| with test_case.assertRaises(WorkflowFailureError) as e_info: |
| yield e_info |
| exception = e_info.exception |
| test_case.assertNotEquals( |
| exception.cause.retry_state, # type: ignore |
| RetryState.TIMEOUT, |
| "Workflow timed out. You may need to increase the execution timeout", |
| ) |
| exception = get_root_cause(exception) |
| test_case.assertEqual(exception.type, cause_type.__name__) # type: ignore |