RIFT OSM R1 Initial Submission
[osm/SO.git] / rwlaunchpad / plugins / rwlaunchpadtasklet / test / utest_onboard.py
diff --git a/rwlaunchpad/plugins/rwlaunchpadtasklet/test/utest_onboard.py b/rwlaunchpad/plugins/rwlaunchpadtasklet/test/utest_onboard.py
new file mode 100755 (executable)
index 0000000..871132f
--- /dev/null
@@ -0,0 +1,294 @@
+#!/usr/bin/env python3
+
+# 
+#   Copyright 2016 RIFT.IO Inc
+#
+#   Licensed under the Apache License, Version 2.0 (the "License");
+#   you may not use this file except in compliance with the License.
+#   You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+#   Unless required by applicable law or agreed to in writing, software
+#   distributed under the License is distributed on an "AS IS" BASIS,
+#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#   See the License for the specific language governing permissions and
+#   limitations under the License.
+#
+
+
+import argparse
+import asyncio
+import base64
+import concurrent.futures
+import io
+import logging
+import os
+import sys
+import tornado.testing
+import tornado.web
+import unittest
+import uuid
+import xmlrunner
+
+from rift.package import convert
+from rift.tasklets.rwlaunchpad import onboard
+import rift.test.dts
+
+import gi
+gi.require_version('NsdYang', '1.0')
+gi.require_version('VnfdYang', '1.0')
+
+from gi.repository import (
+        NsdYang,
+        VnfdYang,
+        )
+
+
+class RestconfDescriptorHandler(tornado.web.RequestHandler):
+    DESC_SERIALIZER_MAP = {
+            "nsd": convert.NsdSerializer(),
+            "vnfd": convert.VnfdSerializer(),
+            }
+
+    class AuthError(Exception):
+        pass
+
+
+    class ContentTypeError(Exception):
+        pass
+
+
+    class RequestBodyError(Exception):
+        pass
+
+
+    def initialize(self, log, auth, info):
+        self._auth = auth
+        # The superclass has self._log already defined so use a different name
+        self._logger = log
+        self._info = info
+        self._logger.debug('Created restconf descriptor handler')
+
+    def _verify_auth(self):
+        if self._auth is None:
+            return None
+
+        auth_header = self.request.headers.get('Authorization')
+        if auth_header is None or not auth_header.startswith('Basic '):
+            self.set_status(401)
+            self.set_header('WWW-Authenticate', 'Basic realm=Restricted')
+            self._transforms = []
+            self.finish()
+
+            msg = "Missing Authorization header"
+            self._logger.error(msg)
+            raise RestconfDescriptorHandler.AuthError(msg)
+
+        auth_header = auth_header.encode('ascii')
+        auth_decoded = base64.decodebytes(auth_header[6:]).decode()
+        login, password = auth_decoded.split(':', 2)
+        login = login
+        password = password
+        is_auth = ((login, password) == self._auth)
+
+        if not is_auth:
+            self.set_status(401)
+            self.set_header('WWW-Authenticate', 'Basic realm=Restricted')
+            self._transforms = []
+            self.finish()
+
+            msg = "Incorrect username and password in auth header: got {}, expected {}".format(
+                    (login, password), self._auth
+                    )
+            self._logger.error(msg)
+            raise RestconfDescriptorHandler.AuthError(msg)
+
+    def _verify_content_type_header(self):
+        content_type_header = self.request.headers.get('content-type')
+        if content_type_header is None:
+            self.set_status(415)
+            self._transforms = []
+            self.finish()
+
+            msg = "Missing content-type header"
+            self._logger.error(msg)
+            raise RestconfDescriptorHandler.ContentTypeError(msg)
+
+        if content_type_header != "application/vnd.yang.data+json":
+            self.set_status(415)
+            self._transforms = []
+            self.finish()
+
+            msg = "Unsupported content type: %s" % content_type_header
+            self._logger.error(msg)
+            raise RestconfDescriptorHandler.ContentTypeError(msg)
+
+    def _verify_headers(self):
+        self._verify_auth()
+        self._verify_content_type_header()
+
+    def _verify_request_body(self, descriptor_type):
+        if descriptor_type not in RestconfDescriptorHandler.DESC_SERIALIZER_MAP:
+            raise ValueError("Unsupported descriptor type: %s" % descriptor_type)
+
+        body = self.request.body
+        bytes_hdl = io.BytesIO(body)
+
+        serializer = RestconfDescriptorHandler.DESC_SERIALIZER_MAP[descriptor_type]
+
+        try:
+            message = serializer.from_file_hdl(bytes_hdl, ".json")
+        except convert.SerializationError as e:
+            self.set_status(400)
+            self._transforms = []
+            self.finish()
+
+            msg = "Descriptor request body not valid"
+            self._logger.error(msg)
+            raise RestconfDescriptorHandler.RequestBodyError() from e
+
+        self._info.last_request_message = message
+
+        self._logger.debug("Received a valid descriptor request")
+
+    def put(self, descriptor_type):
+        self._info.last_descriptor_type = descriptor_type
+        self._info.last_method = "PUT"
+
+        try:
+            self._verify_headers()
+        except (RestconfDescriptorHandler.AuthError,
+                RestconfDescriptorHandler.ContentTypeError):
+            return None
+
+        try:
+            self._verify_request_body(descriptor_type)
+        except RestconfDescriptorHandler.RequestBodyError:
+            return None
+
+        self.write("Response doesn't matter?")
+
+    def post(self, descriptor_type):
+        self._info.last_descriptor_type = descriptor_type
+        self._info.last_method = "POST"
+
+        try:
+            self._verify_headers()
+        except (RestconfDescriptorHandler.AuthError,
+                RestconfDescriptorHandler.ContentTypeError):
+            return None
+
+        try:
+            self._verify_request_body(descriptor_type)
+        except RestconfDescriptorHandler.RequestBodyError:
+            return None
+
+        self.write("Response doesn't matter?")
+
+
+class HandlerInfo(object):
+    def __init__(self):
+        self.last_request_message = None
+        self.last_descriptor_type = None
+        self.last_method = None
+
+
+class OnboardTestCase(tornado.testing.AsyncHTTPTestCase):
+    AUTH = ("admin", "admin")
+    def setUp(self):
+        self._log = logging.getLogger(__file__)
+        self._loop = asyncio.get_event_loop()
+
+        self._handler_info = HandlerInfo()
+        super().setUp()
+        self._port = self.get_http_port()
+        self._onboarder = onboard.DescriptorOnboarder(
+                log=self._log, port=self._port
+                )
+
+    def get_new_ioloop(self):
+        return tornado.platform.asyncio.AsyncIOMainLoop()
+
+    def get_app(self):
+        attrs = dict(auth=OnboardTestCase.AUTH, log=self._log, info=self._handler_info)
+        return tornado.web.Application([
+            (r"/api/config/.*/(nsd|vnfd)", RestconfDescriptorHandler, attrs),
+            ])
+
+    @rift.test.dts.async_test
+    def test_onboard_nsd(self):
+        nsd_msg = NsdYang.YangData_Nsd_NsdCatalog_Nsd(id=str(uuid.uuid4()), name="nsd_name")
+        yield from self._loop.run_in_executor(None, self._onboarder.onboard, nsd_msg)
+        self.assertEqual(self._handler_info.last_request_message, nsd_msg)
+        self.assertEqual(self._handler_info.last_descriptor_type, "nsd")
+        self.assertEqual(self._handler_info.last_method, "POST")
+
+    @rift.test.dts.async_test
+    def test_update_nsd(self):
+        nsd_msg = NsdYang.YangData_Nsd_NsdCatalog_Nsd(id=str(uuid.uuid4()), name="nsd_name")
+        yield from self._loop.run_in_executor(None, self._onboarder.update, nsd_msg)
+        self.assertEqual(self._handler_info.last_request_message, nsd_msg)
+        self.assertEqual(self._handler_info.last_descriptor_type, "nsd")
+        self.assertEqual(self._handler_info.last_method, "PUT")
+
+    @rift.test.dts.async_test
+    def test_bad_descriptor_type(self):
+        nsd_msg = NsdYang.YangData_Nsd_NsdCatalog()
+        with self.assertRaises(TypeError):
+            yield from self._loop.run_in_executor(None, self._onboarder.update, nsd_msg)
+
+        with self.assertRaises(TypeError):
+            yield from self._loop.run_in_executor(None, self._onboarder.onboard, nsd_msg)
+
+    @rift.test.dts.async_test
+    def test_bad_port(self):
+        # Use a port not used by the instantiated server
+        new_port = self._port - 1
+        self._onboarder.port = new_port
+        nsd_msg = NsdYang.YangData_Nsd_NsdCatalog_Nsd(id=str(uuid.uuid4()), name="nsd_name")
+
+        with self.assertRaises(onboard.OnboardError):
+            yield from self._loop.run_in_executor(None, self._onboarder.onboard, nsd_msg)
+
+        with self.assertRaises(onboard.UpdateError):
+            yield from self._loop.run_in_executor(None, self._onboarder.update, nsd_msg)
+
+    @rift.test.dts.async_test
+    def test_timeout(self):
+        # Set the timeout to something minimal to speed up test
+        self._onboarder.timeout = .1
+
+        nsd_msg = NsdYang.YangData_Nsd_NsdCatalog_Nsd(id=str(uuid.uuid4()), name="nsd_name")
+
+        # Force the request to timeout by running the call synchronously so the
+        with self.assertRaises(onboard.OnboardError):
+            self._onboarder.onboard(nsd_msg)
+
+        # Force the request to timeout by running the call synchronously so the
+        with self.assertRaises(onboard.UpdateError):
+            self._onboarder.update(nsd_msg)
+
+
+def main(argv=sys.argv[1:]):
+    logging.basicConfig(format='TEST %(message)s')
+
+    runner = xmlrunner.XMLTestRunner(output=os.environ["RIFT_MODULE_TEST"])
+    parser = argparse.ArgumentParser()
+    parser.add_argument('-v', '--verbose', action='store_true')
+    parser.add_argument('-n', '--no-runner', action='store_true')
+
+    args, unknown = parser.parse_known_args(argv)
+    if args.no_runner:
+        runner = None
+
+    # Set the global logging level
+    logging.getLogger().setLevel(logging.DEBUG if args.verbose else logging.ERROR)
+
+    # The unittest framework requires a program name, so use the name of this
+    # file instead (we do not want to have to pass a fake program name to main
+    # when this is called from the interpreter).
+    unittest.main(argv=[__file__] + unknown + ["-v"], testRunner=runner)
+
+if __name__ == '__main__':
+    main()