train_server.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # coding:utf-8
  2. '''
  3. 训练客户端
  4. '''
  5. import nsq
  6. import json
  7. from machine_models import train_model
  8. from loguru import logger
  9. from queue import Queue
  10. import time
  11. from threading import Thread
  12. logger.add('./logs/runtime_{time}.log', rotation='00:00')
  13. queueSave = Queue(maxsize=10000) # 任务队列
  14. def train_start():
  15. # 检查任务列表,开始训练
  16. global queueSave
  17. while True:
  18. if not queueSave.empty():
  19. params = queueSave.get()
  20. train_model(params)
  21. continue
  22. time.sleep(5)
  23. def handler(message):
  24. '''
  25. nsq队列回调函数
  26. :param message:
  27. :return:
  28. '''
  29. global queueSave
  30. try:
  31. body = message.body
  32. body = json.loads(body)
  33. queueSave.put(body)
  34. except Exception as e:
  35. logger.warning("start-->", e)
  36. return True
  37. r = nsq.Reader(message_handler=handler, nsqd_tcp_addresses=['192.168.3.13:4150'], topic='machine_train',
  38. channel='NO.1',
  39. lookupd_poll_interval=5,
  40. lookupd_connect_timeout=10000,
  41. lookupd_request_timeout=10000)
  42. if __name__ == '__main__':
  43. train_thread = Thread(target=train_start)
  44. train_thread.start()
  45. nsq.run()
  46. train_thread.join()