|
| 1 | +import json |
1 | 2 | import os |
2 | 3 | import platform |
3 | 4 | import random |
@@ -146,6 +147,7 @@ def __init__(self, model: str, model_store_path: str): |
146 | 147 | self._model_type: str |
147 | 148 | self._model_name, self._model_tag, self._model_organization = self.extract_model_identifiers() |
148 | 149 | self._model_type = type(self).__name__.lower() |
| 150 | + self.artifact = False |
149 | 151 |
|
150 | 152 | self._model_store_path: str = model_store_path |
151 | 153 | self._model_store: Optional[ModelStore] = None |
@@ -201,6 +203,8 @@ def _get_entry_model_path(self, use_container: bool, should_generate: bool, dry_ |
201 | 203 |
|
202 | 204 | if self.model_type == 'oci': |
203 | 205 | if use_container or should_generate: |
| 206 | + if self.artifact: |
| 207 | + return os.path.join(MNT_DIR, self.artifact_name()) |
204 | 208 | return os.path.join(MNT_DIR, 'model.file') |
205 | 209 | else: |
206 | 210 | return f"oci://{self.model}" |
@@ -347,9 +351,10 @@ def exec_model_in_container(self, cmd_args, args): |
347 | 351 | def setup_mounts(self, args): |
348 | 352 | if args.dryrun: |
349 | 353 | return |
| 354 | + |
350 | 355 | if self.model_type == 'oci': |
351 | 356 | if self.engine.use_podman: |
352 | | - mount_cmd = f"--mount=type=image,src={self.model},destination={MNT_DIR},subpath=/models,rw=false" |
| 357 | + mount_cmd = self.mount_cmd() |
353 | 358 | elif self.engine.use_docker: |
354 | 359 | output_filename = self._get_entry_model_path(args.container, True, args.dryrun) |
355 | 360 | volume = populate_volume_from_image(self, os.path.basename(output_filename)) |
@@ -655,40 +660,52 @@ def inspect( |
655 | 660 | as_json: bool = False, |
656 | 661 | dryrun: bool = False, |
657 | 662 | ) -> None: |
| 663 | + json_out = self.get_inspect(show_all, show_all_metadata, get_field, dryrun) |
| 664 | + if as_json: |
| 665 | + print(json_out) |
| 666 | + else: |
| 667 | + print(json.loads(json_out)) |
| 668 | + |
| 669 | + def get_inspect( |
| 670 | + self, |
| 671 | + show_all: bool = False, |
| 672 | + show_all_metadata: bool = False, |
| 673 | + get_field: str = "", |
| 674 | + dryrun: bool = False, |
| 675 | + ) -> None: |
| 676 | + as_json = True |
658 | 677 | model_name = self.filename |
659 | 678 | model_registry = self.type.lower() |
660 | | - model_path = self._get_inspect_model_path(dryrun) |
661 | | - |
| 679 | + model_path = self._get_entry_model_path(False, False, dryrun) |
662 | 680 | if GGUFInfoParser.is_model_gguf(model_path): |
663 | 681 | if not show_all_metadata and get_field == "": |
664 | 682 | gguf_info: GGUFModelInfo = GGUFInfoParser.parse(model_name, model_registry, model_path) |
665 | | - print(gguf_info.serialize(json=as_json, all=show_all)) |
666 | | - return |
| 683 | + return gguf_info.serialize(json=as_json, all=show_all) |
667 | 684 |
|
668 | 685 | metadata = GGUFInfoParser.parse_metadata(model_path) |
669 | 686 | if show_all_metadata: |
670 | | - print(metadata.serialize(json=as_json)) |
671 | | - return |
| 687 | + return metadata.serialize(json=as_json) |
672 | 688 | elif get_field != "": # If a specific field is requested, print only that field |
673 | 689 | field_value = metadata.get(get_field) |
674 | 690 | if field_value is None: |
675 | 691 | raise KeyError(f"Field '{get_field}' not found in GGUF model metadata") |
676 | | - print(field_value) |
677 | | - return |
| 692 | + return field_value |
678 | 693 |
|
679 | 694 | if SafetensorInfoParser.is_model_safetensor(model_name): |
680 | 695 | safetensor_info: SafetensorModelInfo = SafetensorInfoParser.parse(model_name, model_registry, model_path) |
681 | | - print(safetensor_info.serialize(json=as_json, all=show_all)) |
682 | | - return |
| 696 | + return safetensor_info.serialize(json=as_json, all=show_all) |
683 | 697 |
|
684 | | - print(ModelInfoBase(model_name, model_registry, model_path).serialize(json=as_json)) |
| 698 | + return ModelInfoBase(model_name, model_registry, model_path).serialize(json=as_json) |
685 | 699 |
|
686 | 700 | def print_pull_message(self, model_name): |
687 | 701 | model_name = trim_model_name(model_name) |
688 | 702 | # Write messages to stderr |
689 | 703 | perror(f"Downloading {model_name} ...") |
690 | 704 | perror(f"Trying to pull {model_name} ...") |
691 | 705 |
|
| 706 | + def is_artifact(self) -> bool: |
| 707 | + return False |
| 708 | + |
692 | 709 |
|
693 | 710 | def compute_ports() -> list: |
694 | 711 | first_port = DEFAULT_PORT_RANGE[0] |
|
0 commit comments