Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
"@nestjs/mongoose": "^11.0.3",
"@nestjs/passport": "^11.0.5",
"@nestjs/platform-express": "^8.0.0",
"@nestjs/terminus": "10.0.1",
"@nestjs/throttler": "^6.4.0",
"@types/bcrypt": "^5.0.2",
"@types/passport-jwt": "^4.0.1",
"apollo-server-express": "^3.13.0",
Expand All @@ -43,6 +45,7 @@
"graphql": "^16.11.0",
"graphql-subscriptions": "^3.0.0",
"graphql-upload-ts": "^2.1.2",
"helmet": "^8.1.0",
"husky": "^9.1.7",
"moment": "^2.30.1",
"mongoose": "^8.15.0",
Expand Down
39 changes: 34 additions & 5 deletions src/app.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ import { Module } from '@nestjs/common';
import { GraphQLModule } from '@nestjs/graphql';
import { ApolloDriver, ApolloDriverConfig } from '@nestjs/apollo';
import { MongooseModule } from '@nestjs/mongoose';
import { ThrottlerModule } from '@nestjs/throttler';
import { PubSub } from 'graphql-subscriptions';
import { join } from 'path';
import * as dotenv from 'dotenv';
import { APP_GUARD } from '@nestjs/core';

import { UsersModule } from './modules/users/users.module';
import { HealthModule } from './modules/health/health.module';
import { GqlThrottlerGuard } from './common/guards/gql-throttler.guard';

const pubSub = new PubSub();
dotenv.config();
Expand All @@ -16,20 +20,45 @@ dotenv.config();
GraphQLModule.forRoot<ApolloDriverConfig>({
driver: ApolloDriver,
installSubscriptionHandlers: true,
context: ({ req, connection }) => ({
context: ({ req, res, connection }) => ({
req,
res,
connection,
pubSub,
}),
autoSchemaFile: join(process.cwd(), 'src/graphql/schema.gql'),
playground: true,
formatError: (error) => {
return error
}
return error;
},
}),

MongooseModule.forRoot(process.env.MONGO_URI),
UsersModule
ThrottlerModule.forRoot([
{
name: 'short',
ttl: 1000,
limit: 3,
},
{
name: 'medium',
ttl: 10000,
limit: 20,
},
{
name: 'long',
ttl: 60000,
limit: 100,
},
]),
UsersModule,
HealthModule,
],
providers: [
{
provide: APP_GUARD,
useClass: GqlThrottlerGuard,
},
],
})
export class AppModule { }
export class AppModule {}
121 changes: 121 additions & 0 deletions src/common/guards/gql-throttler.guard.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import { ExecutionContext, Injectable, Logger } from '@nestjs/common';
import { GqlContextType, GqlExecutionContext } from '@nestjs/graphql';
import { ThrottlerGuard } from '@nestjs/throttler';
import { Reflector } from '@nestjs/core';
import { ThrottlerModuleOptions, ThrottlerStorage } from '@nestjs/throttler';

interface RequestWithIP extends Request {
ip?: string;
headers: any;
connection?: {
remoteAddress?: string;
};
socket?: {
remoteAddress?: string;
};
}

@Injectable()
export class GqlThrottlerGuard extends ThrottlerGuard {
private logger = new Logger(GqlThrottlerGuard.name);

// List of GraphQL fields that should bypass throttling
private skipThrottleFields = [
'publicUserStats',
// Add more fields here as needed, for example:
// 'getPublicConfig',
// 'healthCheck',
];

constructor(
options: ThrottlerModuleOptions,
storageService: ThrottlerStorage,
reflector: Reflector,
) {
super(options, storageService, reflector);
}

getRequestResponse(context: ExecutionContext) {
const gqlCtx = GqlExecutionContext.create(context);
const ctx = gqlCtx.getContext();
return { req: ctx.req, res: ctx.res };
}

async canActivate(context: ExecutionContext): Promise<boolean> {
// For non-GraphQL requests, use the parent implementation
if (context.getType<GqlContextType>() !== 'graphql') {
return super.canActivate(context);
}

// For GraphQL requests, get the resolver and field name
const gqlContext = GqlExecutionContext.create(context);
const info = gqlContext.getInfo();
const fieldName = info?.fieldName;

this.logger.debug(`GraphQL field: ${fieldName}`);

// Special case: Manually check for fields that should skip throttling
// This is a workaround since the @SkipThrottle decorator metadata isn't being detected
if (this.skipThrottleFields.includes(fieldName)) {
this.logger.debug(`Explicitly skipping throttle for ${fieldName}`);
return true;
}

// Get the parent class (Resolver) and method (Query/Mutation)
const handler = context.getHandler();
const classRef = context.getClass();

// Check for SkipThrottle at both method and class level
const methodSkipThrottle = this.reflector.get('skipThrottle', handler);
const classSkipThrottle = this.reflector.get('skipThrottle', classRef);

this.logger.debug(
`Method skipThrottle: ${JSON.stringify(methodSkipThrottle)}`,
);
this.logger.debug(
`Class skipThrottle: ${JSON.stringify(classSkipThrottle)}`,
);

// If method explicitly sets skipThrottle
if (methodSkipThrottle !== undefined) {
const shouldSkip = this.shouldSkipThrottle(methodSkipThrottle);
if (shouldSkip) {
this.logger.debug(
`Skipping throttle for ${fieldName} due to method decorator`,
);
return true;
}
}

// If class has skipThrottle and method doesn't override it
if (classSkipThrottle !== undefined && methodSkipThrottle === undefined) {
const shouldSkip = this.shouldSkipThrottle(classSkipThrottle);
if (shouldSkip) {
this.logger.debug(
`Skipping throttle for ${fieldName} due to class decorator`,
);
return true;
}
}

// Apply throttling
return super.canActivate(context);
}

private shouldSkipThrottle(skipThrottle: any): boolean {
if (skipThrottle === true) {
return true;
}

if (typeof skipThrottle === 'object') {
// Check if any throttler should be skipped
for (const key in skipThrottle) {
if (skipThrottle[key] === true) {
return true;
}
}
}

return false;
}
}
26 changes: 15 additions & 11 deletions src/config/app.config.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import { join } from 'path';
import * as dotenv from 'dotenv';

