4 # Copyright 2016 RIFT.IO Inc
6 # Licensed under the Apache License, Version 2.0 (the "License");
7 # you may not use this file except in compliance with the License.
8 # You may obtain a copy of the License at
10 # http://www.apache.org/licenses/LICENSE-2.0
12 # Unless required by applicable law or agreed to in writing, software
13 # distributed under the License is distributed on an "AS IS" BASIS,
14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 # See the License for the specific language governing permissions and
16 # limitations under the License.
23 import concurrent
.futures
28 import tornado
.testing
34 #Setting RIFT_VAR_ROOT if not already set for unit test execution
35 if "RIFT_VAR_ROOT" not in os
.environ
:
36 os
.environ
['RIFT_VAR_ROOT'] = os
.path
.join(os
.environ
['RIFT_INSTALL'], 'var/rift/unittest')
38 from rift
.package
import convert
39 from rift
.tasklets
.rwlaunchpad
import onboard
43 gi
.require_version('NsdYang', '1.0')
44 gi
.require_version('VnfdYang', '1.0')
46 from gi
.repository
import (
52 class RestconfDescriptorHandler(tornado
.web
.RequestHandler
):
53 DESC_SERIALIZER_MAP
= {
54 "nsd": convert
.NsdSerializer(),
55 "vnfd": convert
.VnfdSerializer(),
58 class AuthError(Exception):
62 class ContentTypeError(Exception):
66 class RequestBodyError(Exception):
70 def initialize(self
, log
, auth
, info
):
72 # The superclass has self._log already defined so use a different name
75 self
._logger
.debug('Created restconf descriptor handler')
77 def _verify_auth(self
):
78 if self
._auth
is None:
81 auth_header
= self
.request
.headers
.get('Authorization')
82 if auth_header
is None or not auth_header
.startswith('Basic '):
84 self
.set_header('WWW-Authenticate', 'Basic realm=Restricted')
88 msg
= "Missing Authorization header"
89 self
._logger
.error(msg
)
90 raise RestconfDescriptorHandler
.AuthError(msg
)
92 auth_header
= auth_header
.encode('ascii')
93 auth_decoded
= base64
.decodebytes(auth_header
[6:]).decode()
94 login
, password
= auth_decoded
.split(':', 2)
97 is_auth
= ((login
, password
) == self
._auth
)
101 self
.set_header('WWW-Authenticate', 'Basic realm=Restricted')
102 self
._transforms
= []
105 msg
= "Incorrect username and password in auth header: got {}, expected {}".format(
106 (login
, password
), self
._auth
108 self
._logger
.error(msg
)
109 raise RestconfDescriptorHandler
.AuthError(msg
)
111 def _verify_content_type_header(self
):
112 content_type_header
= self
.request
.headers
.get('content-type')
113 if content_type_header
is None:
115 self
._transforms
= []
118 msg
= "Missing content-type header"
119 self
._logger
.error(msg
)
120 raise RestconfDescriptorHandler
.ContentTypeError(msg
)
122 if content_type_header
!= "application/vnd.yang.data+json":
124 self
._transforms
= []
127 msg
= "Unsupported content type: %s" % content_type_header
128 self
._logger
.error(msg
)
129 raise RestconfDescriptorHandler
.ContentTypeError(msg
)
131 def _verify_headers(self
):
133 self
._verify
_content
_type
_header
()
135 def _verify_request_body(self
, descriptor_type
):
136 if descriptor_type
not in RestconfDescriptorHandler
.DESC_SERIALIZER_MAP
:
137 raise ValueError("Unsupported descriptor type: %s" % descriptor_type
)
139 body
= self
.request
.body
140 bytes_hdl
= io
.BytesIO(body
)
142 serializer
= RestconfDescriptorHandler
.DESC_SERIALIZER_MAP
[descriptor_type
]
145 message
= serializer
.from_file_hdl(bytes_hdl
, ".json")
146 except convert
.SerializationError
as e
:
148 self
._transforms
= []
151 msg
= "Descriptor request body not valid"
152 self
._logger
.error(msg
)
153 raise RestconfDescriptorHandler
.RequestBodyError() from e
155 self
._info
.last_request_message
= message
157 self
._logger
.debug("Received a valid descriptor request")
159 def put(self
, descriptor_type
):
160 self
._info
.last_descriptor_type
= descriptor_type
161 self
._info
.last_method
= "PUT"
164 self
._verify
_headers
()
165 except (RestconfDescriptorHandler
.AuthError
,
166 RestconfDescriptorHandler
.ContentTypeError
):
170 self
._verify
_request
_body
(descriptor_type
)
171 except RestconfDescriptorHandler
.RequestBodyError
:
174 self
.write("Response doesn't matter?")
176 def post(self
, descriptor_type
):
177 self
._info
.last_descriptor_type
= descriptor_type
178 self
._info
.last_method
= "POST"
181 self
._verify
_headers
()
182 except (RestconfDescriptorHandler
.AuthError
,
183 RestconfDescriptorHandler
.ContentTypeError
):
187 self
._verify
_request
_body
(descriptor_type
)
188 except RestconfDescriptorHandler
.RequestBodyError
:
191 self
.write("Response doesn't matter?")
194 class HandlerInfo(object):
196 self
.last_request_message
= None
197 self
.last_descriptor_type
= None
198 self
.last_method
= None
201 class OnboardTestCase(tornado
.testing
.AsyncHTTPTestCase
):
202 AUTH
= ("admin", "admin")
204 self
._log
= logging
.getLogger(__file__
)
205 self
._loop
= asyncio
.get_event_loop()
207 self
._handler
_info
= HandlerInfo()
209 self
._port
= self
.get_http_port()
210 self
._onboarder
= onboard
.DescriptorOnboarder(
211 log
=self
._log
, port
=self
._port
214 def get_new_ioloop(self
):
215 return tornado
.platform
.asyncio
.AsyncIOMainLoop()
218 attrs
= dict(auth
=OnboardTestCase
.AUTH
, log
=self
._log
, info
=self
._handler
_info
)
219 return tornado
.web
.Application([
220 (r
"/api/config/.*/(nsd|vnfd)", RestconfDescriptorHandler
, attrs
),
223 @rift.test
.dts
.async_test
224 def test_onboard_nsd(self
):
225 nsd_msg
= NsdYang
.YangData_Nsd_NsdCatalog_Nsd(id=str(uuid
.uuid4()), name
="nsd_name")
226 yield from self
._loop
.run_in_executor(None, self
._onboarder
.onboard
, nsd_msg
)
227 self
.assertEqual(self
._handler
_info
.last_request_message
, nsd_msg
)
228 self
.assertEqual(self
._handler
_info
.last_descriptor_type
, "nsd")
229 self
.assertEqual(self
._handler
_info
.last_method
, "POST")
231 @rift.test
.dts
.async_test
232 def test_update_nsd(self
):
233 nsd_msg
= NsdYang
.YangData_Nsd_NsdCatalog_Nsd(id=str(uuid
.uuid4()), name
="nsd_name")
234 yield from self
._loop
.run_in_executor(None, self
._onboarder
.update
, nsd_msg
)
235 self
.assertEqual(self
._handler
_info
.last_request_message
, nsd_msg
)
236 self
.assertEqual(self
._handler
_info
.last_descriptor_type
, "nsd")
237 self
.assertEqual(self
._handler
_info
.last_method
, "PUT")
239 @rift.test
.dts
.async_test
240 def test_bad_descriptor_type(self
):
241 nsd_msg
= NsdYang
.YangData_Nsd_NsdCatalog()
242 with self
.assertRaises(TypeError):
243 yield from self
._loop
.run_in_executor(None, self
._onboarder
.update
, nsd_msg
)
245 with self
.assertRaises(TypeError):
246 yield from self
._loop
.run_in_executor(None, self
._onboarder
.onboard
, nsd_msg
)
248 @rift.test
.dts
.async_test
249 def test_bad_port(self
):
250 # Use a port not used by the instantiated server
251 new_port
= self
._port
- 1
252 self
._onboarder
.port
= new_port
253 nsd_msg
= NsdYang
.YangData_Nsd_NsdCatalog_Nsd(id=str(uuid
.uuid4()), name
="nsd_name")
255 with self
.assertRaises(onboard
.OnboardError
):
256 yield from self
._loop
.run_in_executor(None, self
._onboarder
.onboard
, nsd_msg
)
258 with self
.assertRaises(onboard
.UpdateError
):
259 yield from self
._loop
.run_in_executor(None, self
._onboarder
.update
, nsd_msg
)
261 @rift.test
.dts
.async_test
262 def test_timeout(self
):
263 # Set the timeout to something minimal to speed up test
264 self
._onboarder
.timeout
= .1
266 nsd_msg
= NsdYang
.YangData_Nsd_NsdCatalog_Nsd(id=str(uuid
.uuid4()), name
="nsd_name")
268 # Force the request to timeout by running the call synchronously so the
269 with self
.assertRaises(onboard
.OnboardError
):
270 self
._onboarder
.onboard(nsd_msg
)
272 # Force the request to timeout by running the call synchronously so the
273 with self
.assertRaises(onboard
.UpdateError
):
274 self
._onboarder
.update(nsd_msg
)
277 def main(argv
=sys
.argv
[1:]):
278 logging
.basicConfig(format
='TEST %(message)s')
280 runner
= xmlrunner
.XMLTestRunner(output
=os
.environ
["RIFT_MODULE_TEST"])
281 parser
= argparse
.ArgumentParser()
282 parser
.add_argument('-v', '--verbose', action
='store_true')
283 parser
.add_argument('-n', '--no-runner', action
='store_true')
285 args
, unknown
= parser
.parse_known_args(argv
)
289 # Set the global logging level
290 logging
.getLogger().setLevel(logging
.DEBUG
if args
.verbose
else logging
.ERROR
)
292 # The unittest framework requires a program name, so use the name of this
293 # file instead (we do not want to have to pass a fake program name to main
294 # when this is called from the interpreter).
295 unittest
.main(argv
=[__file__
] + unknown
+ ["-v"], testRunner
=runner
)
297 if __name__
== '__main__':