package main import ( "bufio" "sync" "encoding/base64" "crypto/rand" "crypto/sha256" "crypto/sha512" "database/sql" "encoding/hex" "encoding/json" "flag" "fmt" "io" "io/fs" "log" "os" "os/exec" "path/filepath" "strings" "time" "unsafe" _ "github.com/mattn/go-sqlite3" "golang.org/x/sys/unix" "github.com/yang3yen/xxtea-go/xxtea" ) const ( configFile = "/etc/execguard/config.json" dbFile = "/etc/execguard/allowed.db" logFile = "/var/log/execguard.log" ) type Config struct { ProtectedDirs []string `json:"protected_dirs"` AlertEmail string `json:"alert_email"` SkipDirs []string `json:"skip_dirs"` ScanInterval int `json:"scan_interval"` // in minutes, 0 disables scan Passphrase string `json:"passphrase"` // optional hash encryption key HashEncryption string `json:"hash_encryption"` // "none", "xor", or "xxtea" HashType string `json:"hash_type"` // "sha256" or "sha512" } var initMode bool var initFile string var updateFile string var migrateMode bool var newKey bool var config *Config var dbMutex sync.Mutex func main() { flag.BoolVar(&initMode, "init", false, "initialize and populate allowed executable database") flag.StringVar(&initFile, "initFile", "", "file containing files to add to allowed database with hash") flag.StringVar(&updateFile, "update", "", "add specified file to allowed database with hash") flag.BoolVar(&migrateMode, "migrate", false, "recompute hashes of all allowed paths using current settings") flag.BoolVar(&newKey, "newKey", false, "generate a new XXTEA-compatible encryption key") flag.Parse() if newKey { // XXTEA key should be 16 bytes total...base64 will padd it... key := make([]byte, 12) if _, err := io.ReadFull(rand.Reader, key); err != nil { log.Fatalf("Failed to generate key: %v", err) } fmt.Printf("Generated XXTEA key (base64): %s\n", base64.StdEncoding.EncodeToString(key)) return } if os.Geteuid() != 0 { log.Fatal("This program must be run as root") os.Exit(1) // Exit with status code 1 } logf, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatalf("Error opening log file: %v", err) } defer logf.Close() log.SetOutput(logf) db, err := sql.Open("sqlite3", dbFile) if err != nil { log.Fatalf("Error opening database: %v", err) os.Exit(2) // Exit with status code 2 } defer db.Close() config, err = loadConfig() if err != nil { log.Fatalf("Error loading config: %v", err) os.Exit(3) // Exit with status code 3 } createTable(db) if initFile != "" { absPath, err := filepath.Abs(initFile) if err != nil { log.Fatalf("Invalid init file path: %v", err) os.Exit(1) // Exit with status code 1 } runInit(db, absPath) return } if updateFile != "" { absPath, err := filepath.Abs(updateFile) if err != nil { log.Fatalf("Invalid update file path: %v", err) os.Exit(1) // Exit with status code 1 } addToAllowed(db, absPath) log.Printf("Added to allowed list: %s", absPath) return } if migrateMode { runMigration(db) return } if config.ScanInterval > 0 { go func() { defer func() { if r := recover(); r != nil { log.Printf("Recovered from scan panic: %v", r) } }() periodicScan(config.ProtectedDirs, db) }() } if err := monitorExecutions(db); err != nil { log.Fatalf("Execution monitoring failed: %v", err) os.Exit(4) // Exit with status code 4 } } func randReader() io.Reader { return rand.Reader } func loadConfig() (*Config, error) { data, err := os.ReadFile(configFile) if err != nil { return nil, err } var cfg Config if err := json.Unmarshal(data, &cfg); err != nil { return nil, err } return &cfg, nil } func createTable(db *sql.DB) { query := `CREATE TABLE IF NOT EXISTS allowed ( path TEXT PRIMARY KEY, hash TEXT )` _, err := db.Exec(query) if err != nil { log.Fatalf("Failed to create table: %v", err) os.Exit(5) // Exit with status code 5 } } func readFile(db *sql.DB, input *os.File) { defer input.Close() scanner := bufio.NewScanner(input) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line != "" { time.Sleep(time.Duration(100) * time.Millisecond) addToAllowed(db, line) log.Printf("Migrated path: %s", line) } } if err := scanner.Err(); err != nil { log.Printf("Error reading Migrate file: %v", err) } } func runInit(db *sql.DB, path string) { input, err := os.Open(path) if err != nil { log.Fatalf("Failed to open temp file: %v", err) } readFile(db, input) } func runMigration(db *sql.DB) { tempFile := "Migrate" f, err := os.CreateTemp("", tempFile) if err != nil { log.Fatalf("Failed to create temp file: %v", err) } defer os.Remove(f.Name()) rows, err := db.Query("SELECT path FROM allowed") if err != nil { log.Fatalf("Failed to query allowed paths: %v", err) } defer rows.Close() for rows.Next() { var path string if err := rows.Scan(&path); err != nil { log.Printf("Failed to read row: %v", err) continue } _, _ = fmt.Fprintln(f, path) } // Seek back to start instead of closing/reopening if _, err := f.Seek(0, 0); err != nil { log.Fatalf("Failed to seek file: %v", err) } readFile(db, f) } func isAllowed(db *sql.DB, path string) bool { var storedHash string hash := computeHash(path) if hash == "" { return false } err := db.QueryRow("SELECT hash FROM allowed WHERE path = ?", path).Scan(&storedHash) return err == nil && storedHash == hash } func addToAllowed(db *sql.DB, path string) { dbMutex.Lock() defer dbMutex.Unlock() hash := "" if initMode || updateFile != "" || migrateMode { hash = computeHash(path) } _, err := db.Exec("INSERT OR REPLACE INTO allowed(path, hash) VALUES(?, ?)", path, hash) if err != nil { log.Printf("Error inserting allowed entry: %v", err) } } func normalizeXXTEAKey(key []byte) []byte { switch { case len(key) == 16: return key case len(key) > 16: hash := sha256.Sum256(key) return hash[:16] default: // len(key) < 16 padded := make([]byte, 16) copy(padded, key) // Simple padding with repeated key pattern for i := len(key); i < 16; i++ { padded[i] = key[i%len(key)] } return padded } } func computeHash(path string) string { data, err := os.ReadFile(path) if err != nil { return "" } var hashBytes []byte switch strings.ToLower(config.HashType) { case "sha256": sum := sha256.Sum256(data) hashBytes = sum[:] case "sha512", "": sum := sha512.Sum512(data) hashBytes = sum[:] default: log.Printf("Unknown hash_type '%s', defaulting to sha512.", config.HashType) sum := sha512.Sum512(data) hashBytes = sum[:] } switch strings.ToLower(config.HashEncryption) { case "none": return hex.EncodeToString(hashBytes) case "xor": if config.Passphrase == "" { log.Println("XOR encryption selected but no passphrase provided.") return hex.EncodeToString(hashBytes) } key := []byte(config.Passphrase) enc := make([]byte, len(hashBytes)) for i := 0; i < len(hashBytes); i++ { enc[i] = hashBytes[i] ^ key[i%len(key)] } return hex.EncodeToString(enc) case "xxtea": if config.Passphrase == "" { log.Println("XXTEA encryption selected but no passphrase provided.") return hex.EncodeToString(hashBytes) } key := normalizeXXTEAKey([]byte(config.Passphrase)) enc, err := xxtea.Encrypt(hashBytes, key, false, 0) if err != nil { log.Println("XXTEA encryption KEY error???") return hex.EncodeToString(hashBytes) } return base64.StdEncoding.EncodeToString(enc) default: log.Printf("Unknown hash_encryption type: %s. Using plain hash.", config.HashEncryption) return hex.EncodeToString(hashBytes) } } func periodicScan(dirs []string, db *sql.DB) { skipSet := make(map[string]struct{}) for _, skip := range config.SkipDirs { if abs, err := filepath.Abs(skip); err == nil { skipSet[abs] = struct{}{} } } interval := time.Duration(config.ScanInterval) * time.Minute // log.Printf("Starting periodic scan every %v...", interval) for { for _, dir := range dirs { filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { if err != nil { return nil } absPath, err := filepath.Abs(path) if err != nil { return nil } // Skip if in any of the SkipDirs for skipDir := range skipSet { if strings.HasPrefix(absPath, skipDir) { return filepath.SkipDir } } if d.Type().IsRegular() { info, err := d.Info() if err != nil || (info.Mode().Perm()&0111 == 0) { return nil } absPath, _ = filepath.EvalSymlinks(absPath) if initMode { addToAllowed(db, absPath) } else if !isAllowed(db, absPath) { log.Printf("Found unauthorized executable: %s", absPath) os.Chmod(absPath, info.Mode()&^0111) go sendAlert(fmt.Sprintf("Unauthorized executable found and blocked: %s", absPath)) } } return nil }) } time.Sleep(interval) } } func monitorExecutions(db *sql.DB) error { fd, err := unix.FanotifyInit(unix.FAN_CLOEXEC|unix.FAN_CLASS_CONTENT, unix.O_RDONLY|unix.O_LARGEFILE) if err != nil { return err } defer unix.Close(fd) for _, dir := range config.ProtectedDirs { if err := unix.FanotifyMark(fd, unix.FAN_MARK_ADD|unix.FAN_MARK_MOUNT, unix.FAN_OPEN_EXEC_PERM, unix.AT_FDCWD, dir); err != nil { log.Printf("Failed to mark %s: %v", dir, err) } } buf := make([]byte, 4096) for { n, err := unix.Read(fd, buf) if err != nil { return err } for offset := 0; offset < n; { meta := (*unix.FanotifyEventMetadata)(unsafe.Pointer(&buf[offset])) if meta.Event_len == 0 { break } resp := unix.FanotifyResponse{Fd: meta.Fd, Response: unix.FAN_ALLOW} defer unix.Close(int(meta.Fd)) if meta.Mask&unix.FAN_OPEN_EXEC_PERM != 0 { fdpath := fmt.Sprintf("/proc/self/fd/%d", meta.Fd) path, err := os.Readlink(fdpath) if err == nil { absPath, _ := filepath.Abs(path) absPath, _ = filepath.EvalSymlinks(absPath) info, statErr := os.Stat(absPath) if statErr == nil && info.Mode().IsRegular() && (info.Mode().Perm()&0111 != 0) { if initMode { addToAllowed(db, absPath) } else if !isAllowed(db, absPath) { log.Printf("Blocked execution attempt: %s", absPath) // To avoid locking up the Whole System...use go function on sendAlert!!! go sendAlert(fmt.Sprintf("Unauthorized execution attempt blocked: %s", absPath)) resp.Response = unix.FAN_DENY } } } } b := (*[unsafe.Sizeof(resp)]byte)(unsafe.Pointer(&resp))[:] if _, err := unix.Write(fd, b); err != nil { log.Printf("Fanotify response write error: %v", err) } offset += int(meta.Event_len) } } } func sendAlert(message string) { if config.AlertEmail == "" { return } if _, err := exec.LookPath("mail"); err != nil { log.Printf("Mail command not found: %v", err) return } cmd := exec.Command("mail", "-s", "ExecGuard Alert", config.AlertEmail) cmd.Stdin = strings.NewReader(message) if err := cmd.Run(); err != nil { log.Printf("Failed to send alert: %v", err) } }