Browse Source

fix remote access not work in windows

wangweimin 3 years ago
parent
commit
0cf86b63c3
2 changed files with 51 additions and 33 deletions
  1. 2 2
      pywebio/__version__.py
  2. 49 31
      pywebio/platform/remote_access.py

+ 2 - 2
pywebio/__version__.py

@@ -1,8 +1,8 @@
 __package__ = 'pywebio'
 __description__ = 'Write interactive web app in script way.'
 __url__ = 'https://pywebio.readthedocs.io'
-__version__ = "1.3.2"
-__version_info__ = (1, 3, 2, 0)
+__version__ = "1.3.3"
+__version_info__ = (1, 3, 3, 0)
 __author__ = 'WangWeimin'
 __author_email__ = 'wang0.618@qq.com'
 __license__ = 'MIT'

+ 49 - 31
pywebio/platform/remote_access.py

@@ -18,6 +18,7 @@ to see if it alive, and when the PyWebIO application exit, the child process kil
 
 import json
 import logging
+import multiprocessing
 import os
 import re
 import shlex
@@ -56,10 +57,17 @@ Note that only rsa and ed25519 keys are supported.
 _ssh_process = None  # type: Popen
 
 
-def remote_access_process(local_port=8080, setup_timeout=60, key_path=None, custom_domain=None):
+def remote_access_service(local_port=8080, setup_timeout=60, key_path=None, custom_domain=None, need_exist=None):
+    """
+    :param local_port: ssh local listen port
+    :param setup_timeout: If the service can't setup successfully in `setup_timeout` seconds, then exit.
+    :param key_path: Use a custom ssh key, the default key path is ~/.ssh/id_xxx. Note that only rsa and ed25519 keys are supported.
+    :param custom_domain: Use a custom domain for your remote access address. This need a subscription to localhost.run
+    :param callable need_exist: The service will call this function periodicity, when it return True, then exit the service.
+    """
+
     global _ssh_process
-    ppid = os.getppid()
-    assert ppid != 1
+
     domain_part = '%s:' % custom_domain if custom_domain is not None else ''
     key_path_arg = '-i %s' % key_path if key_path is not None else ''
     cmd = "ssh %s -oStrictHostKeyChecking=no -R %s80:localhost:%s localhost.run -- --output json" % (
@@ -67,25 +75,25 @@ def remote_access_process(local_port=8080, setup_timeout=60, key_path=None, cust
     args = shlex.split(cmd)
     logger.debug('remote access service command: %s', cmd)
 
-    _ssh_process = proc = Popen(args, stdout=PIPE, stderr=PIPE)
-    logger.debug('remote access process pid: %s', proc.pid)
+    _ssh_process = Popen(args, stdout=PIPE, stderr=PIPE)
+    logger.debug('remote access process pid: %s', _ssh_process.pid)
     success = False
 
     def timeout_killer(wait_sec):
         time.sleep(wait_sec)
-        if not success and proc.poll() is None:
-            proc.kill()
+        if not success and _ssh_process.poll() is None:
+            _ssh_process.kill()
 
     threading.Thread(target=timeout_killer, kwargs=dict(wait_sec=setup_timeout), daemon=True).start()
 
-    stdout = proc.stdout.readline().decode('utf8')
+    stdout = _ssh_process.stdout.readline().decode('utf8')
     connection_info = {}
     try:
         connection_info = json.loads(stdout)
         success = True
     except json.decoder.JSONDecodeError:
-        if not success and proc.poll() is None:
-            proc.kill()
+        if not success and _ssh_process.poll() is None:
+            _ssh_process.kill()
 
     if success:
         if connection_info.get('status', 'fail') != 'success':
@@ -95,17 +103,21 @@ def remote_access_process(local_port=8080, setup_timeout=60, key_path=None, cust
             print(success_msg.format(address=connection_info['address']))
 
     # wait ssh or parent process exit
-    while os.getppid() == ppid and proc.poll() is None:
+    while not need_exist() and _ssh_process.poll() is None:
         time.sleep(1)
 
-    if proc.poll() is None:  # parent process exit, kill ssh process
+    if _ssh_process.poll() is None:  # parent process exit, kill ssh process
         logger.debug('App process exit, killing ssh process')
-        proc.kill()
+        _ssh_process.kill()
     else:  # ssh process exit by itself or by timeout killer
-        stderr = proc.stderr.read().decode('utf8')
+        stderr = _ssh_process.stderr.read().decode('utf8')
+        logger.debug("Stderr from ssh process: %s", stderr)
         conn_id = re.search(r'connection id is (.*?),', stderr)
         logger.debug('Remote access connection id: %s', conn_id.group(1) if conn_id else '')
-        ssh_error_msg = stderr.rsplit('**', 1)[-1].rsplit('===', 1)[-1].lower().strip()
+        try:
+            ssh_error_msg = stderr.rsplit('**', 1)[-1].rsplit('===', 1)[-1].lower().strip()
+        except Exception:
+            ssh_error_msg = stderr
         if 'permission denied' in ssh_error_msg:
             print(ssh_key_gen_msg)
         elif ssh_error_msg:
@@ -114,22 +126,27 @@ def remote_access_process(local_port=8080, setup_timeout=60, key_path=None, cust
             print('PyWebIO application remote access service exit.')
 
 
-def start_remote_access_service(local_port=8080, setup_timeout=60, ssh_key_path=None, custom_domain=None):
-    pid = os.fork()
-    if pid == 0:  # in child process
-        try:
-            remote_access_process(local_port=local_port, setup_timeout=setup_timeout,
-                                  key_path=ssh_key_path, custom_domain=custom_domain)
-        except KeyboardInterrupt:  # ignore KeyboardInterrupt
-            pass
-        finally:
-            if _ssh_process:
-                logger.debug('Exception occurred, killing ssh process')
-                _ssh_process.kill()
-            raise SystemExit
+def start_remote_access_service_(local_port, setup_timeout, ssh_key_path, custom_domain):
+    ppid = os.getppid()
+
+    def need_exist():
+        # only for unix
+        return os.getppid() != ppid
+
+    try:
+        remote_access_service(local_port=local_port, setup_timeout=setup_timeout,
+                              key_path=ssh_key_path, custom_domain=custom_domain, need_exist=need_exist)
+    except KeyboardInterrupt:  # ignore KeyboardInterrupt
+        pass
+    finally:
+        if _ssh_process:
+            logger.debug('Exception occurred, killing ssh process')
+            _ssh_process.kill()
+        raise SystemExit
+
 
-    else:
-        return pid
+def start_remote_access_service(local_port=8080, setup_timeout=60, ssh_key_path=None, custom_domain=None):
+    multiprocessing.Process(target=start_remote_access_service_, kwargs=locals()).start()
 
 
 if __name__ == '__main__':
@@ -143,5 +160,6 @@ if __name__ == '__main__':
     parser.add_argument("--key-path", help="custom SSH key path", default=None)
     args = parser.parse_args()
 
-    start_remote_access_service(local_port=args.local_port, ssh_key_path=args.key_path, custom_domain=args.custom_domain)
+    start_remote_access_service(local_port=args.local_port, ssh_key_path=args.key_path,
+                                custom_domain=args.custom_domain)
     os.wait()  # Wait for completion of a child process