Skip to content

Commit 2ba01a8

Browse files
chore(rust): add support for circular references
1 parent 520e995 commit 2ba01a8

File tree

8 files changed

+453
-56
lines changed

8 files changed

+453
-56
lines changed

generators/rust/base/src/context/AbstractRustGeneratorContext.ts

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import {
1010
escapeRustKeyword,
1111
escapeRustReservedType,
1212
generateDefaultCrateName,
13+
RustCycleDetector,
1314
validateAndSanitizeCrateName
1415
} from "../utils";
1516

@@ -28,6 +29,11 @@ export abstract class AbstractRustGeneratorContext<
2829
) {
2930
super(config, generatorNotificationService);
3031

32+
// Detect illegal recursive type cycles before any generation
33+
// This will throw an error if the schema has cycles that cannot be represented in Rust
34+
const cycleDetector = new RustCycleDetector(ir);
35+
cycleDetector.detectIllegalCycles();
36+
3137
// Extract publish config from output mode
3238
config.output.mode._visit<void>({
3339
github: (github) => {
@@ -148,8 +154,11 @@ export abstract class AbstractRustGeneratorContext<
148154
* Check if IR uses a specific primitive type
149155
*/
150156
private irUsesType(typeName: "DATE_TIME" | "DATE" | "UUID" | "BIG_INTEGER"): boolean {
157+
// Use a visited set to prevent infinite recursion on circular types
158+
const visited = new Set<string>();
159+
151160
for (const typeDecl of Object.values(this.ir.types)) {
152-
if (this.typeShapeUsesBuiltin(typeDecl.shape, typeName)) {
161+
if (this.typeShapeUsesBuiltin(typeDecl.shape, typeName, visited)) {
153162
return true;
154163
}
155164
}
@@ -159,12 +168,12 @@ export abstract class AbstractRustGeneratorContext<
159168
if (endpoint.requestBody != null) {
160169
if (endpoint.requestBody.type === "inlinedRequestBody") {
161170
for (const property of endpoint.requestBody.properties) {
162-
if (this.typeReferenceUsesBuiltin(property.valueType, typeName)) {
171+
if (this.typeReferenceUsesBuiltin(property.valueType, typeName, visited)) {
163172
return true;
164173
}
165174
}
166175
} else if (endpoint.requestBody.type === "reference") {
167-
if (this.typeReferenceUsesBuiltin(endpoint.requestBody.requestBodyType, typeName)) {
176+
if (this.typeReferenceUsesBuiltin(endpoint.requestBody.requestBodyType, typeName, visited)) {
168177
return true;
169178
}
170179
}
@@ -175,9 +184,9 @@ export abstract class AbstractRustGeneratorContext<
175184
json: (json: FernIr.JsonResponse) => {
176185
return json._visit({
177186
response: (response: FernIr.JsonResponseBody) =>
178-
this.typeReferenceUsesBuiltin(response.responseBodyType, typeName),
187+
this.typeReferenceUsesBuiltin(response.responseBodyType, typeName, visited),
179188
nestedPropertyAsResponse: (nested: FernIr.JsonResponseBodyWithProperty) =>
180-
this.typeReferenceUsesBuiltin(nested.responseBodyType, typeName),
189+
this.typeReferenceUsesBuiltin(nested.responseBodyType, typeName, visited),
181190
_other: () => false
182191
});
183192
},
@@ -194,19 +203,19 @@ export abstract class AbstractRustGeneratorContext<
194203
}
195204

196205
for (const param of endpoint.queryParameters) {
197-
if (this.typeReferenceUsesBuiltin(param.valueType, typeName)) {
206+
if (this.typeReferenceUsesBuiltin(param.valueType, typeName, visited)) {
198207
return true;
199208
}
200209
}
201210

202211
for (const param of endpoint.pathParameters) {
203-
if (this.typeReferenceUsesBuiltin(param.valueType, typeName)) {
212+
if (this.typeReferenceUsesBuiltin(param.valueType, typeName, visited)) {
204213
return true;
205214
}
206215
}
207216

208217
for (const header of endpoint.headers) {
209-
if (this.typeReferenceUsesBuiltin(header.valueType, typeName)) {
218+
if (this.typeReferenceUsesBuiltin(header.valueType, typeName, visited)) {
210219
return true;
211220
}
212221
}
@@ -218,14 +227,16 @@ export abstract class AbstractRustGeneratorContext<
218227

219228
/**
220229
* Check if a type shape uses a specific builtin type
230+
* @param visited Set of type IDs already visited to prevent infinite recursion
221231
*/
222-
private typeShapeUsesBuiltin(shape: FernIr.Type, typeName: string): boolean {
232+
private typeShapeUsesBuiltin(shape: FernIr.Type, typeName: string, visited: Set<string>): boolean {
223233
return shape._visit({
224-
alias: (alias: FernIr.AliasTypeDeclaration) => this.typeReferenceUsesBuiltin(alias.aliasOf, typeName),
234+
alias: (alias: FernIr.AliasTypeDeclaration) =>
235+
this.typeReferenceUsesBuiltin(alias.aliasOf, typeName, visited),
225236
enum: () => false,
226237
object: (obj: FernIr.ObjectTypeDeclaration) => {
227238
for (const property of obj.properties) {
228-
if (this.typeReferenceUsesBuiltin(property.valueType, typeName)) {
239+
if (this.typeReferenceUsesBuiltin(property.valueType, typeName, visited)) {
229240
return true;
230241
}
231242
}
@@ -235,11 +246,15 @@ export abstract class AbstractRustGeneratorContext<
235246
for (const variant of union.types) {
236247
const usesBuiltin = variant.shape._visit({
237248
singleProperty: (property: FernIr.SingleUnionTypeProperty) =>
238-
this.typeReferenceUsesBuiltin(property.type, typeName),
249+
this.typeReferenceUsesBuiltin(property.type, typeName, visited),
239250
samePropertiesAsObject: (declaredType: FernIr.DeclaredTypeName) => {
251+
// Prevent infinite recursion by checking if we've visited this type
252+
if (visited.has(declaredType.typeId)) {
253+
return false;
254+
}
240255
const typeDecl = this.ir.types[declaredType.typeId];
241256
if (typeDecl) {
242-
return this.typeShapeUsesBuiltin(typeDecl.shape, typeName);
257+
return this.typeShapeUsesBuiltin(typeDecl.shape, typeName, visited);
243258
}
244259
return false;
245260
},
@@ -254,7 +269,7 @@ export abstract class AbstractRustGeneratorContext<
254269
},
255270
undiscriminatedUnion: (union: FernIr.UndiscriminatedUnionTypeDeclaration) => {
256271
for (const member of union.members) {
257-
if (this.typeReferenceUsesBuiltin(member.type, typeName)) {
272+
if (this.typeReferenceUsesBuiltin(member.type, typeName, visited)) {
258273
return true;
259274
}
260275
}
@@ -266,29 +281,40 @@ export abstract class AbstractRustGeneratorContext<
266281

267282
/**
268283
* Check if a type reference uses a specific builtin type
284+
* @param visited Set of type IDs already visited to prevent infinite recursion
269285
*/
270-
private typeReferenceUsesBuiltin(typeRef: FernIr.TypeReference, typeName: string): boolean {
286+
private typeReferenceUsesBuiltin(typeRef: FernIr.TypeReference, typeName: string, visited: Set<string>): boolean {
271287
return typeRef._visit({
272288
primitive: (primitive: FernIr.PrimitiveType) => {
273289
return primitive.v1 === typeName;
274290
},
275291
container: (container: FernIr.ContainerType) => {
276292
return container._visit({
277-
list: (list: FernIr.TypeReference) => this.typeReferenceUsesBuiltin(list, typeName),
278-
set: (set: FernIr.TypeReference) => this.typeReferenceUsesBuiltin(set, typeName),
279-
optional: (optional: FernIr.TypeReference) => this.typeReferenceUsesBuiltin(optional, typeName),
280-
nullable: (nullable: FernIr.TypeReference) => this.typeReferenceUsesBuiltin(nullable, typeName),
293+
list: (list: FernIr.TypeReference) => this.typeReferenceUsesBuiltin(list, typeName, visited),
294+
set: (set: FernIr.TypeReference) => this.typeReferenceUsesBuiltin(set, typeName, visited),
295+
optional: (optional: FernIr.TypeReference) =>
296+
this.typeReferenceUsesBuiltin(optional, typeName, visited),
297+
nullable: (nullable: FernIr.TypeReference) =>
298+
this.typeReferenceUsesBuiltin(nullable, typeName, visited),
281299
map: (map: FernIr.MapType) =>
282-
this.typeReferenceUsesBuiltin(map.keyType, typeName) ||
283-
this.typeReferenceUsesBuiltin(map.valueType, typeName),
300+
this.typeReferenceUsesBuiltin(map.keyType, typeName, visited) ||
301+
this.typeReferenceUsesBuiltin(map.valueType, typeName, visited),
284302
literal: () => false,
285303
_other: () => false
286304
});
287305
},
288306
named: (named: FernIr.NamedType) => {
307+
// Prevent infinite recursion by checking if we've already visited this type
308+
if (visited.has(named.typeId)) {
309+
return false;
310+
}
311+
312+
// Mark this type as visited
313+
visited.add(named.typeId);
314+
289315
const typeDecl = this.ir.types[named.typeId];
290316
if (typeDecl) {
291-
return this.typeShapeUsesBuiltin(typeDecl.shape, typeName);
317+
return this.typeShapeUsesBuiltin(typeDecl.shape, typeName, visited);
292318
}
293319
return false;
294320
},

0 commit comments

Comments
 (0)