dotenv.config();

export const AppConfig = {
port: process.env.PORT || 3000,
graphqlUpload: {
maxFileSize: 10000000,
maxFiles: 10,
},
staticFiles: {
uploadsPath: join(__dirname, '..', '..', 'uploads'),
},
cors: true,
bodyParser: true,
};
port: process.env.PORT || 3000,
env: process.env.NODE_ENV || 'development',
graphqlUpload: {
maxFileSize: 10000000,
maxFiles: 10,
},
staticFiles: {
uploadsPath: join(__dirname, '..', '..', 'uploads'),
},
cors: true,
bodyParser: true,
};
8 changes: 8 additions & 0 deletions src/graphql/schema.gql
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,12 @@ scalar DateTime

type Query {
users: [User!]!
publicUserStats: String!

"""Health check endpoint"""
healthCheck: String!
}

type Mutation {
login(email: String!, password: String!): String!
}
28 changes: 24 additions & 4 deletions src/main.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { NestFactory } from '@nestjs/core';
import { ValidationPipe } from '@nestjs/common';
import { ValidationPipe, Logger } from '@nestjs/common';
import { ExpressAdapter } from '@nestjs/platform-express';
import { NestExpressApplication } from '@nestjs/platform-express';
import * as express from 'express';
import { graphqlUploadExpress } from 'graphql-upload-ts';
import helmet from 'helmet';

import { AppModule } from './app.module';
import { AppConfig } from './config/app.config';
Expand All @@ -15,17 +16,36 @@ async function bootstrap() {
{
cors: AppConfig.cors,
bodyParser: AppConfig.bodyParser,
}
logger: ['debug', 'error', 'warn', 'log'],
},
);

app.use(
'/graphql',
graphqlUploadExpress(AppConfig.graphqlUpload),
helmet({
contentSecurityPolicy: false,
crossOriginEmbedderPolicy: false,
crossOriginOpenerPolicy: false,
crossOriginResourcePolicy: false,
}),
);

app.use('/graphql', graphqlUploadExpress(AppConfig.graphqlUpload));

app.use((req, res, next) => {
if (req.path !== '/graphql') {
helmet()(req, res, next);
} else {
next();
}
});

app.useGlobalPipes(new ValidationPipe({ transform: true }));
app.use('/uploads', express.static(AppConfig.staticFiles.uploadsPath));

const logger = new Logger('Bootstrap');
await app.listen(AppConfig.port);
logger.log(`Application is running on port ${AppConfig.port}`);
logger.log(`GraphQL endpoint available at /graphql`);
}
bootstrap();
bootstrap();
11 changes: 11 additions & 0 deletions src/modules/health/health.module.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { Module } from '@nestjs/common';
import { TerminusModule } from '@nestjs/terminus';
import { MongooseModule } from '@nestjs/mongoose';
import { HealthResolver } from './health.resolver';
import { HealthService } from './health.service';

@Module({
imports: [TerminusModule, MongooseModule.forFeature([])],
providers: [HealthResolver, HealthService],
})
export class HealthModule {}
31 changes: 31 additions & 0 deletions src/modules/health/health.resolver.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { Query, Resolver } from '@nestjs/graphql';
import { HealthService } from './health.service';
import { HealthCheckService, MongooseHealthIndicator } from '@nestjs/terminus';
import { SkipThrottle } from '@nestjs/throttler';

@Resolver('Health')
export class HealthResolver {
constructor(
private health: HealthCheckService,
private mongooseHealth: MongooseHealthIndicator,
private healthService: HealthService,
) {}

@Query(() => String, { description: 'Health check endpoint' })
@SkipThrottle()
async healthCheck() {
try {
const healthCheck = await this.health.check([
() => this.mongooseHealth.pingCheck('mongodb'),
() => this.healthService.checkApiStatus(),
]);

return JSON.stringify(healthCheck);
} catch (error) {
return JSON.stringify({
status: 'error',
info: { error: error.message },
});
}
}
}
19 changes: 19 additions & 0 deletions src/modules/health/health.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { Injectable } from '@nestjs/common';
import { HealthIndicatorResult, HealthIndicatorStatus } from '@nestjs/terminus';

@Injectable()
export class HealthService {
async checkApiStatus(): Promise<HealthIndicatorResult> {
const isHealthy = true; // Replace with actual health logic if needed

const result: HealthIndicatorResult = {
api: {
status: isHealthy
? ('up' as HealthIndicatorStatus)
: ('down' as HealthIndicatorStatus),
},
};

return result;
}
}
Loading