package main import ( "compress/gzip" "fmt" "io" "log" "os" "os/exec" "os/signal" "path/filepath" "sort" "strings" "sync" "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/google/gopacket/pcap" "gopkg.in/yaml.v3" ) const ( defaultConfigFile = "/etc/SYN-Scan-Firewall/config.yaml" ) // Config structures type LoggingConfig struct { FilePath string `yaml:"filePath"` MaxSizeMB int `yaml:"maxSizeMB"` Backups int `yaml:"backups"` CompressBackups bool `yaml:"compressBackups"` TimestampFormat string `yaml:"timestampFormat"` } type Config struct { BlockDuration string `yaml:"blockDuration"` MaxScanAttempts int `yaml:"maxScanAttempts"` Device string `yaml:"device"` Logging LoggingConfig `yaml:"logging"` IgnoredPorts []int `yaml:"ignoredPorts"` WhitelistedIPs []string `yaml:"whitelistedIPs"` } type AppConfig struct { BlockDuration time.Duration MaxScanAttempts int Device string Logging LoggingConfig IgnoredPorts map[int]bool WhitelistedIPs map[string]bool } type UnblockTask struct { IP string UnblockAt time.Time } type Sniffer struct { unblockTasks []UnblockTask tracker *ScanTracker handle *pcap.Handle config *AppConfig logger *log.Logger } // ScanTracker implementation type ScanTracker struct { sync.Mutex entries map[string]*ScanEntry } type ScanEntry struct { Count int Timestamp time.Time } // NewScanTracker creates and initializes a new ScanTracker func NewScanTracker() *ScanTracker { return &ScanTracker{ entries: make(map[string]*ScanEntry), } } // RotatingLogger implementation type RotatingLogger struct { config LoggingConfig currentFile *os.File mu sync.Mutex } func NewRotatingLogger(config LoggingConfig) (*RotatingLogger, error) { rl := &RotatingLogger{config: config} if err := rl.openFile(); err != nil { return nil, err } return rl, nil } func (rl *RotatingLogger) openFile() error { if err := os.MkdirAll(filepath.Dir(rl.config.FilePath), 0755); err != nil { return fmt.Errorf("failed to create log directory: %v", err) } file, err := os.OpenFile(rl.config.FilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0640) if err != nil { return fmt.Errorf("failed to open log file: %v", err) } rl.currentFile = file return nil } func (rl *RotatingLogger) rotate() error { rl.mu.Lock() defer rl.mu.Unlock() if err := rl.currentFile.Close(); err != nil { return err } for i := rl.config.Backups - 1; i >= 0; i-- { src := rl.getBackupName(i) if _, err := os.Stat(src); err == nil { dst := rl.getBackupName(i + 1) if i+1 >= rl.config.Backups { os.Remove(dst) } else { if rl.config.CompressBackups && !strings.HasSuffix(src, ".gz") { if err := rl.compressFile(src); err != nil { return err } src += ".gz" dst += ".gz" } os.Rename(src, dst) } } } if err := os.Rename(rl.config.FilePath, rl.getBackupName(0)); err != nil { return err } return rl.openFile() } func (rl *RotatingLogger) compressFile(src string) error { in, err := os.Open(src) if err != nil { return err } defer in.Close() out, err := os.Create(src + ".gz") if err != nil { return err } defer out.Close() gz := gzip.NewWriter(out) defer gz.Close() if _, err = io.Copy(gz, in); err != nil { return err } return os.Remove(src) } func (rl *RotatingLogger) getBackupName(index int) string { if index == 0 { return rl.config.FilePath + ".1" } return fmt.Sprintf("%s.%d", rl.config.FilePath, index+1) } func (rl *RotatingLogger) needsRotation() (bool, error) { info, err := rl.currentFile.Stat() if err != nil { return false, err } return info.Size() >= int64(rl.config.MaxSizeMB*1024*1024), nil } func (rl *RotatingLogger) Write(p []byte) (n int, err error) { if rotate, err := rl.needsRotation(); rotate && err == nil { if err := rl.rotate(); err != nil { return 0, err } } rl.mu.Lock() defer rl.mu.Unlock() return rl.currentFile.Write(p) } func (rl *RotatingLogger) Close() error { rl.mu.Lock() defer rl.mu.Unlock() return rl.currentFile.Close() } // Helper functions func loadConfig(path string) (*AppConfig, error) { data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("failed to read config file: %v", err) } var cfg Config if err := yaml.Unmarshal(data, &cfg); err != nil { return nil, fmt.Errorf("failed to parse config: %v", err) } // Set defaults if cfg.Logging.FilePath == "" { cfg.Logging.FilePath = "/var/log/SYN-Scan-Firewall.log" } if cfg.Logging.MaxSizeMB == 0 { cfg.Logging.MaxSizeMB = 10 } if cfg.Logging.Backups == 0 { cfg.Logging.Backups = 5 } if cfg.Logging.TimestampFormat == "" { cfg.Logging.TimestampFormat = "2006-01-02 15:04:05" } blockDuration, err := time.ParseDuration(cfg.BlockDuration) if err != nil { return nil, fmt.Errorf("invalid blockDuration format: %v", err) } ignoredPorts := make(map[int]bool) for _, port := range cfg.IgnoredPorts { ignoredPorts[port] = true } whitelistedIPs := make(map[string]bool) for _, ip := range cfg.WhitelistedIPs { whitelistedIPs[ip] = true } return &AppConfig{ BlockDuration: blockDuration, MaxScanAttempts: cfg.MaxScanAttempts, Device: cfg.Device, Logging: cfg.Logging, IgnoredPorts: ignoredPorts, WhitelistedIPs: whitelistedIPs, }, nil } func (st *ScanTracker) Add(ip string, config *AppConfig) { st.Lock() defer st.Unlock() entry, exists := st.entries[ip] if !exists { entry = &ScanEntry{} st.entries[ip] = entry } now := time.Now() if entry.Timestamp.Add(config.BlockDuration).Before(now) { entry.Count = 0 } entry.Count++ entry.Timestamp = now } func (st *ScanTracker) GetCount(ip string) int { st.Lock() defer st.Unlock() if entry, exists := st.entries[ip]; exists { return entry.Count } return 0 } func isIPBlocked(ip string) bool { cmd := exec.Command("sudo", "iptables", "-L", "-n") output, err := cmd.Output() if err != nil { return false } return strings.Contains(string(output), ip) } func ruleExists(ip string, drop bool) (bool, error) { var cmd *exec.Cmd if drop { cmd = exec.Command("sudo", "iptables", "-C", "INPUT", "-s", ip, "-j", "DROP") } else { cmd = exec.Command("sudo", "iptables", "-t", "nat", "-C", "PREROUTING", "-s", ip, "-p", "tcp", "--dport", "1:65535", "-j", "REDIRECT", "--to-port", "9999") } err := cmd.Run() if err == nil { return true, nil } // Check if error is because rule doesn't exist if exitErr, ok := err.(*exec.ExitError); ok { if exitErr.ExitCode() == 1 { return false, nil } } return false, err } func blockIP(ip string, logger *log.Logger) { if isIPBlocked(ip) { logger.Printf("IP %s is already blocked", ip) return } logger.Printf("Redirecting IP: %s to port 9999 BANNER", ip) cmd := exec.Command("sudo", "iptables", "-t", "nat", "-A", "PREROUTING", "-s", ip, "-p", "tcp", "--dport", "1:65535", "-j", "REDIRECT", "--to-port", "9999") if err := cmd.Run(); err != nil { logger.Printf("Error redirecting IP %s to banner service: %v", ip, err) } // Delay for 3 seconds before executing the command time.Sleep(3 * time.Second) logger.Printf("Blocking IP: %s", ip) cmd = exec.Command("sudo", "iptables", "-A", "INPUT", "-s", ip, "-j", "DROP") if err := cmd.Run(); err != nil { logger.Printf("Error blocking IP %s: %v", ip, err) } } func unblockIP(ip string, logger *log.Logger) { route_exists, _ := ruleExists(ip, false) if route_exists { logger.Printf("Unblocking IP: %s", ip) deleteCmd := exec.Command("sudo", "iptables", "-t", "nat", "-D", "PREROUTING", "-s", ip, "-p", "tcp", "--dport", "1:65535", "-j", "REDIRECT", "--to-port", "9999") if err := deleteCmd.Run(); err != nil { logger.Printf("Error unRedirecting IP %s: %v", ip, err) } } drop_exists, _ := ruleExists(ip, true) if drop_exists { cmd := exec.Command("sudo", "iptables", "-D", "INPUT", "-s", ip, "-j", "DROP") if err := cmd.Run(); err != nil { logger.Printf("Error unBlocking IP %s: %v", ip, err) } } } // Sniffer methods func (s *Sniffer) isWhitelisted(ip string) bool { return s.config.WhitelistedIPs[ip] } func (s *Sniffer) handlePacket(packet gopacket.Packet) { tcpLayer := packet.Layer(layers.LayerTypeTCP) if tcpLayer == nil { return } tcp, _ := tcpLayer.(*layers.TCP) if tcp.SYN && !tcp.ACK { ipLayer := packet.Layer(layers.LayerTypeIPv4) if ipLayer == nil { return } ip, _ := ipLayer.(*layers.IPv4) srcIP := ip.SrcIP.String() dstPort := int(tcp.DstPort) if s.isWhitelisted(srcIP) { return } if s.config.IgnoredPorts[dstPort] { return } s.logger.Printf("Scan detected on port %d from %s", dstPort, srcIP) s.tracker.Add(srcIP, s.config) count := s.tracker.GetCount(srcIP) if count > s.config.MaxScanAttempts { s.logger.Printf("IP %s exceeded scan limit (%d attempts), blocking for %.0f minutes", srcIP, s.config.MaxScanAttempts, s.config.BlockDuration.Minutes()) blockIP(srcIP, s.logger) unblockTime := time.Now().Add(s.config.BlockDuration) s.logger.Printf("IP %s will be unblocked at %s", srcIP, unblockTime.Format(s.config.Logging.TimestampFormat)) s.unblockTasks = append(s.unblockTasks, UnblockTask{ IP: srcIP, UnblockAt: unblockTime, }) } } } func (s *Sniffer) unblockExpiredIPs() { now := time.Now() var remainingTasks []UnblockTask for _, task := range s.unblockTasks { if now.After(task.UnblockAt) || now.Equal(task.UnblockAt) { unblockIP(task.IP, s.logger) } else { remainingTasks = append(remainingTasks, task) } } s.unblockTasks = remainingTasks } func (s *Sniffer) StartSniffing() error { s.logger.Printf("Starting port scan detection") s.logger.Printf("Configuration:") s.logger.Printf(" Device: %s", s.config.Device) s.logger.Printf(" Ignored ports: %v", s.getSortedPorts()) s.logger.Printf(" Whitelisted IPs: %v", s.getSortedIPs()) s.logger.Printf(" Block duration: %.0f minutes", s.config.BlockDuration.Minutes()) s.logger.Printf(" Max scan attempts before blocking: %d", s.config.MaxScanAttempts) handle, err := pcap.OpenLive(s.config.Device, 1600, false, pcap.BlockForever) if err != nil { return fmt.Errorf("error opening device %s: %v", s.config.Device, err) } s.handle = handle err = handle.SetBPFFilter("tcp") if err != nil { return fmt.Errorf("error setting filter: %v", err) } packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) for packet := range packetSource.Packets() { s.handlePacket(packet) } return nil } func (s *Sniffer) getSortedPorts() []int { ports := make([]int, 0, len(s.config.IgnoredPorts)) for port := range s.config.IgnoredPorts { ports = append(ports, port) } sort.Ints(ports) return ports } func (s *Sniffer) getSortedIPs() []string { ips := make([]string, 0, len(s.config.WhitelistedIPs)) for ip := range s.config.WhitelistedIPs { ips = append(ips, ip) } sort.Strings(ips) return ips } // Main program func main() { if os.Geteuid() != 0 { log.Fatal("This program must be run as root (sudo)") } config, err := loadConfig(defaultConfigFile) if err != nil { log.Fatalf("Configuration error: %v", err) } rl, err := NewRotatingLogger(config.Logging) if err != nil { log.Fatalf("Failed to initialize logger: %v", err) } defer rl.Close() logger := log.New(rl, "", 0) logger.SetPrefix(fmt.Sprintf("[%s] ", time.Now().Format(config.Logging.TimestampFormat))) tracker := NewScanTracker() sniffer := &Sniffer{ tracker: tracker, config: config, logger: logger, } sniffer.tracker.entries = make(map[string]*ScanEntry) go func() { if err := sniffer.StartSniffing(); err != nil { logger.Fatalf("Sniffer error: %v", err) } }() ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt) logger.Println("Running... Press Ctrl+C to stop") for { select { case <-ticker.C: sniffer.unblockExpiredIPs() case <-sigChan: logger.Println("Stopping...") if sniffer.handle != nil { sniffer.handle.Close() } return } } }