blob: 871132f599f6c0648081ef866bb33975d02a1830 [file] [log] [blame]
#!/usr/bin/env python3
#
# Copyright 2016 RIFT.IO Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import asyncio
import base64
import concurrent.futures
import io
import logging
import os
import sys
import tornado.testing
import tornado.web
import unittest
import uuid
import xmlrunner
from rift.package import convert
from rift.tasklets.rwlaunchpad import onboard
import rift.test.dts
import gi
gi.require_version('NsdYang', '1.0')
gi.require_version('VnfdYang', '1.0')
from gi.repository import (
NsdYang,
VnfdYang,
)
class RestconfDescriptorHandler(tornado.web.RequestHandler):
DESC_SERIALIZER_MAP = {
"nsd": convert.NsdSerializer(),
"vnfd": convert.VnfdSerializer(),
}
class AuthError(Exception):
pass
class ContentTypeError(Exception):
pass
class RequestBodyError(Exception):
pass
def initialize(self, log, auth, info):
self._auth = auth
# The superclass has self._log already defined so use a different name
self._logger = log
self._info = info
self._logger.debug('Created restconf descriptor handler')
def _verify_auth(self):
if self._auth is None:
return None
auth_header = self.request.headers.get('Authorization')
if auth_header is None or not auth_header.startswith('Basic '):
self.set_status(401)
self.set_header('WWW-Authenticate', 'Basic realm=Restricted')
self._transforms = []
self.finish()
msg = "Missing Authorization header"
self._logger.error(msg)
raise RestconfDescriptorHandler.AuthError(msg)
auth_header = auth_header.encode('ascii')
auth_decoded = base64.decodebytes(auth_header[6:]).decode()
login, password = auth_decoded.split(':', 2)
login = login
password = password
is_auth = ((login, password) == self._auth)
if not is_auth:
self.set_status(401)
self.set_header('WWW-Authenticate', 'Basic realm=Restricted')
self._transforms = []
self.finish()
msg = "Incorrect username and password in auth header: got {}, expected {}".format(
(login, password), self._auth
)
self._logger.error(msg)
raise RestconfDescriptorHandler.AuthError(msg)
def _verify_content_type_header(self):
content_type_header = self.request.headers.get('content-type')
if content_type_header is None:
self.set_status(415)
self._transforms = []
self.finish()
msg = "Missing content-type header"
self._logger.error(msg)
raise RestconfDescriptorHandler.ContentTypeError(msg)
if content_type_header != "application/vnd.yang.data+json":
self.set_status(415)
self._transforms = []
self.finish()
msg = "Unsupported content type: %s" % content_type_header
self._logger.error(msg)
raise RestconfDescriptorHandler.ContentTypeError(msg)
def _verify_headers(self):
self._verify_auth()
self._verify_content_type_header()
def _verify_request_body(self, descriptor_type):
if descriptor_type not in RestconfDescriptorHandler.DESC_SERIALIZER_MAP:
raise ValueError("Unsupported descriptor type: %s" % descriptor_type)
body = self.request.body
bytes_hdl = io.BytesIO(body)
serializer = RestconfDescriptorHandler.DESC_SERIALIZER_MAP[descriptor_type]
try:
message = serializer.from_file_hdl(bytes_hdl, ".json")
except convert.SerializationError as e:
self.set_status(400)
self._transforms = []
self.finish()
msg = "Descriptor request body not valid"
self._logger.error(msg)
raise RestconfDescriptorHandler.RequestBodyError() from e
self._info.last_request_message = message
self._logger.debug("Received a valid descriptor request")
def put(self, descriptor_type):
self._info.last_descriptor_type = descriptor_type
self._info.last_method = "PUT"
try:
self._verify_headers()
except (RestconfDescriptorHandler.AuthError,
RestconfDescriptorHandler.ContentTypeError):
return None
try:
self._verify_request_body(descriptor_type)
except RestconfDescriptorHandler.RequestBodyError:
return None
self.write("Response doesn't matter?")
def post(self, descriptor_type):
self._info.last_descriptor_type = descriptor_type
self._info.last_method = "POST"
try:
self._verify_headers()
except (RestconfDescriptorHandler.AuthError,
RestconfDescriptorHandler.ContentTypeError):
return None
try:
self._verify_request_body(descriptor_type)
except RestconfDescriptorHandler.RequestBodyError:
return None
self.write("Response doesn't matter?")
class HandlerInfo(object):
def __init__(self):
self.last_request_message = None
self.last_descriptor_type = None
self.last_method = None
class OnboardTestCase(tornado.testing.AsyncHTTPTestCase):
AUTH = ("admin", "admin")
def setUp(self):
self._log = logging.getLogger(__file__)
self._loop = asyncio.get_event_loop()
self._handler_info = HandlerInfo()
super().setUp()
self._port = self.get_http_port()
self._onboarder = onboard.DescriptorOnboarder(
log=self._log, port=self._port
)
def get_new_ioloop(self):
return tornado.platform.asyncio.AsyncIOMainLoop()
def get_app(self):
attrs = dict(auth=OnboardTestCase.AUTH, log=self._log, info=self._handler_info)
return tornado.web.Application([
(r"/api/config/.*/(nsd|vnfd)", RestconfDescriptorHandler, attrs),
])
@rift.test.dts.async_test
def test_onboard_nsd(self):
nsd_msg = NsdYang.YangData_Nsd_NsdCatalog_Nsd(id=str(uuid.uuid4()), name="nsd_name")
yield from self._loop.run_in_executor(None, self._onboarder.onboard, nsd_msg)
self.assertEqual(self._handler_info.last_request_message, nsd_msg)
self.assertEqual(self._handler_info.last_descriptor_type, "nsd")
self.assertEqual(self._handler_info.last_method, "POST")
@rift.test.dts.async_test
def test_update_nsd(self):
nsd_msg = NsdYang.YangData_Nsd_NsdCatalog_Nsd(id=str(uuid.uuid4()), name="nsd_name")
yield from self._loop.run_in_executor(None, self._onboarder.update, nsd_msg)
self.assertEqual(self._handler_info.last_request_message, nsd_msg)
self.assertEqual(self._handler_info.last_descriptor_type, "nsd")
self.assertEqual(self._handler_info.last_method, "PUT")
@rift.test.dts.async_test
def test_bad_descriptor_type(self):
nsd_msg = NsdYang.YangData_Nsd_NsdCatalog()
with self.assertRaises(TypeError):
yield from self._loop.run_in_executor(None, self._onboarder.update, nsd_msg)
with self.assertRaises(TypeError):
yield from self._loop.run_in_executor(None, self._onboarder.onboard, nsd_msg)
@rift.test.dts.async_test
def test_bad_port(self):
# Use a port not used by the instantiated server
new_port = self._port - 1
self._onboarder.port = new_port
nsd_msg = NsdYang.YangData_Nsd_NsdCatalog_Nsd(id=str(uuid.uuid4()), name="nsd_name")
with self.assertRaises(onboard.OnboardError):
yield from self._loop.run_in_executor(None, self._onboarder.onboard, nsd_msg)
with self.assertRaises(onboard.UpdateError):
yield from self._loop.run_in_executor(None, self._onboarder.update, nsd_msg)
@rift.test.dts.async_test
def test_timeout(self):
# Set the timeout to something minimal to speed up test
self._onboarder.timeout = .1
nsd_msg = NsdYang.YangData_Nsd_NsdCatalog_Nsd(id=str(uuid.uuid4()), name="nsd_name")
# Force the request to timeout by running the call synchronously so the
with self.assertRaises(onboard.OnboardError):
self._onboarder.onboard(nsd_msg)
# Force the request to timeout by running the call synchronously so the
with self.assertRaises(onboard.UpdateError):
self._onboarder.update(nsd_msg)
def main(argv=sys.argv[1:]):
logging.basicConfig(format='TEST %(message)s')
runner = xmlrunner.XMLTestRunner(output=os.environ["RIFT_MODULE_TEST"])
parser = argparse.ArgumentParser()
parser.add_argument('-v', '--verbose', action='store_true')
parser.add_argument('-n', '--no-runner', action='store_true')
args, unknown = parser.parse_known_args(argv)
if args.no_runner:
runner = None
# Set the global logging level
logging.getLogger().setLevel(logging.DEBUG if args.verbose else logging.ERROR)
# The unittest framework requires a program name, so use the name of this
# file instead (we do not want to have to pass a fake program name to main
# when this is called from the interpreter).
unittest.main(argv=[__file__] + unknown + ["-v"], testRunner=runner)
if __name__ == '__main__':
main()