1+ # Copyright 2024 Google LLC
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+
116from __future__ import annotations
217
318from collections .abc import Iterable , Mapping , Sequence
@@ -300,7 +315,12 @@ def to_contents(contents: ContentsType) -> list[glm.Content]:
300315 return contents
301316
302317
303- def _generate_schema (
318+ def _schema_for_class (cls : TypedDict ) -> dict [str , Any ]:
319+ schema = _build_schema ("dummy" , {"dummy" : (cls , pydantic .Field ())})
320+ return schema ["properties" ]["dummy" ]
321+
322+
323+ def _schema_for_function (
304324 f : Callable [..., Any ],
305325 * ,
306326 descriptions : Mapping [str , str ] | None = None ,
@@ -323,52 +343,36 @@ def _generate_schema(
323343 """
324344 if descriptions is None :
325345 descriptions = {}
326- if required is None :
327- required = []
328346 defaults = dict (inspect .signature (f ).parameters )
329- fields_dict = {
330- name : (
331- # 1. We infer the argument type here: use Any rather than None so
332- # it will not try to auto-infer the type based on the default value.
333- (param .annotation if param .annotation != inspect .Parameter .empty else Any ),
334- pydantic .Field (
335- # 2. We do not support default values for now.
336- # default=(
337- # param.default if param.default != inspect.Parameter.empty
338- # else None
339- # ),
340- # 3. We support user-provided descriptions.
341- description = descriptions .get (name , None ),
342- ),
343- )
344- for name , param in defaults .items ()
345- # We do not support *args or **kwargs
346- if param .kind
347- in (
347+
348+ fields_dict = {}
349+ for name , param in defaults .items ():
350+ if param .kind in (
348351 inspect .Parameter .POSITIONAL_OR_KEYWORD ,
349352 inspect .Parameter .KEYWORD_ONLY ,
350353 inspect .Parameter .POSITIONAL_ONLY ,
351- )
352- }
353- parameters = pydantic .create_model (f .__name__ , ** fields_dict ).schema ()
354- # Postprocessing
355- # 4. Suppress unnecessary title generation:
356- # * https://github.com/pydantic/pydantic/issues/1051
357- # * http://cl/586221780
358- parameters .pop ("title" , None )
359- for name , function_arg in parameters .get ("properties" , {}).items ():
360- function_arg .pop ("title" , None )
361- annotation = defaults [name ].annotation
362- # 5. Nullable fields:
363- # * https://github.com/pydantic/pydantic/issues/1270
364- # * https://stackoverflow.com/a/58841311
365- # * https://github.com/pydantic/pydantic/discussions/4872
366- if typing .get_origin (annotation ) is typing .Union and type (None ) in typing .get_args (
367- annotation
368354 ):
369- function_arg ["nullable" ] = True
355+ # We do not support default values for now.
356+ # default=(
357+ # param.default if param.default != inspect.Parameter.empty
358+ # else None
359+ # ),
360+ field = pydantic .Field (
361+ # We support user-provided descriptions.
362+ description = descriptions .get (name , None )
363+ )
364+
365+ # 1. We infer the argument type here: use Any rather than None so
366+ # it will not try to auto-infer the type based on the default value.
367+ if param .annotation != inspect .Parameter .empty :
368+ fields_dict [name ] = param .annotation , field
369+ else :
370+ fields_dict [name ] = Any , field
371+
372+ parameters = _build_schema (f .__name__ , fields_dict )
373+
370374 # 6. Annotate required fields.
371- if required :
375+ if required is not None :
372376 # We use the user-provided "required" fields if specified.
373377 parameters ["required" ] = required
374378 else :
@@ -387,9 +391,112 @@ def _generate_schema(
387391 )
388392 ]
389393 schema = dict (name = f .__name__ , description = f .__doc__ , parameters = parameters )
394+
390395 return schema
391396
392397
398+ def _build_schema (fname , fields_dict ):
399+ parameters = pydantic .create_model (fname , ** fields_dict ).schema ()
400+ defs = parameters .pop ("$defs" , {})
401+ # flatten the defs
402+ for name , value in defs .items ():
403+ unpack_defs (value , defs )
404+ unpack_defs (parameters , defs )
405+
406+ # 5. Nullable fields:
407+ # * https://github.com/pydantic/pydantic/issues/1270
408+ # * https://stackoverflow.com/a/58841311
409+ # * https://github.com/pydantic/pydantic/discussions/4872
410+ convert_to_nullable (parameters )
411+ add_object_type (parameters )
412+ # Postprocessing
413+ # 4. Suppress unnecessary title generation:
414+ # * https://github.com/pydantic/pydantic/issues/1051
415+ # * http://cl/586221780
416+ strip_titles (parameters )
417+ return parameters
418+
419+
420+ def unpack_defs (schema , defs ):
421+ properties = schema ["properties" ]
422+ for name , value in properties .items ():
423+ ref_key = value .get ("$ref" , None )
424+ if ref_key is not None :
425+ ref = defs [ref_key .split ("defs/" )[- 1 ]]
426+ unpack_defs (ref , defs )
427+ properties [name ] = ref
428+ continue
429+
430+ anyof = value .get ("anyOf" , None )
431+ if anyof is not None :
432+ for i , atype in enumerate (anyof ):
433+ ref_key = atype .get ("$ref" , None )
434+ if ref_key is not None :
435+ ref = defs [ref_key .split ("defs/" )[- 1 ]]
436+ unpack_defs (ref , defs )
437+ anyof [i ] = ref
438+ continue
439+
440+ items = value .get ("items" , None )
441+ if items is not None :
442+ ref_key = items .get ("$ref" , None )
443+ if ref_key is not None :
444+ ref = defs [ref_key .split ("defs/" )[- 1 ]]
445+ unpack_defs (ref , defs )
446+ value ["items" ] = ref
447+ continue
448+
449+
450+ def strip_titles (schema ):
451+ title = schema .pop ("title" , None )
452+
453+ properties = schema .get ("properties" , None )
454+ if properties is not None :
455+ for name , value in properties .items ():
456+ strip_titles (value )
457+
458+ items = schema .get ("items" , None )
459+ if items is not None :
460+ strip_titles (items )
461+
462+
463+ def add_object_type (schema ):
464+ properties = schema .get ("properties" , None )
465+ if properties is not None :
466+ schema .pop ("required" , None )
467+ schema ["type" ] = "object"
468+ for name , value in properties .items ():
469+ add_object_type (value )
470+
471+ items = schema .get ("items" , None )
472+ if items is not None :
473+ add_object_type (items )
474+
475+
476+ def convert_to_nullable (schema ):
477+ anyof = schema .pop ("anyOf" , None )
478+ if anyof is not None :
479+ if len (anyof ) != 2 :
480+ raise ValueError ("Type Unions are not supported (except for Optional)" )
481+ a , b = anyof
482+ if a == {"type" : "null" }:
483+ schema .update (b )
484+ elif b == {"type" : "null" }:
485+ schema .update (a )
486+ else :
487+ raise ValueError ("Type Unions are not supported (except for Optional)" )
488+ schema ["nullable" ] = True
489+
490+ properties = schema .get ("properties" , None )
491+ if properties is not None :
492+ for name , value in properties .items ():
493+ convert_to_nullable (value )
494+
495+ items = schema .get ("items" , None )
496+ if items is not None :
497+ convert_to_nullable (items )
498+
499+
393500def _rename_schema_fields (schema ):
394501 if schema is None :
395502 return schema
@@ -460,7 +567,7 @@ def from_function(function: Callable[..., Any], descriptions: dict[str, str] | N
460567 if descriptions is None :
461568 descriptions = {}
462569
463- schema = _generate_schema (function , descriptions = descriptions )
570+ schema = _schema_for_function (function , descriptions = descriptions )
464571
465572 return CallableFunctionDeclaration (** schema , function = function )
466573
0 commit comments