]> git.seodisparate.com - AnotherAURHelper/commitdiff
Impl. pkg state via sqlite
authorStephen Seo <seo.disparate@gmail.com>
Thu, 6 Mar 2025 02:38:39 +0000 (11:38 +0900)
committerStephen Seo <seo.disparate@gmail.com>
Thu, 6 Mar 2025 02:57:39 +0000 (11:57 +0900)
Keep track of trusted PKGBUILDs using sqlite, so that
"hash_compare_PKGBUILD" never skips a package's PKGBUILD that hasn't
been explicitly trusted.

example_config.toml
update.py

index 04e40daa50b29b0f568fc9d55e81e3094cfd8795..f4ee7a611f2d3750735fbfc36fb38f1de737d62d 100644 (file)
@@ -26,6 +26,8 @@ datetime_in_local_time = true
 tmpfs = false
 # If true, only packages to be built will be printed when USR1 is signaled.
 print_state_info_only_building_sigusr1 = true
+# The path to the persistent state.
+persistent_state_db = "/home/stephen/aur_helper_state.db"
 ########## END OF MANDATORY VARIABLES
 
 # Each [[entry]] needs a "name".
index 4e221d3dddc13c31a3858ea394b90c5326f1e5b7..972c45a5d4a284030ede8a50796271f3b56a1ca1 100755 (executable)
--- a/update.py
+++ b/update.py
@@ -19,6 +19,7 @@ from pathlib import Path
 from typing import Any, Union
 import signal
 import pwd
+import sqlite3
 
 # SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
 SUDO_PROC = False
@@ -545,7 +546,13 @@ def check_pkg_build(
 
     pkgdir = os.path.join(other_state["clones_dir"], pkg)
 
-    if pkg_state[pkg]["hash_compare_PKGBUILD"]:
+    if not pkg in other_state["state_db_state"]:
+        other_state["state_db_state"][pkg] = False
+
+    if (
+        pkg_state[pkg]["hash_compare_PKGBUILD"]
+        and other_state["state_db_state"][pkg]
+    ):
         log_print(
             "Checking PKGBUILD (hash_compare_PKGBUILD enabled for this pkg)...",
             other_state=other_state,
@@ -564,6 +571,8 @@ def check_pkg_build(
                     other_state=other_state,
                 )
                 return "ok"
+            else:
+                other_state["state_db_state"][pkg] = False
         except subprocess.CalledProcessError:
             log_print(
                 'WARNING: Failed to get sha256sum of PKGBUILD pkg "{}"!'.format(
@@ -591,18 +600,24 @@ def check_pkg_build(
         user_input = sys.stdin.buffer.readline().decode().strip().lower()
         if user_input == "y" or len(user_input) == 0:
             log_print("User decided PKGBUILD is ok", other_state=other_state)
+            other_state["state_db_state"][pkg] = True
+            save_persistent_state_from_other(other_state)
             return "ok"
         elif user_input == "n":
             log_print(
                 "User decided PKGBUILD is not ok", other_state=other_state
             )
+            other_state["state_db_state"][pkg] = False
+            save_persistent_state_from_other(other_state)
             return "not_ok"
         elif user_input == "c":
             log_print("User will check PKGBUILD again", other_state=other_state)
             return check_pkg_build(pkg, pkg_state, other_state, editor)
         elif user_input == "a":
+            save_persistent_state_from_other(other_state)
             return "abort"
         elif user_input == "f":
+            save_persistent_state_from_other(other_state)
             return "force_build"
         elif user_input == "b":
             return "back"
@@ -1952,6 +1967,7 @@ def signal_handler(sig, frame):
         if signal.Signals(sig) is not signal.SIGINT:
             return
         OTHER_STATE["stop_building"] = True
+        save_persistent_state_from_other(OTHER_STATE)
         sys.exit(0)
     if signal.Signals(sig) is not signal.SIGINT:
         return
@@ -2136,6 +2152,61 @@ def prefetch_dependencies(pkg_names: [str], other_state: dict[str, Any]):
     return "fetched"
 
 
+def load_peristent_state(db_path: str):
+    """Returns a dict that is the "persistent_state"."""
+    state = {}
+    con = sqlite3.connect(db_path)
+    cur = con.cursor()
+    cur.execute(
+        "CREATE TABLE IF NOT EXISTS state (pkgname TEXT PRIMARY KEY, trusted INTEGER)"
+    )
+    con.commit()
+    res = cur.execute("SELECT * FROM state")
+    for item in res.fetchall():
+        if item[1] == 0:
+            state[item[0]] = False
+        else:
+            state[item[0]] = True
+    return state
+
+
+def save_persistent_state(db_path: str, persistent_state: dict[str, bool]):
+    """Saves a dict that is the "persistent_state"."""
+    con = sqlite3.connect(db_path)
+    cur = con.cursor()
+    cur.execute(
+        "CREATE TABLE IF NOT EXISTS state (pkgname TEXT PRIMARY KEY, trusted INTEGER)"
+    )
+    for pkg_trust in list(persistent_state):
+        trusted = 0
+        if persistent_state[pkg_trust]:
+            trusted = 1
+        res = cur.execute(
+            "SELECT pkgname FROM state WHERE pkgname = '{}'".format(pkg_trust)
+        )
+        if res.fetchone() is None:
+            cur.execute(
+                "INSERT INTO state VALUES ('{}', {})".format(pkg_trust, trusted)
+            )
+        else:
+            cur.execute(
+                "UPDATE state SET trusted = {} WHERE pkgname = '{}'".format(
+                    trusted, pkg_trust
+                )
+            )
+    con.commit()
+
+
+def save_persistent_state_from_other(other_state: dict[str, Any]):
+    """Attempts to save persistent state from inside other_state."""
+    try:
+        save_persistent_state(
+            other_state["state_db_path"], other_state["state_db_state"]
+        )
+    except BaseException:
+        log_print("WARNING: Failed to save state db!", other_state=other_state)
+
+
 def main():
     """The main function."""
     signal.signal(signal.SIGINT, signal_handler)
@@ -2215,6 +2286,8 @@ def main():
     other_state["error_on_limit"] = False
     other_state["print_state_SIGUSR1"] = False
     other_state["print_state_info_only_building_sigusr1"] = True
+    other_state["state_db_path"] = "aur_helper_state.db"
+    other_state["state_db_state"] = {}
     if args.pkg and not args.config:
         for pkg in args.pkg:
             pkg_state[pkg] = {}
@@ -2342,6 +2415,14 @@ def main():
                 and entry["link_cargo_registry"]
             ):
                 pkg_state[entry["name"]]["link_cargo_registry"] = True
+        if "persistent_state_db" in d and type(d["persistent_state_db"]) is str:
+            other_state["state_db_path"] = d["persistent_state_db"]
+        else:
+            log_print(
+                "ERROR: persistent_state_db in config is invalid!",
+                other_state=other_state,
+            )
+            sys.exit(1)
         other_state["chroot"] = d["chroot"]
         other_state["pkg_out_dir"] = d["pkg_out_dir"]
         other_state["repo"] = d["repo"]
@@ -2449,6 +2530,10 @@ def main():
         )
         sys.exit(1)
 
+    other_state["state_db_state"] = load_peristent_state(
+        other_state["state_db_path"]
+    )
+
     while len(other_state["chroot"]) > 1 and other_state["chroot"][-1] == "/":
         other_state["chroot"] = other_state["chroot"][:-1]