4 # Copyright 2016 RIFT.IO Inc
6 # Licensed under the Apache License, Version 2.0 (the "License");
7 # you may not use this file except in compliance with the License.
8 # You may obtain a copy of the License at
10 # http://www.apache.org/licenses/LICENSE-2.0
12 # Unless required by applicable law or agreed to in writing, software
13 # distributed under the License is distributed on an "AS IS" BASIS,
14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 # See the License for the specific language governing permissions and
16 # limitations under the License.
23 import concurrent
.futures
29 import tornado
.testing
35 from rift
.package
import convert
36 from rift
.tasklets
.rwlaunchpad
import onboard
40 gi
.require_version('NsdYang', '1.0')
41 gi
.require_version('VnfdYang', '1.0')
42 gi
.require_version('ProjectNsdYang', '1.0')
43 gi
.require_version('ProjectVnfdYang', '1.0')
45 from gi
.repository
import (
53 class RestconfDescriptorHandler(tornado
.web
.RequestHandler
):
54 class AuthError(Exception):
58 class ContentTypeError(Exception):
62 class RequestBodyError(Exception):
66 def initialize(self
, log
, auth
, info
):
68 # The superclass has self._log already defined so use a different name
71 self
._logger
.debug('Created restconf descriptor handler')
73 def _verify_auth(self
):
74 if self
._auth
is None:
77 auth_header
= self
.request
.headers
.get('Authorization')
78 if auth_header
is None or not auth_header
.startswith('Basic '):
80 self
.set_header('WWW-Authenticate', 'Basic realm=Restricted')
84 msg
= "Missing Authorization header"
85 self
._logger
.error(msg
)
86 raise RestconfDescriptorHandler
.AuthError(msg
)
88 auth_header
= auth_header
.encode('ascii')
89 auth_decoded
= base64
.decodebytes(auth_header
[6:]).decode()
90 login
, password
= auth_decoded
.split(':', 2)
93 is_auth
= ((login
, password
) == self
._auth
)
97 self
.set_header('WWW-Authenticate', 'Basic realm=Restricted')
101 msg
= "Incorrect username and password in auth header: got {}, expected {}".format(
102 (login
, password
), self
._auth
104 self
._logger
.error(msg
)
105 raise RestconfDescriptorHandler
.AuthError(msg
)
107 def _verify_content_type_header(self
):
108 content_type_header
= self
.request
.headers
.get('content-type')
109 if content_type_header
is None:
111 self
._transforms
= []
114 msg
= "Missing content-type header"
115 self
._logger
.error(msg
)
116 raise RestconfDescriptorHandler
.ContentTypeError(msg
)
118 if content_type_header
!= "application/vnd.yang.data+json":
120 self
._transforms
= []
123 msg
= "Unsupported content type: %s" % content_type_header
124 self
._logger
.error(msg
)
125 raise RestconfDescriptorHandler
.ContentTypeError(msg
)
127 def _verify_headers(self
):
129 self
._verify
_content
_type
_header
()
131 def _verify_request_body(self
, descriptor_type
):
132 if descriptor_type
not in ['nsd', 'vnfd']:
133 raise ValueError("Unsupported descriptor type: %s" % descriptor_type
)
135 body
= convert
.decode(self
.request
.body
)
136 self
._logger
.debug("Received msg: {}".format(body
))
139 message
= json
.loads(body
)
140 except convert
.SerializationError
as e
:
142 self
._transforms
= []
145 msg
= "Descriptor request body not valid"
146 self
._logger
.error(msg
)
147 raise RestconfDescriptorHandler
.RequestBodyError() from e
149 self
._info
.last_request_message
= message
151 self
._logger
.debug("Received a valid descriptor request: {}".format(message
))
153 def put(self
, descriptor_type
):
154 self
._info
.last_descriptor_type
= descriptor_type
155 self
._info
.last_method
= "PUT"
158 self
._verify
_headers
()
159 except (RestconfDescriptorHandler
.AuthError
,
160 RestconfDescriptorHandler
.ContentTypeError
):
164 self
._verify
_request
_body
(descriptor_type
)
165 except RestconfDescriptorHandler
.RequestBodyError
:
168 self
.write("Response doesn't matter?")
170 def post(self
, descriptor_type
):
171 self
._info
.last_descriptor_type
= descriptor_type
172 self
._info
.last_method
= "POST"
175 self
._verify
_headers
()
176 except (RestconfDescriptorHandler
.AuthError
,
177 RestconfDescriptorHandler
.ContentTypeError
):
181 self
._verify
_request
_body
(descriptor_type
)
182 except RestconfDescriptorHandler
.RequestBodyError
:
185 self
.write("Response doesn't matter?")
188 class HandlerInfo(object):
190 self
.last_request_message
= None
191 self
.last_descriptor_type
= None
192 self
.last_method
= None
195 class OnboardTestCase(tornado
.testing
.AsyncHTTPTestCase
):
196 DESC_SERIALIZER_MAP
= {
197 "nsd": convert
.NsdSerializer(),
198 "vnfd": convert
.VnfdSerializer(),
201 AUTH
= ("admin", "admin")
203 self
._log
= logging
.getLogger(__file__
)
204 self
._loop
= asyncio
.get_event_loop()
206 self
._handler
_info
= HandlerInfo()
208 self
._port
= self
.get_http_port()
209 self
._onboarder
= onboard
.DescriptorOnboarder(
210 log
=self
._log
, port
=self
._port
213 def get_new_ioloop(self
):
214 return tornado
.platform
.asyncio
.AsyncIOMainLoop()
217 attrs
= dict(auth
=OnboardTestCase
.AUTH
, log
=self
._log
, info
=self
._handler
_info
)
218 return tornado
.web
.Application([
219 (r
"/api/config/project/default/.*/(nsd|vnfd)",
220 RestconfDescriptorHandler
, attrs
),
224 def get_msg(self
, desc
=None):
226 desc
= NsdYang
.YangData_Nsd_NsdCatalog_Nsd(id=str(uuid
.uuid4()), name
="nsd_name")
227 serializer
= OnboardTestCase
.DESC_SERIALIZER_MAP
['nsd']
228 jstr
= serializer
.to_json_string(desc
, project_ns
=False)
230 hdl
= io
.BytesIO(str.encode(jstr
))
231 return serializer
.from_file_hdl(hdl
, ".json")
233 def get_json(self
, msg
):
234 serializer
= OnboardTestCase
.DESC_SERIALIZER_MAP
['nsd']
235 json_data
= serializer
.to_json_string(msg
, project_ns
=True)
236 return json
.loads(json_data
)
238 @rift.test
.dts
.async_test
239 def test_onboard_nsd(self
):
240 nsd_msg
= self
.get_msg()
241 yield from self
._loop
.run_in_executor(None, self
._onboarder
.onboard
, nsd_msg
)
242 self
.assertEqual(self
._handler
_info
.last_request_message
, self
.get_json(nsd_msg
))
243 self
.assertEqual(self
._handler
_info
.last_descriptor_type
, "nsd")
244 self
.assertEqual(self
._handler
_info
.last_method
, "POST")
246 @rift.test
.dts
.async_test
247 def test_update_nsd(self
):
248 nsd_msg
= self
.get_msg()
249 yield from self
._loop
.run_in_executor(None, self
._onboarder
.update
, nsd_msg
)
250 self
.assertEqual(self
._handler
_info
.last_request_message
, self
.get_json(nsd_msg
))
251 self
.assertEqual(self
._handler
_info
.last_descriptor_type
, "nsd")
252 self
.assertEqual(self
._handler
_info
.last_method
, "PUT")
254 @rift.test
.dts
.async_test
255 def test_bad_descriptor_type(self
):
256 nsd_msg
= NsdYang
.YangData_Nsd_NsdCatalog_Nsd()
257 with self
.assertRaises(TypeError):
258 yield from self
._loop
.run_in_executor(None, self
._onboarder
.update
, nsd_msg
)
260 with self
.assertRaises(TypeError):
261 yield from self
._loop
.run_in_executor(None, self
._onboarder
.onboard
, nsd_msg
)
263 @rift.test
.dts
.async_test
264 def test_bad_port(self
):
265 # Use a port not used by the instantiated server
266 new_port
= self
._port
- 1
267 self
._onboarder
.port
= new_port
268 nsd_msg
= self
.get_msg()
270 with self
.assertRaises(onboard
.OnboardError
):
271 yield from self
._loop
.run_in_executor(None, self
._onboarder
.onboard
, nsd_msg
)
273 with self
.assertRaises(onboard
.UpdateError
):
274 yield from self
._loop
.run_in_executor(None, self
._onboarder
.update
, nsd_msg
)
276 @rift.test
.dts
.async_test
277 def test_timeout(self
):
278 # Set the timeout to something minimal to speed up test
279 self
._onboarder
.timeout
= .1
281 nsd_msg
= self
.get_msg()
283 # Force the request to timeout by running the call synchronously so the
284 with self
.assertRaises(onboard
.OnboardError
):
285 self
._onboarder
.onboard(nsd_msg
)
287 # Force the request to timeout by running the call synchronously so the
288 with self
.assertRaises(onboard
.UpdateError
):
289 self
._onboarder
.update(nsd_msg
)
292 def main(argv
=sys
.argv
[1:]):
293 logging
.basicConfig(format
='TEST %(message)s')
295 runner
= xmlrunner
.XMLTestRunner(output
=os
.environ
["RIFT_MODULE_TEST"])
296 parser
= argparse
.ArgumentParser()
297 parser
.add_argument('-v', '--verbose', action
='store_true')
298 parser
.add_argument('-n', '--no-runner', action
='store_true')
300 args
, unknown
= parser
.parse_known_args(argv
)
304 # Set the global logging level
305 logging
.getLogger().setLevel(logging
.DEBUG
if args
.verbose
else logging
.ERROR
)
307 # The unittest framework requires a program name, so use the name of this
308 # file instead (we do not want to have to pass a fake program name to main
309 # when this is called from the interpreter).
310 unittest
.main(argv
=[__file__
] + unknown
+ ["-v"], testRunner
=runner
)
312 if __name__
== '__main__':