Skip to content

Commit 285b258

Browse files
authored
fix: make sure both fk and relation fields are optional in create input types (#1862)
1 parent ad07053 commit 285b258

File tree

7 files changed

+358
-115
lines changed

7 files changed

+358
-115
lines changed

packages/runtime/src/enhancements/node/delegate.ts

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
587587
let curr = args;
588588
let base = this.getBaseModel(model);
589589
let sub = this.getModelInfo(model);
590+
const hasDelegateBase = !!base;
590591

591592
while (base) {
592593
const baseRelationName = this.makeAuxRelationName(base);
@@ -615,6 +616,55 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
615616
sub = base;
616617
base = this.getBaseModel(base.name);
617618
}
619+
620+
if (hasDelegateBase) {
621+
// A delegate base model creation is added, this can be incompatible if
622+
// the user-provided payload assigns foreign keys directly, because Prisma
623+
// doesn't permit mixed "checked" and "unchecked" fields in a payload.
624+
//
625+
// {
626+
// delegate_aux_base: { ... },
627+
// [fkField]: value // <- this is not compatible
628+
// }
629+
//
630+
// We need to convert foreign key assignments to `connect`.
631+
this.fkAssignmentToConnect(model, args);
632+
}
633+
}
634+
635+
// convert foreign key assignments to `connect` payload
636+
// e.g.: { authorId: value } -> { author: { connect: { id: value } } }
637+
private fkAssignmentToConnect(model: string, args: any) {
638+
const keysToDelete: string[] = [];
639+
for (const [key, value] of Object.entries(args)) {
640+
if (value === undefined) {
641+
continue;
642+
}
643+
644+
const fieldInfo = this.queryUtils.getModelField(model, key);
645+
if (
646+
!fieldInfo?.inheritedFrom && // fields from delegate base are handled outside
647+
fieldInfo?.isForeignKey
648+
) {
649+
const relationInfo = this.queryUtils.getRelationForForeignKey(model, key);
650+
if (relationInfo) {
651+
// turn { [fk]: value } into { [relation]: { connect: { [id]: value } } }
652+
const relationName = relationInfo.relation.name;
653+
if (!args[relationName]) {
654+
args[relationName] = {};
655+
}
656+
if (!args[relationName].connect) {
657+
args[relationName].connect = {};
658+
}
659+
if (!(relationInfo.idField in args[relationName].connect)) {
660+
args[relationName].connect[relationInfo.idField] = value;
661+
keysToDelete.push(key);
662+
}
663+
}
664+
}
665+
}
666+
667+
keysToDelete.forEach((key) => delete args[key]);
618668
}
619669

620670
// inject field data that belongs to base type into proper nesting structure

packages/runtime/src/enhancements/node/query-utils.ts

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,4 +232,25 @@ export class QueryUtils {
232232

233233
return model;
234234
}
235+
236+
/**
237+
* Gets relation info for a foreign key field.
238+
*/
239+
getRelationForForeignKey(model: string, fkField: string) {
240+
const modelInfo = getModelInfo(this.options.modelMeta, model);
241+
if (!modelInfo) {
242+
return undefined;
243+
}
244+
245+
for (const field of Object.values(modelInfo.fields)) {
246+
if (field.foreignKeyMapping) {
247+
const entry = Object.entries(field.foreignKeyMapping).find(([, v]) => v === fkField);
248+
if (entry) {
249+
return { relation: field, idField: entry[0], fkField: entry[1] };
250+
}
251+
}
252+
}
253+
254+
return undefined;
255+
}
235256
}

packages/schema/src/plugins/enhancer/enhance/index.ts

