4 # Copyright 2016 RIFT.IO Inc
6 # Licensed under the Apache License, Version 2.0 (the "License");
7 # you may not use this file except in compliance with the License.
8 # You may obtain a copy of the License at
10 # http://www.apache.org/licenses/LICENSE-2.0
12 # Unless required by applicable law or agreed to in writing, software
13 # distributed under the License is distributed on an "AS IS" BASIS,
14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 # See the License for the specific language governing permissions and
16 # limitations under the License.
23 import concurrent
.futures
29 import tornado
.testing
35 #Setting RIFT_VAR_ROOT if not already set for unit test execution
36 if "RIFT_VAR_ROOT" not in os
.environ
:
37 os
.environ
['RIFT_VAR_ROOT'] = os
.path
.join(os
.environ
['RIFT_INSTALL'], 'var/rift/unittest')
39 from rift
.package
import convert
40 from rift
.tasklets
.rwlaunchpad
import onboard
45 gi
.require_version('NsdYang', '1.0')
46 gi
.require_version('VnfdYang', '1.0')
47 gi
.require_version('ProjectNsdYang', '1.0')
48 gi
.require_version('ProjectVnfdYang', '1.0')
50 from gi
.repository
import (
58 class RestconfDescriptorHandler(tornado
.web
.RequestHandler
):
59 class AuthError(Exception):
63 class ContentTypeError(Exception):
67 class RequestBodyError(Exception):
71 def initialize(self
, log
, auth
, info
):
73 # The superclass has self._log already defined so use a different name
76 self
._logger
.debug('Created restconf descriptor handler')
78 def _verify_auth(self
):
79 if self
._auth
is None:
82 auth_header
= self
.request
.headers
.get('Authorization')
83 if auth_header
is None or not auth_header
.startswith('Basic '):
85 self
.set_header('WWW-Authenticate', 'Basic realm=Restricted')
89 msg
= "Missing Authorization header"
90 self
._logger
.error(msg
)
91 raise RestconfDescriptorHandler
.AuthError(msg
)
93 auth_header
= auth_header
.encode('ascii')
94 auth_decoded
= base64
.decodebytes(auth_header
[6:]).decode()
95 login
, password
= auth_decoded
.split(':', 2)
98 is_auth
= ((login
, password
) == self
._auth
)
102 self
.set_header('WWW-Authenticate', 'Basic realm=Restricted')
103 self
._transforms
= []
106 msg
= "Incorrect username and password in auth header: got {}, expected {}".format(
107 (login
, password
), self
._auth
109 self
._logger
.error(msg
)
110 raise RestconfDescriptorHandler
.AuthError(msg
)
112 def _verify_content_type_header(self
):
113 content_type_header
= self
.request
.headers
.get('content-type')
114 if content_type_header
is None:
116 self
._transforms
= []
119 msg
= "Missing content-type header"
120 self
._logger
.error(msg
)
121 raise RestconfDescriptorHandler
.ContentTypeError(msg
)
123 if content_type_header
!= "application/vnd.yang.data+json":
125 self
._transforms
= []
128 msg
= "Unsupported content type: %s" % content_type_header
129 self
._logger
.error(msg
)
130 raise RestconfDescriptorHandler
.ContentTypeError(msg
)
132 def _verify_headers(self
):
134 self
._verify
_content
_type
_header
()
136 def _verify_request_body(self
, descriptor_type
):
137 if descriptor_type
not in ['nsd', 'vnfd']:
138 raise ValueError("Unsupported descriptor type: %s" % descriptor_type
)
140 body
= convert
.decode(self
.request
.body
)
141 self
._logger
.debug("Received msg: {}".format(body
))
144 message
= json
.loads(body
)
145 except convert
.SerializationError
as e
:
147 self
._transforms
= []
150 msg
= "Descriptor request body not valid"
151 self
._logger
.error(msg
)
152 raise RestconfDescriptorHandler
.RequestBodyError() from e
154 self
._info
.last_request_message
= message
156 self
._logger
.debug("Received a valid descriptor request: {}".format(message
))
158 def put(self
, descriptor_type
):
159 self
._info
.last_descriptor_type
= descriptor_type
160 self
._info
.last_method
= "PUT"
163 self
._verify
_headers
()
164 except (RestconfDescriptorHandler
.AuthError
,
165 RestconfDescriptorHandler
.ContentTypeError
):
169 self
._verify
_request
_body
(descriptor_type
)
170 except RestconfDescriptorHandler
.RequestBodyError
:
173 self
.write("Response doesn't matter?")
175 def post(self
, descriptor_type
):
176 self
._info
.last_descriptor_type
= descriptor_type
177 self
._info
.last_method
= "POST"
180 self
._verify
_headers
()
181 except (RestconfDescriptorHandler
.AuthError
,
182 RestconfDescriptorHandler
.ContentTypeError
):
186 self
._verify
_request
_body
(descriptor_type
)
187 except RestconfDescriptorHandler
.RequestBodyError
:
190 self
.write("Response doesn't matter?")
193 class HandlerInfo(object):
195 self
.last_request_message
= None
196 self
.last_descriptor_type
= None
197 self
.last_method
= None
200 class OnboardTestCase(tornado
.testing
.AsyncHTTPTestCase
):
201 DESC_SERIALIZER_MAP
= {
202 "nsd": convert
.NsdSerializer(),
203 "vnfd": convert
.VnfdSerializer(),
206 AUTH
= ("admin","admin")
208 self
._log
= logging
.getLogger(__file__
)
209 self
._loop
= asyncio
.get_event_loop()
211 self
._handler
_info
= HandlerInfo()
213 self
._port
= self
.get_http_port()
214 self
._onboarder
= onboard
.DescriptorOnboarder(
215 log
=self
._log
, port
=self
._port
218 def get_new_ioloop(self
):
219 return tornado
.platform
.asyncio
.AsyncIOMainLoop()
222 attrs
= dict(auth
=OnboardTestCase
.AUTH
, log
=self
._log
, info
=self
._handler
_info
)
223 return tornado
.web
.Application([
224 (r
"/api/config/project/default/.*/(nsd|vnfd)",
225 RestconfDescriptorHandler
, attrs
),
229 def get_msg(self
, desc
=None):
231 desc
= NsdYang
.YangData_Nsd_NsdCatalog_Nsd(id=str(uuid
.uuid4()), name
="nsd_name")
232 serializer
= OnboardTestCase
.DESC_SERIALIZER_MAP
['nsd']
233 jstr
= serializer
.to_json_string(desc
, project_ns
=False)
235 hdl
= io
.BytesIO(str.encode(jstr
))
236 return serializer
.from_file_hdl(hdl
, ".json")
238 def get_json(self
, msg
):
239 serializer
= OnboardTestCase
.DESC_SERIALIZER_MAP
['nsd']
240 json_data
= serializer
.to_json_string(msg
, project_ns
=True)
241 return json
.loads(json_data
)
243 @rift.test
.dts
.async_test
244 def test_onboard_nsd(self
):
245 nsd_msg
= self
.get_msg()
246 yield from self
._loop
.run_in_executor(None, functools
.partial(self
._onboarder
.onboard
, descriptor_msg
=nsd_msg
, auth
=OnboardTestCase
.AUTH
))
247 self
.assertEqual(self
._handler
_info
.last_request_message
, self
.get_json(nsd_msg
))
248 self
.assertEqual(self
._handler
_info
.last_descriptor_type
, "nsd")
249 self
.assertEqual(self
._handler
_info
.last_method
, "POST")
251 @rift.test
.dts
.async_test
252 def test_update_nsd(self
):
253 nsd_msg
= self
.get_msg()
254 yield from self
._loop
.run_in_executor(None, functools
.partial(self
._onboarder
.update
, descriptor_msg
=nsd_msg
, auth
=OnboardTestCase
.AUTH
))
255 self
.assertEqual(self
._handler
_info
.last_request_message
, self
.get_json(nsd_msg
))
256 self
.assertEqual(self
._handler
_info
.last_descriptor_type
, "nsd")
257 self
.assertEqual(self
._handler
_info
.last_method
, "PUT")
259 @rift.test
.dts
.async_test
260 def test_bad_descriptor_type(self
):
261 nsd_msg
= NsdYang
.YangData_Nsd_NsdCatalog_Nsd()
262 with self
.assertRaises(TypeError):
263 yield from self
._loop
.run_in_executor(None, self
._onboarder
.update
, nsd_msg
)
265 with self
.assertRaises(TypeError):
266 yield from self
._loop
.run_in_executor(None, self
._onboarder
.onboard
, nsd_msg
)
268 @rift.test
.dts
.async_test
269 def test_bad_port(self
):
270 # Use a port not used by the instantiated server
271 new_port
= self
._port
- 1
272 self
._onboarder
.port
= new_port
273 nsd_msg
= self
.get_msg()
275 with self
.assertRaises(onboard
.OnboardError
):
276 yield from self
._loop
.run_in_executor(None, self
._onboarder
.onboard
, nsd_msg
)
278 with self
.assertRaises(onboard
.UpdateError
):
279 yield from self
._loop
.run_in_executor(None, self
._onboarder
.update
, nsd_msg
)
281 @rift.test
.dts
.async_test
282 def test_timeout(self
):
283 # Set the timeout to something minimal to speed up test
284 self
._onboarder
.timeout
= .1
286 nsd_msg
= self
.get_msg()
288 # Force the request to timeout by running the call synchronously so the
289 with self
.assertRaises(onboard
.OnboardError
):
290 self
._onboarder
.onboard(nsd_msg
)
292 # Force the request to timeout by running the call synchronously so the
293 with self
.assertRaises(onboard
.UpdateError
):
294 self
._onboarder
.update(nsd_msg
)
297 def main(argv
=sys
.argv
[1:]):
298 logging
.basicConfig(format
='TEST %(message)s')
300 runner
= xmlrunner
.XMLTestRunner(output
=os
.environ
["RIFT_MODULE_TEST"])
301 parser
= argparse
.ArgumentParser()
302 parser
.add_argument('-v', '--verbose', action
='store_true')
303 parser
.add_argument('-n', '--no-runner', action
='store_true')
305 args
, unknown
= parser
.parse_known_args(argv
)
309 # Set the global logging level
310 logging
.getLogger().setLevel(logging
.DEBUG
if args
.verbose
else logging
.ERROR
)
312 # The unittest framework requires a program name, so use the name of this
313 # file instead (we do not want to have to pass a fake program name to main
314 # when this is called from the interpreter).
315 unittest
.main(argv
=[__file__
] + unknown
+ ["-v"], testRunner
=runner
)
317 if __name__
== '__main__':