]> git.seodisparate.com - AnotherAURHelper/commitdiff
Handle SIGINT
authorStephen Seo <seo.disparate@gmail.com>
Wed, 25 Oct 2023 13:19:35 +0000 (22:19 +0900)
committerStephen Seo <seo.disparate@gmail.com>
Wed, 25 Oct 2023 13:25:24 +0000 (22:25 +0900)
update.py

index b840b902e891facd329ab590c19562783d9ab24d..237359ab72706947da53139c00cfea06fd5b2c28 100755 (executable)
--- a/update.py
+++ b/update.py
@@ -17,6 +17,7 @@ import tempfile
 import threading
 from pathlib import Path
 from typing import Any, Union
+import signal
 
 # SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
 SUDO_PROC = False
@@ -27,6 +28,8 @@ DEFAULT_EDITOR = "/usr/bin/nano"
 IS_DIGIT_REGEX = re.compile("^[0-9]+$")
 STRFTIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
 STRFTIME_LOCAL_FORMAT = "%Y-%m-%dT%H:%M:%S"
+PKG_STATE = None
+OTHER_STATE = None
 
 
 class ArchPkgVersion:
@@ -1585,7 +1588,9 @@ def confirm_result(pkg: str, state_result: str):
             continue
 
 
-def print_state_info_and_get_update_list(pkg_state: dict[str, Any]):
+def print_state_info_and_get_update_list(
+    other_state: dict[str, Any], pkg_state: dict[str, Any]
+):
     """Prints the current "checked" state of all pkgs in the config."""
 
     to_update = []
@@ -1701,7 +1706,17 @@ def validate_and_verify_paths(other_state: dict[str, Union[None, str]]):
         clones_dir_path.mkdir(parents=True)
 
 
+def signal_handler(sig, frame):
+    """Handle SIGINT"""
+    global OTHER_STATE, PKG_STATE
+    if OTHER_STATE is not None and PKG_STATE is not None:
+        print_state_info_and_get_update_list(OTHER_STATE, PKG_STATE)
+        sys.exit(0)
+    sys.exit(1)
+
+
 if __name__ == "__main__":
+    signal.signal(signal.SIGINT, signal_handler)
     editor = None
     parser = argparse.ArgumentParser(description="Update AUR pkgs")
     parser.add_argument(
@@ -1761,6 +1776,8 @@ if __name__ == "__main__":
 
     pkg_state = {}
     other_state = {}
+    PKG_STATE = pkg_state
+    OTHER_STATE = other_state
     other_state["logs_dir"] = None
     other_state["log_limit"] = 1024 * 1024 * 1024
     other_state["error_on_limit"] = False
@@ -1993,7 +2010,7 @@ if __name__ == "__main__":
         if i > furthest_checked:
             furthest_checked = i
         if not ensure_pkg_dir_exists(pkg_list[i], pkg_state, other_state):
-            print_state_info_and_get_update_list(pkg_state)
+            print_state_info_and_get_update_list(other_state, pkg_state)
             sys.exit(1)
         if (
             "repo_path" not in pkg_state[pkg_list[i]]
@@ -2017,7 +2034,7 @@ if __name__ == "__main__":
                     pkg_list[i],
                     other_state=other_state,
                 )
-                print_state_info_and_get_update_list(pkg_state)
+                print_state_info_and_get_update_list(other_state, pkg_state)
                 sys.exit(1)
         if skip_on_same_ver and i >= furthest_checked:
             check_pkg_version_result = check_pkg_version(
@@ -2051,7 +2068,7 @@ if __name__ == "__main__":
                 i -= 1
             continue
         else:  # check_pkg_build_result == "abort":
-            print_state_info_and_get_update_list(pkg_state)
+            print_state_info_and_get_update_list(other_state, pkg_state)
             sys.exit(1)
         while True:
             if (
@@ -2087,7 +2104,7 @@ if __name__ == "__main__":
                 going_back = True
                 break
             else:  # confirm_result_result == "abort"
-                print_state_info_and_get_update_list(pkg_state)
+                print_state_info_and_get_update_list(other_state, pkg_state)
                 sys.exit(1)
         if going_back:
             going_back = False
@@ -2095,7 +2112,9 @@ if __name__ == "__main__":
             i += 1
 
     log_print("Showing current actions:", other_state=other_state)
-    pkgs_to_update = print_state_info_and_get_update_list(pkg_state)
+    pkgs_to_update = print_state_info_and_get_update_list(
+        other_state, pkg_state
+    )
     if len(pkgs_to_update) > 0:
         log_print("Continue? [Y/n]", other_state=other_state)
         user_input = sys.stdin.buffer.readline().decode().strip().lower()