package main
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"strconv"
"syscall"
"time"
"go-chat-react/internal/server"
)
func gracefulShutdown(apiServer *http.Server, done chan bool) {
// Create context that listens for the interrupt signal from the OS.
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
// Listen for the interrupt signal.
<-ctx.Done()
log.Println("shutting down gracefully, press Ctrl+C again to force")
// The context is used to inform the server it has 5 seconds to finish
// the request it is currently handling
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := apiServer.Shutdown(ctx); err != nil {
log.Printf("Server forced to shutdown with error: %v", err)
}
log.Println("Server exiting")
// Notify the main goroutine that the shutdown is complete
done <- true
}
func main() {
log.Println("Starting server...")
port, _ := strconv.Atoi(os.Getenv("PORT"))
server := server.NewServer(true, port)
// Create a done channel to signal when the shutdown is complete
done := make(chan bool, 1)
// Run graceful shutdown in a separate goroutine
go gracefulShutdown(server, done)
err := server.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
panic(fmt.Sprintf("http server error: %s", err))
}
// Wait for the graceful shutdown to complete
<-done
log.Println("Graceful shutdown complete.")
}
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"strconv"
"time"
_ "github.com/joho/godotenv/autoload"
"github.com/mattn/go-sqlite3"
_ "github.com/mattn/go-sqlite3"
)
type dbConn interface {
Query(query string, args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
Exec(query string, args ...any) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
type AtomitcDBService struct {
service *DBService
commit func() error
rollback func() error
}
func (a *AtomitcDBService) Service() *DBService {
return a.service
}
func (a *AtomitcDBService) Commit() error {
return a.commit()
}
func (a *AtomitcDBService) Rollback() error {
return a.rollback()
}
type DBService struct {
db *sql.DB
conn dbConn
}
func New(db *sql.DB) *DBService {
return &DBService{db: db, conn: db}
}
func (r *DBService) DeleteUserSessionToken(userid Id) error {
random_token := "TODO_FIND_METHODALATER"
expire := time.Now().Add(-24 * time.Hour)
result, err := r.conn.Exec(
"UPDATE UserLoginTable SET token = ?, token_expire_time = ? WHERE userid=? ",
random_token,
expire,
userid,
)
if err != nil {
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("error getting rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrRecordNotFound
}
return nil
}
func (r *DBService) ValidateUserLoginInfo(userid Id, password string) (bool, error) {
user, err := r.GetUserLoginInfo(userid)
if err != nil {
return false, err
}
return comparePassword(user, password), nil
}
func (db *DBService) withTx(tx *sql.Tx) *DBService {
return &DBService{db: db.db, conn: tx}
}
func (r *DBService) Atomic(ctx context.Context, opts *sql.TxOptions) (*AtomitcDBService, error) {
tx, err := r.db.BeginTx(ctx, opts)
if err != nil {
return &AtomitcDBService{}, err
}
commit := func() error {
return tx.Commit()
}
rollback := func() error {
return tx.Rollback()
}
a := r.withTx(tx)
return &AtomitcDBService{service: a, commit: commit, rollback: rollback}, nil
}
func hashPassword(password string, salt string) string {
return (password + salt)
}
func comparePassword(userinfo UserLoginInfo, password string) bool {
return hashPassword(password, userinfo.Salt) == userinfo.PasswordHash
}
// Close closes the database connection.
// It logs a message indicating the disconnection from the specific database.
// If the connection is successfully closed, it returns nil.
// If an error occurs while closing the connection, it returns the error.
func (s *DBService) Close() error {
return s.db.Close()
}
func (r *DBService) AddUserToServer(userid Id, serverid Id, nickname string) error {
_, err := r.conn.Exec(
"INSERT INTO UsersServerTable ( userid, serverid, nickname) VALUES ( ?, ?, ?)",
userid,
serverid,
nickname,
)
if err != nil {
return fmt.Errorf("add user - userid: %d err: %w", userid, err)
}
return err
}
func (r *DBService) CreateServer(ownerid Id, servername string) (Id, error) {
d, err := r.conn.Exec(
"INSERT INTO ServerTable (servername, ownerid) VALUES (?, ?)",
servername,
ownerid,
)
if err != nil {
return 0, fmt.Errorf(
"add server - servername: %s ownerid: %d err: %w",
servername,
ownerid,
err,
)
}
id, err := d.LastInsertId()
if err != nil {
return 0, err
}
if id < 0 {
return 0, ErrNegativeRowIndex
}
return Id(id), nil
}
func (r *DBService) DeleteMessage(messageid Id) error {
a, err := r.conn.Exec("DELETE FROM ChannelMessageTable WHERE messageid = ?", messageid)
if err != nil {
return err
}
rowsAffected, err := a.RowsAffected()
if err != nil {
return fmt.Errorf("error getting rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrRecordNotFound
}
return nil
}
func (r *DBService) UpdateMessage(messageid Id, message string) error {
_, err := r.conn.Exec(
"UPDATE ChannelMessageTable SET contents = ? WHERE messageid=? ",
message,
messageid,
)
return err
}
func (r *DBService) GetUserIDFromUserName(username string) (Id, error) {
rows, err := r.conn.Query("SELECT userid FROM UserTable WHERE username = ?", username)
if err != nil {
return 0, err
}
defer rows.Close()
count := 0
var userid Id
for rows.Next() {
count += 1
if count > 1 {
return 0, ErrMultipleRecords
}
err := rows.Scan(&userid)
if err != nil {
return 0, err
}
}
if count == 0 {
return 0, ErrRecordNotFound
}
return userid, nil
}
func (r *DBService) UpdateUserSessionToken(userid Id) (string, time.Time, error) {
token := "token" + strconv.FormatUint(uint64(userid), 10)
expire := time.Now().Add(24 * time.Hour)
_, err := r.conn.Exec(
"UPDATE UserLoginTable SET token = ?, token_expire_time = ? WHERE userid=? ",
token,
expire,
userid,
)
if err != nil {
return "", expire, err
}
return token, expire, nil
}
func (r *DBService) GetUserLoginInfo(userid Id) (UserLoginInfo, error) {
rows, err := r.conn.Query(
"SELECT userid, passwordhash, salt, token, token_expire_time FROM UserLoginTable WHERE userid = ?",
userid,
)
if err != nil {
return UserLoginInfo{}, err
}
defer rows.Close()
count := 0
var user UserLoginInfo
for rows.Next() {
count += 1
if count > 1 {
return UserLoginInfo{}, ErrMultipleRecords
}
err := rows.Scan(
&user.UserId,
&user.PasswordHash,
&user.Salt,
&user.Token,
&user.TokenExpireTime,
)
if err != nil {
return UserLoginInfo{}, err
}
}
if count == 0 {
return UserLoginInfo{}, ErrRecordNotFound
}
return user, nil
}
func (r *DBService) GetUserLoginInfoFromToken(token string) (UserLoginInfo, error) {
rows, err := r.conn.Query(
"SELECT userid, passwordhash, salt, token, token_expire_time FROM UserLoginTable WHERE token = ?",
token,
)
if err != nil {
return UserLoginInfo{}, err
}
defer rows.Close()
count := 0
var user UserLoginInfo
for rows.Next() {
count += 1
if count > 1 {
return UserLoginInfo{}, ErrMultipleRecords
}
err := rows.Scan(
&user.UserId,
&user.PasswordHash,
&user.Salt,
&user.Token,
&user.TokenExpireTime,
)
if err != nil {
return UserLoginInfo{}, err
}
}
if count == 0 {
return UserLoginInfo{}, ErrRecordNotFound
}
return user, nil
}
func (r *DBService) GetUser(userid Id) (User, error) {
rows, err := r.conn.Query("SELECT userid, username FROM UserTable WHERE userid = ?", userid)
if errors.Is(err, sql.ErrNoRows) {
return User{}, ErrRecordNotFound
}
if err != nil {
return User{}, err
}
defer rows.Close()
count := 0
var user User
for rows.Next() {
count += 1
if count > 1 {
return User{}, ErrMultipleRecords
}
err := rows.Scan(&user.UserId, &user.UserName)
if err != nil {
return User{}, err
}
}
if count == 0 {
return User{}, ErrRecordNotFound
}
return user, nil
}
func (r *DBService) CreateUser(username string, password string) (Id, error) {
d, err := r.conn.Exec("INSERT INTO UserTable (username) VALUES (?)", username)
var sqliteErr sqlite3.Error
if errors.As(err, &sqliteErr) && sqliteErr.Code == sqlite3.ErrConstraint &&
sqliteErr.ExtendedCode == sqlite3.ErrConstraintUnique {
return 0, ErrRecordAlreadyExists
}
if err != nil {
return 0, fmt.Errorf("add user - username: %s err: %w", username, err)
}
id, err := d.LastInsertId()
if err != nil {
return 0, err
}
if id < 0 {
return 0, ErrNegativeRowIndex
}
random_salt := "salt" + strconv.FormatUint(uint64(id), 10)
hashed_password := hashPassword(password, random_salt)
_, err = r.conn.Exec(
"INSERT INTO UserLoginTable (userid, passwordhash, salt, token) VALUES ( ?, ?, ?, ?)",
id,
hashed_password,
random_salt,
"",
)
if err != nil {
return 0, err
}
return Id(id), nil
}
func (r *DBService) UpdateUserName(userid Id, username string) error {
_, err := r.conn.Exec("UPDATE UserTable SET username = ? WHERE userid=? ", username, userid)
return err
}
func (r *DBService) UpdateServerName(serverid Id, servername string) error {
_, err := r.conn.Exec(
"UPDATE ServerTable SET servername = ? WHERE serverid=? ",
servername,
serverid,
)
return err
}
func (r *DBService) GetRecentUsernames(userid Id, number uint) ([]UsernameLogEntry, error) {
rows, err := r.conn.Query(
"SELECT userid, username, timestamp FROM UserNameLogTable WHERE userid = ? ORDER BY timestamp DESC LIMIT ?",
userid,
number,
)
if err != nil {
return []UsernameLogEntry{}, err
}
defer rows.Close()
var names []UsernameLogEntry
for rows.Next() {
var name UsernameLogEntry
err := rows.Scan(&name.UserId, &name.Username, &name.Timestamp)
if err != nil {
return []UsernameLogEntry{}, err
}
names = append(names, name)
}
return names, nil
}
func (r *DBService) GetUsersOfServer(serverid Id) ([]User, error) {
rows, err := r.conn.Query(
"SELECT U.userid, U.username FROM UsersServerTable as US INNER JOIN UserTable as U ON US.userid = U.userid WHERE US.serverid = ?",
serverid,
)
if err != nil {
return []User{}, err
}
defer rows.Close()
var names []User
for rows.Next() {
var name User
err := rows.Scan(&name.UserId, &name.UserName)
if err != nil {
return []User{}, err
}
names = append(names, name)
}
return names, nil
}
func (r *DBService) DeleteServer(serverid Id) error {
_, err := r.conn.Exec("DELETE FROM ServerTable WHERE serverid = ?", serverid)
return err
}
func (r *DBService) GetServersOfUser(userid Id) ([]Server, error) {
rows, err := r.conn.Query(
"SELECT S.serverid, S.ownerid, S.servername FROM UsersServerTable as U INNER JOIN ServerTable as S ON U.serverid = S.serverid WHERE U.userid = ?",
userid,
)
if err != nil {
return []Server{}, err
}
defer rows.Close()
var servers []Server
for rows.Next() {
var s Server
err := rows.Scan(&s.ServerId, &s.OwnerId, &s.ServerName)
if err != nil {
return []Server{}, err
}
servers = append(servers, s)
}
return servers, nil
}
func (r *DBService) GetChannelsOfServer(serverid Id) ([]Channel, error) {
rows, err := r.conn.Query(
"SELECT channelid, serverid, channelname, timestamp FROM ChannelTable WHERE serverid = ?",
serverid,
)
if err != nil {
return []Channel{}, err
}
defer rows.Close()
var servers []Channel
for rows.Next() {
var s Channel
err := rows.Scan(&s.ChannelId, &s.ServerId, &s.ChannelName, &s.Timestamp)
if err != nil {
return []Channel{}, err
}
servers = append(servers, s)
}
return servers, nil
}
func (r *DBService) IsUserInChannel(userid Id, channelid Id) (bool, error) {
query := `SELECT COUNT(1) FROM UsersChannelTable WHERE channelid = ? AND userid = ?`
var count int
err := r.db.QueryRow(query, channelid, userid).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
func (r *DBService) AddUserToChannel(userid Id, channelid Id) error {
d, err := r.conn.Exec(
"INSERT INTO UsersChannelTable ( userid, channelid) VALUES ( ?, ?)",
userid,
channelid,
)
if err != nil {
return fmt.Errorf("add user - userid: %d err: %w", userid, err)
}
id, err := d.LastInsertId()
if err != nil {
return err
}
if id < 0 {
return ErrNegativeRowIndex
}
return nil
}
func (r *DBService) AddChannel(serverid Id, channelname string) (Id, error) {
d, err := r.conn.Exec(
"INSERT INTO ChannelTable ( serverid, channelname) VALUES ( ?, ?)",
serverid,
channelname,
)
if err != nil {
return 0, fmt.Errorf("add user - username: %s err: %w", channelname, err)
}
id, err := d.LastInsertId()
if err != nil {
return 0, err
}
if id < 0 {
return 0, ErrNegativeRowIndex
}
return Id(id), nil
}
func (r *DBService) IsUserInServer(userid Id, serverid Id) (bool, error) {
query := `SELECT COUNT(1) FROM UsersServerTable WHERE serverid = ? AND userid = ?`
var count int
err := r.db.QueryRow(query, serverid, userid).Scan(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
func (r *DBService) GetChannel(channelid Id) (Channel, error) {
rows, err := r.conn.Query(
"SELECT channelid, channelname, serverid, timestamp FROM ChannelTable WHERE channelid = ?",
channelid,
)
if err != nil {
return Channel{}, err
}
defer rows.Close()
count := 0
var channel Channel
for rows.Next() {
count += 1
if count > 1 {
return Channel{}, ErrMultipleRecords
}
err := rows.Scan(
&channel.ChannelId,
&channel.ChannelName,
&channel.ServerId,
&channel.Timestamp,
)
if err != nil {
return Channel{}, err
}
}
if count == 0 {
return Channel{}, ErrRecordNotFound
}
return channel, nil
}
func (r *DBService) UpdateChannel(channelid Id, new_server_name string) error {
_, err := r.conn.Exec(
"UPDATE ChannelTable SET channelname = ? WHERE channelid = ? ",
new_server_name,
channelid,
)
return err
}
func (r *DBService) GetServer(serverid Id) (Server, error) {
rows, err := r.conn.Query(
"SELECT serverid, ownerid, servername FROM ServerTable WHERE serverid = ? ",
serverid,
)
if err != nil {
return Server{}, err
}
defer rows.Close()
var server Server
server_found := false
for rows.Next() {
server_found = true
err := rows.Scan(
&server.ServerId,
&server.OwnerId,
&server.ServerName,
)
if err != nil {
return Server{}, err
}
}
if !server_found {
return Server{}, ErrRecordNotFound
}
return server, nil
}
func (r *DBService) GetMessage(messageid Id) (Message, error) {
rows, err := r.conn.Query(
"SELECT m.messageid, m.channelid, m.userid, m.contents, m.timestamp, m.editted, m.edittimestamp, c.serverid FROM ChannelMessageTable m JOIN ChannelTable c on m.channelid = c.channelid WHERE m.messageid = ?",
messageid,
)
if err != nil {
return Message{}, err
}
defer rows.Close()
count := 0
var message Message
for rows.Next() {
count += 1
err := rows.Scan(
&message.MessageId,
&message.ChannelId,
&message.UserId,
&message.Contents,
&message.Timestamp,
&message.Editted,
&message.EdittedTimeStamp,
&message.ServerId,
)
if err != nil {
return Message{}, err
}
}
if count == 0 {
return Message{}, ErrRecordNotFound
}
return message, nil
}
func (r *DBService) AddMessage(channelid Id, userid Id, message string) (Id, error) {
if userid == 0 || channelid == 0 {
return 0, fmt.Errorf("add message - zero userid or channel id")
}
d, err := r.conn.Exec(
"INSERT INTO ChannelMessageTable (userid, channelid, contents) VALUES ( ?, ?, ?)",
userid,
channelid,
message,
)
if err != nil {
return 0, fmt.Errorf("add user - userid: %d err: %w", userid, err)
}
id, err := d.LastInsertId()
if err != nil {
return 0, err
}
if id < 0 {
return 0, ErrNegativeRowIndex
}
return Id(id), nil
}
// "SELECT m.messageid, m.channelid, m.userid, m.contents, m.timestamp, m.editted, m.edittimestamp, c.serverid FROM ChannelMessageTable m JOIN ChannelServeTable c ON m.channelid = c.channelid WHERE m.channelid = ? ORDER BY m.timestamp DESC LIMIT ?",
func (r *DBService) GetMessagesInChannel(channelid Id, number uint) ([]Message, error) {
rows, err := r.conn.Query(
"SELECT m.messageid, m.channelid, m.userid, m.contents, m.timestamp, m.editted, m.edittimestamp, c.serverid FROM ChannelMessageTable m JOIN ChannelTable c on m.channelid = c.channelid WHERE m.channelid = ? ORDER BY m.timestamp DESC LIMIT ?",
channelid,
number,
)
if err != nil {
return []Message{}, err
}
defer rows.Close()
var messages []Message
for rows.Next() {
var message Message
err := rows.Scan(
&message.MessageId,
&message.ChannelId,
&message.UserId,
&message.Contents,
&message.Timestamp,
&message.Editted,
&message.EdittedTimeStamp,
&message.ServerId,
)
if err != nil {
return []Message{}, err
}
messages = append(messages, message)
}
return messages, nil
}
func (r *DBService) GetUsersInChannel(channelid Id) ([]User, error) {
rows, err := r.conn.Query(
"SELECT U.userid, U.username FROM UsersChannelTable as UC INNER JOIN UserTable as U ON UC.userid = U.userid WHERE UC.channelid = ?",
channelid,
)
if err != nil {
return []User{}, err
}
defer rows.Close()
var names []User
for rows.Next() {
var name User
err := rows.Scan(&name.UserId, &name.UserName)
if err != nil {
return []User{}, err
}
names = append(names, name)
}
return names, nil
}
func (r *DBService) RemoveUserFromChannel(channelid Id, userid Id) error {
result, err := r.conn.Exec(
"DELETE FROM UsersChannelTable WHERE channelid = ? AND userid = ?",
channelid,
userid,
)
// check and ensure that at least one row was deleted
if err != nil {
return fmt.Errorf(
"remove user from channel - channelid: %d userid: %d err: %w",
channelid,
userid,
err,
)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("error getting rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrRecordNotFound
}
return nil
}
func (r *DBService) DeleteChannel(channelid Id) error {
a, err := r.conn.Exec("DELETE FROM ChannelTable WHERE channelid = ?", channelid)
if err != nil {
return err
}
rowsAffected, err := a.RowsAffected()
if err != nil {
return fmt.Errorf("error getting rows affected: %w", err)
}
if rowsAffected == 0 {
return ErrRecordNotFound
}
return nil
}
package database
import (
"strconv"
"time"
)
type Id = uint
func ParseIntToID(id int) (Id, error) {
if id <= 0 {
return Id(0), ErrUnsupportedNegativeValue
}
return Id(id), nil
}
func ParseStringToID(id string) (Id, error) {
intid, err := strconv.Atoi(id)
if err != nil {
return Id(0), ErrParsingValue
}
return ParseIntToID(intid)
}
type User struct {
UserId Id
UserName string
}
type UserLoginInfo struct {
UserId Id
PasswordHash string
Salt string
Token string
TokenExpireTime time.Time
}
type UsernameLogEntry struct {
UserId Id
Username string
Timestamp time.Time
}
type UserNicknameLogEntry struct {
UserId Id
ServerId Id
Nickname string
Timestamp time.Time
}
type Server struct {
ServerId Id
OwnerId Id
ServerName string
}
type Channel struct {
ChannelId Id
ServerId Id
ChannelName string
Timestamp time.Time
}
type Message struct {
MessageId Id
UserId Id
ServerId Id
ChannelId Id
Contents string
Timestamp time.Time
Editted *bool
EdittedTimeStamp *time.Time
}
package server
import (
"context"
"log"
"net/http"
"time"
)
func (s *Server) logEndpoint(next http.Handler) http.Handler {
counter := 0
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
counter = counter + 1
start_time := time.Since(startTime)
// Proceed with the next handler
next.ServeHTTP(w, r)
end_time := time.Since(startTime.Add(start_time))
log.Printf(
"%d Endpoint hit: %s took %d ms\n",
counter,
r.URL,
end_time.Milliseconds(),
)
})
}
func (s *Server) WithAuthUser(next http.HandlerFunc) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookieName := "token"
cookie, err := r.Cookie(cookieName)
if err != nil {
if err == http.ErrNoCookie {
// Handle the case where the cookie is not found
http.Error(w, "Token cookie not found", http.StatusUnauthorized)
return
}
// Handle other potential errors
http.Error(w, "Error retrieving cookie", http.StatusInternalServerError)
return
}
// Access the cookie value
token := cookie.Value
passwordInfo, err := s.db.GetUserLoginInfoFromToken(token)
if err != nil {
http.Error(w, "unable to locate password", http.StatusBadRequest)
return
}
if !s.validSession(passwordInfo, token) {
http.Error(w, "invalid token", http.StatusBadRequest)
return
}
next(w, r.WithContext(context.WithValue(r.Context(), "userid", passwordInfo.UserId)))
})
}
func (s *Server) corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers
w.Header().
Set("Access-Control-Allow-Origin", "localhost")
// Replace "*" with specific origins if needed
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH")
w.Header().
Set("Access-Control-Allow-Headers", "Accept, Authorization, Content-Type, X-CSRF-Token")
w.Header().
Set("Access-Control-Allow-Credentials", "true")
// Set to "true" if credentials are required
// Handle preflight OPTIONS requests
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
// Proceed with the next handler
next.ServeHTTP(w, r)
})
}
package server
import (
"errors"
"fmt"
"net/http"
"strconv"
"time"
"go-chat-react/internal/database"
)
type httpErrorInfo struct {
StatusCode int
Message string
}
type serverVerification struct {
Validated bool
UserId database.Id
Server database.Server
}
func parsePathFromID(r *http.Request, field string) (database.Id, error) {
fieldIDStr := r.PathValue(field)
fieldID, err := strconv.Atoi(fieldIDStr)
if err != nil {
return 0, fmt.Errorf("invalid request: unable to parse %s", field)
}
if fieldID <= 0 {
return 0, fmt.Errorf("invalid request: valid %s id requires >=0", field)
}
return database.Id(fieldID), nil
}
func getUserIdFromContext(r *http.Request) (database.Id, error) {
val := r.Context().Value("userid")
if val == nil {
return database.Id(0), errors.New("unable to get userid from context")
}
userid, ok := val.(database.Id)
if !ok {
return database.Id(0), errors.New("unable to get userid from context")
}
return userid, nil
}
func (s *Server) validSession(userinfo database.UserLoginInfo, usertoken string) bool {
// if no token has been set
if userinfo.Token == "" {
return false
}
// if the token has expired
if time.Now().After(userinfo.TokenExpireTime) {
return false
}
// if the token is not the same as the one in the database
return userinfo.Token == usertoken
}
func (s *Server) GetServerFromRequest(r *http.Request) (database.Server, error) {
serverid, err := parsePathFromID(r, "serverid")
if err != nil {
return database.Server{}, errors.New("invalid request: unable to parse server id")
}
server, err := s.db.GetServer(database.Id(serverid))
if err != nil {
return database.Server{}, errors.New("error: unable to locate server")
}
return server, nil
}
func (s *Server) GetChannelFromRequest(r *http.Request) (database.Channel, error) {
channelid, err := parsePathFromID(r, "channelid")
if err != nil {
return database.Channel{}, errors.New("invalid request: unable to parse server id")
}
channel, err := s.db.GetChannel(channelid)
if err != nil {
return database.Channel{}, errors.New("error: unable to locate server")
}
return channel, nil
}
func (s *Server) GetServerFromChannel(channelid database.Id) (database.Server, error) {
channel, err := s.db.GetChannel(channelid)
if err != nil {
return database.Server{}, err
}
server, err := s.db.GetServer(channel.ServerId)
if err != nil {
return database.Server{}, err
}
return server, nil
}
package server
import (
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strconv"
"time"
"go-chat-react/internal/database"
"go-chat-react/internal/websocket"
)
var startTime = time.Now()
type ServerResponseMessage struct {
Message_type string `json:"message_type"`
Payload any `json:"payload"`
}
type ServerMessage struct {
UserId database.Id `json:"userid"`
MessageID database.Id `json:"messageid"`
ChannelId database.Id `json:"channelid"`
ServerId database.Id `json:"serverid"`
Message string `json:"message"`
Date string `json:"date"`
}
type User struct {
UserID database.Id `json:"userid"`
UserName string `json:"username"`
}
type SubmittedMessage struct {
UserID string `json:"userid"`
ChannelId database.Id `json:"channelid"`
Token string `json:"token"`
Message string `json:"message"`
}
func (s *Server) RegisterRoutes(logserver bool) http.Handler {
mux := http.NewServeMux()
// Register routes
mux.HandleFunc("/", s.redirectToReact)
mux.HandleFunc("/websocket", s.WithAuthUser(s.websocketHandler))
mux.HandleFunc("POST /api/auth/login", s.loginHandler)
mux.HandleFunc("POST /api/auth/session", s.WithAuthUser(s.sessionHandler))
mux.HandleFunc("POST /api/auth/logout", s.WithAuthUser(s.LogoutHandler))
mux.HandleFunc("POST /api/users", s.createUserHandler)
mux.HandleFunc("GET /api/users/{userid}", s.GetUserHandler)
mux.HandleFunc("PATCH /api/users/{userid}", s.WithAuthUser(s.UpdateUser))
mux.HandleFunc("GET /api/users/{userid}/servers", s.WithAuthUser(s.GetServersOfUser))
mux.HandleFunc("POST /api/servers", s.WithAuthUser(s.createNewServer))
mux.HandleFunc("GET /api/servers/{serverid}", s.GetServerInformation)
mux.HandleFunc("PATCH /api/servers/{serverid}", s.WithAuthUser(s.UpdateServer))
mux.HandleFunc("DELETE /api/servers/{serverid}", s.WithAuthUser(s.DeleteServer))
mux.HandleFunc("GET /api/servers/{serverid}/channels", s.WithAuthUser(s.GetServerChannels))
mux.HandleFunc("POST /api/servers/{serverid}/channels", s.WithAuthUser(s.CreateChannel))
mux.HandleFunc("GET /api/servers/{serverid}/members", s.WithAuthUser(s.GetServerMembersHandler))
mux.HandleFunc("GET /api/servers/{serverid}/messages", s.WithAuthUser(s.GetServerMessages))
mux.HandleFunc("GET /api/channels/{channelid}", s.WithAuthUser(s.GetChannel))
mux.HandleFunc("PATCH /api/channels/{channelid}", s.WithAuthUser(s.UpdateChannel))
mux.HandleFunc("DELETE /api/channels/{channelid}", s.WithAuthUser(s.DeleteChannel))
mux.HandleFunc("POST /api/channels/{channelid}/members", s.WithAuthUser(s.AddChannelMember))
mux.HandleFunc("GET /api/channels/{channelid}/members", s.WithAuthUser(s.GetChannelMembers))
mux.HandleFunc(
"DELETE /api/channels/{channelid}/members",
s.WithAuthUser(s.RemoveChannelMember),
)
mux.HandleFunc("GET /api/channels/{channelid}/messages", s.GetChannelMessages)
mux.HandleFunc("POST /api/channels/{channelid}/messages", s.CreateChannelMessage)
mux.HandleFunc("GET /api/channels/{channelid}/messages/{messageid}", s.GetMessage)
mux.HandleFunc(
"PATCH /api/channels/{channelid}/messages/{messageid}",
s.WithAuthUser(s.UpdateMessage),
)
mux.HandleFunc(
"DELETE /api/channels/{channelid}/messages/{messageid}",
s.WithAuthUser(s.DeleteMessage),
)
handler := http.Handler(mux)
if logserver {
handler = s.logEndpoint(handler)
}
// Wrap the mux with CORS middleware
return s.corsMiddleware(handler)
}
// redirect so I only have to remember one port during development
func (s *Server) redirectToReact(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "http://localhost:5173", http.StatusTemporaryRedirect)
}
func (s *Server) UpdateServer(w http.ResponseWriter, r *http.Request) {
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
serverid, err := parsePathFromID(r, "serverid")
if err != nil {
http.Error(w, "invalid request: unable to parse server id", http.StatusBadRequest)
return
}
server_info, err := s.db.GetServer(serverid)
if err != nil {
http.Error(w, "error: unable to locate server", http.StatusBadRequest)
return
}
if server_info.OwnerId != userid {
http.Error(w, "error: user not owner of server", http.StatusBadRequest)
return
}
new_server_name := struct {
ServerName string `json:"servername"`
}{}
err = json.NewDecoder(r.Body).Decode(&new_server_name)
if err != nil {
http.Error(w, "error: unable to parse request", http.StatusBadRequest)
return
}
err = s.db.UpdateServerName(serverid, new_server_name.ServerName)
if err != nil {
http.Error(w, "error: unable to update server name", http.StatusBadRequest)
return
}
}
func (s *Server) DeleteServer(w http.ResponseWriter, r *http.Request) {
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
serverid, err := parsePathFromID(r, "serverid")
if err != nil {
http.Error(w, "invalid request: unable to parse server id", http.StatusBadRequest)
return
}
server_info, err := s.db.GetServer(serverid)
if err != nil {
http.Error(w, "error: unable to locate server", http.StatusBadRequest)
return
}
if server_info.OwnerId != userid {
http.Error(w, "error: user not owner of server", http.StatusBadRequest)
return
}
err = s.db.DeleteServer(serverid)
if err != nil {
http.Error(w, "error: unable to locate server", http.StatusBadRequest)
return
}
}
func (s *Server) UpdateChannel(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPatch {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
channelid, err := parsePathFromID(r, "channelid")
if err != nil {
http.Error(w, "invalid request: unable to parse server id", http.StatusBadRequest)
return
}
channel_info, err := s.db.GetChannel(channelid)
if err != nil {
http.Error(w, "error: unable to locate channel", http.StatusBadRequest)
return
}
server_info, err := s.db.GetServer(channel_info.ServerId)
if server_info.OwnerId != userid {
http.Error(w, "error: user not owner of channel", http.StatusBadRequest)
return
}
new_channel_name := struct {
UpdatedChannelName string `json:"channelname"`
}{}
err = json.NewDecoder(r.Body).Decode(&new_channel_name)
if err != nil {
http.Error(w, "error: unable to parse request", http.StatusBadRequest)
return
}
err = s.db.UpdateChannel(channelid, new_channel_name.UpdatedChannelName)
if err != nil {
http.Error(w, "error: unable to update channel", http.StatusBadRequest)
return
}
}
func (s *Server) GetChannelMembers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
channelid, err := parsePathFromID(r, "channelid")
if err != nil {
http.Error(w, "invalid request: unable to parse server id", http.StatusBadRequest)
return
}
inchannel, err := s.db.IsUserInChannel(userid, channelid)
if err != nil {
http.Error(w, fmt.Sprintf("error: %s", err), http.StatusBadRequest)
return
}
if !inchannel {
http.Error(w, "user not in channel", http.StatusBadRequest)
return
}
users, err := s.db.GetUsersInChannel(channelid)
if err != nil {
http.Error(w, "database error", http.StatusInternalServerError)
return
}
// convert users from database.User to server.User
newusers := make([]User, len(users))
for i, user := range users {
newusers[i] = User{
UserID: user.UserId,
UserName: user.UserName,
}
}
resp := map[string]any{"users": newusers}
jsonResp, err := json.Marshal(resp)
if err != nil {
http.Error(w, "Failed to marshal response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jsonResp); err != nil {
log.Printf("Failed to write response: %v", err)
}
}
func (s *Server) AddChannelMember(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
post_data := struct {
UserId string `json:"userid"`
}{}
err = json.NewDecoder(r.Body).Decode(&post_data)
if err != nil {
fmt.Printf("error: unable to parse request %s", err)
http.Error(w, fmt.Sprintf("error: unable to parse request %s", err), http.StatusBadRequest)
return
}
newuserid_str, err := strconv.Atoi(post_data.UserId)
if err != nil {
http.Error(w, "invalid request: unable to parse user id", http.StatusBadRequest)
return
}
if newuserid_str <= 0 {
http.Error(w, "invalid request: invalid user id", http.StatusBadRequest)
return
}
newuserid := database.Id(newuserid_str)
channelid, err := parsePathFromID(r, "channelid")
if err != nil {
http.Error(w, "invalid request: unable to parse server id", http.StatusBadRequest)
return
}
channel, err := s.db.GetChannel(channelid)
if err != nil {
http.Error(w, "error: unable to locate server", http.StatusBadRequest)
return
}
server_info, err := s.db.GetServer(channel.ServerId)
if err != nil {
http.Error(w, "error: unable to locate server", http.StatusBadRequest)
return
}
if server_info.OwnerId != userid {
http.Error(w, "error: user not owner of server", http.StatusBadRequest)
return
}
inserver, err := s.db.IsUserInServer(newuserid, server_info.ServerId)
if err != nil {
http.Error(w, fmt.Sprintf("error: %s", err), http.StatusBadRequest)
return
}
if !inserver {
http.Error(w, "user not in server", http.StatusBadRequest)
return
}
err = s.db.AddUserToChannel(newuserid, server_info.ServerId)
if err != nil {
http.Error(w, "error: unable to add user to channel", http.StatusBadRequest)
return
}
}
func (s *Server) RemoveChannelMember(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// get serverid for the channel and make sure the user is the owner
channelid, err := parsePathFromID(r, "channelid")
if err != nil {
http.Error(w, "invalid request: unable to parse server id", http.StatusBadRequest)
return
}
channel, err := s.db.GetChannel(channelid)
if err != nil {
http.Error(w, "error: unable to locate server", http.StatusBadRequest)
return
}
server_info, err := s.db.GetServer(channel.ServerId)
if err != nil {
http.Error(w, "error: unable to locate server", http.StatusBadRequest)
return
}
if server_info.OwnerId != userid {
http.Error(w, "error: user not owner of server", http.StatusBadRequest)
return
}
post_data := struct {
UserId string `json:"userid"`
}{}
err = json.NewDecoder(r.Body).Decode(&post_data)
if err != nil {
fmt.Printf("error: unable to parse request %s", err)
http.Error(w, fmt.Sprintf("error: unable to parse request %s", err), http.StatusBadRequest)
return
}
newuserid_str, err := strconv.Atoi(post_data.UserId)
if err != nil {
http.Error(w, "invalid request: unable to parse user id", http.StatusBadRequest)
return
}
if newuserid_str <= 0 {
http.Error(w, "invalid request: invalid user id", http.StatusBadRequest)
return
}
newuserid := database.Id(newuserid_str)
err = s.db.RemoveUserFromChannel(database.Id(channelid), newuserid)
if errors.Is(err, database.ErrRecordNotFound) {
http.Error(w, "user not in channel", http.StatusBadRequest)
return
} else if err != nil {
http.Error(w, "error: unable to remove user from channel", http.StatusBadRequest)
return
}
}
func (s *Server) UpdateMessage(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPatch {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
messageid, err := parsePathFromID(r, "messageid")
if err != nil {
http.Error(w, "error: unable to parse messageid", http.StatusBadRequest)
return
}
message, err := s.db.GetMessage(messageid)
if err != nil {
http.Error(w, "error: unable to fetch message", http.StatusBadRequest)
return
}
if message.UserId != userid {
http.Error(w, "error: attempting to modify different user message", http.StatusBadRequest)
return
}
in_channel, err := s.db.IsUserInChannel(userid, message.ChannelId)
if err != nil {
http.Error(w, "error: unable to verify channel permissions", http.StatusBadRequest)
return
}
if !in_channel {
http.Error(w, "error: user not in channel", http.StatusBadRequest)
return
}
message_data := struct {
Message string `json:"message"`
}{}
err = json.NewDecoder(r.Body).Decode(&message_data)
if err != nil {
http.Error(w, "error: unable to parse message from body", http.StatusBadRequest)
return
}
err = s.db.UpdateMessage(message.MessageId, message_data.Message)
if err != nil {
http.Error(w, "error: issue while updating message", http.StatusBadRequest)
return
}
}
func (s *Server) DeleteMessage(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
messageid, err := parsePathFromID(r, "messageid")
if err != nil {
http.Error(w, "error: unable to parse messageid", http.StatusBadRequest)
return
}
message, err := s.db.GetMessage(messageid)
if err != nil {
http.Error(w, "error: unable to fetch message", http.StatusBadRequest)
return
}
if message.UserId != userid {
http.Error(w, "error: attempting to modify different user message", http.StatusBadRequest)
return
}
err = s.db.DeleteMessage(message.MessageId)
if err != nil {
http.Error(w, "error: issue while deleting message", http.StatusBadRequest)
return
}
}
func (s *Server) UpdateUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPatch {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
user, err := s.db.GetUser(userid)
if err != nil {
http.Error(w, "error: unable to fetch user", http.StatusBadRequest)
return
}
err = s.db.UpdateUserName(user.UserId, user.UserName)
if err != nil {
http.Error(w, "error: unable to update username", http.StatusBadRequest)
return
}
}
func (s *Server) DeleteChannel(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
channelid, err := parsePathFromID(r, "channelid")
if err != nil {
http.Error(w, "error: unable to parse channelid", http.StatusBadRequest)
return
}
channel, err := s.db.GetChannel(channelid)
if err != nil {
http.Error(w, "error: unable to fetch channel", http.StatusBadRequest)
return
}
server, err := s.db.GetServer(channel.ServerId)
if err != nil {
http.Error(w, "error: unable to fetch server", http.StatusBadRequest)
return
}
if server.OwnerId != userid {
http.Error(w, "error: user not owner of server", http.StatusBadRequest)
return
}
err = s.db.DeleteChannel(channel.ChannelId)
if errors.Is(err, database.ErrRecordNotFound) {
http.Error(w, "error: unable to locate channel", http.StatusNotFound)
return
}
if err != nil {
http.Error(w, "error: unable to delete channel", http.StatusBadRequest)
return
}
}
func (s *Server) CreateChannel(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
_, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
channel_data := struct {
ChannelName string `json:"channelname"`
}{}
err = json.NewDecoder(r.Body).Decode(&channel_data)
if err != nil {
http.Error(w, "error: unable to parse request", http.StatusBadRequest)
return
}
serverid, err := parsePathFromID(r, "serverid")
if err != nil {
http.Error(w, "error: unable to parse request", http.StatusBadRequest)
return
}
channelid, err := s.db.AddChannel(serverid, channel_data.ChannelName)
if err != nil {
http.Error(w, "error: unable to create channel", http.StatusBadRequest)
return
}
resp := map[string]any{
"channelid": channelid,
}
jsonResp, err := json.Marshal(resp)
if err != nil {
http.Error(w, "Failed to marshal response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jsonResp); err != nil {
log.Printf("Failed to write response: %v", err)
}
}
func (s *Server) LogoutHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
err = s.db.DeleteUserSessionToken(userid)
if err != nil {
http.Error(w, "error: unable to delete session token", http.StatusBadRequest)
return
}
cookie := &http.Cookie{
Name: "token", // Replace with your actual cookie name
Value: "",
Path: "/", // Ensure this matches the cookie's original path
HttpOnly: true,
Secure: false, // Set to true if your site uses HTTPS
SameSite: http.SameSiteStrictMode,
Expires: time.Now().Add(-time.Hour), // Set the expiration time to the past
MaxAge: -1, // Set MaxAge to 0 or a negative value to delete the cookie immediately
}
// Set the expired cookie in the response header.
http.SetCookie(w, cookie)
return
}
func (s *Server) CreateChannelMessage(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
channelid, err := parsePathFromID(r, "channelid")
if err != nil {
http.Error(w, "invalid request: unable to parse server id", http.StatusBadRequest)
return
}
inchannel, err := s.db.IsUserInChannel(userid, channelid)
if err != nil {
http.Error(w, fmt.Sprintf("error: %s", err), http.StatusBadRequest)
return
}
if !inchannel {
http.Error(w, "user not in channel", http.StatusBadRequest)
return
}
message_data := struct {
Message string `json:"message"`
}{}
err = json.NewDecoder(r.Body).Decode(&message_data)
if err != nil {
http.Error(w, "error: unable to parse request", http.StatusBadRequest)
return
}
messageid, err := s.db.AddMessage(userid, channelid, message_data.Message)
if err != nil {
http.Error(w, "error: unable to create message", http.StatusBadRequest)
return
}
resp := map[string]any{
"messageid": messageid,
}
jsonResp, err := json.Marshal(resp)
if err != nil {
http.Error(w, "Failed to marshal response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jsonResp); err != nil {
log.Printf("Failed to write response: %v", err)
}
}
func (s *Server) GetChannelMessages(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
channelid, err := parsePathFromID(r, "channelid")
if err != nil {
http.Error(w, "invalid request: unable to parse server id", http.StatusBadRequest)
return
}
inchannel, err := s.db.IsUserInChannel(userid, channelid)
if err != nil {
http.Error(w, fmt.Sprintf("error: %s", err), http.StatusBadRequest)
return
}
if !inchannel {
http.Error(w, "user not in channel", http.StatusBadRequest)
return
}
count_str := r.URL.Query().Get("count")
var count uint = 30
if count_str != "" {
tempcount, err := strconv.Atoi(count_str)
if err != nil {
http.Error(w, "invalid request: unable to parse count", http.StatusBadRequest)
return
}
if tempcount > 0 {
count = uint(tempcount)
} else {
http.Error(w, "invalid request: invalid count", http.StatusBadRequest)
return
}
}
messages, err := s.db.GetMessagesInChannel(channelid, count)
if err != nil {
http.Error(w, "database error", http.StatusInternalServerError)
return
}
resp := map[string]any{"messages": messages}
jsonResp, err := json.Marshal(resp)
if err != nil {
http.Error(w, "Failed to marshal response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jsonResp); err != nil {
log.Printf("Failed to write response: %v", err)
}
}
func (s *Server) GetChannel(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
_, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
channelid, err := parsePathFromID(r, "channelid")
if err != nil {
http.Error(w, "invalid request: unable to parse server id", http.StatusBadRequest)
return
}
channel_info, err := s.db.GetChannel(channelid)
if errors.Is(err, database.ErrRecordNotFound) {
http.Error(w, "error: unable to locate channel", http.StatusNotFound)
return
}
if err != nil {
http.Error(w, "error: unable to locate channel", http.StatusBadRequest)
return
}
payload := struct {
ChannelId database.Id `json:"channelid"`
ServerId database.Id `json:"serverid"`
ChannelName string `json:"channelname"`
Timestamp time.Time `json:"timestamp"`
}{
ChannelId: channel_info.ChannelId,
ServerId: channel_info.ServerId,
ChannelName: channel_info.ChannelName,
Timestamp: channel_info.Timestamp,
}
jsonResp, err := json.Marshal(payload)
if err != nil {
http.Error(w, "Failed to marshal response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jsonResp); err != nil {
log.Printf("Failed to write response: %v", err)
}
}
func (s *Server) GetMessage(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusBadRequest)
return
}
messageid, err := parsePathFromID(r, "messageid")
if err != nil {
http.Error(w, "invalid request: unable to parse server id", http.StatusBadRequest)
return
}
dbmessage, err := s.db.GetMessage(messageid)
if errors.Is(err, database.ErrRecordNotFound) {
http.Error(w, "error: unable to locate message", http.StatusNotFound)
return
}
if err != nil {
http.Error(w, "error: internal server error", http.StatusBadRequest)
return
}
message := fromDBMessageToSeverMessage(dbmessage)
jsonResp, err := json.Marshal(message)
if err != nil {
http.Error(
w,
"error: internal server error. Unable to process request",
http.StatusBadRequest,
)
}
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jsonResp); err != nil {
log.Printf("Failed to write response: %v", err)
}
}
func (s *Server) sessionHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, "Unable to get user from context", http.StatusUnauthorized)
return
}
user, err := s.db.GetUser(userid)
if err != nil {
http.Error(w, "Unable to find user", http.StatusInternalServerError)
return
}
resp := map[string]any{
"userid": userid,
"username": user.UserName,
}
jsonResp, err := json.Marshal(resp)
if err != nil {
http.Error(w, "Failed to marshal response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jsonResp); err != nil {
log.Printf("Failed to write response: %v", err)
}
}
func (s *Server) loginHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse form data (default for HTML form submission)
err := r.ParseForm()
if err != nil {
http.Error(w, "Unable to parse form", http.StatusBadRequest)
return
}
loginData := struct {
Username string `json:"username"`
Password string `json:"password"`
}{}
err = json.NewDecoder(r.Body).Decode(&loginData)
username := loginData.Username
password := loginData.Password
userid, err := s.db.GetUserIDFromUserName(username)
if err != nil {
http.Error(w, "unable to locate username", http.StatusBadRequest)
return
}
valid, err := s.db.ValidateUserLoginInfo(userid, password)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if !valid {
http.Error(w, "invalid password", http.StatusBadRequest)
return
}
token, _, err := s.db.UpdateUserSessionToken(userid)
if err != nil {
http.Error(w, "unable to update session token", http.StatusBadRequest)
return
}
resp := map[string]any{
"userid": userid,
}
// Set a cookie (you can modify the cookie as needed)
http.SetCookie(w, &http.Cookie{
Name: "token",
Value: token,
Path: "/",
Secure: false,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
// Redirect the user to /chat
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(resp)
if err != nil {
http.Error(
w,
"internal error: unable to send encoded response",
http.StatusInternalServerError,
)
return
}
}
func (s *Server) AddUserToServer(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// do a not implemented warning
http.Error(w, "not implemented", http.StatusNotImplemented)
}
func (s *Server) createNewServer(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
err = r.ParseForm()
if err != nil {
http.Error(w, "Unable to parse form", http.StatusBadRequest)
return
}
newServerData := struct {
ServerName string `json:"servername"`
}{}
err = json.NewDecoder(r.Body).Decode(&newServerData)
if err != nil {
http.Error(w, "unable to parse request", http.StatusBadRequest)
return
}
if len(newServerData.ServerName) > 30 {
http.Error(w, "server name too long", http.StatusBadRequest)
return
}
if len(newServerData.ServerName) < 3 {
http.Error(w, "server name too short", http.StatusBadRequest)
return
}
serverid, err := s.db.CreateServer(userid, newServerData.ServerName)
if err != nil {
http.Error(w, "unable to create server", http.StatusBadRequest)
return
}
// TODO: add proper user name here
err = s.db.AddUserToServer(userid, serverid, "")
if err != nil {
http.Error(w, "unable to add user to server", http.StatusBadRequest)
return
}
resp := map[string]any{
"serverid": serverid,
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(resp)
if err != nil {
http.Error(
w,
"internal error: unable to send encoded response",
http.StatusInternalServerError,
)
return
}
return
}
func (s *Server) createUserHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse form data (default for HTML form submission)
err := r.ParseForm()
if err != nil {
http.Error(w, "Unable to parse form", http.StatusBadRequest)
return
}
loginData := struct {
Username string `json:"username"`
Password string `json:"password"`
}{}
err = json.NewDecoder(r.Body).Decode(&loginData)
username := loginData.Username
password := loginData.Password
userid, err := s.db.CreateUser(username, password)
if errors.Is(err, database.ErrRecordAlreadyExists) {
http.Error(w, "user already exists", http.StatusBadRequest)
return
}
if err != nil {
http.Error(w, "unable to create user", http.StatusBadRequest)
return
}
valid_user, err := s.db.ValidateUserLoginInfo(userid, password)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
if !valid_user {
http.Error(w, "invalid password", http.StatusBadRequest)
return
}
token, _, err := s.db.UpdateUserSessionToken(userid)
if err != nil {
http.Error(w, "unable to update session token", http.StatusBadRequest)
return
}
resp := map[string]any{
"userid": userid,
}
// Set a cookie (you can modify the cookie as needed)
http.SetCookie(w, &http.Cookie{
Name: "token",
Value: token,
Path: "/",
Secure: false,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
// Redirect the user to /chat
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(resp)
if err != nil {
http.Error(
w,
"internal error: unable to send encoded response",
http.StatusInternalServerError,
)
return
}
}
func (s *Server) GetServerChannels(w http.ResponseWriter, r *http.Request) {
serverid, err := parsePathFromID(r, "serverid")
if err != nil {
http.Error(w, "invalid request: unable to parse server id", http.StatusBadRequest)
return
}
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
userinfo, err := s.db.GetUser(userid)
isServerMember, err := s.db.IsUserInServer(userinfo.UserId, serverid)
if err != nil {
http.Error(w, "database error", http.StatusInternalServerError)
return
}
if !isServerMember {
http.Error(w, "user not member of server", http.StatusNetworkAuthenticationRequired)
return
}
channels, err := s.db.GetChannelsOfServer(serverid)
if err != nil {
http.Error(w, "database error", http.StatusInternalServerError)
return
}
resp := map[string]any{"channels": channels}
jsonResp, err := json.Marshal(resp)
if err != nil {
http.Error(w, "Failed to marshal response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jsonResp); err != nil {
log.Printf("Failed to write response: %v", err)
}
}
func (s *Server) GetServersOfUser(w http.ResponseWriter, r *http.Request) {
userid, err := parsePathFromID(r, "userid")
if err != nil {
http.Error(w, "invalid request: unable to parse server id", http.StatusBadRequest)
return
}
autheduserid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if userid != autheduserid {
http.Error(w, "invalid request: invalid user id", http.StatusBadRequest)
return
}
servers, err := s.db.GetServersOfUser(userid)
if err != nil {
http.Error(w, "database error", http.StatusInternalServerError)
return
}
resp := map[string]any{"servers": servers}
jsonResp, err := json.Marshal(resp)
if err != nil {
http.Error(w, "Failed to marshal response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jsonResp); err != nil {
log.Printf("Failed to write response: %v", err)
}
}
func (s *Server) GetServerInformation(w http.ResponseWriter, r *http.Request) {
serverid, err := parsePathFromID(r, "serverid")
if err != nil {
http.Error(w, "invalid request: unable to parse server id", http.StatusBadRequest)
return
}
server, err := s.db.GetServer(serverid)
if errors.Is(err, database.ErrRecordNotFound) {
// TODO: figure out proper method for valid resopnse but no data
http.Error(w, "server not found", http.StatusNotFound)
return
}
if err != nil {
http.Error(w, "invalid request: unable to find server", http.StatusNotFound)
return
}
jsonstruct := struct {
ServerId database.Id `json:"serverid"`
OwnerId database.Id `json:"ownerid"`
ServerName string `json:"servername"`
}{
ServerId: server.ServerId,
OwnerId: server.OwnerId,
ServerName: server.ServerName,
}
jsonResp, err := json.Marshal(jsonstruct)
if err != nil {
http.Error(w, "Failed to marshal response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jsonResp); err != nil {
log.Printf("Failed to write response: %v", err)
}
}
func (s *Server) GetServerMessages(w http.ResponseWriter, r *http.Request) {
serverid, err := parsePathFromID(r, "serverid")
if err != nil {
http.Error(w, "invalid request: unable to parse server id", http.StatusBadRequest)
return
}
count_str := r.URL.Query().Get("count")
var count uint = 30
if count_str != "" {
tempcount, err := strconv.Atoi(count_str)
if err != nil {
http.Error(w, "invalid request: unable to parse count", http.StatusBadRequest)
return
}
if tempcount > 0 {
count = uint(tempcount)
} else {
http.Error(w, "invalid request: invalid count", http.StatusBadRequest)
return
}
}
userid, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
isServerMember, err := s.db.IsUserInServer(userid, serverid)
if err != nil {
http.Error(w, "database error", http.StatusInternalServerError)
return
}
if !isServerMember {
http.Error(w, "user not member of server", http.StatusNetworkAuthenticationRequired)
return
}
channels, err := s.db.GetChannelsOfServer(serverid)
if err != nil {
http.Error(w, "database error", http.StatusInternalServerError)
return
}
var messages []ServerMessage
for _, channel := range channels {
db_messages, err := s.db.GetMessagesInChannel(channel.ChannelId, count)
if err != nil {
http.Error(w, "database error", http.StatusInternalServerError)
return
}
tempmsgs := make([]ServerMessage, len(db_messages))
for i, dbmsg := range db_messages {
tempmsgs[i] = ServerMessage{
UserId: dbmsg.UserId,
MessageID: dbmsg.MessageId,
ServerId: serverid,
ChannelId: dbmsg.ChannelId,
Message: dbmsg.Contents,
Date: dbmsg.Timestamp.Format(time.UnixDate),
}
}
messages = append(messages, tempmsgs...)
}
resp := map[string]any{"serverid": serverid, "messages": messages}
jsonResp, err := json.Marshal(resp)
if err != nil {
http.Error(w, "Failed to marshal response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jsonResp); err != nil {
log.Printf("Failed to write response: %v", err)
}
}
func (s *Server) GetServerMembersHandler(w http.ResponseWriter, r *http.Request) {
serverid, err := parsePathFromID(r, "serverid")
if err != nil {
http.Error(w, "invalid request: unable to parse server id", http.StatusBadRequest)
return
}
users, err := s.db.GetUsersOfServer(serverid)
if err != nil {
http.Error(w, "database error", http.StatusInternalServerError)
return
}
resp := map[string]any{"users": users, "serverid": serverid}
jsonResp, err := json.Marshal(resp)
if err != nil {
http.Error(w, "Failed to marshal response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jsonResp); err != nil {
log.Printf("Failed to write response: %v", err)
}
}
func (s *Server) GetUserHandler(w http.ResponseWriter, r *http.Request) {
userid, err := parsePathFromID(r, "userid")
if err != nil {
http.Error(w, "invalid request: unable to parse user id", http.StatusBadRequest)
return
}
user, err := s.db.GetUser(userid)
if errors.Is(err, database.ErrRecordNotFound) {
http.Error(w, "user not found", http.StatusNotFound)
return
}
if err != nil {
http.Error(w, "database error", http.StatusInternalServerError)
return
}
resp := map[string]string{"username": user.UserName}
jsonResp, err := json.Marshal(resp)
if err != nil {
http.Error(w, "Failed to marshal response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jsonResp); err != nil {
log.Printf("Failed to write response: %v", err)
}
}
type rawChannelMessage struct {
channel_id database.Id
message string
}
func (s *Server) websocketHandler(w http.ResponseWriter, r *http.Request) {
passinfo, err := getUserIdFromContext(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
userinfo, err := s.db.GetUser(passinfo)
if err != nil {
http.Error(w, "error fetching user", http.StatusInternalServerError)
return
}
servers, err := s.db.GetServersOfUser(userinfo.UserId)
if err != nil {
http.Error(w, "error fetching channels of user", http.StatusInternalServerError)
return
}
conn, err := websocket.NewCoderWebSocketConnection(w, r)
if err != nil {
log.Printf("error creating websocket connection: %v\n", err)
http.Error(w, "error creating websocket connection", http.StatusInternalServerError)
return
}
id, incoming := s.ws_manager.NewConnection(conn)
for _, channel := range servers {
if _, ok := s.sessions_in_channel[channel.ServerId]; !ok {
s.sessions_in_channel[channel.ServerId] = make(map[string]bool)
}
s.sessions_in_channel[channel.ServerId][id] = true
}
defer func() {
s.ws_manager.CloseConnection(id)
for _, channel := range servers {
if _, ok := s.sessions_in_channel[channel.ServerId]; !ok {
continue
}
// TODO: Add some kind of mutex lock
delete(s.sessions_in_channel[channel.ServerId], id)
if len(s.sessions_in_channel[channel.ServerId]) == 0 {
delete(s.sessions_in_channel, channel.ServerId)
}
}
}()
fmt.Printf("starting websocket loop: %d ms\n",
time.Since(startTime).Milliseconds(),
)
for {
select {
case msg, ok := <-incoming:
if !ok {
log.Printf("websocketHandler: incoming channel closed for user %d", userinfo.UserId)
return
}
serverid, byte_data, err := s.ProcessMessage(userinfo.UserId, msg)
if err != nil {
log.Printf(
"websocketHandler: error processing message for user %d: %v",
userinfo.UserId,
err,
)
continue
}
for k := range s.sessions_in_channel[serverid] {
s.ws_manager.SendToClient(k, byte_data)
}
}
}
}
func (s *Server) ProcessMessage(
userid database.Id,
msg websocket.IncomingMessage,
) (database.Id, []byte, error) {
// todo add message parsing
data := ServerResponseMessage{}
err := json.Unmarshal(msg.Payload, &data)
if err != nil {
fmt.Printf("error getting message from websocket: %e\n", err)
return 0, nil, err
}
if data.Message_type != "channel_message" {
fmt.Printf("websocketHandler: invalid message type %s\n\n", data.Message_type)
return 0, nil, err
}
paymap, ok := data.Payload.(map[string]any)
if !ok {
fmt.Printf("websocketHandler: invalid payload type %T\n", data.Payload)
return 0, nil, err
}
channelidstr, ok := paymap["channel_id"]
if !ok {
fmt.Printf("websocketHandler: invalid payload %s\n", data.Payload)
return 0, nil, err
}
channelidfloat, ok := channelidstr.(float64)
if !ok {
fmt.Printf("websocketHandler: invalid payload %s\n", data.Payload)
return 0, nil, err
}
var channelid database.Id
channelid = database.Id(channelidfloat)
payload := rawChannelMessage{
channel_id: database.Id(channelid),
message: paymap["message"].(string),
}
if payload.channel_id <= 0 {
fmt.Printf(
"websocketHandler: invalid channel id channe_id=%d\n",
payload.channel_id,
)
return 0, nil, err
}
if len(payload.message) > 1000 {
fmt.Printf(
"format error: length of message to large length=%d\n",
len(payload.message),
)
return 0, nil, err
}
messageid, err := s.db.AddMessage(payload.channel_id, userid, payload.message)
if err != nil {
fmt.Printf("error saving message: %e\n", err)
return 0, nil, err
}
dbmsg, err := s.db.GetMessage(messageid)
if err != nil {
fmt.Printf("error saving message: %e\n", err)
return 0, nil, err
}
smsg := ServerMessage{
UserId: dbmsg.UserId,
MessageID: messageid,
ServerId: dbmsg.ServerId,
ChannelId: dbmsg.ChannelId,
Message: dbmsg.Contents,
Date: dbmsg.Timestamp.Format(time.UnixDate),
}
server_msg := ServerResponseMessage{Message_type: "message", Payload: smsg}
byte_data, err := json.Marshal(server_msg)
if err != nil {
fmt.Printf("error marshalling message: %e\n", err)
return 0, nil, err
}
log.Printf("websocketHandler: sending message to user %d", userid)
return dbmsg.ServerId, byte_data, nil
}
package server
import (
"time"
"go-chat-react/internal/database"
)
func fromDBMessageToSeverMessage(message database.Message) ServerMessage {
return ServerMessage{
UserId: message.UserId,
ChannelId: message.ChannelId,
MessageID: message.MessageId,
Message: message.Contents,
Date: message.Timestamp.Format(time.UnixDate),
}
}
package server
import (
"context"
"database/sql"
"fmt"
"log"
"net/http"
"os"
"time"
_ "github.com/joho/godotenv/autoload"
"go-chat-react/internal/database"
"go-chat-react/internal/websocket"
)
type UserService interface {
GetUserIDFromUserName(username string) (database.Id, error)
UpdateUserSessionToken(userid database.Id) (string, time.Time, error)
DeleteUserSessionToken(userid database.Id) error
GetUserLoginInfoFromToken(token string) (database.UserLoginInfo, error)
GetUserLoginInfo(userid database.Id) (database.UserLoginInfo, error)
ValidateUserLoginInfo(userid database.Id, password string) (bool, error)
GetUser(userid database.Id) (database.User, error)
CreateUser(username string, password string) (database.Id, error)
UpdateUserName(userid database.Id, username string) error
GetRecentUsernames(userid database.Id, number uint) ([]database.UsernameLogEntry, error)
}
type ServerService interface {
GetUsersOfServer(serverid database.Id) ([]database.User, error)
GetServersOfUser(userid database.Id) ([]database.Server, error)
GetServer(serverid database.Id) (database.Server, error)
CreateServer(ownerid database.Id, servername string) (database.Id, error)
DeleteServer(serverid database.Id) error
UpdateServerName(serverid database.Id, servername string) error
IsUserInServer(userid database.Id, serverid database.Id) (bool, error)
AddUserToServer(serverid database.Id, userid database.Id, nickname string) error
}
type ChannelService interface {
AddChannel(serverid database.Id, channelname string) (database.Id, error)
DeleteChannel(channelid database.Id) error
GetChannel(channelid database.Id) (database.Channel, error)
GetChannelsOfServer(serverid database.Id) ([]database.Channel, error)
UpdateChannel(channelid database.Id, username string) error
AddUserToChannel(channelid database.Id, userid database.Id) error
RemoveUserFromChannel(channelid database.Id, userid database.Id) error
GetUsersInChannel(channelid database.Id) ([]database.User, error)
IsUserInChannel(userid database.Id, channelid database.Id) (bool, error)
}
type MessageService interface {
GetMessage(messageid database.Id) (database.Message, error)
GetMessagesInChannel(channelid database.Id, number uint) ([]database.Message, error)
AddMessage(channelid database.Id, userid database.Id, message string) (database.Id, error)
UpdateMessage(messageid database.Id, message string) error
DeleteMessage(messageid database.Id) error
}
type LifecycleService interface {
Close() error
}
type (
AtomicService interface {
Service() Service
Commit() error
Rollback() error
}
Service interface {
UserService
ServerService
ChannelService
MessageService
LifecycleService
}
)
var (
dburl = os.Getenv("BLUEPRINT_DB_URL")
dbInstance *database.DBService
)
func executeSQLFile(db *sql.DB, filename string) error {
data, err := os.ReadFile(filename)
if err != nil {
return fmt.Errorf("failed to read file: %w", err)
}
_, err = db.Exec(string(data))
if err != nil {
return fmt.Errorf("failed to execute SQL: %w", err)
}
return nil
}
func NewInMemoryDB() *database.DBService {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
log.Fatal(err)
}
err = executeSQLFile(db, "../../schema.sql")
if err != nil {
log.Fatal(err)
}
err = executeSQLFile(db, "../../mockdata.sql")
if err != nil {
log.Fatal(err)
}
return database.New(db)
}
func (s *Server) Atomic(
ctx context.Context,
opts *sql.TxOptions,
) (*database.AtomitcDBService, error) {
a, err := s.Atomic(ctx, opts)
if err != nil {
return nil, err
}
return a, nil
}
func NewDB() *database.DBService {
// Reuse Connection
if dbInstance != nil {
return dbInstance
}
db, err := sql.Open("sqlite3", dburl)
if err != nil {
// This will not be a connection error, but a DSN parse error or
// another initialization error.
log.Fatal(err)
}
return database.New(db)
}
type Server struct {
port int
sessions_in_channel map[database.Id]map[string]bool
ws_manager *websocket.WebSocketManager
db Service
}
func NewServer(logserver bool, port int) *http.Server {
fmt.Printf("opening on port %d", port)
db := NewDB()
NewServer := &Server{
port: port,
sessions_in_channel: make(map[database.Id]map[string]bool),
ws_manager: websocket.NewWebSocketManager(),
db: db,
}
atomicdb, err := db.Atomic(context.Background(), nil)
if err != nil {
log.Fatal(err)
}
atomicdb.Rollback()
// Declare Server config
server := &http.Server{
Addr: fmt.Sprintf(":%d", NewServer.port),
Handler: NewServer.RegisterRoutes(logserver),
IdleTimeout: time.Minute,
ReadTimeout: 10 * time.Second,
WriteTimeout: 30 * time.Second,
}
return server
}
package websocket
import (
"context"
"errors"
"net/http"
"github.com/coder/websocket"
)
type CoderWebSocketConnection struct {
conn *websocket.Conn
}
func NewCoderWebSocketConnection(
w http.ResponseWriter,
r *http.Request,
) (*CoderWebSocketConnection, error) {
opts := websocket.AcceptOptions{InsecureSkipVerify: true}
conn, err := websocket.Accept(w, r, &opts)
if err != nil {
return nil, errors.New("failed to open websocket connection: " + err.Error())
}
return &CoderWebSocketConnection{conn: conn}, nil
}
func (w *CoderWebSocketConnection) Close(code StatusCode, message string) error {
return w.conn.Close(websocket.StatusCode(code), message)
}
func (w *CoderWebSocketConnection) Read(ctx context.Context) (MessageType, []byte, error) {
msgType, data, err := w.conn.Read(ctx)
return MessageType(msgType), data, err
}
func (w *CoderWebSocketConnection) Write(
ctx context.Context,
msgType MessageType,
data []byte,
) error {
return w.conn.Write(ctx, websocket.MessageType(msgType), data)
}
package websocket
import (
"context"
"errors"
"log"
"sync"
"github.com/google/uuid"
)
type IncomingMessage struct {
Payload []byte
}
type (
WebSocketConnection interface {
Close(StatusCode, string) error
Read(context.Context) (MessageType, []byte, error)
Write(context.Context, MessageType, []byte) error
}
webSocketClient struct {
ID string
conn WebSocketConnection
receive chan IncomingMessage
send chan []byte
cancel context.CancelFunc
closed bool
}
)
func newWebSocketClient(
Id string,
conn WebSocketConnection,
incoming chan IncomingMessage,
) *webSocketClient {
ctx, cancel := context.WithCancel(context.Background())
send := make(chan []byte)
client := webSocketClient{
ID: Id,
conn: conn,
cancel: cancel,
send: send,
receive: incoming,
closed: false,
}
go client.read(ctx)
go client.write(ctx)
return &client
}
func (c *webSocketClient) close(status StatusCode) error {
// TODO: find correct ordering of close operations. how to handle closing channel vs connetion
if c.closed {
return errors.New("client already closed")
}
c.cancel()
close(c.receive)
log.Printf("Client %s closed with status %d", c.ID, status)
err := c.conn.Close(status, "") // TODO: determine proper status
if err != nil {
return err
}
return nil
}
func (c *webSocketClient) read(ctx context.Context) {
for {
messageType, message, err := c.conn.Read(ctx)
if err != nil {
if errors.Is(err, context.Canceled) {
log.Printf("Client %s read cancelled", c.ID)
// c.close(StatusNormalClosure)
} else {
log.Printf("Client %s read error: %v", c.ID, err)
c.close(StatusAbnormalClosure)
}
return
}
if messageType == MessageBinary || messageType == MessageText {
log.Printf("Received from client %s (%d bytes): %s", c.ID, len(message), message)
msg := IncomingMessage{
Payload: message,
}
c.receive <- msg
}
}
}
func (c *webSocketClient) write(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case msg, ok := <-c.send:
if !ok {
return
}
err := c.conn.Write(ctx, MessageText, msg)
if err != nil {
log.Printf("Client %s write error: %v", c.ID, err)
c.close(StatusAbnormalClosure)
return
}
}
}
}
type WebSocketManager struct {
clients map[string]*webSocketClient
mutex sync.RWMutex
}
func NewWebSocketManager() *WebSocketManager {
return &WebSocketManager{
clients: make(map[string]*webSocketClient),
mutex: sync.RWMutex{},
}
}
func (m *WebSocketManager) NewConnection(
conn WebSocketConnection,
) (string, chan IncomingMessage) {
Id := uuid.New().String()
incoming := make(chan IncomingMessage, 100) // Buffered channel for incoming messages
client := newWebSocketClient(Id, conn, incoming)
m.mutex.Lock()
defer m.mutex.Unlock()
m.clients[client.ID] = client
log.Printf("Client %s registered. Total clients: %d", client.ID, len(m.clients))
return Id, incoming
}
func (m *WebSocketManager) CloseConnection(id string) {
// remove later, redundant with Deregistre
m.mutex.Lock()
defer m.mutex.Unlock()
client, ok := m.clients[id]
if !ok {
return
}
delete(m.clients, id)
close(client.send)
log.Printf("Client %s unregistered. Total clients: %d", client.ID, len(m.clients))
}
func (m *WebSocketManager) SendToClient(Id string, message []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
client, ok := m.clients[Id]
if !ok {
log.Printf("Client %s not found.", Id)
return false
}
select {
case client.send <- message:
return true
default:
log.Printf("Client %s send channel full, dropping message.", Id)
return false
}
}