add validation middleware and improve REST API structure

This commit is contained in:
StarAppeal
2025-09-06 03:47:44 +02:00
parent b3381e04e3
commit 3a939c2b36
24 changed files with 4044 additions and 264 deletions
+2651 -58
View File
File diff suppressed because it is too large Load Diff
+8 -2
View File
@@ -7,7 +7,9 @@
"start-local": "tsc && cross-env NODE_ENV=development node dist/index.js",
"clean": "rimraf dist",
"build": "npm run clean & tsc",
"test": "echo \"Error: no test specified, please add them later\""
"test": "vitest run",
"test:watch": "vitest",
"test:coverage": "vitest run --coverage"
},
"keywords": [],
"author": "",
@@ -24,6 +26,7 @@
"cors": "^2.8.5",
"dotenv": "^16.4.4",
"express": "5.0.0",
"express-rate-limit": "^8.1.0",
"jsonwebtoken": "^9.0.2",
"mongoose": "^8.8.2",
"openweather-api-node": "^3.1.5",
@@ -34,7 +37,10 @@
},
"devDependencies": {
"@types/cors": "^2.8.17",
"@types/supertest": "^6.0.3",
"cross-env": "^7.0.3",
"prettier": "^3.2.5"
"prettier": "^3.2.5",
"supertest": "^7.1.4",
"vitest": "^2.1.9"
}
}
+47
View File
@@ -0,0 +1,47 @@
type NodeEnv = "development" | "test" | "production";
function required(name: string, value: string | undefined): string {
if (!value || value.trim() === "") {
throw new Error(`Missing required env var: ${name}`);
}
return value;
}
function optionalNumber(name: string, value: string | undefined, fallback: number): number {
if (value === undefined) return fallback;
const n = Number(value);
if (!Number.isFinite(n) || n <= 0) {
throw new Error(`Env var ${name} must be a positive number`);
}
return n;
}
function optionalString(name: string, value: string | undefined, fallback: string): string {
return value ?? fallback;
}
function isValidUrl(u: string): boolean {
try {
new URL(u);
return true;
} catch {
return false;
}
}
const NODE_ENV = (optionalString("NODE_ENV", process.env.NODE_ENV, "development") as NodeEnv);
const PORT = optionalNumber("PORT", process.env.PORT, 3000);
const FRONTEND_URL = required("FRONTEND_URL", process.env.FRONTEND_URL);
if (!isValidUrl(FRONTEND_URL)) {
throw new Error("FRONTEND_URL must be a valid URL");
}
export const config = {
env: NODE_ENV,
port: PORT,
cors: {
origin: FRONTEND_URL,
credentials: true,
},
};
+90 -16
View File
@@ -1,9 +1,9 @@
import "dotenv/config";
import mongoose, {Schema} from "mongoose";
import {ObjectId} from "mongodb";
import bcrypt from "bcrypt";
import {PasswordUtils} from "../../utils/passwordUtils";
export interface IUser {
id: ObjectId;
name: string;
password?: string;
uuid: string;
@@ -52,24 +52,40 @@ export interface SpotifyConfig {
const matrixStateSchema = new Schema<MatrixState>({
global: {
mode: {type: String, enum: ['image', 'text', 'idle', 'music', 'clock']},
brightness: {type: Number},
mode: {type: String, enum: ['image', 'text', 'idle', 'music', 'clock'], default: 'idle'},
brightness: {type: Number, min: 0, max: 100, default: 50},
},
text: {
text: {type: String},
align: {type: String, enum: ['left', 'center', 'right']},
speed: {type: Number},
size: {type: Number},
color: {type: [Number]},
text: {type: String, default: ""},
align: {type: String, enum: ['left', 'center', 'right'], default: 'center'},
speed: {type: Number, min: 0, max: 10, default: 3},
size: {type: Number, min: 1, max: 64, default: 12},
color: {
type: [Number],
validate: {
validator: (v: number[]) =>
Array.isArray(v) && v.length === 3 && v.every(n => Number.isInteger(n) && n >= 0 && n <= 255),
message: "color must be an array of three integers between 0 and 255",
},
default: [255, 255, 255],
},
},
image: {
image: {type: String},
image: {type: String, default: ""},
},
clock: {
color: {type: [Number]},
color: {
type: [Number],
validate: {
validator: (v: number[]) =>
Array.isArray(v) && v.length === 3 && v.every(n => Number.isInteger(n) && n >= 0 && n <= 255),
message: "color must be an array of three integers between 0 and 255",
},
default: [255, 255, 255],
},
},
music: {
fullscreen: {type: Boolean},
fullscreen: {type: Boolean, default: false},
},
}, {_id: false});
@@ -87,14 +103,72 @@ const userConfigSchema = new Schema<UserConfig>({
}, {_id: false});
const userSchema = new Schema<IUser>({
name: {type: String, required: true},
password: {type: String, required: true},
uuid: {type: String, required: true},
name: {type: String, required: true, index: true},
password: {type: String, required: true, select: false},
uuid: {type: String, required: true, unique: true, index: true},
config: {type: userConfigSchema, required: true},
lastState: {type: matrixStateSchema},
spotifyConfig: {type: spotifyConfigSchema},
timezone: {type: String, required: true},
location: {type: String, required: true},
}, {optimisticConcurrency: true});
}, {
optimisticConcurrency: true,
timestamps: true,
toJSON: {
transform(_doc, ret) {
delete ret.password;
return ret;
},
},
toObject: {
transform(_doc, ret) {
delete ret.password;
return ret;
},
},
});
userSchema.virtual("id").get(function (this: any) {
return this._id?.toHexString?.() ?? this._id;
});
function isBcryptHash(value: unknown): boolean {
return typeof value === "string" && /^\$2[aby]\$\d{2}\$[./A-Za-z0-9]{53}$/.test(value);
}
async function hashIfNeeded(next: Function, user: any) {
if (!user.isModified?.("password")) return next();
if (isBcryptHash(user.password)) return next();
try {
user.password = await PasswordUtils.hashPassword(user.password)
return next();
} catch (e) {
return next(e);
}
}
userSchema.pre("save", function (next) {
// @ts-ignore
return hashIfNeeded(next, this);
});
userSchema.pre("findOneAndUpdate", async function (next) {
const update = this.getUpdate() as any;
if (!update) return next();
const newPassword = update.password ?? update.$set?.password;
if (!newPassword) return next();
if (isBcryptHash(newPassword)) return next();
try {
const saltRounds = 10;
const hashed = await bcrypt.hash(newPassword, saltRounds);
if (update.password) update.password = hashed;
if (update.$set?.password) update.$set.password = hashed;
return next();
} catch (e: Error | any) {
return next(e);
}
});
export const UserModel = mongoose.model<IUser>('User', userSchema);
+32 -5
View File
@@ -1,5 +1,6 @@
import {IUser, SpotifyConfig, UserModel} from "../../models/user";
import {connectToDatabase} from "./database.service";
import {UpdateQuery} from "mongoose";
export class UserService {
private static _instance: UserService;
@@ -23,20 +24,29 @@ export class UserService {
}
public async updateUser(user: IUser): Promise<IUser | null> {
const {id, ...rest} = user;
return this.updateUserById(id.toString(), rest);
const anyUser = user as any;
const targetId: string | undefined = anyUser?.id?.toString?.() ?? anyUser?._id?.toString?.();
if (!targetId) {
throw new Error("updateUser requires user.id or user._id");
}
const { id, _id, ...rest } = anyUser;
return this.updateUserById(targetId, rest as Partial<IUser>);
}
public async getAllUsers(): Promise<IUser[]> {
return await UserModel.find({}, {password: 0, spotifyConfig: 0, lastState: 0}).exec();
return await UserModel.find({}, {spotifyConfig: 0, lastState: 0}).exec();
}
public async getUserById(id: string): Promise<IUser | null> {
return await UserModel.findById(id, {password: 0}).exec();
return await UserModel.findById(id).exec();
}
public async getUserByUUID(uuid: string): Promise<IUser | null> {
return await UserModel.findOne({uuid}, {password: 0}).exec();
return await UserModel.findOne({uuid}).exec();
}
public async getUserByName(name: string): Promise<IUser | null> {
@@ -45,6 +55,14 @@ export class UserService {
.exec();
}
public async getUserAuthByName(name: string): Promise<IUser | null> {
return await UserModel.findOne({name})
.collation({locale: "en", strength: 2})
.select("+password")
.exec();
}
public async getSpotifyConfigByUUID(uuid: string): Promise<SpotifyConfig | undefined> {
return await UserModel.findOne({uuid}, {spotifyConfig: 1}).exec().then(user => user?.spotifyConfig);
}
@@ -60,4 +78,13 @@ export class UserService {
return !!(await UserModel.findOne({name}).exec());
}
public async clearSpotifyConfigByUUID(uuid: string): Promise<IUser | null> {
return await UserModel.findOneAndUpdate(
{ uuid },
{ $unset: { spotifyConfig: 1 } } as UpdateQuery<IUser>,
{ new: true, projection: { password: 0 } }
).exec();
}
}
+43 -9
View File
@@ -7,18 +7,37 @@ import {JwtTokenPropertiesExtractor} from "./rest/jwtTokenPropertiesExtractor";
import cors from "cors";
import {SpotifyTokenGenerator} from "./rest/spotifyTokenGenerator";
import {RestAuth} from "./rest/auth";
import { config } from "./config";
import {authLimiter, spotifyLimiter} from "./rest/middleware/rateLimit";
const app = express();
const port = process.env.PORT || 3000;
const port = config.port;
app.set("trust proxy", 1);
app.use(cors({
origin: config.cors.origin,
credentials: config.cors.credentials,
}));
app.use((_req, res, next) => {
res.set({
"X-DNS-Prefetch-Control": "off",
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"Referrer-Policy": "no-referrer",
"Permissions-Policy": "geolocation=()",
});
next();
});
app.use(express.json({limit: "2mb"}));
app.get("/healthz", (_req, res) => res.status(200).send({status: "ok"}));
const server = app.listen(port, () => {
console.log(`Server is running on port ${port}`);
});
app.use(cors({
origin: process.env.FRONTEND_URL,
}));
app.use(express.json({limit: "15mb"}));
const webSocketServer = new ExtendedWebSocketServer(server);
const restWebSocket = new RestWebSocket(webSocketServer);
@@ -27,6 +46,10 @@ const auth = new RestAuth();
const jwtTokenPropertiesExtractor = new JwtTokenPropertiesExtractor();
const spotify = new SpotifyTokenGenerator();
app.use("/api/auth", authLimiter, auth.createRouter());
app.use("/api/spotify", authenticateJwt, spotifyLimiter, spotify.createRouter());
app.use("/api/websocket", authenticateJwt, restWebSocket.createRouter());
app.use("/api/user", authenticateJwt, restUser.createRouter());
app.use(
@@ -34,6 +57,17 @@ app.use(
authenticateJwt,
jwtTokenPropertiesExtractor.createRouter(),
);
app.use("/api/spotify", authenticateJwt, spotify.createRouter());
app.use("/api/auth", auth.createRouter());
app.use((err: any, _req: express.Request, res: express.Response, _next: express.NextFunction) => {
console.error(err);
res
.status(err?.status || 500)
.send({ ok: false, data: {}, error: err?.message || "Internal Server Error" });
});
process.on("SIGTERM", () => {
server.close(() => {
console.log("HTTP server closed");
process.exit(0);
});
});
+68 -57
View File
@@ -1,79 +1,90 @@
import express from "express";
import {UserService} from "../db/services/db/UserService";
import {IUser} from "../db/models/user";
import {ObjectId} from "mongodb";
import {JwtAuthenticator} from "../utils/jwtAuthenticator";
import crypto from "crypto"
import crypto from "crypto";
import {PasswordUtils} from "../utils/passwordUtils";
import { asyncHandler } from "./middleware/asyncHandler";
import { validateBody, v } from "./middleware/validate";
import {ok, badRequest, unauthorized, created, conflict, notFound} from "./utils/responses";
export class RestAuth {
public createRouter() {
const router = express.Router();
router.post("/register", async (req, res) => {
const username = req.body.username;
const timezone = req.body.timezone;
const location = req.body.location;
const password = req.body.password;
const userService = await UserService.create();
router.post(
"/register",
validateBody({
username: { required: true, validator: v.isString({ nonEmpty: true, min: 3 }) },
password: { required: true, validator: v.isString({ nonEmpty: true, min: 8 }) },
timezone: { required: true, validator: v.isString({ nonEmpty: true }) },
location: { required: true, validator: v.isString({ nonEmpty: true }) },
}),
asyncHandler(async (req, res) => {
const { username, password, timezone, location } = req.body as {
username: string; password: string; timezone: string; location: string;
};
const userService = await UserService.create();
if (await userService.existsUserByName(username)) {
res.status(409).send({message: "Username already exists"});
return;
}
if (await userService.existsUserByName(username)) {
return conflict(res, "Username already exists");
}
const passwordValidation = PasswordUtils.validatePassword(password);
const passwordValidation = PasswordUtils.validatePassword(password);
if (!passwordValidation.valid) {
return badRequest(res, passwordValidation.message ?? "Invalid password");
}
if (!passwordValidation.valid) {
res.status(400).send({success: false, message: passwordValidation.message});
return;
}
const hashedPassword = await PasswordUtils.hashPassword(password);
const newUser: IUser = {
name: username,
password: hashedPassword,
uuid: crypto.randomUUID(),
config: {
isVisible: false,
isAdmin: false,
canBeModified: false
},
timezone,
location
};
const hashedPassword = await PasswordUtils.hashPassword(password);
const newUser: IUser = {
id: ObjectId.createFromTime(Date.now()),
name: username,
password: hashedPassword,
uuid: crypto.randomUUID(),
config: {
isVisible: false,
isAdmin: false,
canBeModified: false
},
timezone,
location
};
const result = await userService.createUser(newUser);
res.status(201).send({success: true, user: result});
});
const result = await userService.createUser(newUser);
return created(res, {user: result });
})
);
router.post("/login", async (req, res) => {
const username = req.body.username;
const password = req.body.password;
const userService = await UserService.create();
const user = await userService.getUserByName(username);
router.post(
"/login",
validateBody({
username: { required: true, validator: v.isString({ nonEmpty: true }) },
password: { required: true, validator: v.isString({ nonEmpty: true }) },
}),
asyncHandler(async (req, res) => {
const { username, password } = req.body as { username: string; password: string };
const userService = await UserService.create();
const user = await userService.getUserAuthByName(username);
if (!user) {
res.status(404).send({success: false, message: "User not found", id: "username"});
return;
}
if (!user) {
return notFound(res, "User not found");
}
const isValid = await PasswordUtils.comparePassword(password, user.password!);
const isValid = await PasswordUtils.comparePassword(password, user.password!);
if (!isValid) {
return unauthorized(res, "Invalid password");
}
if (!isValid) {
res.status(401).send({success: false, message: "Invalid password", id: "password"});
return;
}
// generate JWT token here
const jwtToken = new JwtAuthenticator(
process.env.SECRET_KEY!,
).generateToken({username: user.name, id: user.id.toString(), uuid: user.uuid});
res.status(200).send({success: true, token: jwtToken});
});
const jwtToken = new JwtAuthenticator(process.env.SECRET_KEY!)
.generateToken({
username: user.name,
id: (user as any).id?.toString?.() ?? (user as any)._id?.toString?.(),
uuid: user.uuid
});
return ok(res, { token: jwtToken });
})
);
return router;
}
}
}
+13 -10
View File
@@ -1,22 +1,25 @@
import express from "express";
import type { Request, Response } from "express";
import { asyncHandler } from "./middleware/asyncHandler";
import { ok } from "./utils/responses";
export class JwtTokenPropertiesExtractor {
public createRouter() {
const router = express.Router();
router.get("/id", (req, res) => {
res.status(200).send(req.payload.id);
});
router.get("/id", asyncHandler(async (req: Request, res: Response) => {
return ok(res, req.payload.id);
}));
router.get("/username", (req, res) => {
res.status(200).send(req.payload.username);
});
router.get("/uuid", (req, res) => {
res.status(200).send(req.payload.uuid);
});
router.get("/username", asyncHandler(async (req: Request, res: Response) => {
return ok(res, req.payload.username);
}));
router.get("/uuid", asyncHandler(async (req: Request, res: Response) => {
return ok(res, req.payload.uuid);
}));
return router;
}
}
}
+9
View File
@@ -0,0 +1,9 @@
import type { Request, Response, NextFunction, RequestHandler } from "express";
export function asyncHandler(
fn: (req: Request, res: Response, next: NextFunction) => Promise<any>
): RequestHandler {
return (req, res, next) => {
Promise.resolve(fn(req, res, next)).catch(next);
};
}
+23
View File
@@ -0,0 +1,23 @@
import rateLimit from "express-rate-limit";
import type { Request, Response } from "express";
import { tooManyRequests } from "../utils/responses";
const onLimitReached = (_req: Request, res: Response) => {
return tooManyRequests(res);
};
export const authLimiter = rateLimit({
windowMs: 60_000,
limit: 30,
standardHeaders: true,
legacyHeaders: false,
handler: onLimitReached,
});
export const spotifyLimiter = rateLimit({
windowMs: 60_000,
limit: 60,
standardHeaders: true,
legacyHeaders: false,
handler: onLimitReached,
});
+155
View File
@@ -0,0 +1,155 @@
import type { Request, Response, NextFunction } from "express";
import {badRequest} from "../utils/responses";
/**
* A type definition for a validation function.
* The `Validator` function type accepts a single argument `value`
* of any type and returns either `true` if the validation is successful
* or a `string` containing an error message if the validation fails.
*/
type Validator = (value: any) => true | string;
/**
* A collection of validation functions for validating various data types.
*
* @property {function(opts?: { nonEmpty?: boolean; max?: number; min?: number }): Validator} isString
* Validates whether a value is a string. Additional options can be provided:
* - `nonEmpty`: Ensures the string is not empty when set to `true`.
* - `max`: Specifies the maximum allowed string length.
* - `min`: Specifies the minimum required string length.
*
* @property {function(opts?: { min?: number; max?: number; integer?: boolean }): Validator} isNumber
* Validates whether a value is a number. Additional options can be provided:
* - `min`: Specifies the minimum allowed value.
* - `max`: Specifies the maximum allowed value.
* - `integer`: Ensures the value is an integer when set to `true`.
*
* @property {function(): Validator} isBoolean
* Validates whether a value is a boolean.
*
* @property {function<T extends readonly string[]>(values: T): Validator} isEnum
* Validates whether a value matches one of the provided allowed values in the enum.
* - `values`: An array of valid string options to check against.
*
* @property {function(len: number): Validator} isArrayLength
* Validates whether an array has the exact specified length.
* - `len`: The required array length.
*
* @property {function(): Validator} isUrl
* Validates whether a value is a valid URL. The value must be a string and conform to standard URL formatting rules.
*/
export const v = {
isString: (opts?: { nonEmpty?: boolean; max?: number; min?: number }): Validator => {
return (value: any) => {
if (typeof value !== "string") return "must be a string";
if (opts?.nonEmpty && value.trim().length === 0) return "must be a non-empty string";
if (opts?.max !== undefined && value.length > opts.max) return `must be at most ${opts.max} chars`;
if (opts?.min !== undefined && value.length < opts.min) return `must be at least ${opts.min} chars`;
return true;
};
},
isNumber: (opts?: { min?: number; max?: number; integer?: boolean }): Validator => {
return (value: any) => {
if (typeof value !== "number" || Number.isNaN(value)) return "must be a number";
if (opts?.integer && !Number.isInteger(value)) return "must be an integer";
if (opts?.min !== undefined && value < opts.min) return `must be >= ${opts.min}`;
if (opts?.max !== undefined && value > opts.max) return `must be <= ${opts.max}`;
return true;
};
},
isBoolean: (): Validator => {
return (value: any) => (typeof value === "boolean" ? true : "must be a boolean");
},
isEnum: <T extends readonly string[]>(values: T): Validator => {
return (value: any) => (values.includes(value) ? true : `must be one of: ${values.join(", ")}`);
},
isArrayLength: (len: number): Validator => {
return (value: any) => (Array.isArray(value) && value.length === len ? true : `must be an array of length ${len}`);
},
isUrl: (): Validator => {
return (value: any) => {
if (typeof value !== "string") return "must be a string URL";
try {
new URL(value);
return true;
} catch {
return "must be a valid URL";
}
};
},
};
/**
* Represents a schema definition for validating objects.
*
* Each key in the schema corresponds to a property name in the target object,
* with its value defining the validation rules for that property.
*
* @typedef {Object} Schema
* @property {boolean} [required] - Specifies if the property is mandatory in the target object.
* @property {Validator} validator - A function or object that validates the value of the property.
*/
type Schema = Record<string, { required?: boolean; validator: Validator }>;
/**
* Validates a given source object against a specified schema and returns an array of error messages.
*
* @param {any} source - The object to be validated.
* @param {Schema} schema - The schema containing validation rules for each property.
* @return {string[]} An array of error messages. If there are no validation errors, the array will be empty.
*/
function validate(source: any, schema: Schema): string[] {
const errors: string[] = [];
for (const [key, rule] of Object.entries(schema)) {
const value = source?.[key];
if (value === undefined || value === null) {
if (rule.required) errors.push(`${key} is required`);
continue;
}
const res = rule.validator(value);
if (res !== true) errors.push(`${key} ${res}`);
}
return errors;
}
/**
* Middleware to validate the request body against a specified schema.
*
* @param {Schema} schema The schema against which the request body will be validated.
* @return {Function} Express middleware function that validates the request body and invokes the next middleware if valid; otherwise, it responds with a 400 status and validation errors.
*/
export function validateBody(schema: Schema) {
return (req: Request, res: Response, next: NextFunction) => {
const errs = validate(req.body, schema);
if (errs.length) return badRequest(res, "Validation failed", errs);
next();
};
}
/**
* Middleware to validate the request parameters against the provided schema.
*
* @param {Schema} schema - The validation schema to check the request parameters.
* @return {(req: Request, res: Response, next: NextFunction) => void} A middleware function that validates the request parameters.
*/
export function validateParams(schema: Schema) {
return (req: Request, res: Response, next: NextFunction) => {
const errs = validate(req.params, schema);
if (errs.length) return res.status(400).send({ error: "Validation failed", details: errs });
next();
};
}
/**
* Middleware function to validate the query parameters of a request against a predefined schema.
*
* @param {Schema} schema - The validation schema used to validate the query parameters.
* @return {Function} Middleware function that validates the query parameters and either sends a 400 response with validation errors or proceeds to the next middleware.
*/
export function validateQuery(schema: Schema) {
return (req: Request, res: Response, next: NextFunction) => {
const errs = validate(req.query, schema);
if (errs.length) return res.status(400).send({ error: "Validation failed", details: errs });
next();
};
}
+91 -52
View File
@@ -1,77 +1,116 @@
import express from "express";
import {UserService} from "../db/services/db/UserService";
import {PasswordUtils} from "../utils/passwordUtils";
import {asyncHandler} from "./middleware/asyncHandler";
import {v, validateBody, validateParams} from "./middleware/validate";
import {badRequest, ok} from "./utils/responses";
export class RestUser {
public createRouter() {
const router = express.Router();
router.get("/", async (req, res) => {
router.get("/", asyncHandler(async (_req, res) => {
const userService = await UserService.create();
const users = await userService.getAllUsers();
res.status(200).send({users});
});
return ok(res, { users });
}));
router.get("/me", async (req, res) => {
router.get("/me", asyncHandler(async (req, res) => {
const userService = await UserService.create();
const user = await userService.getUserByUUID(req.payload.uuid);
res.status(200).send(user);
});
return ok(res, user);
}));
router.put("/me/spotify", async (req, res) => {
router.put(
"/me/spotify",
validateBody({
accessToken: { required: true, validator: v.isString({ nonEmpty: true }) },
refreshToken: { required: true, validator: v.isString({ nonEmpty: true }) },
scope: { required: true, validator: v.isString({ nonEmpty: true }) },
expirationDate: { required: true, validator: v.isString({ nonEmpty: true }) },
}),
asyncHandler(async (req, res) => {
const userService = await UserService.create();
const user = await userService.getUserByUUID(req.payload.uuid);
if (!user) {
return badRequest(res, "User not found");
}
const { accessToken, refreshToken, scope, expirationDate } = req.body as {
accessToken: string; refreshToken: string; scope: string; expirationDate: string;
};
user.spotifyConfig = {
accessToken,
refreshToken,
scope,
expirationDate: new Date(expirationDate),
};
await userService.updateUser(user);
return ok(res, { message: "Spotify Config erfolgreich geändert" });
})
);
router.delete("/me/spotify", asyncHandler(async (req, res) => {
const userService = await UserService.create();
const user = await userService.getUserByUUID(req.payload.uuid);
user!.spotifyConfig = req.body;
userService.updateUser(user!)
.then(() => {
res.status(200).send({result: {success: true, message: "Spotify Config erfolgreich geändert"}});
});
});
router.put("/me/password", async (req, res) => {
const userService = await UserService.create();
const user = await userService.getUserByUUID(req.payload.uuid);
const password = req.body.password;
const passwordConfirmation = req.body.passwordConfirmation;
if (password !== passwordConfirmation) {
res.status(400).send({
result: {
success: false,
message: "Passwörter stimmen nicht überein"
}
});
return;
if (!user) {
return badRequest(res, "User not found");
}
const passwordValidation = PasswordUtils.validatePassword(password);
const updated = await userService.clearSpotifyConfigByUUID(req.payload.uuid);
return ok(res, { user: updated });
}));
if (!passwordValidation.valid) {
res.status(400).send({result: passwordValidation});
return;
}
router.put(
"/me/password",
validateBody({
password: { required: true, validator: v.isString({ nonEmpty: true, min: 8 }) },
passwordConfirmation: { required: true, validator: v.isString({ nonEmpty: true, min: 8 }) },
}),
asyncHandler(async (req, res) => {
const userService = await UserService.create();
const user = await userService.getUserByUUID(req.payload.uuid);
if (!user) {
return badRequest(res, "User not found");
}
PasswordUtils.hashPassword(password).then(hashedPassword => {
user!.password = hashedPassword;
userService.updateUser(user!)
.then(() => {
res.status(200).send({result: {success: true, message: "Passwort erfolgreich geändert"}});
});
});
});
const { password, passwordConfirmation } = req.body as { password: string; passwordConfirmation: string };
router.get("/:id", async (req, res) => {
const userService = await UserService.create();
const id = req.params.id;
const user = await userService.getUserById(id);
if (password !== passwordConfirmation) {
return badRequest(res, "Passwörter stimmen nicht überein");
}
user
? res.status(200).send(user)
: res
.status(404)
.send(`Unable to find matching document with id: ${req.params.id}`);
});
const passwordValidation = PasswordUtils.validatePassword(password);
if (!passwordValidation.valid) {
return badRequest(res, passwordValidation.message ?? "Ungültiges Passwort");
}
user.password = await PasswordUtils.hashPassword(password);
await userService.updateUser(user);
return ok(res, { message: "Passwort erfolgreich geändert" });
})
);
router.get(
"/:id",
validateParams({
id: { required: true, validator: v.isString({ nonEmpty: true }) },
}),
asyncHandler(async (req, res) => {
const userService = await UserService.create();
const id = req.params.id;
const user = await userService.getUserById(id);
if (!user) {
return badRequest(res, `Unable to find matching document with id: ${id}`);
}
return ok(res, user);
})
);
return router;
}
}
}
+53 -33
View File
@@ -1,44 +1,64 @@
import express, { Request, Response, Router } from "express";
import express, { Router, Request, Response } from "express";
import { ExtendedWebSocketServer } from "../websocket";
import { DecodedToken } from "../interfaces/decodedToken";
import { asyncHandler } from "./middleware/asyncHandler";
import { validateBody, v } from "./middleware/validate";
import { ok } from "./utils/responses";
import {ExtendedWebSocket} from "../interfaces/extendedWebsocket";
export class RestWebSocket {
constructor(private webSocketServer: ExtendedWebSocketServer) {}
constructor(private webSocketServer: ExtendedWebSocketServer) {}
public createRouter(): Router {
const router = express.Router();
public createRouter(): Router {
const router = express.Router();
router.post("/broadcast", (req: Request, res: Response) => {
const payload: string = JSON.stringify(req.body.payload);
router.post(
"/broadcast",
validateBody({
payload: {
required: true,
// allow any json
validator: (_: unknown) => true,
},
}),
asyncHandler(async (req: Request, res: Response) => {
const payload: string = JSON.stringify(req.body.payload);
this.webSocketServer.broadcast(payload);
return ok(res, { status: "OK" });
})
);
this.webSocketServer.broadcast(payload);
router.post(
"/send-message",
validateBody({
payload: {
required: true,
validator: (_: unknown) => true,
},
users: {
required: true,
validator: (value: any) =>
Array.isArray(value) && value.length > 0 && value.every((s) => typeof s === "string" && s.trim().length > 0)
? true
: "must be a non-empty array of strings",
},
}),
asyncHandler(async (req: Request, res: Response) => {
const payload = JSON.stringify(req.body.payload);
const users: Array<string> = req.body.users;
res.status(200).send("OK");
});
users.forEach((user) => this.webSocketServer.sendMessageToUser(user, payload));
router.post("/send-message", (req, res) => {
const payload = JSON.stringify(req.body.payload);
const users: Array<string> = req.body.users;
return ok(res, { status: "OK" });
})
);
users.forEach((user) =>
this.webSocketServer.sendMessageToUser(user, payload),
);
router.get("/all-clients", asyncHandler(async (_req: Request, res: Response) => {
const connectedClients = this.webSocketServer.getConnectedClients();
const result = Array.from(connectedClients).map((client: ExtendedWebSocket) => client.payload);
return ok(res, { result });
}));
res.status(200).send("OK");
});
router.get("/all-clients", (req, res) => {
const connectedClients = this.webSocketServer.getConnectedClients();
const result: Array<DecodedToken> = [];
connectedClients.forEach((client) => result.push(client.payload));
console.log("Connected clients:", result);
res.status(200).send({ result });
});
return router;
}
}
return router;
}
}
+36 -21
View File
@@ -1,35 +1,50 @@
import express from "express";
import {SpotifyTokenService} from "../db/services/spotifyTokenService";
import {UserService} from "../db/services/db/UserService";
import { asyncHandler } from "./middleware/asyncHandler";
import { validateBody, v } from "./middleware/validate";
import { ok, internalError } from "./utils/responses";
export class SpotifyTokenGenerator {
private tokenService = new SpotifyTokenService();
public createRouter() {
const router = express.Router();
router.get("/token/refresh/:refresh_token", async (req, res) => {
const refreshToken = req.params.refresh_token;
router.post(
"/token/refresh",
validateBody({
refreshToken: { required: true, validator: v.isString({ nonEmpty: true }) },
}),
asyncHandler(async (req, res) => {
const { refreshToken } = req.body as { refreshToken: string };
const token = await new SpotifyTokenService().refreshToken(refreshToken);
const token = await this.tokenService.refreshToken(refreshToken);
res.status(200).send({token});
});
router.get(
"/token/generate/code/:auth_code/redirect-uri/:redirect_uri",
async (req, res) => {
const authCode = req.params.auth_code;
const redirectUri = req.params.redirect_uri;
const token = await new SpotifyTokenService().generateToken(
authCode,
redirectUri,
);
res.status(200).send({token});
},
return ok(res, { token });
})
);
router.post(
"/token/generate",
validateBody({
authCode: { required: true, validator: v.isString({ nonEmpty: true }) },
redirectUri: { required: true, validator: v.isUrl() },
}),
asyncHandler(async (req, res) => {
const { authCode, redirectUri } = req.body as { authCode: string; redirectUri: string };
const token = await this.tokenService.generateToken(authCode, redirectUri);
return ok(res, { token });
})
);
router.use((err: any, _req: any, res: any, _next: any) => {
return internalError(res, "Failed to handle spotify token request");
});
return router;
}
}
}
+54
View File
@@ -0,0 +1,54 @@
import type { Response } from "express";
type ErrorDetails = unknown;
function respondError(
res: Response,
status: number,
message: string,
details?: ErrorDetails
) {
return res.status(status).send({
ok: false,
data: {
message,
details,
},
});
}
export function ok<T>(res: Response, data: T) {
return res.status(200).send({ ok: true, data });
}
export function created<T>(res: Response, data: T) {
return res.status(201).send({ ok: true, data });
}
export function badRequest(res: Response, message = "Bad Request", details?: ErrorDetails) {
return respondError(res, 400, message, details);
}
export function unauthorized(res: Response, message = "Unauthorized", details?: ErrorDetails) {
return respondError(res, 401, message, details);
}
export function forbidden(res: Response, message = "Forbidden", details?: ErrorDetails) {
return respondError(res, 403, message, details);
}
export function notFound(res: Response, message = "Not Found", details?: ErrorDetails) {
return respondError(res, 404, message, details);
}
export function conflict(res: Response, message = "Conflict", details?: ErrorDetails) {
return respondError(res, 409, message, details);
}
export function tooManyRequests(res: Response, message = "Too Many Requests", details?: ErrorDetails) {
return respondError(res, 429, message, details);
}
export function internalError(res: Response, message = "Internal Server Error", details?: ErrorDetails) {
return respondError(res, 500, message, details);
}
-1
View File
@@ -15,7 +15,6 @@ export class PasswordUtils {
}
public static async comparePassword(password: string, hashedPassword: string): Promise<boolean> {
const bcrypt = await import('bcrypt');
return bcrypt.compare(password, hashedPassword);
}
+42
View File
@@ -0,0 +1,42 @@
import { describe, it, expect } from "vitest";
import request from "supertest";
import express from "express";
import { authLimiter, spotifyLimiter } from "../../../src/rest/middleware/rateLimit";
function createTestApp() {
const app = express();
app.set("trust proxy", 1);
app.get("/auth-test", authLimiter, (_req, res) => res.status(200).send({ ok: true }));
app.get("/spotify-test", spotifyLimiter, (_req, res) => res.status(200).send({ ok: true }));
return app;
}
async function hit(app: express.Express, path: string, times: number) {
for (let i = 0; i < times; i++) {
await request(app).get(path);
}
}
describe("RateLimit", () => {
it("limits /auth-test after 30 Requests, returns http 429", async () => {
const app = createTestApp();
// 30 are allowed
await hit(app, "/auth-test", 30);
// afterwards, any request returns 429
const res = await request(app).get("/auth-test");
expect(res.status).toBe(429);
expect(res.headers["ratelimit-policy"]).toBeTruthy();
});
it("limits /spotify-test after 60 requests, returns http 429", async () => {
const app = createTestApp();
await hit(app, "/spotify-test", 60);
const res = await request(app).get("/spotify-test");
expect(res.status).toBe(429);
});
});
+206
View File
@@ -0,0 +1,206 @@
import { describe, it, expect, vi } from "vitest";
import { v, validateBody, validateParams, validateQuery } from "../../../src/rest/middleware/validate";
describe("v.isString", () => {
it("accepts a simple string", () => {
const res = v.isString()("abc");
expect(res).toBe(true);
});
it("rejects non strings", () => {
const res = v.isString()(123 as any);
expect(res).toBe("must be a string");
});
it("forces nonEmpty", () => {
const res = v.isString({ nonEmpty: true })(" ");
expect(res).toBe("must be a non-empty string");
});
it("checks min/max length", () => {
expect(v.isString({ min: 3 })("ab")).toBe("must be at least 3 chars");
expect(v.isString({ max: 3 })("abcd")).toBe("must be at most 3 chars");
expect(v.isString({ min: 2, max: 4 })("abc")).toBe(true);
});
});
describe("v.isNumber", () => {
it("accepts numbers, rejects NaN", () => {
expect(v.isNumber()(10)).toBe(true);
expect(v.isNumber()(Number.NaN)).toBe("must be a number");
expect(v.isNumber()("10")).toBe("must be a number");
});
it("checks integer min/max", () => {
expect(v.isNumber({ integer: true })(1.2)).toBe("must be an integer");
expect(v.isNumber({ min: 5 })(4)).toBe("must be >= 5");
expect(v.isNumber({ max: 5 })(6)).toBe("must be <= 5");
expect(v.isNumber({ integer: true, min: 0, max: 10 })(10)).toBe(true);
});
});
describe("v.isBoolean", () => {
it("accepts true/false", () => {
expect(v.isBoolean()(true)).toBe(true);
expect(v.isBoolean()(false)).toBe(true);
});
it("rejects other types", () => {
expect(v.isBoolean()("true" )).toBe("must be a boolean");
expect(v.isBoolean()(0)).toBe("must be a boolean");
});
});
describe("v.isEnum", () => {
it("accepts only correct values", () => {
const validator = v.isEnum(["A", "B", "C"] as const);
expect(validator("A")).toBe(true);
expect(validator("D")).toBe("must be one of: A, B, C");
});
});
describe("v.isArrayLength", () => {
it("checks the exact length", () => {
expect(v.isArrayLength(2)([1, 2])).toBe(true);
expect(v.isArrayLength(2)([1])).toBe("must be an array of length 2");
expect(v.isArrayLength(2)("not array")).toBe("must be an array of length 2");
});
});
describe("v.isUrl", () => {
it("accepts valid urls", () => {
expect(v.isUrl()("https://example.com")).toBe(true);
});
it("rejects invalid URLs and nom-strings", () => {
expect(v.isUrl()("notaurl")).toBe("must be a valid URL");
expect(v.isUrl()(123 as any)).toBe("must be a string URL");
});
});
describe("validateBody Middleware", () => {
const schema = {
name: { required: true, validator: v.isString({ nonEmpty: true }) },
age: { required: false, validator: v.isNumber({ min: 0, integer: true }) },
};
function makeRes() {
const res: any = {};
res.status = vi.fn().mockReturnValue(res);
res.send = vi.fn().mockReturnValue(res);
return res;
}
it("calls next(), when valid", async () => {
const req: any = { body: { name: "Alice", age: 30 } };
const res = makeRes();
const next = vi.fn();
const mw = validateBody(schema);
mw(req, res, next);
expect(next).toHaveBeenCalledOnce();
expect(res.status).not.toHaveBeenCalled();
expect(res.send).not.toHaveBeenCalled();
});
it("sends 400 with missing/incorrect fields", async () => {
const req: any = { body: { name: " ", age: -1 } };
const res = makeRes();
const next = vi.fn();
const mw = validateBody(schema);
mw(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(400);
expect(res.send).toHaveBeenCalledWith(
expect.objectContaining({
error: "Validation failed",
details: expect.arrayContaining([
"name must be a non-empty string",
"age must be >= 0",
]),
})
);
});
it("correctly checks missing fields", async () => {
const req: any = { body: { /* no name */ } };
const res = makeRes();
const next = vi.fn();
const mw = validateBody(schema);
mw(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(400);
expect(res.send).toHaveBeenCalledWith(
expect.objectContaining({
error: "Validation failed",
details: expect.arrayContaining(["name is required"]),
})
);
});
});
describe("validateParams and validateQuery middleware", () => {
const paramSchema = {
id: { required: true, validator: v.isString({ nonEmpty: true }) },
};
const querySchema = {
limit: { required: false, validator: v.isNumber({ integer: true, min: 1, max: 100 }) },
};
function makeRes() {
const res: any = {};
res.status = vi.fn().mockReturnValue(res);
res.send = vi.fn().mockReturnValue(res);
return res;
}
it("validateParams: ok -> next()", () => {
const req: any = { params: { id: "abc" } };
const res = makeRes();
const next = vi.fn();
validateParams(paramSchema)(req, res, next);
expect(next).toHaveBeenCalledOnce();
});
it("validateParams: error -> 400", () => {
const req: any = { params: { id: " " } };
const res = makeRes();
const next = vi.fn();
validateParams(paramSchema)(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(400);
expect(res.send).toHaveBeenCalled();
});
it("validateQuery: ok without limit", () => {
const req: any = { query: { } };
const res = makeRes();
const next = vi.fn();
validateQuery(querySchema)(req, res, next);
expect(next).toHaveBeenCalledOnce();
});
it("validateQuery: error on limit outside range", () => {
const req: any = { query: { limit: 101 } };
const res = makeRes();
const next = vi.fn();
validateQuery(querySchema)(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(400);
expect(res.send).toHaveBeenCalledWith(
expect.objectContaining({
error: "Validation failed",
details: expect.arrayContaining(["limit must be <= 100"]),
})
);
});
});
+62
View File
@@ -0,0 +1,62 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
vi.mock("jsonwebtoken", () => {
return {
default: {
verify: vi.fn(),
sign: vi.fn(),
},
};
});
import jwt from "jsonwebtoken";
import { JwtAuthenticator } from "../../src/utils/jwtAuthenticator";
describe("JwtAuthenticator", () => {
const secret = "test-secret";
let auth: JwtAuthenticator;
beforeEach(() => {
auth = new JwtAuthenticator(secret);
vi.clearAllMocks();
});
afterEach(() => {
vi.restoreAllMocks();
});
it("verifyToken returns null when no token is passed", () => {
expect(auth.verifyToken(undefined)).toBeNull();
expect(jwt.verify).not.toHaveBeenCalled();
});
it("verifyToken returns DecodedToken when verify was successful ", () => {
const payload = { username: "alice", id: "1", uuid: "u-1" };
(jwt.verify as any).mockReturnValue(payload);
const res = auth.verifyToken("valid.jwt.token");
expect(jwt.verify).toHaveBeenCalledWith("valid.jwt.token", secret);
expect(res).toEqual(payload);
});
it("verifyToken returns null when verify throws error", () => {
const spy = vi.spyOn(console, "error").mockImplementation(() => {});
(jwt.verify as any).mockImplementation(() => {
throw new Error("invalid");
});
const res = auth.verifyToken("broken.token");
expect(res).toBeNull();
expect(spy).toHaveBeenCalled();
spy.mockRestore();
});
it("generateToken signs payload with secret", () => {
(jwt.sign as any).mockReturnValue("signed.jwt");
const payload = { username: "bob" } as any;
const token = auth.generateToken(payload);
expect(jwt.sign).toHaveBeenCalledWith(payload, secret);
expect(token).toBe("signed.jwt");
});
});
+83
View File
@@ -0,0 +1,83 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
const { hashMock, compareMock } = vi.hoisted(() => ({
hashMock: vi.fn(),
compareMock: vi.fn(),
}));
vi.mock("bcrypt", () => {
return {
hash: hashMock,
compare: compareMock,
default: {
hash: hashMock,
compare: compareMock,
},
};
});
import { PasswordUtils } from "../../src/utils/passwordUtils";
describe("PasswordUtils", () => {
beforeEach(() => {
hashMock.mockReset();
compareMock.mockReset();
});
afterEach(() => {
vi.restoreAllMocks();
});
it("hashPassword uses bcrypt.hash with 10 saltrounds", async () => {
hashMock.mockResolvedValue("hashed");
const res = await PasswordUtils.hashPassword("secret");
expect(hashMock).toHaveBeenCalledWith("secret", 10);
expect(res).toBe("hashed");
});
it("comparePassword uses bcrypt.compare", async () => {
compareMock.mockResolvedValue(true);
const ok = await PasswordUtils.comparePassword("secret", "hashed");
expect(compareMock).toHaveBeenCalledWith("secret", "hashed");
expect(ok).toBe(true);
});
describe("validatePassword", () => {
it("fails when password too short", () => {
const res = PasswordUtils.validatePassword("A1!");
expect(res.valid).toBe(false);
expect(res.message).toMatch(/mindestens 8 Zeichen/);
});
it("fails without capital letter", () => {
const res = PasswordUtils.validatePassword("password1!");
expect(res.valid).toBe(false);
expect(res.message).toMatch(/Großbuchstaben/);
});
it("fails without uncapitalized letter", () => {
const res = PasswordUtils.validatePassword("PASSWORD1!");
expect(res.valid).toBe(false);
expect(res.message).toMatch(/Kleinbuchstaben/);
});
it("fails without number", () => {
const res = PasswordUtils.validatePassword("Password!");
expect(res.valid).toBe(false);
expect(res.message).toMatch(/Zahl/);
});
it("fails without special characters", () => {
const res = PasswordUtils.validatePassword("Password1");
expect(res.valid).toBe(false);
expect(res.message).toMatch(/Sonderzeichen/);
});
it("accepts valid password", () => {
const res = PasswordUtils.validatePassword("ValidPassword1!");
expect(res.valid).toBe(true);
});
});
});
+80
View File
@@ -0,0 +1,80 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
vi.mock("../../src/utils/jwtAuthenticator", () => {
return {
JwtAuthenticator: vi.fn().mockImplementation(() => ({
verifyToken: mockVerifyToken,
})),
};
});
const mockVerifyToken = vi.fn();
import type { IncomingMessage } from "node:http";
import { verifyClient } from "../../src/utils/verifyClient";
describe("verifyClient", () => {
const cb = vi.fn();
let consoleSpy: ReturnType<typeof vi.spyOn>;
function makeReq(authHeader?: string) {
const headers: Record<string, string> = {};
if (authHeader) headers["authorization"] = authHeader;
// socket infos just for log
const socket: any = { remoteAddress: "127.0.0.1", remotePort: 12345 };
return { headers, socket } as unknown as IncomingMessage & { [k: string]: any };
}
beforeEach(() => {
cb.mockReset();
mockVerifyToken.mockReset();
consoleSpy = vi.spyOn(console, "log").mockImplementation(() => {});
});
afterEach(() => {
consoleSpy.mockRestore();
});
it("accepts connections with valid token and sets payload", () => {
const req = makeReq("Bearer valid.jwt");
mockVerifyToken.mockReturnValue({ sub: "user-1" });
verifyClient(req, cb);
expect(mockVerifyToken).toHaveBeenCalledWith("valid.jwt");
expect(cb).toHaveBeenCalledWith(true);
expect((req as any).payload).toEqual({ sub: "user-1" });
});
it("Rejects connection if no Authorization header is set", () => {
const req = makeReq(undefined);
mockVerifyToken.mockReturnValue(null);
verifyClient(req, cb);
expect(cb).toHaveBeenCalledWith(false, 401, "Unauthorized");
expect(consoleSpy).toHaveBeenCalled();
});
it("rejects connection, if token is invalid", () => {
const req = makeReq("Bearer bad.jwt");
mockVerifyToken.mockReturnValue(null);
verifyClient(req, cb);
expect(mockVerifyToken).toHaveBeenCalledWith("bad.jwt");
expect(cb).toHaveBeenCalledWith(false, 401, "Unauthorized");
});
it("extracts token correctly after 'Bearer ' prefix", () => {
const expectedToken = " fancy.token.with.spaces ";
const req = makeReq(`Bearer ${expectedToken}`);
mockVerifyToken.mockReturnValue({ ok: true });
verifyClient(req, cb);
expect(mockVerifyToken).toHaveBeenCalledWith(expectedToken);
expect(cb).toHaveBeenCalledWith(true);
});
});
@@ -0,0 +1,40 @@
import { describe, it, expect, vi } from "vitest";
import { getEventListeners } from "../../../../src/utils/websocket/websocketCustomEvents/websocketEventUtils";
import { WebsocketEventType } from "../../../../src/utils/websocket/websocketCustomEvents/websocketEventType";
describe("websocketEventUtils.getEventListeners", () => {
function makeWs() {
return {
user: {
timezone: "Europe/Berlin",
lastState: { global: { mode: "idle", brightness: 42 } },
},
send: vi.fn(),
};
}
it("returns a list of event-handlers incl. GET_STATE/GET_SETTINGS", async () => {
const ws: any = makeWs();
const listeners = getEventListeners(ws);
expect(Array.isArray(listeners)).toBe(true);
expect(listeners.length).toBeGreaterThan(0);
const byType = Object.fromEntries(listeners.map(l => [l.event, l]));
expect(byType[WebsocketEventType.GET_STATE]).toBeTruthy();
byType[WebsocketEventType.GET_STATE].handler({});
expect(ws.send).toHaveBeenCalledWith(
JSON.stringify({ type: "STATE", payload: ws.user.lastState }),
{ binary: false },
);
ws.send.mockClear();
expect(byType[WebsocketEventType.GET_SETTINGS]).toBeTruthy();
byType[WebsocketEventType.GET_SETTINGS].handler({});
expect(ws.send).toHaveBeenCalledWith(
JSON.stringify({ type: "SETTINGS", payload: { timezone: ws.user.timezone } }),
{ binary: false },
);
});
});
@@ -0,0 +1,112 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
const { heartbeatSpy, getUserByUUID } = vi.hoisted(() => ({
heartbeatSpy: vi.fn(),
getUserByUUID: vi.fn(),
}));
vi.mock("../../../src/utils/websocket/websocketServerHeartbeatInterval", () => {
return {
heartbeat: () => heartbeatSpy,
};
});
const userObj = {
name: "tester",
uuid: "uuid-1",
timezone: "Europe/Berlin",
location: "Berlin",
lastState: { global: { mode: "idle", brightness: 50 } },
};
vi.mock("../../../src/db/services/db/UserService", () => {
return {
UserService: {
create: vi.fn().mockResolvedValue({
getUserByUUID,
}),
},
};
});
class FakeWSS {
clients = new Set<any>();
handlers = new Map<string, Function>();
on(event: string, handler: Function) {
this.handlers.set(event, handler);
}
emit(event: string, ...args: any[]) {
const h = this.handlers.get(event);
if (h) h(...args);
}
}
import { WebsocketServerEventHandler } from "../../../src/utils/websocket/websocketServerEventHandler";
describe("WebsocketServerEventHandler", () => {
let wss: FakeWSS;
beforeEach(() => {
wss = new FakeWSS();
heartbeatSpy.mockReset();
getUserByUUID.mockReset();
getUserByUUID.mockResolvedValue(userObj);
});
afterEach(() => {
vi.restoreAllMocks();
});
it("enableConnectionEvent sets user/payload/isAlive/asyncUpdates and calls callback", async () => {
const handler = new WebsocketServerEventHandler(wss as any);
const cb = vi.fn();
const done = new Promise<void>((resolve) => {
cb.mockImplementation(() => resolve());
});
handler.enableConnectionEvent(cb);
const req = { payload: { uuid: "uuid-1" } };
const ws: any = {};
wss.emit("connection", ws, req);
await done;
expect(getUserByUUID).toHaveBeenCalledWith("uuid-1");
expect(ws.user).toEqual(userObj);
expect(ws.payload).toEqual(req.payload);
expect(ws.isAlive).toBe(true);
expect(ws.asyncUpdates).toBeInstanceOf(Map);
expect(cb).toHaveBeenCalledWith(ws, req);
});
it("enableHeartbeat starts interval and calls heartbeat()", () => {
vi.useFakeTimers();
const handler = new WebsocketServerEventHandler(wss as any);
const id = handler.enableHeartbeat(1000);
expect(["number", "object"]).toContain(typeof id);
vi.advanceTimersByTime(3000);
expect(heartbeatSpy).toHaveBeenCalledTimes(3);
clearInterval(id);
vi.useRealTimers();
});
it("enableCloseEvent registers Listener and calls callback on close", () => {
const handler = new WebsocketServerEventHandler(wss as any);
const cb = vi.fn();
const logSpy = vi.spyOn(console, "log").mockImplementation(() => {});
handler.enableCloseEvent(cb);
wss.emit("close");
expect(cb).toHaveBeenCalledTimes(1);
expect(logSpy).toHaveBeenCalledWith("WebSocket server closed");
logSpy.mockRestore();
});
});
@@ -0,0 +1,46 @@
import {describe, it, expect, vi, beforeEach, afterEach} from "vitest";
import {heartbeat} from "../../../src/utils/websocket/websocketServerHeartbeatInterval";
describe("heartbeat(wss)", () => {
let consoleSpy: ReturnType<typeof vi.spyOn>;
beforeEach(() => {
consoleSpy = vi.spyOn(console, "log").mockImplementation(() => {
});
});
afterEach(() => {
consoleSpy.mockRestore();
});
function makeClient({
isAlive,
username,
}: {
isAlive: boolean;
username: string;
}) {
return {
isAlive,
payload: {username},
ping: vi.fn(),
terminate: vi.fn(),
} as any;
}
it("terminated dead clients and pings alive ones, sets isAlive to false", () => {
const alive = makeClient({isAlive: true, username: "alive-user"});
const dead = makeClient({isAlive: false, username: "dead-user"});
const wss = {
clients: new Set<any>([alive, dead]),
} as any;
const hb = heartbeat(wss);
hb();
expect(dead.terminate).toHaveBeenCalledTimes(1);
expect(alive.ping).toHaveBeenCalledTimes(1);
expect(alive.isAlive).toBe(false);
});
});