Port Scanning with block that IP with a banner...
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
SYN-Scan-Firewall/SYN-Scan-Firewall.go

521 lines
12 KiB

package main
import (
"compress/gzip"
"fmt"
"io"
"log"
"os"
"os/exec"
"os/signal"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/pcap"
"gopkg.in/yaml.v3"
)
const (
defaultConfigFile = "/etc/SYN-Scan-Firewall/config.yaml"
)
// Config structures
type LoggingConfig struct {
FilePath string `yaml:"filePath"`
MaxSizeMB int `yaml:"maxSizeMB"`
Backups int `yaml:"backups"`
CompressBackups bool `yaml:"compressBackups"`
TimestampFormat string `yaml:"timestampFormat"`
}
type Config struct {
BlockDuration string `yaml:"blockDuration"`
MaxScanAttempts int `yaml:"maxScanAttempts"`
Device string `yaml:"device"`
Logging LoggingConfig `yaml:"logging"`
IgnoredPorts []int `yaml:"ignoredPorts"`
WhitelistedIPs []string `yaml:"whitelistedIPs"`
}
type AppConfig struct {
BlockDuration time.Duration
MaxScanAttempts int
Device string
Logging LoggingConfig
IgnoredPorts map[int]bool
WhitelistedIPs map[string]bool
}
type UnblockTask struct {
IP string
UnblockAt time.Time
}
type Sniffer struct {
unblockTasks []UnblockTask
tracker *ScanTracker
handle *pcap.Handle
config *AppConfig
logger *log.Logger
}
// ScanTracker implementation
type ScanTracker struct {
sync.Mutex
entries map[string]*ScanEntry
}
type ScanEntry struct {
Count int
Timestamp time.Time
}
// NewScanTracker creates and initializes a new ScanTracker
func NewScanTracker() *ScanTracker {
return &ScanTracker{
entries: make(map[string]*ScanEntry),
}
}
// RotatingLogger implementation
type RotatingLogger struct {
config LoggingConfig
currentFile *os.File
mu sync.Mutex
}
func NewRotatingLogger(config LoggingConfig) (*RotatingLogger, error) {
rl := &RotatingLogger{config: config}
if err := rl.openFile(); err != nil {
return nil, err
}
return rl, nil
}
func (rl *RotatingLogger) openFile() error {
if err := os.MkdirAll(filepath.Dir(rl.config.FilePath), 0755); err != nil {
return fmt.Errorf("failed to create log directory: %v", err)
}
file, err := os.OpenFile(rl.config.FilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0640)
if err != nil {
return fmt.Errorf("failed to open log file: %v", err)
}
rl.currentFile = file
return nil
}
func (rl *RotatingLogger) rotate() error {
rl.mu.Lock()
defer rl.mu.Unlock()
if err := rl.currentFile.Close(); err != nil {
return err
}
for i := rl.config.Backups - 1; i >= 0; i-- {
src := rl.getBackupName(i)
if _, err := os.Stat(src); err == nil {
dst := rl.getBackupName(i + 1)
if i+1 >= rl.config.Backups {
os.Remove(dst)
} else {
if rl.config.CompressBackups && !strings.HasSuffix(src, ".gz") {
if err := rl.compressFile(src); err != nil {
return err
}
src += ".gz"
dst += ".gz"
}
os.Rename(src, dst)
}
}
}
if err := os.Rename(rl.config.FilePath, rl.getBackupName(0)); err != nil {
return err
}
return rl.openFile()
}
func (rl *RotatingLogger) compressFile(src string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
out, err := os.Create(src + ".gz")
if err != nil {
return err
}
defer out.Close()
gz := gzip.NewWriter(out)
defer gz.Close()
if _, err = io.Copy(gz, in); err != nil {
return err
}
return os.Remove(src)
}
func (rl *RotatingLogger) getBackupName(index int) string {
if index == 0 {
return rl.config.FilePath + ".1"
}
return fmt.Sprintf("%s.%d", rl.config.FilePath, index+1)
}
func (rl *RotatingLogger) needsRotation() (bool, error) {
info, err := rl.currentFile.Stat()
if err != nil {
return false, err
}
return info.Size() >= int64(rl.config.MaxSizeMB*1024*1024), nil
}
func (rl *RotatingLogger) Write(p []byte) (n int, err error) {
if rotate, err := rl.needsRotation(); rotate && err == nil {
if err := rl.rotate(); err != nil {
return 0, err
}
}
rl.mu.Lock()
defer rl.mu.Unlock()
return rl.currentFile.Write(p)
}
func (rl *RotatingLogger) Close() error {
rl.mu.Lock()
defer rl.mu.Unlock()
return rl.currentFile.Close()
}
// Helper functions
func loadConfig(path string) (*AppConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %v", err)
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse config: %v", err)
}
// Set defaults
if cfg.Logging.FilePath == "" {
cfg.Logging.FilePath = "/var/log/SYN-Scan-Firewall.log"
}
if cfg.Logging.MaxSizeMB == 0 {
cfg.Logging.MaxSizeMB = 10
}
if cfg.Logging.Backups == 0 {
cfg.Logging.Backups = 5
}
if cfg.Logging.TimestampFormat == "" {
cfg.Logging.TimestampFormat = "2006-01-02 15:04:05"
}
blockDuration, err := time.ParseDuration(cfg.BlockDuration)
if err != nil {
return nil, fmt.Errorf("invalid blockDuration format: %v", err)
}
ignoredPorts := make(map[int]bool)
for _, port := range cfg.IgnoredPorts {
ignoredPorts[port] = true
}
whitelistedIPs := make(map[string]bool)
for _, ip := range cfg.WhitelistedIPs {
whitelistedIPs[ip] = true
}
return &AppConfig{
BlockDuration: blockDuration,
MaxScanAttempts: cfg.MaxScanAttempts,
Device: cfg.Device,
Logging: cfg.Logging,
IgnoredPorts: ignoredPorts,
WhitelistedIPs: whitelistedIPs,
}, nil
}
func (st *ScanTracker) Add(ip string, config *AppConfig) {
st.Lock()
defer st.Unlock()
entry, exists := st.entries[ip]
if !exists {
entry = &ScanEntry{}
st.entries[ip] = entry
}
now := time.Now()
if entry.Timestamp.Add(config.BlockDuration).Before(now) {
entry.Count = 0
}
entry.Count++
entry.Timestamp = now
}
func (st *ScanTracker) GetCount(ip string) int {
st.Lock()
defer st.Unlock()
if entry, exists := st.entries[ip]; exists {
return entry.Count
}
return 0
}
func isIPBlocked(ip string) bool {
cmd := exec.Command("sudo", "iptables", "-L", "-n")
output, err := cmd.Output()
if err != nil {
return false
}
return strings.Contains(string(output), ip)
}
func ruleExists(ip string, drop bool) (bool, error) {
var cmd *exec.Cmd
if drop {
cmd = exec.Command("sudo", "iptables", "-C", "INPUT", "-s", ip, "-j", "DROP")
} else {
cmd = exec.Command("sudo", "iptables", "-t", "nat", "-C", "PREROUTING", "-s", ip,
"-p", "tcp", "--dport", "1:65535", "-j", "REDIRECT", "--to-port", "9999")
}
err := cmd.Run()
if err == nil {
return true, nil
}
// Check if error is because rule doesn't exist
if exitErr, ok := err.(*exec.ExitError); ok {
if exitErr.ExitCode() == 1 {
return false, nil
}
}
return false, err
}
func blockIP(ip string, logger *log.Logger) {
if isIPBlocked(ip) {
logger.Printf("IP %s is already blocked", ip)
return
}
logger.Printf("Redirecting IP: %s to port 9999 BANNER", ip)
cmd := exec.Command("sudo", "iptables", "-t", "nat", "-A", "PREROUTING", "-s", ip, "-p", "tcp", "--dport", "1:65535", "-j", "REDIRECT", "--to-port", "9999")
if err := cmd.Run(); err != nil {
logger.Printf("Error redirecting IP %s to banner service: %v", ip, err)
}
// Delay for 3 seconds before executing the command
time.Sleep(3 * time.Second)
logger.Printf("Blocking IP: %s", ip)
cmd = exec.Command("sudo", "iptables", "-A", "INPUT", "-s", ip, "-j", "DROP")
if err := cmd.Run(); err != nil {
logger.Printf("Error blocking IP %s: %v", ip, err)
}
}
func unblockIP(ip string, logger *log.Logger) {
route_exists, _ := ruleExists(ip, false)
if route_exists {
logger.Printf("Unblocking IP: %s", ip)
deleteCmd := exec.Command("sudo", "iptables", "-t", "nat", "-D", "PREROUTING", "-s", ip, "-p", "tcp", "--dport", "1:65535", "-j", "REDIRECT", "--to-port", "9999")
if err := deleteCmd.Run(); err != nil {
logger.Printf("Error unRedirecting IP %s: %v", ip, err)
}
}
drop_exists, _ := ruleExists(ip, true)
if drop_exists {
cmd := exec.Command("sudo", "iptables", "-D", "INPUT", "-s", ip, "-j", "DROP")
if err := cmd.Run(); err != nil {
logger.Printf("Error unBlocking IP %s: %v", ip, err)
}
}
}
// Sniffer methods
func (s *Sniffer) isWhitelisted(ip string) bool {
return s.config.WhitelistedIPs[ip]
}
func (s *Sniffer) handlePacket(packet gopacket.Packet) {
tcpLayer := packet.Layer(layers.LayerTypeTCP)
if tcpLayer == nil {
return
}
tcp, _ := tcpLayer.(*layers.TCP)
if tcp.SYN && !tcp.ACK {
ipLayer := packet.Layer(layers.LayerTypeIPv4)
if ipLayer == nil {
return
}
ip, _ := ipLayer.(*layers.IPv4)
srcIP := ip.SrcIP.String()
dstPort := int(tcp.DstPort)
if s.isWhitelisted(srcIP) {
return
}
if s.config.IgnoredPorts[dstPort] {
return
}
s.logger.Printf("Scan detected on port %d from %s", dstPort, srcIP)
s.tracker.Add(srcIP, s.config)
count := s.tracker.GetCount(srcIP)
if count > s.config.MaxScanAttempts {
s.logger.Printf("IP %s exceeded scan limit (%d attempts), blocking for %.0f minutes",
srcIP, s.config.MaxScanAttempts, s.config.BlockDuration.Minutes())
blockIP(srcIP, s.logger)
unblockTime := time.Now().Add(s.config.BlockDuration)
s.logger.Printf("IP %s will be unblocked at %s", srcIP, unblockTime.Format(s.config.Logging.TimestampFormat))
s.unblockTasks = append(s.unblockTasks, UnblockTask{
IP: srcIP,
UnblockAt: unblockTime,
})
}
}
}
func (s *Sniffer) unblockExpiredIPs() {
now := time.Now()
var remainingTasks []UnblockTask
for _, task := range s.unblockTasks {
if now.After(task.UnblockAt) || now.Equal(task.UnblockAt) {
unblockIP(task.IP, s.logger)
} else {
remainingTasks = append(remainingTasks, task)
}
}
s.unblockTasks = remainingTasks
}
func (s *Sniffer) StartSniffing() error {
s.logger.Printf("Starting port scan detection")
s.logger.Printf("Configuration:")
s.logger.Printf(" Device: %s", s.config.Device)
s.logger.Printf(" Ignored ports: %v", s.getSortedPorts())
s.logger.Printf(" Whitelisted IPs: %v", s.getSortedIPs())
s.logger.Printf(" Block duration: %.0f minutes", s.config.BlockDuration.Minutes())
s.logger.Printf(" Max scan attempts before blocking: %d", s.config.MaxScanAttempts)
handle, err := pcap.OpenLive(s.config.Device, 1600, false, pcap.BlockForever)
if err != nil {
return fmt.Errorf("error opening device %s: %v", s.config.Device, err)
}
s.handle = handle
err = handle.SetBPFFilter("tcp")
if err != nil {
return fmt.Errorf("error setting filter: %v", err)
}
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
for packet := range packetSource.Packets() {
s.handlePacket(packet)
}
return nil
}
func (s *Sniffer) getSortedPorts() []int {
ports := make([]int, 0, len(s.config.IgnoredPorts))
for port := range s.config.IgnoredPorts {
ports = append(ports, port)
}
sort.Ints(ports)
return ports
}
func (s *Sniffer) getSortedIPs() []string {
ips := make([]string, 0, len(s.config.WhitelistedIPs))
for ip := range s.config.WhitelistedIPs {
ips = append(ips, ip)
}
sort.Strings(ips)
return ips
}
// Main program
func main() {
if os.Geteuid() != 0 {
log.Fatal("This program must be run as root (sudo)")
}
config, err := loadConfig(defaultConfigFile)
if err != nil {
log.Fatalf("Configuration error: %v", err)
}
rl, err := NewRotatingLogger(config.Logging)
if err != nil {
log.Fatalf("Failed to initialize logger: %v", err)
}
defer rl.Close()
logger := log.New(rl, "", 0)
logger.SetPrefix(fmt.Sprintf("[%s] ", time.Now().Format(config.Logging.TimestampFormat)))
tracker := NewScanTracker()
sniffer := &Sniffer{
tracker: tracker,
config: config,
logger: logger,
}
sniffer.tracker.entries = make(map[string]*ScanEntry)
go func() {
if err := sniffer.StartSniffing(); err != nil {
logger.Fatalf("Sniffer error: %v", err)
}
}()
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)
logger.Println("Running... Press Ctrl+C to stop")
for {
select {
case <-ticker.C:
sniffer.unblockExpiredIPs()
case <-sigChan:
logger.Println("Stopping...")
if sniffer.handle != nil {
sniffer.handle.Close()
}
return
}
}
}