4 # Copyright 2016 RIFT.IO Inc
6 # Licensed under the Apache License, Version 2.0 (the "License");
7 # you may not use this file except in compliance with the License.
8 # You may obtain a copy of the License at
10 # http://www.apache.org/licenses/LICENSE-2.0
12 # Unless required by applicable law or agreed to in writing, software
13 # distributed under the License is distributed on an "AS IS" BASIS,
14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 # See the License for the specific language governing permissions and
16 # limitations under the License.
23 import concurrent
.futures
28 import tornado
.testing
34 from rift
.package
import convert
35 from rift
.tasklets
.rwlaunchpad
import onboard
39 gi
.require_version('ProjectNsdYang', '1.0')
40 gi
.require_version('ProjectVnfdYang', '1.0')
42 from gi
.repository
import (
43 ProjectNsdYang
as NsdYang
,
44 ProjectVnfdYang
as VnfdYang
,
48 class RestconfDescriptorHandler(tornado
.web
.RequestHandler
):
49 DESC_SERIALIZER_MAP
= {
50 "nsd": convert
.NsdSerializer(),
51 "vnfd": convert
.VnfdSerializer(),
54 class AuthError(Exception):
58 class ContentTypeError(Exception):
62 class RequestBodyError(Exception):
66 def initialize(self
, log
, auth
, info
):
68 # The superclass has self._log already defined so use a different name
71 self
._logger
.debug('Created restconf descriptor handler')
73 def _verify_auth(self
):
74 if self
._auth
is None:
77 auth_header
= self
.request
.headers
.get('Authorization')
78 if auth_header
is None or not auth_header
.startswith('Basic '):
80 self
.set_header('WWW-Authenticate', 'Basic realm=Restricted')
84 msg
= "Missing Authorization header"
85 self
._logger
.error(msg
)
86 raise RestconfDescriptorHandler
.AuthError(msg
)
88 auth_header
= auth_header
.encode('ascii')
89 auth_decoded
= base64
.decodebytes(auth_header
[6:]).decode()
90 login
, password
= auth_decoded
.split(':', 2)
93 is_auth
= ((login
, password
) == self
._auth
)
97 self
.set_header('WWW-Authenticate', 'Basic realm=Restricted')
101 msg
= "Incorrect username and password in auth header: got {}, expected {}".format(
102 (login
, password
), self
._auth
104 self
._logger
.error(msg
)
105 raise RestconfDescriptorHandler
.AuthError(msg
)
107 def _verify_content_type_header(self
):
108 content_type_header
= self
.request
.headers
.get('content-type')
109 if content_type_header
is None:
111 self
._transforms
= []
114 msg
= "Missing content-type header"
115 self
._logger
.error(msg
)
116 raise RestconfDescriptorHandler
.ContentTypeError(msg
)
118 if content_type_header
!= "application/vnd.yang.data+json":
120 self
._transforms
= []
123 msg
= "Unsupported content type: %s" % content_type_header
124 self
._logger
.error(msg
)
125 raise RestconfDescriptorHandler
.ContentTypeError(msg
)
127 def _verify_headers(self
):
129 self
._verify
_content
_type
_header
()
131 def _verify_request_body(self
, descriptor_type
):
132 if descriptor_type
not in RestconfDescriptorHandler
.DESC_SERIALIZER_MAP
:
133 raise ValueError("Unsupported descriptor type: %s" % descriptor_type
)
135 body
= self
.request
.body
136 bytes_hdl
= io
.BytesIO(body
)
138 serializer
= RestconfDescriptorHandler
.DESC_SERIALIZER_MAP
[descriptor_type
]
141 message
= serializer
.from_file_hdl(bytes_hdl
, ".json")
142 except convert
.SerializationError
as e
:
144 self
._transforms
= []
147 msg
= "Descriptor request body not valid"
148 self
._logger
.error(msg
)
149 raise RestconfDescriptorHandler
.RequestBodyError() from e
151 self
._info
.last_request_message
= message
153 self
._logger
.debug("Received a valid descriptor request")
155 def put(self
, descriptor_type
):
156 self
._info
.last_descriptor_type
= descriptor_type
157 self
._info
.last_method
= "PUT"
160 self
._verify
_headers
()
161 except (RestconfDescriptorHandler
.AuthError
,
162 RestconfDescriptorHandler
.ContentTypeError
):
166 self
._verify
_request
_body
(descriptor_type
)
167 except RestconfDescriptorHandler
.RequestBodyError
:
170 self
.write("Response doesn't matter?")
172 def post(self
, descriptor_type
):
173 self
._info
.last_descriptor_type
= descriptor_type
174 self
._info
.last_method
= "POST"
177 self
._verify
_headers
()
178 except (RestconfDescriptorHandler
.AuthError
,
179 RestconfDescriptorHandler
.ContentTypeError
):
183 self
._verify
_request
_body
(descriptor_type
)
184 except RestconfDescriptorHandler
.RequestBodyError
:
187 self
.write("Response doesn't matter?")
190 class HandlerInfo(object):
192 self
.last_request_message
= None
193 self
.last_descriptor_type
= None
194 self
.last_method
= None
197 class OnboardTestCase(tornado
.testing
.AsyncHTTPTestCase
):
198 AUTH
= ("admin", "admin")
200 self
._log
= logging
.getLogger(__file__
)
201 self
._loop
= asyncio
.get_event_loop()
203 self
._handler
_info
= HandlerInfo()
205 self
._port
= self
.get_http_port()
206 self
._onboarder
= onboard
.DescriptorOnboarder(
207 log
=self
._log
, port
=self
._port
210 def get_new_ioloop(self
):
211 return tornado
.platform
.asyncio
.AsyncIOMainLoop()
214 attrs
= dict(auth
=OnboardTestCase
.AUTH
, log
=self
._log
, info
=self
._handler
_info
)
215 return tornado
.web
.Application([
216 (r
"/api/config/.*/(nsd|vnfd)", RestconfDescriptorHandler
, attrs
),
219 @rift.test
.dts
.async_test
220 def test_onboard_nsd(self
):
221 nsd_msg
= NsdYang
.YangData_RwProject_Project_NsdCatalog_Nsd(id=str(uuid
.uuid4()), name
="nsd_name")
222 yield from self
._loop
.run_in_executor(None, self
._onboarder
.onboard
, nsd_msg
)
223 self
.assertEqual(self
._handler
_info
.last_request_message
, nsd_msg
)
224 self
.assertEqual(self
._handler
_info
.last_descriptor_type
, "nsd")
225 self
.assertEqual(self
._handler
_info
.last_method
, "POST")
227 @rift.test
.dts
.async_test
228 def test_update_nsd(self
):
229 nsd_msg
= NsdYang
.YangData_RwProject_Project_NsdCatalog_Nsd(id=str(uuid
.uuid4()), name
="nsd_name")
230 yield from self
._loop
.run_in_executor(None, self
._onboarder
.update
, nsd_msg
)
231 self
.assertEqual(self
._handler
_info
.last_request_message
, nsd_msg
)
232 self
.assertEqual(self
._handler
_info
.last_descriptor_type
, "nsd")
233 self
.assertEqual(self
._handler
_info
.last_method
, "PUT")
235 @rift.test
.dts
.async_test
236 def test_bad_descriptor_type(self
):
237 nsd_msg
= NsdYang
.YangData_RwProject_Project_NsdCatalog()
238 with self
.assertRaises(TypeError):
239 yield from self
._loop
.run_in_executor(None, self
._onboarder
.update
, nsd_msg
)
241 with self
.assertRaises(TypeError):
242 yield from self
._loop
.run_in_executor(None, self
._onboarder
.onboard
, nsd_msg
)
244 @rift.test
.dts
.async_test
245 def test_bad_port(self
):
246 # Use a port not used by the instantiated server
247 new_port
= self
._port
- 1
248 self
._onboarder
.port
= new_port
249 nsd_msg
= NsdYang
.YangData_RwProject_Project_NsdCatalog_Nsd(id=str(uuid
.uuid4()), name
="nsd_name")
251 with self
.assertRaises(onboard
.OnboardError
):
252 yield from self
._loop
.run_in_executor(None, self
._onboarder
.onboard
, nsd_msg
)
254 with self
.assertRaises(onboard
.UpdateError
):
255 yield from self
._loop
.run_in_executor(None, self
._onboarder
.update
, nsd_msg
)
257 @rift.test
.dts
.async_test
258 def test_timeout(self
):
259 # Set the timeout to something minimal to speed up test
260 self
._onboarder
.timeout
= .1
262 nsd_msg
= NsdYang
.YangData_RwProject_Project_NsdCatalog_Nsd(id=str(uuid
.uuid4()), name
="nsd_name")
264 # Force the request to timeout by running the call synchronously so the
265 with self
.assertRaises(onboard
.OnboardError
):
266 self
._onboarder
.onboard(nsd_msg
)
268 # Force the request to timeout by running the call synchronously so the
269 with self
.assertRaises(onboard
.UpdateError
):
270 self
._onboarder
.update(nsd_msg
)
273 def main(argv
=sys
.argv
[1:]):
274 logging
.basicConfig(format
='TEST %(message)s')
276 runner
= xmlrunner
.XMLTestRunner(output
=os
.environ
["RIFT_MODULE_TEST"])
277 parser
= argparse
.ArgumentParser()
278 parser
.add_argument('-v', '--verbose', action
='store_true')
279 parser
.add_argument('-n', '--no-runner', action
='store_true')
281 args
, unknown
= parser
.parse_known_args(argv
)
285 # Set the global logging level
286 logging
.getLogger().setLevel(logging
.DEBUG
if args
.verbose
else logging
.ERROR
)
288 # The unittest framework requires a program name, so use the name of this
289 # file instead (we do not want to have to pass a fake program name to main
290 # when this is called from the interpreter).
291 unittest
.main(argv
=[__file__
] + unknown
+ ["-v"], testRunner
=runner
)
293 if __name__
== '__main__':