package sys_database // Copyright (c) 2025 Robert Strutts // License: MIT // GIT: https://git.mysnippetsofcode.com/bobs/execguard import ( "execguard/core/hasher" "bufio" "os" "fmt" "sync" "database/sql" "log" "strings" "time" ) var ( initMode bool initFile string updateFile string migrateMode bool dbMutex sync.Mutex ) func SetModes(mode bool, file string, update string, migrate bool) { initMode = mode initFile = file updateFile = update migrateMode = migrate } func CreateTable(db *sql.DB, log log.Logger) { query := `CREATE TABLE IF NOT EXISTS allowed ( path TEXT PRIMARY KEY, hash TEXT )` _, err := db.Exec(query) if err != nil { log.Fatalf("Failed to create table: %v", err) os.Exit(5) // Exit with status code 5 } } func readFile(db *sql.DB, log log.Logger, input *os.File) { defer input.Close() scanner := bufio.NewScanner(input) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line != "" { time.Sleep(time.Duration(100) * time.Millisecond) AddToAllowed(db, log, line) log.Printf("Migrated path: %s", line) } } if err := scanner.Err(); err != nil { log.Printf("Error reading Migrate file: %v", err) } } func RunInit(db *sql.DB, log log.Logger, path string) { input, err := os.Open(path) if err != nil { log.Fatalf("Failed to open temp file: %v", err) } readFile(db, log, input) } func RunMigration(db *sql.DB, log log.Logger) { tempFile := "Migrate" f, err := os.CreateTemp("", tempFile) if err != nil { log.Fatalf("Failed to create temp file: %v", err) } defer os.Remove(f.Name()) rows, err := db.Query("SELECT path FROM allowed") if err != nil { log.Fatalf("Failed to query allowed paths: %v", err) } defer rows.Close() for rows.Next() { var path string if err := rows.Scan(&path); err != nil { log.Printf("Failed to read row: %v", err) continue } _, _ = fmt.Fprintln(f, path) } // Seek back to start instead of closing/reopening if _, err := f.Seek(0, 0); err != nil { log.Fatalf("Failed to seek file: %v", err) } readFile(db, log, f) } func IsAllowed(db *sql.DB, log log.Logger, path string) bool { var storedHash string hash := hasher.ComputeHash(path, log) if hash == "" { return false } err := db.QueryRow("SELECT hash FROM allowed WHERE path = ?", path).Scan(&storedHash) return err == nil && storedHash == hash } func AddToAllowed(db *sql.DB, log log.Logger, path string) { dbMutex.Lock() defer dbMutex.Unlock() hash := "" if initMode || updateFile != "" || migrateMode { hash = hasher.ComputeHash(path, log) } _, err := db.Exec("INSERT OR REPLACE INTO allowed(path, hash) VALUES(?, ?)", path, hash) if err != nil { log.Printf("Error inserting allowed entry: %v", err) } }