Lines changed: 127 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {
77
getDataModelAndTypeDefs,
88
getDataModels,
99
getLiteral,
10+
getRelationField,
1011
isDelegateModel,
1112
isDiscriminatorField,
1213
normalizedRelative,
@@ -55,12 +56,23 @@ type DelegateInfo = [DataModel, DataModel[]][];
5556
const LOGICAL_CLIENT_GENERATION_PATH = './.logical-prisma-client';
5657

5758
export class EnhancerGenerator {
59+
// regex for matching "ModelCreateXXXInput" and "ModelUncheckedCreateXXXInput" type
60+
// names for models that use `auth()` in `@default` attribute
61+
private readonly modelsWithAuthInDefaultCreateInputPattern: RegExp;
62+
5863
constructor(
5964
private readonly model: Model,
6065
private readonly options: PluginOptions,
6166
private readonly project: Project,
6267
private readonly outDir: string
63-
) {}
68+
) {
69+
const modelsWithAuthInDefault = this.model.declarations.filter(
70+
(d): d is DataModel => isDataModel(d) && d.fields.some((f) => f.attributes.some(isDefaultWithAuth))
71+
);
72+
this.modelsWithAuthInDefaultCreateInputPattern = new RegExp(
73+
`^(${modelsWithAuthInDefault.map((m) => m.name).join('|')})(Unchecked)?Create.*?Input$`
74+
);
75+
}
6476

6577
async generate(): Promise<{ dmmf: DMMF.Document | undefined; newPrismaClientDtsPath: string | undefined }> {
6678
let dmmf: DMMF.Document | undefined;
@@ -69,7 +81,7 @@ export class EnhancerGenerator {
6981
let prismaTypesFixed = false;
7082
let resultPrismaImport = prismaImport;
7183

72-
if (this.needsLogicalClient || this.needsPrismaClientTypeFixes) {
84+
if (this.needsLogicalClient) {
7385
prismaTypesFixed = true;
7486
resultPrismaImport = `${LOGICAL_CLIENT_GENERATION_PATH}/index-fixed`;
7587
const result = await this.generateLogicalPrisma();
@@ -230,11 +242,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
230242
}
231243

232244
private get needsLogicalClient() {
233-
return this.hasDelegateModel(this.model) || this.hasAuthInDefault(this.model);
234-
}
235-
236-
private get needsPrismaClientTypeFixes() {
237-
return this.hasTypeDef(this.model);
245+
return this.hasDelegateModel(this.model) || this.hasAuthInDefault(this.model) || this.hasTypeDef(this.model);
238246
}
239247

240248
private hasDelegateModel(model: Model) {
@@ -449,11 +457,13 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
449457
const auxFields = this.findAuxDecls(variable);
450458
if (auxFields.length > 0) {
451459
structure.declarations.forEach((variable) => {
452-
let source = variable.type?.toString();
453-
auxFields.forEach((f) => {
454-
source = source?.replace(f.getText(), '');
455-
});
456-
variable.type = source;
460+
if (variable.type) {
461+
let source = variable.type.toString();
462+
auxFields.forEach((f) => {
463+
source = this.removeFromSource(source, f.getText());
464+
});
465+
variable.type = source;
466+
}
457467
});
458468
}
459469

@@ -498,72 +508,16 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
498508
// fix delegate payload union type
499509
source = this.fixDelegatePayloadType(typeAlias, delegateInfo, source);
500510

511+
// fix fk and relation fields related to using `auth()` in `@default`
512+
source = this.fixDefaultAuthType(typeAlias, source);
513+
501514
// fix json field type
502515
source = this.fixJsonFieldType(typeAlias, source);
503516

504517
structure.type = source;
505518
return structure;
506519
}
507520

508-
private fixJsonFieldType(typeAlias: TypeAliasDeclaration, source: string) {
509-
const modelsWithTypeField = this.model.declarations.filter(
510-
(d): d is DataModel => isDataModel(d) && d.fields.some((f) => isTypeDef(f.type.reference?.ref))
511-
);
512-
const typeName = typeAlias.getName();
513-
514-
const getTypedJsonFields = (model: DataModel) => {
515-
return model.fields.filter((f) => isTypeDef(f.type.reference?.ref));
516-
};
517-
518-
const replacePrismaJson = (source: string, field: DataModelField) => {
519-
return source.replace(
520-
new RegExp(`(${field.name}\\??\\s*):[^\\n]+`),
521-
`$1: ${field.type.reference!.$refText}${field.type.array ? '[]' : ''}${
522-
field.type.optional ? ' | null' : ''
523-
}`
524-
);
525-
};
526-
527-
// fix "$[Model]Payload" type
528-
const payloadModelMatch = modelsWithTypeField.find((m) => `$${m.name}Payload` === typeName);
529-
if (payloadModelMatch) {
530-
const scalars = typeAlias
531-
.getDescendantsOfKind(SyntaxKind.PropertySignature)
532-
.find((p) => p.getName() === 'scalars');
533-
if (!scalars) {
534-
return source;
535-
}
536-
537-
const fieldsToFix = getTypedJsonFields(payloadModelMatch);
538-
for (const field of fieldsToFix) {
539-
source = replacePrismaJson(source, field);
540-
}
541-
}
542-
543-
// fix input/output types, "[Model]CreateInput", etc.
544-
const inputOutputModelMatch = modelsWithTypeField.find((m) => typeName.startsWith(m.name));
545-
if (inputOutputModelMatch) {
546-
const relevantTypePatterns = [
547-
'GroupByOutputType',
548-
'(Unchecked)?Create(\\S+?)?Input',
549-
'(Unchecked)?Update(\\S+?)?Input',
550-
'CreateManyInput',
551-
'(Unchecked)?UpdateMany(Mutation)?Input',
552-
];
553-
const typeRegex = modelsWithTypeField.map(
554-
(m) => new RegExp(`^(${m.name})(${relevantTypePatterns.join('|')})$`)
555-
);
556-
if (typeRegex.some((r) => r.test(typeName))) {
557-
const fieldsToFix = getTypedJsonFields(inputOutputModelMatch);
558-
for (const field of fieldsToFix) {
559-
source = replacePrismaJson(source, field);
560-
}
561-
}
562-
}
563-
564-
return source;
565-
}
566-
567521
private fixDelegatePayloadType(typeAlias: TypeAliasDeclaration, delegateInfo: DelegateInfo, source: string) {
568522
// change the type of `$<DelegateModel>Payload` type of delegate model to a union of concrete types
569523
const typeName = typeAlias.getName();
@@ -595,7 +549,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
595549
.getDescendantsOfKind(SyntaxKind.PropertySignature)
596550
.filter((p) => ['create', 'createMany', 'connectOrCreate', 'upsert'].includes(p.getName()));
597551
toRemove.forEach((r) => {
598-
source = source.replace(r.getText(), '');
552+
this.removeFromSource(source, r.getText());
599553
});
600554
}
601555
return source;
@@ -633,7 +587,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
633587
if (isDiscriminatorField(field)) {
634588
const fieldDef = this.findNamedProperty(typeAlias, field.name);
635589
if (fieldDef) {
636-
source = source.replace(fieldDef.getText(), '');
590+
source = this.removeFromSource(source, fieldDef.getText());
637591
}
638592
}
639593
}
@@ -646,7 +600,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
646600
const auxDecls = this.findAuxDecls(typeAlias);
647601
if (auxDecls.length > 0) {
648602
auxDecls.forEach((d) => {
649-
source = source.replace(d.getText(), '');
603+
source = this.removeFromSource(source, d.getText());
650604
});
651605
}
652606
return source;
@@ -677,7 +631,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
677631
const fieldDef = this.findNamedProperty(typeAlias, relationFieldName);
678632
if (fieldDef) {
679633
// remove relation field of delegate type, e.g., `asset`
680-
source = source.replace(fieldDef.getText(), '');
634+
source = this.removeFromSource(source, fieldDef.getText());
681635
}
682636

683637
// remove fk fields related to the delegate type relation, e.g., `assetId`
@@ -709,13 +663,103 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
709663
fkFields.forEach((fkField) => {
710664
const fieldDef = this.findNamedProperty(typeAlias, fkField);
711665
if (fieldDef) {
712-
source = source.replace(fieldDef.getText(), '');
666+
source = this.removeFromSource(source, fieldDef.getText());
713667
}
714668
});
715669

716670
return source;
717671
}
718672

673+
private fixDefaultAuthType(typeAlias: TypeAliasDeclaration, source: string) {
674+
const match = typeAlias.getName().match(this.modelsWithAuthInDefaultCreateInputPattern);
675+
if (!match) {
676+
return source;
677+
}
678+
679+
const modelName = match[1];
680+
const dataModel = this.model.declarations.find((d): d is DataModel => isDataModel(d) && d.name === modelName);
681+
if (dataModel) {
682+
for (const fkField of dataModel.fields.filter((f) => f.attributes.some(isDefaultWithAuth))) {
683+
// change fk field to optional since it has a default
684+
source = source.replace(new RegExp(`^(\\s*${fkField.name}\\s*):`, 'm'), `$1?:`);
685+
686+
const relationField = getRelationField(fkField);
687+
if (relationField) {
688+
// change relation field to optional since its fk has a default
689+
source = source.replace(new RegExp(`^(\\s*${relationField.name}\\s*):`, 'm'), `$1?:`);
690+
}
691+
}
692+
}
693+
return source;
694+
}
695+
696+
private fixJsonFieldType(typeAlias: TypeAliasDeclaration, source: string) {
697+
const modelsWithTypeField = this.model.declarations.filter(
698+
(d): d is DataModel => isDataModel(d) && d.fields.some((f) => isTypeDef(f.type.reference?.ref))
699+
);
700+
const typeName = typeAlias.getName();
701+
702+
const getTypedJsonFields = (model: DataModel) => {
703+
return model.fields.filter((f) => isTypeDef(f.type.reference?.ref));
704+
};
705+
706+
const replacePrismaJson = (source: string, field: DataModelField) => {
707+
return source.replace(
708+
new RegExp(`(${field.name}\\??\\s*):[^\\n]+`),
709+
`$1: ${field.type.reference!.$refText}${field.type.array ? '[]' : ''}${
710+
field.type.optional ? ' | null' : ''
711+
}`
712+
);
713+
};
714+
715+
// fix "$[Model]Payload" type
716+
const payloadModelMatch = modelsWithTypeField.find((m) => `$${m.name}Payload` === typeName);
717+
if (payloadModelMatch) {
718+
const scalars = typeAlias
719+
.getDescendantsOfKind(SyntaxKind.PropertySignature)
720+
.find((p) => p.getName() === 'scalars');
721+
if (!scalars) {
722+
return source;
723+
}
724+
725+
const fieldsToFix = getTypedJsonFields(payloadModelMatch);
726+
for (const field of fieldsToFix) {
727+
source = replacePrismaJson(source, field);
728+
}
729+
}
730+
731+
// fix input/output types, "[Model]CreateInput", etc.
732+
const inputOutputModelMatch = modelsWithTypeField.find((m) => typeName.startsWith(m.name));
733+
if (inputOutputModelMatch) {
734+
const relevantTypePatterns = [
735+
'GroupByOutputType',
736+
'(Unchecked)?Create(\\S+?)?Input',
737+
'(Unchecked)?Update(\\S+?)?Input',
738+
'CreateManyInput',
739+
'(Unchecked)?UpdateMany(Mutation)?Input',
740+
];
741+
const typeRegex = modelsWithTypeField.map(
742+
(m) => new RegExp(`^(${m.name})(${relevantTypePatterns.join('|')})$`)
743+
);
744+
if (typeRegex.some((r) => r.test(typeName))) {
745+
const fieldsToFix = getTypedJsonFields(inputOutputModelMatch);
746+
for (const field of fieldsToFix) {
747+
source = replacePrismaJson(source, field);
748+
}
749+
}
750+
}
751+
752+
return source;
753+
}
754+
755+
private async generateExtraTypes(sf: SourceFile) {
756+
for (const decl of this.model.declarations) {
757+
if (isTypeDef(decl)) {
758+
generateTypeDefType(sf, decl);
759+
}
760+
}
761+
}
762+
719763
private findNamedProperty(typeAlias: TypeAliasDeclaration, name: string) {
720764
return typeAlias.getFirstDescendant((d) => d.isKind(SyntaxKind.PropertySignature) && d.getName() === name);
721765
}
@@ -745,11 +789,12 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
745789
return this.options.generatePermissionChecker === true;
746790
}
747791

748-
private async generateExtraTypes(sf: SourceFile) {
749-
for (const decl of this.model.declarations) {
750-
if (isTypeDef(decl)) {
751-
generateTypeDefType(sf, decl);
752-
}
753-
}
792+
private removeFromSource(source: string, text: string) {
793+
source = source.replace(text, '');
794+
return this.trimEmptyLines(source);
795+
}
796+
797+
private trimEmptyLines(source: string): string {
798+
return source.replace(/^\s*[\r\n]/gm, '');
754799
}
755800
}

0 commit comments

Comments
 (0)