4 # Copyright 2016 RIFT.IO Inc
6 # Licensed under the Apache License, Version 2.0 (the "License");
7 # you may not use this file except in compliance with the License.
8 # You may obtain a copy of the License at
10 # http://www.apache.org/licenses/LICENSE-2.0
12 # Unless required by applicable law or agreed to in writing, software
13 # distributed under the License is distributed on an "AS IS" BASIS,
14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 # See the License for the specific language governing permissions and
16 # limitations under the License.
25 import tornado
.testing
32 from requests_toolbelt
import MultipartEncoder
41 from rift
.rwlib
.util
import certs
43 from rift
.package
.handler
import FileRestApiHandler
44 from rift
.tasklets
.rwstagingmgr
.server
.app
import StagingApplication
, CleanUpStaging
45 from rift
.tasklets
.rwstagingmgr
.model
import StagingArea
48 gi
.require_version('RwStagingMgmtYang', '1.0')
49 from gi
.repository
import (
54 class TestCase(tornado
.testing
.AsyncHTTPTestCase
):
56 self
._log
= logging
.getLogger(__file__
)
57 self
._loop
= asyncio
.get_event_loop()
60 self
._port
= self
.get_http_port()
62 def get_new_ioloop(self
):
63 return tornado
.platform
.asyncio
.AsyncIOMainLoop()
65 def create_mock_store(self
):
66 self
.staging_dir_tmp
= tempfile
.mkdtemp()
67 self
.staging_id
= str(uuid
.uuid4())
68 self
.staging_dir
= os
.path
.join(self
.staging_dir_tmp
, self
.staging_id
)
69 os
.makedirs(self
.staging_dir
)
70 mock_model
= RwStagingMgmtYang
.YangData_RwProject_Project_StagingAreas_StagingArea
.from_dict({
71 'path': self
.staging_dir
,
72 "validity_time": int(time
.time()) + 5
75 with
open(os
.path
.join(self
.staging_dir
, "meta.yaml"), "w") as fh
:
76 yaml
.dump(mock_model
.as_dict(), fh
, default_flow_style
=True)
78 mock_model
= StagingArea(mock_model
)
79 store
= mock
.MagicMock()
80 store
.get_staging_area
.return_value
= mock_model
81 store
.root_dir
= self
.staging_dir_tmp
82 store
.tmp_dir
= self
.staging_dir_tmp
83 store
.META_YAML
= "meta.yaml"
84 store
.remove_staging_area
= mock
.Mock(return_value
=None)
86 return store
, mock_model
88 def create_tmp_file(self
):
89 _
, self
.temp_file
= tempfile
.mkstemp()
90 with
open(self
.temp_file
, "w") as fh
:
91 fh
.write("Lorem Ipsum")
97 self
.store
, self
.mock_model
= self
.create_mock_store()
98 return StagingApplication(self
.store
, self
._loop
, cleanup_interval
=5)
100 def test_file_upload_and_download(self
):
105 2. the response of the file upload
106 3. Finally downloads the file and verifies if the uploaded and download
108 4. Verify if the directory is cleaned up after expiry
110 temp_file
= self
.create_tmp_file()
111 form
= MultipartEncoder(fields
={
112 'file': (os
.path
.basename(temp_file
), open(temp_file
, 'rb'), 'application/octet-stream')})
115 response
= self
.fetch("/api/upload/{}".format(self
.staging_id
),
117 body
=form
.to_string(),
118 headers
={"Content-Type": "multipart/form-data"})
120 assert response
.code
== 200
122 assert os
.path
.isfile(os
.path
.join(
124 os
.path
.basename(temp_file
)))
125 assert self
.staging_id
in response
.body
.decode("utf-8")
127 response
= response
.body
.decode("utf-8")
128 response
= json
.loads(response
)
131 _
, downloaded_file
= tempfile
.mkstemp()
132 response
= self
.fetch(response
['path'])
134 with
open(downloaded_file
, 'wb') as fh
:
135 fh
.write(response
.body
)
137 assert filecmp
.cmp(temp_file
, downloaded_file
)
139 print (self
.get_url('/'))
140 print (self
.staging_dir
)
143 self
.store
.remove_staging_area(self
.mock_model
)
144 self
.store
.remove_staging_area
.assert_called_once_with(self
.mock_model
)
147 shutil
.rmtree(self
.staging_dir_tmp
)
150 def main(argv
=sys
.argv
[1:]):
151 logging
.basicConfig(format
='TEST %(message)s')
153 runner
= xmlrunner
.XMLTestRunner(output
=os
.environ
["RIFT_MODULE_TEST"])
154 parser
= argparse
.ArgumentParser()
155 parser
.add_argument('-v', '--verbose', action
='store_true')
156 parser
.add_argument('-n', '--no-runner', action
='store_true')
158 args
, unknown
= parser
.parse_known_args(argv
)
162 # Set the global logging level
163 logging
.getLogger().setLevel(logging
.DEBUG
if args
.verbose
else logging
.ERROR
)
165 # The unittest framework requires a program name, so use the name of this
166 # file instead (we do not want to have to pass a fake program name to main
167 # when this is called from the interpreter).
168 unittest
.main(argv
=[__file__
] + unknown
+ ["-v"], testRunner
=runner
)
170 if __name__
== '__main__':