From 71ce657a5a940db71cafaa3f9b683bdeb8ab48f3 Mon Sep 17 00:00:00 2001
From: Sergei Zharinov <zharinov@users.noreply.github.com>
Date: Wed, 24 May 2023 13:36:19 +0300
Subject: [PATCH] feat(schema-utils): Support `LooseRecord` key validation
 (#22404)

---
 lib/util/schema-utils.spec.ts |  36 +++++++++++
 lib/util/schema-utils.ts      | 109 +++++++++++++++++++++++++++-------
 2 files changed, 124 insertions(+), 21 deletions(-)

diff --git a/lib/util/schema-utils.spec.ts b/lib/util/schema-utils.spec.ts
index a10869342d..1a9528d07c 100644
--- a/lib/util/schema-utils.spec.ts
+++ b/lib/util/schema-utils.spec.ts
@@ -49,6 +49,42 @@ describe('util/schema-utils', () => {
       expect(s.parse({ foo: 'foo', bar: 123 })).toEqual({ foo: 'foo' });
     });
 
+    it('supports key schema', () => {
+      const s = LooseRecord(
+        z.string().refine((x) => x === 'bar'),
+        z.string()
+      );
+      expect(s.parse({ foo: 'foo', bar: 'bar' })).toEqual({ bar: 'bar' });
+    });
+
+    it('reports key schema errors', () => {
+      let errorData: unknown = null;
+      const s = LooseRecord(
+        z.string().refine((x) => x === 'bar'),
+        z.string(),
+        {
+          onError: (x) => {
+            errorData = x;
+          },
+        }
+      );
+
+      s.parse({ foo: 'foo', bar: 'bar' });
+
+      expect(errorData).toMatchObject({
+        error: {
+          issues: [
+            {
+              code: 'custom',
+              message: 'Invalid input',
+              path: ['foo'],
+            },
+          ],
+        },
+        input: { bar: 'bar', foo: 'foo' },
+      });
+    });
+
     it('runs callback for wrong elements', () => {
       let err: z.ZodError | undefined = undefined;
       const Schema = LooseRecord(
diff --git a/lib/util/schema-utils.ts b/lib/util/schema-utils.ts
index 752ab2effe..bc0b9e0cf1 100644
--- a/lib/util/schema-utils.ts
+++ b/lib/util/schema-utils.ts
@@ -67,32 +67,90 @@ export function LooseArray<Schema extends z.ZodTypeAny>(
   });
 }
 
+type LooseRecordResult<
+  KeySchema extends z.ZodTypeAny,
+  ValueSchema extends z.ZodTypeAny
+> = z.ZodEffects<
+  z.ZodRecord<z.ZodString, z.ZodAny>,
+  Record<z.TypeOf<KeySchema>, z.TypeOf<ValueSchema>>,
+  Record<z.TypeOf<KeySchema>, any>
+>;
+
+type LooseRecordOpts<
+  KeySchema extends z.ZodTypeAny,
+  ValueSchema extends z.ZodTypeAny
+> = LooseOpts<Record<z.TypeOf<KeySchema> | z.TypeOf<ValueSchema>, unknown>>;
+
 /**
  * Works like `z.record()`, but drops wrong elements instead of invalidating the whole record.
  *
  * **Important**: non-record inputs other are still invalid.
  * Use `LooseRecord(...).catch({})` to handle it.
  *
- * @param Elem Schema for record values
+ * @param KeyValue Schema for record keys
+ * @param ValueValue Schema for record values
  * @param onError Callback for errors
  * @returns Schema for record
  */
-export function LooseRecord<Schema extends z.ZodTypeAny>(
-  Elem: Schema,
-  { onError }: LooseOpts<Record<string, unknown>> = {}
-): z.ZodEffects<
-  z.ZodRecord<z.ZodString, z.ZodAny>,
-  Record<string, z.TypeOf<Schema>>,
-  Record<string, any>
-> {
+export function LooseRecord<ValueSchema extends z.ZodTypeAny>(
+  Value: ValueSchema
+): LooseRecordResult<z.ZodString, ValueSchema>;
+export function LooseRecord<
+  KeySchema extends z.ZodTypeAny,
+  ValueSchema extends z.ZodTypeAny
+>(
+  Key: KeySchema,
+  Value: ValueSchema
+): LooseRecordResult<KeySchema, ValueSchema>;
+export function LooseRecord<ValueSchema extends z.ZodTypeAny>(
+  Value: ValueSchema,
+  { onError }: LooseRecordOpts<z.ZodString, ValueSchema>
+): LooseRecordResult<z.ZodString, ValueSchema>;
+export function LooseRecord<
+  KeySchema extends z.ZodTypeAny,
+  ValueSchema extends z.ZodTypeAny
+>(
+  Key: KeySchema,
+  Value: ValueSchema,
+  { onError }: LooseRecordOpts<KeySchema, ValueSchema>
+): LooseRecordResult<KeySchema, ValueSchema>;
+export function LooseRecord<
+  KeySchema extends z.ZodTypeAny,
+  ValueSchema extends z.ZodTypeAny
+>(
+  arg1: ValueSchema | KeySchema,
+  arg2?: ValueSchema | LooseOpts<Record<string, unknown>>,
+  arg3?: LooseRecordOpts<KeySchema, ValueSchema>
+): LooseRecordResult<KeySchema, ValueSchema> {
+  let Key: z.ZodSchema = z.any();
+  let Value: ValueSchema;
+  let opts: LooseRecordOpts<KeySchema, ValueSchema> = {};
+  if (arg2 && arg3) {
+    Key = arg1 as KeySchema;
+    Value = arg2 as ValueSchema;
+    opts = arg3;
+  } else if (arg2) {
+    if (arg2 instanceof z.ZodType) {
+      Key = arg1 as KeySchema;
+      Value = arg2;
+    } else {
+      Value = arg1 as ValueSchema;
+      opts = arg2;
+    }
+  } else {
+    Value = arg1 as ValueSchema;
+  }
+
+  const { onError } = opts;
   if (!onError) {
     // Avoid error-related computations inside the loop
     return z.record(z.any()).transform((input) => {
-      const output: Record<string, z.infer<Schema>> = {};
+      const output: Record<string, z.infer<ValueSchema>> = {};
       for (const [key, val] of Object.entries(input)) {
-        const parsed = Elem.safeParse(val);
-        if (parsed.success) {
-          output[key] = parsed.data;
+        const parsedKey = Key.safeParse(key);
+        const parsedValue = Value.safeParse(val);
+        if (parsedKey.success && parsedValue.success) {
+          output[key] = parsedValue.data;
         }
       }
       return output;
@@ -100,21 +158,30 @@ export function LooseRecord<Schema extends z.ZodTypeAny>(
   }
 
   return z.record(z.any()).transform((input) => {
-    const output: Record<string, z.infer<Schema>> = {};
+    const output: Record<string, z.infer<ValueSchema>> = {};
     const issues: z.ZodIssue[] = [];
 
     for (const [key, val] of Object.entries(input)) {
-      const parsed = Elem.safeParse(val);
-
-      if (parsed.success) {
-        output[key] = parsed.data;
+      const parsedKey = Key.safeParse(key);
+      if (!parsedKey.success) {
+        for (const issue of parsedKey.error.issues) {
+          issue.path.unshift(key);
+          issues.push(issue);
+        }
         continue;
       }
 
-      for (const issue of parsed.error.issues) {
-        issue.path.unshift(key);
-        issues.push(issue);
+      const parsedValue = Value.safeParse(val);
+      if (!parsedValue.success) {
+        for (const issue of parsedValue.error.issues) {
+          issue.path.unshift(key);
+          issues.push(issue);
+        }
+        continue;
       }
+
+      output[key] = parsedValue.data;
+      continue;
     }
 
     if (issues.length) {
-- 
GitLab