Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 48 additions & 22 deletions generators/rust/base/src/context/AbstractRustGeneratorContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
escapeRustKeyword,
escapeRustReservedType,
generateDefaultCrateName,
RustCycleDetector,
validateAndSanitizeCrateName
} from "../utils";

Expand All @@ -28,6 +29,11 @@ export abstract class AbstractRustGeneratorContext<
) {
super(config, generatorNotificationService);

// Detect illegal recursive type cycles before any generation
// This will throw an error if the schema has cycles that cannot be represented in Rust
const cycleDetector = new RustCycleDetector(ir);
cycleDetector.detectIllegalCycles();

// Extract publish config from output mode
config.output.mode._visit<void>({
github: (github) => {
Expand Down Expand Up @@ -148,8 +154,11 @@ export abstract class AbstractRustGeneratorContext<
* Check if IR uses a specific primitive type
*/
private irUsesType(typeName: "DATE_TIME" | "DATE" | "UUID" | "BIG_INTEGER"): boolean {
// Use a visited set to prevent infinite recursion on circular types
const visited = new Set<string>();

for (const typeDecl of Object.values(this.ir.types)) {
if (this.typeShapeUsesBuiltin(typeDecl.shape, typeName)) {
if (this.typeShapeUsesBuiltin(typeDecl.shape, typeName, visited)) {
return true;
}
}
Expand All @@ -159,12 +168,12 @@ export abstract class AbstractRustGeneratorContext<
if (endpoint.requestBody != null) {
if (endpoint.requestBody.type === "inlinedRequestBody") {
for (const property of endpoint.requestBody.properties) {
if (this.typeReferenceUsesBuiltin(property.valueType, typeName)) {
if (this.typeReferenceUsesBuiltin(property.valueType, typeName, visited)) {
return true;
}
}
} else if (endpoint.requestBody.type === "reference") {
if (this.typeReferenceUsesBuiltin(endpoint.requestBody.requestBodyType, typeName)) {
if (this.typeReferenceUsesBuiltin(endpoint.requestBody.requestBodyType, typeName, visited)) {
return true;
}
}
Expand All @@ -175,9 +184,9 @@ export abstract class AbstractRustGeneratorContext<
json: (json: FernIr.JsonResponse) => {
return json._visit({
response: (response: FernIr.JsonResponseBody) =>
this.typeReferenceUsesBuiltin(response.responseBodyType, typeName),
this.typeReferenceUsesBuiltin(response.responseBodyType, typeName, visited),
nestedPropertyAsResponse: (nested: FernIr.JsonResponseBodyWithProperty) =>
this.typeReferenceUsesBuiltin(nested.responseBodyType, typeName),
this.typeReferenceUsesBuiltin(nested.responseBodyType, typeName, visited),
_other: () => false
});
},
Expand All @@ -194,19 +203,19 @@ export abstract class AbstractRustGeneratorContext<
}

for (const param of endpoint.queryParameters) {
if (this.typeReferenceUsesBuiltin(param.valueType, typeName)) {
if (this.typeReferenceUsesBuiltin(param.valueType, typeName, visited)) {
return true;
}
}

for (const param of endpoint.pathParameters) {
if (this.typeReferenceUsesBuiltin(param.valueType, typeName)) {
if (this.typeReferenceUsesBuiltin(param.valueType, typeName, visited)) {
return true;
}
}

for (const header of endpoint.headers) {
if (this.typeReferenceUsesBuiltin(header.valueType, typeName)) {
if (this.typeReferenceUsesBuiltin(header.valueType, typeName, visited)) {
return true;
}
}
Expand All @@ -218,14 +227,16 @@ export abstract class AbstractRustGeneratorContext<

/**
* Check if a type shape uses a specific builtin type
* @param visited Set of type IDs already visited to prevent infinite recursion
*/
private typeShapeUsesBuiltin(shape: FernIr.Type, typeName: string): boolean {
private typeShapeUsesBuiltin(shape: FernIr.Type, typeName: string, visited: Set<string>): boolean {
return shape._visit({
alias: (alias: FernIr.AliasTypeDeclaration) => this.typeReferenceUsesBuiltin(alias.aliasOf, typeName),
alias: (alias: FernIr.AliasTypeDeclaration) =>
this.typeReferenceUsesBuiltin(alias.aliasOf, typeName, visited),
enum: () => false,
object: (obj: FernIr.ObjectTypeDeclaration) => {
for (const property of obj.properties) {
if (this.typeReferenceUsesBuiltin(property.valueType, typeName)) {
if (this.typeReferenceUsesBuiltin(property.valueType, typeName, visited)) {
return true;
}
}
Expand All @@ -235,11 +246,15 @@ export abstract class AbstractRustGeneratorContext<
for (const variant of union.types) {
const usesBuiltin = variant.shape._visit({
singleProperty: (property: FernIr.SingleUnionTypeProperty) =>
this.typeReferenceUsesBuiltin(property.type, typeName),
this.typeReferenceUsesBuiltin(property.type, typeName, visited),
samePropertiesAsObject: (declaredType: FernIr.DeclaredTypeName) => {
// Prevent infinite recursion by checking if we've visited this type
if (visited.has(declaredType.typeId)) {
return false;
}
const typeDecl = this.ir.types[declaredType.typeId];
if (typeDecl) {
return this.typeShapeUsesBuiltin(typeDecl.shape, typeName);
return this.typeShapeUsesBuiltin(typeDecl.shape, typeName, visited);
}
return false;
},
Expand All @@ -254,7 +269,7 @@ export abstract class AbstractRustGeneratorContext<
},
undiscriminatedUnion: (union: FernIr.UndiscriminatedUnionTypeDeclaration) => {
for (const member of union.members) {
if (this.typeReferenceUsesBuiltin(member.type, typeName)) {
if (this.typeReferenceUsesBuiltin(member.type, typeName, visited)) {
return true;
}
}
Expand All @@ -266,29 +281,40 @@ export abstract class AbstractRustGeneratorContext<

/**
* Check if a type reference uses a specific builtin type
* @param visited Set of type IDs already visited to prevent infinite recursion
*/
private typeReferenceUsesBuiltin(typeRef: FernIr.TypeReference, typeName: string): boolean {
private typeReferenceUsesBuiltin(typeRef: FernIr.TypeReference, typeName: string, visited: Set<string>): boolean {
return typeRef._visit({
primitive: (primitive: FernIr.PrimitiveType) => {
return primitive.v1 === typeName;
},
container: (container: FernIr.ContainerType) => {
return container._visit({
list: (list: FernIr.TypeReference) => this.typeReferenceUsesBuiltin(list, typeName),
set: (set: FernIr.TypeReference) => this.typeReferenceUsesBuiltin(set, typeName),
optional: (optional: FernIr.TypeReference) => this.typeReferenceUsesBuiltin(optional, typeName),
nullable: (nullable: FernIr.TypeReference) => this.typeReferenceUsesBuiltin(nullable, typeName),
list: (list: FernIr.TypeReference) => this.typeReferenceUsesBuiltin(list, typeName, visited),
set: (set: FernIr.TypeReference) => this.typeReferenceUsesBuiltin(set, typeName, visited),
optional: (optional: FernIr.TypeReference) =>
this.typeReferenceUsesBuiltin(optional, typeName, visited),
nullable: (nullable: FernIr.TypeReference) =>
this.typeReferenceUsesBuiltin(nullable, typeName, visited),
map: (map: FernIr.MapType) =>
this.typeReferenceUsesBuiltin(map.keyType, typeName) ||
this.typeReferenceUsesBuiltin(map.valueType, typeName),
this.typeReferenceUsesBuiltin(map.keyType, typeName, visited) ||
this.typeReferenceUsesBuiltin(map.valueType, typeName, visited),
literal: () => false,
_other: () => false
});
},
named: (named: FernIr.NamedType) => {
// Prevent infinite recursion by checking if we've already visited this type
if (visited.has(named.typeId)) {
return false;
}

// Mark this type as visited
visited.add(named.typeId);

const typeDecl = this.ir.types[named.typeId];
if (typeDecl) {
return this.typeShapeUsesBuiltin(typeDecl.shape, typeName);
return this.typeShapeUsesBuiltin(typeDecl.shape, typeName, visited);
}
return false;
},
Expand Down
Loading
Loading