Skip to content

Commit 4574154

Browse files
committed
fix: make sure the logical DMMF respects auth() in @default
fixes #1893
1 parent d5c30f9 commit 4574154

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

packages/plugins/openapi/tests/openapi-rpc.test.ts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,36 @@ model post_Item {
457457

458458
await OpenAPIParser.validate(output);
459459
});
460+
461+
it('auth() in @default()', async () => {
462+
const { projectDir } = await loadSchema(`
463+
plugin openapi {
464+
provider = '${normalizePath(path.resolve(__dirname, '../dist'))}'
465+
output = '$projectRoot/openapi.yaml'
466+
flavor = 'rpc'
467+
}
468+
469+
model User {
470+
id Int @id
471+
posts Post[]
472+
}
473+
474+
model Post {
475+
id Int @id
476+
title String
477+
author User @relation(fields: [authorId], references: [id])
478+
authorId Int @default(auth().id)
479+
}
480+
`);
481+
482+
const output = path.join(projectDir, 'openapi.yaml');
483+
console.log('OpenAPI specification generated:', output);
484+
485+
await OpenAPIParser.validate(output);
486+
const parsed = YAML.parse(fs.readFileSync(output, 'utf-8'));
487+
expect(parsed.components.schemas.PostCreateInput.required).not.toContain('author');
488+
expect(parsed.components.schemas.PostCreateManyInput.required).not.toContain('authorId');
489+
});
460490
});
461491

462492
function buildOptions(model: Model, modelFile: string, output: string) {

packages/schema/src/plugins/enhancer/enhance/index.ts

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { ReadonlyDeep } from '@prisma/generator-helper';
12
import { DELEGATE_AUX_RELATION_PREFIX } from '@zenstackhq/runtime';
23
import {
34
PluginError,
@@ -6,8 +7,10 @@ import {
67
getAuthDecl,
78
getDataModelAndTypeDefs,
89
getDataModels,
10+
getForeignKeyFields,
911
getLiteral,
1012
getRelationField,
13+
hasAttribute,
1114
isDelegateModel,
1215
isDiscriminatorField,
1316
normalizedRelative,
@@ -311,7 +314,8 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
311314
// make a bunch of typing fixes to the generated prisma client
312315
await this.processClientTypes(path.join(this.outDir, LOGICAL_CLIENT_GENERATION_PATH));
313316

314-
const dmmf = await getDMMF({ datamodel: fs.readFileSync(logicalPrismaFile, { encoding: 'utf-8' }) });
317+
// get the dmmf of the logical prisma schema
318+
const dmmf = await this.getLogicalDMMF(logicalPrismaFile);
315319

316320
try {
317321
// clean up temp schema
@@ -329,6 +333,56 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
329333
};
330334
}
331335

336+
private async getLogicalDMMF(logicalPrismaFile: string) {
337+
const dmmf = await getDMMF({ datamodel: fs.readFileSync(logicalPrismaFile, { encoding: 'utf-8' }) });
338+
339+
// make necessary fixes
340+
341+
// fields that use `auth()` in `@default` are not handled by Prisma so in the DMMF
342+
// they may be incorrectly represented as required, we need to fix that for input types
343+
// also, if a FK field is of such case, its corresponding relation field should be optional
344+
const createInputPattern = new RegExp(`^(.+?)(Unchecked)?Create.*Input$`);
345+
for (const inputType of dmmf.schema.inputObjectTypes.prisma) {
346+
const match = inputType.name.match(createInputPattern);
347+
const modelName = match?.[1];
348+
if (modelName) {
349+
const dataModel = this.model.declarations.find(
350+
(d): d is DataModel => isDataModel(d) && d.name === modelName
351+
);
352+
if (dataModel) {
353+
for (const field of inputType.fields) {
354+
if (field.isRequired && this.shouldBeOptional(field, dataModel)) {
355+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
356+
(field as any).isRequired = false;
357+
}
358+
}
359+
}
360+
}
361+
}
362+
return dmmf;
363+
}
364+
365+
private shouldBeOptional(field: ReadonlyDeep<DMMF.SchemaArg>, dataModel: DataModel) {
366+
const dmField = dataModel.fields.find((f) => f.name === field.name);
367+
if (!dmField) {
368+
return false;
369+
}
370+
371+
if (hasAttribute(dmField, '@default')) {
372+
return true;
373+
}
374+
375+
if (isDataModel(dmField.type.reference?.ref)) {
376+
// if FK field should be optional, the relation field should too
377+
const fkFields = getForeignKeyFields(dmField);
378+
if (fkFields.length > 0 && fkFields.every((f) => hasAttribute(f, '@default'))) {
379+
return true;
380+
}
381+
}
382+
383+
return false;
384+
}
385+
332386
private getPrismaClientGeneratorName(model: Model) {
333387
for (const generator of model.declarations.filter(isGeneratorDecl)) {
334388
if (

0 commit comments

Comments
 (0)