Reformat POL to standardized format
[osm/POL.git] / osm_policy_module / tests / integration / test_kafka_messages.py
index d17d2b2..725cc3f 100644 (file)
 # For those usages not covered by the Apache License, Version 2.0 please
 # contact: bdiaz@whitestack.com or glavado@whitestack.com
 ##
+import asyncio
 import json
 import logging
 import os
 import sys
 import unittest
 
-from kafka import KafkaProducer, KafkaConsumer
+from aiokafka import AIOKafkaProducer, AIOKafkaConsumer
 from kafka.errors import KafkaError
 
-from osm_policy_module.core.agent import PolicyModuleAgent
 from osm_policy_module.core.config import Config
 
 log = logging.getLogger()
@@ -41,39 +41,60 @@ log.addHandler(stream_handler)
 
 class KafkaMessagesTest(unittest.TestCase):
     def setUp(self):
-        try:
-            cfg = Config.instance()
-            kafka_server = '{}:{}'.format(cfg.OSMPOL_MESSAGE_HOST,
-                                          cfg.OSMPOL_MESSAGE_PORT)
-            self.producer = KafkaProducer(bootstrap_servers=kafka_server,
-                                          key_serializer=str.encode,
-                                          value_serializer=str.encode)
-            self.consumer = KafkaConsumer(bootstrap_servers=kafka_server,
-                                          key_deserializer=bytes.decode,
-                                          value_deserializer=bytes.decode,
-                                          auto_offset_reset='earliest',
-                                          consumer_timeout_ms=5000)
-            self.consumer.subscribe(['ns'])
-        except KafkaError:
-            self.skipTest('Kafka server not present.')
+        super()
+        cfg = Config()
+        self.kafka_server = "{}:{}".format(
+            cfg.get("message", "host"), cfg.get("message", "port")
+        )
+        self.loop = asyncio.new_event_loop()
 
     def tearDown(self):
-        self.producer.close()
-        self.consumer.close()
+        super()
 
     def test_send_instantiated_msg(self):
-        with open(
-                os.path.join(os.path.dirname(__file__), '../examples/instantiated.json')) as file:
-            payload = json.load(file)
-            self.producer.send('ns', json.dumps(payload), key="instantiated")
-            self.producer.flush()
+        async def test_send_instantiated_msg():
+            producer = AIOKafkaProducer(
+                loop=self.loop,
+                bootstrap_servers=self.kafka_server,
+                key_serializer=str.encode,
+                value_serializer=str.encode,
+            )
+            await producer.start()
+            consumer = AIOKafkaConsumer(
+                "ns",
+                loop=self.loop,
+                bootstrap_servers=self.kafka_server,
+                consumer_timeout_ms=10000,
+                auto_offset_reset="earliest",
+                value_deserializer=bytes.decode,
+                key_deserializer=bytes.decode,
+            )
+            await consumer.start()
+            try:
+                with open(
+                    os.path.join(
+                        os.path.dirname(__file__), "../examples/instantiated.json"
+                    )
+                ) as file:
+                    payload = json.load(file)
+                    await producer.send_and_wait(
+                        "ns", key="instantiated", value=json.dumps(payload)
+                    )
+            finally:
+                await producer.stop()
+            try:
+                async for message in consumer:
+                    if message.key == "instantiated":
+                        self.assertIsNotNone(message.value)
+                        return
+            finally:
+                await consumer.stop()
 
-        for message in self.consumer:
-            if message.key == 'instantiated':
-                self.assertIsNotNone(message.value)
-                return
-        self.fail("No message received in consumer")
+        try:
+            self.loop.run_until_complete(test_send_instantiated_msg())
+        except KafkaError:
+            self.skipTest("Kafka server not present.")
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()