You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
483 lines
11 KiB
483 lines
11 KiB
package main
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
"strings"
|
|
"gopkg.in/yaml.v2"
|
|
"os"
|
|
"github.com/yang3yen/xxtea-go/xxtea"
|
|
)
|
|
|
|
// Copyright (c) 2025 Robert Strutts, License MIT
|
|
|
|
// Globals Vars:
|
|
var (
|
|
encryptor Encryptor
|
|
config Config
|
|
userManagerTCP *UserManagerTCP
|
|
userManagerUDP *UserManagerUDP
|
|
)
|
|
|
|
type Message struct {
|
|
Username string
|
|
Text string
|
|
Time string
|
|
}
|
|
|
|
type Encryptor interface {
|
|
Encrypt([]byte) ([]byte, error)
|
|
Decrypt([]byte) ([]byte, error)
|
|
}
|
|
|
|
type AESEncryptor struct {
|
|
key []byte
|
|
}
|
|
|
|
type XXTEAEncryptor struct {
|
|
key []byte
|
|
}
|
|
|
|
// Config represents the server configuration
|
|
type Config struct {
|
|
Server struct {
|
|
Timezone string `yaml:"timezone"`
|
|
Address string `yaml:"address"`
|
|
Port int `yaml:"port"`
|
|
Protocol string `yaml:"protocol"`
|
|
Encryption struct {
|
|
Type string `yaml:"type"`
|
|
Key string `yaml:"key"`
|
|
} `yaml:"encryption"`
|
|
} `yaml:"server"`
|
|
}
|
|
|
|
type ClientTCP struct {
|
|
conn net.Conn
|
|
username string
|
|
}
|
|
type ClientUDP struct {
|
|
addr *net.UDPAddr // Use UDPAddr for UDP clients
|
|
username string
|
|
}
|
|
|
|
type UserManagerTCP struct {
|
|
clients map[net.Conn]*ClientTCP
|
|
mutex sync.Mutex
|
|
}
|
|
type UserManagerUDP struct {
|
|
clients map[string]*ClientUDP // Use a string key (e.g., addr.String())
|
|
mutex sync.Mutex
|
|
}
|
|
|
|
func NewUserManagerTCP() *UserManagerTCP {
|
|
return &UserManagerTCP{
|
|
clients: make(map[net.Conn]*ClientTCP),
|
|
}
|
|
}
|
|
func NewUserManagerUDP() *UserManagerUDP {
|
|
return &UserManagerUDP{
|
|
clients: make(map[string]*ClientUDP),
|
|
}
|
|
}
|
|
|
|
func (um *UserManagerTCP) AddClientTCP(conn net.Conn, username string) {
|
|
um.mutex.Lock()
|
|
defer um.mutex.Unlock()
|
|
um.clients[conn] = &ClientTCP{conn: conn, username: username}
|
|
}
|
|
func (um *UserManagerUDP) AddClientUDP(addr *net.UDPAddr, username string) {
|
|
um.mutex.Lock()
|
|
defer um.mutex.Unlock()
|
|
um.clients[addr.String()] = &ClientUDP{addr: addr, username: username}
|
|
}
|
|
|
|
func (um *UserManagerTCP) RemoveClientTCP(conn net.Conn) {
|
|
um.mutex.Lock()
|
|
defer um.mutex.Unlock()
|
|
delete(um.clients, conn)
|
|
}
|
|
func (um *UserManagerUDP) RemoveClientUDP(addr *net.UDPAddr) {
|
|
um.mutex.Lock()
|
|
defer um.mutex.Unlock()
|
|
delete(um.clients, addr.String())
|
|
}
|
|
|
|
func (um *UserManagerTCP) GetUserListTCP() string {
|
|
um.mutex.Lock()
|
|
defer um.mutex.Unlock()
|
|
var userList string
|
|
for _, client := range um.clients {
|
|
userList += client.username + ", "
|
|
}
|
|
return userList
|
|
}
|
|
func (um *UserManagerTCP) UserMatchTCP(user string, encrypted []byte, sender net.Conn, message string) {
|
|
um.mutex.Lock()
|
|
defer um.mutex.Unlock()
|
|
for clientCon, client := range um.clients {
|
|
if user == client.username {
|
|
clientCon.Write(encrypted)
|
|
} else if clientCon == sender {
|
|
sayTCP(clientCon, message, "@YOU-Said(In Private)")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (um *UserManagerUDP) GetUserListUDP() string {
|
|
um.mutex.Lock()
|
|
defer um.mutex.Unlock()
|
|
var userList string
|
|
for _, client := range um.clients {
|
|
userList += client.username + ", "
|
|
}
|
|
return userList
|
|
}
|
|
|
|
func (a *AESEncryptor) Encrypt(plaintext []byte) ([]byte, error) {
|
|
block, err := aes.NewCipher(a.key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
nonce := make([]byte, gcm.NonceSize())
|
|
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
|
return nil, err
|
|
}
|
|
return gcm.Seal(nonce, nonce, plaintext, nil), nil
|
|
}
|
|
|
|
func (a *AESEncryptor) Decrypt(ciphertext []byte) ([]byte, error) {
|
|
block, err := aes.NewCipher(a.key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
nonceSize := gcm.NonceSize()
|
|
if len(ciphertext) < nonceSize {
|
|
return nil, fmt.Errorf("ciphertext too short")
|
|
}
|
|
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
|
return gcm.Open(nil, nonce, ciphertext, nil)
|
|
}
|
|
|
|
func (x *XXTEAEncryptor) Encrypt(plaintext []byte) ([]byte, error) {
|
|
result, err := xxtea.Encrypt(plaintext, x.key, true, 0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (x *XXTEAEncryptor) Decrypt(ciphertext []byte) ([]byte, error) {
|
|
result, err := xxtea.Decrypt(ciphertext, x.key, true, 0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func loadConfig(filename string) error {
|
|
var key []byte
|
|
|
|
data, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
return fmt.Errorf("error reading config file: %v", err)
|
|
}
|
|
|
|
err = yaml.Unmarshal(data, &config)
|
|
if err != nil {
|
|
return fmt.Errorf("error parsing config file: %v", err)
|
|
}
|
|
|
|
key = []byte(config.Server.Encryption.Key)
|
|
|
|
// Initialize the appropriate encryptor
|
|
switch config.Server.Encryption.Type {
|
|
case "aes":
|
|
// AES requires exactly 32 bytes key
|
|
if len(key) != 32 {
|
|
return fmt.Errorf("AES key must be exactly 32 bytes")
|
|
}
|
|
encryptor = &AESEncryptor{key: key}
|
|
case "xxtea":
|
|
// XXTEA Recommended 16 bytes key
|
|
if len(key) < 16 {
|
|
return fmt.Errorf("XXTEA key should be at least 16 bytes")
|
|
}
|
|
key := key[:16] // Take the first 16 bytes
|
|
encryptor = &XXTEAEncryptor{key: key}
|
|
default:
|
|
return fmt.Errorf("unsupported encryption type: %s", config.Server.Encryption.Type)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func sayTCP(conn net.Conn, text string, who string) {
|
|
encrypted, err := sayServer(text, who)
|
|
if err != nil {
|
|
return
|
|
}
|
|
conn.Write(encrypted)
|
|
}
|
|
func sayUDP(conn *net.UDPConn, addr *net.UDPAddr, text string, who string) {
|
|
encrypted, err := sayServer(text, who)
|
|
if err != nil {
|
|
return
|
|
}
|
|
conn.WriteToUDP(encrypted, addr)
|
|
}
|
|
|
|
func sayServer(text string, who string) ([]byte, error) {
|
|
// Load timezone
|
|
timezone, err := time.LoadLocation(config.Server.Timezone)
|
|
if err != nil {
|
|
fmt.Println("Error loading timezone:", err)
|
|
timezone = time.Local // Fallback to local time if timezone loading fails
|
|
}
|
|
|
|
msg := Message{
|
|
Username: who,
|
|
Text: text,
|
|
Time: time.Now().In(timezone).Format("2006-01-02 03:04:05 PM"),
|
|
}
|
|
|
|
jsonMsg, _ := json.Marshal(msg)
|
|
encrypted, err := encryptor.Encrypt(jsonMsg)
|
|
if err != nil {
|
|
fmt.Printf("Encryption error: %v\n", err)
|
|
return nil, err
|
|
}
|
|
return encrypted, nil
|
|
}
|
|
|
|
func parseInput(input string) (username string, message string) {
|
|
// Check if the input starts with '@'
|
|
if strings.HasPrefix(input, "@") {
|
|
// Trim the '@' and split by spaces
|
|
parts := strings.SplitN(input[1:], " ", 2)
|
|
username = parts[0]
|
|
if len(parts) > 1 {
|
|
message = parts[1]
|
|
return username, message
|
|
}
|
|
}
|
|
return "", ""
|
|
}
|
|
|
|
func handleTCPClient(conn net.Conn) {
|
|
defer conn.Close()
|
|
|
|
buf := make([]byte, 1024)
|
|
n, err := conn.Read(buf)
|
|
if err != nil {
|
|
return
|
|
}
|
|
decrypted, err := encryptor.Decrypt(buf[:n])
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
var msg Message
|
|
json.Unmarshal(decrypted, &msg)
|
|
username := msg.Username
|
|
|
|
// Add the client to the user manager
|
|
userManagerTCP.AddClientTCP(conn, username)
|
|
fmt.Printf("%s connected\n", username)
|
|
|
|
// Say Hello to all, but self
|
|
broadcastTCP(decrypted, conn, "")
|
|
|
|
for {
|
|
buf := make([]byte, 1024)
|
|
n, err := conn.Read(buf)
|
|
if err != nil {
|
|
break
|
|
}
|
|
decrypted, err := encryptor.Decrypt(buf[:n])
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
var msg Message
|
|
json.Unmarshal(decrypted, &msg)
|
|
|
|
message := string(msg.Text)
|
|
if message == "users!" {
|
|
// Handle the 'users' command
|
|
userList := userManagerTCP.GetUserListTCP()
|
|
sayTCP(conn, "User's Online: "+userList, "@SERVER")
|
|
} else {
|
|
privateUser, text := parseInput(message)
|
|
if privateUser != "" && text != "" {
|
|
privateTCP(privateUser, decrypted, conn, message)
|
|
} else {
|
|
// Broadcast the message to all other clients
|
|
broadcastTCP(decrypted, conn, message)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Remove the client from the user manager when they disconnect
|
|
userManagerTCP.RemoveClientTCP(conn)
|
|
fmt.Printf("%s disconnected\n", username)
|
|
}
|
|
|
|
func startTCPServer() error {
|
|
address := fmt.Sprintf("%s:%d", config.Server.Address, config.Server.Port)
|
|
ln, err := net.Listen("tcp", address)
|
|
if err != nil {
|
|
return fmt.Errorf("error starting TCP server: %v", err)
|
|
}
|
|
defer ln.Close()
|
|
|
|
fmt.Printf("TCP server listening on %s using %s encryption\n",
|
|
address, config.Server.Encryption.Type)
|
|
|
|
for {
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
fmt.Printf("Error accepting connection: %v\n", err)
|
|
continue
|
|
}
|
|
go handleTCPClient(conn)
|
|
}
|
|
}
|
|
|
|
func startUDPServer() error {
|
|
address := fmt.Sprintf("%s:%d", config.Server.Address, config.Server.Port)
|
|
addr, err := net.ResolveUDPAddr("udp", address)
|
|
if err != nil {
|
|
return fmt.Errorf("error resolving UDP address: %v", err)
|
|
}
|
|
|
|
conn, err := net.ListenUDP("udp", addr)
|
|
if err != nil {
|
|
return fmt.Errorf("error starting UDP server: %v", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
fmt.Printf("UDP server listening on %s using %s encryption\n",
|
|
address, config.Server.Encryption.Type)
|
|
|
|
for {
|
|
buf := make([]byte, 1024)
|
|
n, addr, err := conn.ReadFromUDP(buf)
|
|
if err != nil {
|
|
fmt.Printf("Error reading UDP: %v\n", err)
|
|
continue
|
|
}
|
|
|
|
decrypted, err := encryptor.Decrypt(buf[:n])
|
|
if err != nil {
|
|
fmt.Printf("Decryption error: %v\n", err)
|
|
continue
|
|
}
|
|
|
|
var msg Message
|
|
if err := json.Unmarshal(decrypted, &msg); err != nil {
|
|
fmt.Printf("Error unmarshalling message: %v\n", err)
|
|
continue
|
|
}
|
|
|
|
// Add the client to the user manager if they're new
|
|
if _, exists := userManagerUDP.clients[addr.String()]; !exists {
|
|
userManagerUDP.AddClientUDP(addr, msg.Username)
|
|
fmt.Printf("%s connected from %s\n", msg.Username, addr.String())
|
|
}
|
|
|
|
message := msg.Text
|
|
if message == "users!" {
|
|
// Handle the 'users' command
|
|
userList := userManagerUDP.GetUserListUDP()
|
|
sayUDP(conn, addr, "User's Online: "+userList, "@SERVER")
|
|
} else {
|
|
// Broadcast the message to all other clients
|
|
broadcastUDP(conn, decrypted, addr)
|
|
}
|
|
}
|
|
}
|
|
|
|
func privateTCP(user string, data []byte, sender net.Conn, message string) {
|
|
encrypted, err := encryptor.Encrypt(data)
|
|
if err != nil {
|
|
fmt.Printf("Encryption error: %v\n", err)
|
|
return
|
|
}
|
|
userManagerTCP.UserMatchTCP(user, encrypted, sender, message)
|
|
}
|
|
|
|
func broadcastTCP(data []byte, sender net.Conn, message string) {
|
|
encrypted, err := encryptor.Encrypt(data)
|
|
if err != nil {
|
|
fmt.Printf("Encryption error: %v\n", err)
|
|
return
|
|
}
|
|
|
|
userManagerTCP.mutex.Lock()
|
|
defer userManagerTCP.mutex.Unlock()
|
|
for client := range userManagerTCP.clients {
|
|
if client != sender {
|
|
client.Write(encrypted)
|
|
} else if message != "" {
|
|
sayTCP(client, message, "@YOU-Said")
|
|
}
|
|
}
|
|
}
|
|
|
|
func broadcastUDP(conn *net.UDPConn, data []byte, sender *net.UDPAddr) {
|
|
encrypted, err := encryptor.Encrypt(data)
|
|
if err != nil {
|
|
fmt.Printf("Encryption error: %v\n", err)
|
|
return
|
|
}
|
|
|
|
userManagerUDP.mutex.Lock()
|
|
defer userManagerUDP.mutex.Unlock()
|
|
for _, client := range userManagerUDP.clients {
|
|
if client.addr.String() != sender.String() {
|
|
conn.WriteToUDP(encrypted, client.addr)
|
|
}
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
if err := loadConfig("chat_server.yaml"); err != nil {
|
|
fmt.Printf("Failed to load configuration: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
fmt.Printf("Starting %s server...\n", config.Server.Protocol)
|
|
|
|
var err error
|
|
if config.Server.Protocol == "tcp" {
|
|
userManagerTCP = NewUserManagerTCP()
|
|
err = startTCPServer()
|
|
} else if config.Server.Protocol == "udp" {
|
|
userManagerUDP = NewUserManagerUDP()
|
|
err = startUDPServer()
|
|
} else {
|
|
fmt.Printf("Invalid protocol specified: %s\n", config.Server.Protocol)
|
|
os.Exit(1)
|
|
}
|
|
|
|
if err != nil {
|
|
fmt.Printf("Server error: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|