package main import ( "crypto/aes" "crypto/cipher" "crypto/rand" "encoding/json" "fmt" "io" "net" "sync" "time" "strings" "gopkg.in/yaml.v2" "os" "github.com/yang3yen/xxtea-go/xxtea" ) // Copyright (c) 2025 Robert Strutts, License MIT // Globals Vars: var ( encryptor Encryptor config Config userManagerTCP *UserManagerTCP userManagerUDP *UserManagerUDP ) type Message struct { Username string Text string Time string } type Encryptor interface { Encrypt([]byte) ([]byte, error) Decrypt([]byte) ([]byte, error) } type AESEncryptor struct { key []byte } type XXTEAEncryptor struct { key []byte } // Config represents the server configuration type Config struct { Server struct { Timezone string `yaml:"timezone"` Address string `yaml:"address"` Port int `yaml:"port"` Protocol string `yaml:"protocol"` Encryption struct { Type string `yaml:"type"` Key string `yaml:"key"` } `yaml:"encryption"` } `yaml:"server"` } type ClientTCP struct { conn net.Conn username string } type ClientUDP struct { addr *net.UDPAddr // Use UDPAddr for UDP clients username string } type UserManagerTCP struct { clients map[net.Conn]*ClientTCP mutex sync.Mutex } type UserManagerUDP struct { clients map[string]*ClientUDP // Use a string key (e.g., addr.String()) mutex sync.Mutex } func NewUserManagerTCP() *UserManagerTCP { return &UserManagerTCP{ clients: make(map[net.Conn]*ClientTCP), } } func NewUserManagerUDP() *UserManagerUDP { return &UserManagerUDP{ clients: make(map[string]*ClientUDP), } } func (um *UserManagerTCP) AddClientTCP(conn net.Conn, username string) { um.mutex.Lock() defer um.mutex.Unlock() um.clients[conn] = &ClientTCP{conn: conn, username: username} } func (um *UserManagerUDP) AddClientUDP(addr *net.UDPAddr, username string) { um.mutex.Lock() defer um.mutex.Unlock() um.clients[addr.String()] = &ClientUDP{addr: addr, username: username} } func (um *UserManagerTCP) RemoveClientTCP(conn net.Conn) { um.mutex.Lock() defer um.mutex.Unlock() delete(um.clients, conn) } func (um *UserManagerUDP) RemoveClientUDP(addr *net.UDPAddr) { um.mutex.Lock() defer um.mutex.Unlock() delete(um.clients, addr.String()) } func (um *UserManagerTCP) GetUserListTCP() string { um.mutex.Lock() defer um.mutex.Unlock() var userList string for _, client := range um.clients { userList += client.username + ", " } return userList } func (um *UserManagerTCP) UserMatchTCP(user string, encrypted []byte, sender net.Conn, message string) { um.mutex.Lock() defer um.mutex.Unlock() for clientCon, client := range um.clients { if user == client.username { clientCon.Write(encrypted) } else if clientCon == sender { sayTCP(clientCon, message, "@YOU-Said(In Private)") } } } func (um *UserManagerUDP) GetUserListUDP() string { um.mutex.Lock() defer um.mutex.Unlock() var userList string for _, client := range um.clients { userList += client.username + ", " } return userList } func (a *AESEncryptor) Encrypt(plaintext []byte) ([]byte, error) { block, err := aes.NewCipher(a.key) if err != nil { return nil, err } gcm, err := cipher.NewGCM(block) if err != nil { return nil, err } nonce := make([]byte, gcm.NonceSize()) if _, err = io.ReadFull(rand.Reader, nonce); err != nil { return nil, err } return gcm.Seal(nonce, nonce, plaintext, nil), nil } func (a *AESEncryptor) Decrypt(ciphertext []byte) ([]byte, error) { block, err := aes.NewCipher(a.key) if err != nil { return nil, err } gcm, err := cipher.NewGCM(block) if err != nil { return nil, err } nonceSize := gcm.NonceSize() if len(ciphertext) < nonceSize { return nil, fmt.Errorf("ciphertext too short") } nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] return gcm.Open(nil, nonce, ciphertext, nil) } func (x *XXTEAEncryptor) Encrypt(plaintext []byte) ([]byte, error) { result, err := xxtea.Encrypt(plaintext, x.key, true, 0) if err != nil { return nil, err } return result, nil } func (x *XXTEAEncryptor) Decrypt(ciphertext []byte) ([]byte, error) { result, err := xxtea.Decrypt(ciphertext, x.key, true, 0) if err != nil { return nil, err } return result, nil } func loadConfig(filename string) error { var key []byte data, err := os.ReadFile(filename) if err != nil { return fmt.Errorf("error reading config file: %v", err) } err = yaml.Unmarshal(data, &config) if err != nil { return fmt.Errorf("error parsing config file: %v", err) } key = []byte(config.Server.Encryption.Key) // Initialize the appropriate encryptor switch config.Server.Encryption.Type { case "aes": // AES requires exactly 32 bytes key if len(key) != 32 { return fmt.Errorf("AES key must be exactly 32 bytes") } encryptor = &AESEncryptor{key: key} case "xxtea": // XXTEA Recommended 16 bytes key if len(key) < 16 { return fmt.Errorf("XXTEA key should be at least 16 bytes") } key := key[:16] // Take the first 16 bytes encryptor = &XXTEAEncryptor{key: key} default: return fmt.Errorf("unsupported encryption type: %s", config.Server.Encryption.Type) } return nil } func sayTCP(conn net.Conn, text string, who string) { encrypted, err := sayServer(text, who) if err != nil { return } conn.Write(encrypted) } func sayUDP(conn *net.UDPConn, addr *net.UDPAddr, text string, who string) { encrypted, err := sayServer(text, who) if err != nil { return } conn.WriteToUDP(encrypted, addr) } func sayServer(text string, who string) ([]byte, error) { // Load timezone timezone, err := time.LoadLocation(config.Server.Timezone) if err != nil { fmt.Println("Error loading timezone:", err) timezone = time.Local // Fallback to local time if timezone loading fails } msg := Message{ Username: who, Text: text, Time: time.Now().In(timezone).Format("2006-01-02 03:04:05 PM"), } jsonMsg, _ := json.Marshal(msg) encrypted, err := encryptor.Encrypt(jsonMsg) if err != nil { fmt.Printf("Encryption error: %v\n", err) return nil, err } return encrypted, nil } func parseInput(input string) (username string, message string) { // Check if the input starts with '@' if strings.HasPrefix(input, "@") { // Trim the '@' and split by spaces parts := strings.SplitN(input[1:], " ", 2) username = parts[0] if len(parts) > 1 { message = parts[1] return username, message } } return "", "" } func handleTCPClient(conn net.Conn) { defer conn.Close() buf := make([]byte, 1024) n, err := conn.Read(buf) if err != nil { return } decrypted, err := encryptor.Decrypt(buf[:n]) if err != nil { return } var msg Message json.Unmarshal(decrypted, &msg) username := msg.Username // Add the client to the user manager userManagerTCP.AddClientTCP(conn, username) fmt.Printf("%s connected\n", username) // Say Hello to all, but self broadcastTCP(decrypted, conn, "") for { buf := make([]byte, 1024) n, err := conn.Read(buf) if err != nil { break } decrypted, err := encryptor.Decrypt(buf[:n]) if err != nil { continue } var msg Message json.Unmarshal(decrypted, &msg) message := string(msg.Text) if message == "users!" { // Handle the 'users' command userList := userManagerTCP.GetUserListTCP() sayTCP(conn, "User's Online: "+userList, "@SERVER") } else { privateUser, text := parseInput(message) if privateUser != "" && text != "" { privateTCP(privateUser, decrypted, conn, message) } else { // Broadcast the message to all other clients broadcastTCP(decrypted, conn, message) } } } // Remove the client from the user manager when they disconnect userManagerTCP.RemoveClientTCP(conn) fmt.Printf("%s disconnected\n", username) } func startTCPServer() error { address := fmt.Sprintf("%s:%d", config.Server.Address, config.Server.Port) ln, err := net.Listen("tcp", address) if err != nil { return fmt.Errorf("error starting TCP server: %v", err) } defer ln.Close() fmt.Printf("TCP server listening on %s using %s encryption\n", address, config.Server.Encryption.Type) for { conn, err := ln.Accept() if err != nil { fmt.Printf("Error accepting connection: %v\n", err) continue } go handleTCPClient(conn) } } func startUDPServer() error { address := fmt.Sprintf("%s:%d", config.Server.Address, config.Server.Port) addr, err := net.ResolveUDPAddr("udp", address) if err != nil { return fmt.Errorf("error resolving UDP address: %v", err) } conn, err := net.ListenUDP("udp", addr) if err != nil { return fmt.Errorf("error starting UDP server: %v", err) } defer conn.Close() fmt.Printf("UDP server listening on %s using %s encryption\n", address, config.Server.Encryption.Type) for { buf := make([]byte, 1024) n, addr, err := conn.ReadFromUDP(buf) if err != nil { fmt.Printf("Error reading UDP: %v\n", err) continue } decrypted, err := encryptor.Decrypt(buf[:n]) if err != nil { fmt.Printf("Decryption error: %v\n", err) continue } var msg Message if err := json.Unmarshal(decrypted, &msg); err != nil { fmt.Printf("Error unmarshalling message: %v\n", err) continue } // Add the client to the user manager if they're new if _, exists := userManagerUDP.clients[addr.String()]; !exists { userManagerUDP.AddClientUDP(addr, msg.Username) fmt.Printf("%s connected from %s\n", msg.Username, addr.String()) } message := msg.Text if message == "users!" { // Handle the 'users' command userList := userManagerUDP.GetUserListUDP() sayUDP(conn, addr, "User's Online: "+userList, "@SERVER") } else { // Broadcast the message to all other clients broadcastUDP(conn, decrypted, addr) } } } func privateTCP(user string, data []byte, sender net.Conn, message string) { encrypted, err := encryptor.Encrypt(data) if err != nil { fmt.Printf("Encryption error: %v\n", err) return } userManagerTCP.UserMatchTCP(user, encrypted, sender, message) } func broadcastTCP(data []byte, sender net.Conn, message string) { encrypted, err := encryptor.Encrypt(data) if err != nil { fmt.Printf("Encryption error: %v\n", err) return } userManagerTCP.mutex.Lock() defer userManagerTCP.mutex.Unlock() for client := range userManagerTCP.clients { if client != sender { client.Write(encrypted) } else if message != "" { sayTCP(client, message, "@YOU-Said") } } } func broadcastUDP(conn *net.UDPConn, data []byte, sender *net.UDPAddr) { encrypted, err := encryptor.Encrypt(data) if err != nil { fmt.Printf("Encryption error: %v\n", err) return } userManagerUDP.mutex.Lock() defer userManagerUDP.mutex.Unlock() for _, client := range userManagerUDP.clients { if client.addr.String() != sender.String() { conn.WriteToUDP(encrypted, client.addr) } } } func main() { if err := loadConfig("chat_server.yaml"); err != nil { fmt.Printf("Failed to load configuration: %v\n", err) os.Exit(1) } fmt.Printf("Starting %s server...\n", config.Server.Protocol) var err error if config.Server.Protocol == "tcp" { userManagerTCP = NewUserManagerTCP() err = startTCPServer() } else if config.Server.Protocol == "udp" { userManagerUDP = NewUserManagerUDP() err = startUDPServer() } else { fmt.Printf("Invalid protocol specified: %s\n", config.Server.Protocol) os.Exit(1) } if err != nil { fmt.Printf("Server error: %v\n", err) os.Exit(1) } }