ソースを参照

refine remote access module

wangweimin 3 年 前
コミット
8dd9877d81
1 ファイル変更27 行追加34 行削除
  1. 27 34
      pywebio/platform/remote_access.py

+ 27 - 34
pywebio/platform/remote_access.py

@@ -1,28 +1,16 @@
 """
 * Implementation of remote access
-Use localhost.run ssh remote port forwarding service by running a ssh subprocess in PyWebIO application.
+Use https://github.com/wang0618/localshare service by running a ssh subprocess in PyWebIO application.
 
-The stdout of ssh process is the connection info. 
-
-* Strategy
-Wait at most one minute to get stdout, if it gets a normal out, the connection is successfully established. 
-Otherwise report error.
-
-* One Issue
-When the PyWebIO application process exits, the ssh process becomes an orphan process and does not exit.
-
-* Solution.
-Use a child process to create the ssh process, the child process monitors the PyWebIO application process
-to see if it alive, and when the PyWebIO application exit, the child process kills the ssh process and exit.
+The stdout of ssh process is the connection info.
 """
 
 import json
 import logging
-import multiprocessing
 import os
-import shlex
 import threading
 import time
+import shlex
 from subprocess import Popen, PIPE
 
 logger = logging.getLogger(__name__)
@@ -38,14 +26,24 @@ Remote access address: {address}
 _ssh_process = None  # type: Popen
 
 
-def remote_access_service(local_port=8080, server='app.pywebio.online', server_port=1022, setup_timeout=60,
-                          need_exit=None):
+def am_i_the_only_thread():
+    """Whether the current thread is the only non-Daemon threads in the process"""
+    alive_none_daemonic_thread_cnt = sum(
+        1 for t in threading.enumerate()
+        if t.is_alive() and not t.isDaemon()
+    )
+    return alive_none_daemonic_thread_cnt == 1
+
+
+def remote_access_service(local_port=8080, server='app.pywebio.online', server_port=1022, setup_timeout=60):
     """
+    Wait at most one minute to get the ssh output, if it gets a normal out, the connection is successfully established.
+    Otherwise report error and kill ssh process.
+
     :param local_port: ssh local listen port
     :param server: ssh server domain
     :param server_port: ssh server port
     :param setup_timeout: If the service can't setup successfully in `setup_timeout` seconds, then exit.
-    :param callable need_exit: The service will call this function periodicity, when it return True, then exit the service.
     """
 
     global _ssh_process
@@ -83,31 +81,24 @@ def remote_access_service(local_port=8080, server='app.pywebio.online', server_p
         else:
             print(success_msg.format(address=connection_info['address']))
 
-    # wait ssh or parent process exit
-    while not need_exit() and _ssh_process.poll() is None:
+    # wait ssh or main thread exit
+    while not am_i_the_only_thread() and _ssh_process.poll() is None:
         time.sleep(1)
 
-    if _ssh_process.poll() is None:  # parent process exit, kill ssh process
+    if _ssh_process.poll() is None:  # main thread exit, kill ssh process
         logger.debug('App process exit, killing ssh process')
         _ssh_process.kill()
     else:  # ssh process exit by itself or by timeout killer
         stderr = _ssh_process.stderr.read().decode('utf8')
-        logger.debug("Stderr from ssh process: %s", stderr)
         if stderr:
-            print(stderr)
+            logger.error('PyWebIO application remote access service error: %s', stderr)
         else:
-            print('PyWebIO application remote access service exit.')
+            logger.info('PyWebIO application remote access service exit.')
 
 
 def start_remote_access_service_(**kwargs):
-    ppid = os.getppid()
-
-    def need_exit():
-        # only for unix
-        return os.getppid() != ppid
-
     try:
-        remote_access_service(**kwargs, need_exit=need_exit)
+        remote_access_service(**kwargs)
     except KeyboardInterrupt:  # ignore KeyboardInterrupt
         pass
     finally:
@@ -125,7 +116,9 @@ def start_remote_access_service(**kwargs):
         server, server_port = server.split(':', 1)
     kwargs.setdefault('server', server)
     kwargs.setdefault('server_port', server_port)
-    multiprocessing.Process(target=start_remote_access_service_, kwargs=kwargs).start()
+    thread = threading.Thread(target=start_remote_access_service_, kwargs=kwargs)
+    thread.start()
+    return thread
 
 
 if __name__ == '__main__':
@@ -140,5 +133,5 @@ if __name__ == '__main__':
     parser.add_argument("--server-port", help="the local port to connect the tunnel to", type=int, default=1022)
     args = parser.parse_args()
 
-    start_remote_access_service(local_port=args.local_port, server=args.server, server_port=args.server_port)
-    os.wait()  # Wait for completion of a child process
+    t = start_remote_access_service(local_port=args.local_port, server=args.server, server_port=args.server_port)
+    t.join()