-
Notifications
You must be signed in to change notification settings - Fork 354
/
Copy pathtraining_jobs.py
7024 lines (6397 loc) · 358 KB
/
training_jobs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import time
from typing import Dict, List, Optional, Sequence, Tuple, Union
import abc
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base
from google.cloud.aiplatform.constants import base as constants
from google.cloud.aiplatform import datasets
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import models
from google.cloud.aiplatform import jobs
from google.cloud.aiplatform import schema
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.utils import console_utils
from google.cloud.aiplatform.compat.types import (
env_var as gca_env_var,
io as gca_io,
model as gca_model,
pipeline_state as gca_pipeline_state,
training_pipeline as gca_training_pipeline,
)
from google.cloud.aiplatform.utils import _timestamped_gcs_dir
from google.cloud.aiplatform.utils import source_utils
from google.cloud.aiplatform.utils import worker_spec_utils
from google.cloud.aiplatform.utils import column_transformations_utils
from google.cloud.aiplatform.v1.schema.trainingjob import (
definition_v1 as training_job_inputs,
)
from google.rpc import code_pb2
from google.rpc import status_pb2
import proto
_LOGGER = base.Logger(__name__)
_PIPELINE_COMPLETE_STATES = set(
[
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED,
gca_pipeline_state.PipelineState.PIPELINE_STATE_CANCELLED,
gca_pipeline_state.PipelineState.PIPELINE_STATE_PAUSED,
]
)
# _block_until_complete wait times
_JOB_WAIT_TIME = 5 # start at five seconds
_LOG_WAIT_TIME = 5
_MAX_WAIT_TIME = 60 * 5 # 5 minute wait
_WAIT_TIME_MULTIPLIER = 2 # scale wait by 2 every iteration
class _TrainingJob(base.VertexAiStatefulResource):
client_class = utils.PipelineClientWithOverride
_resource_noun = "trainingPipelines"
_getter_method = "get_training_pipeline"
_list_method = "list_training_pipelines"
_delete_method = "delete_training_pipeline"
_parse_resource_name_method = "parse_training_pipeline_path"
_format_resource_name_method = "training_pipeline_path"
# Required by the done() method
_valid_done_states = _PIPELINE_COMPLETE_STATES
def __init__(
self,
display_name: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
labels: Optional[Dict[str, str]] = None,
training_encryption_spec_key_name: Optional[str] = None,
model_encryption_spec_key_name: Optional[str] = None,
):
"""Constructs a Training Job.
Args:
display_name (str):
Optional. The user-defined name of this TrainingPipeline.
project (str):
Optional project to retrieve model from. If not set, project set in
aiplatform.init will be used.
location (str):
Optional location to retrieve model from. If not set, location set in
aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional credentials to use to retrieve the model.
labels (Dict[str, str]):
Optional. The labels with user-defined metadata to
organize TrainingPipelines.
Label keys and values can be no longer than 64
characters (Unicode codepoints), can only
contain lowercase letters, numeric characters,
underscores and dashes. International characters
are allowed.
See https://goo.gl/xmQnxf for more information
and examples of labels.
training_encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the training pipeline. Has the
form:
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute
resource is created.
If set, this TrainingPipeline will be secured by this key.
Note: Model trained by this TrainingPipeline is also secured
by this key if ``model_to_upload`` is not set separately.
Overrides encryption_spec_key_name set in aiplatform.init.
model_encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the model. Has the
form:
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute
resource is created.
If set, the trained Model will be secured by this key.
Overrides encryption_spec_key_name set in aiplatform.init.
"""
if not display_name:
display_name = self.__class__._generate_display_name()
utils.validate_display_name(display_name)
if labels:
utils.validate_labels(labels)
super().__init__(project=project, location=location, credentials=credentials)
self._display_name = display_name
self._labels = labels
self._training_encryption_spec = initializer.global_config.get_encryption_spec(
encryption_spec_key_name=training_encryption_spec_key_name
)
self._model_encryption_spec = initializer.global_config.get_encryption_spec(
encryption_spec_key_name=model_encryption_spec_key_name
)
self._gca_resource = None
@property
@classmethod
@abc.abstractmethod
def _supported_training_schemas(cls) -> Tuple[str]:
"""List of supported schemas for this training job."""
pass
@property
def start_time(self) -> Optional[datetime.datetime]:
"""Time when the TrainingJob entered the `PIPELINE_STATE_RUNNING` for
the first time."""
self._sync_gca_resource()
return getattr(self._gca_resource, "start_time")
@property
def end_time(self) -> Optional[datetime.datetime]:
"""Time when the TrainingJob resource entered the `PIPELINE_STATE_SUCCEEDED`,
`PIPELINE_STATE_FAILED`, `PIPELINE_STATE_CANCELLED` state."""
self._sync_gca_resource()
return getattr(self._gca_resource, "end_time")
@property
def error(self) -> Optional[status_pb2.Status]:
"""Detailed error info for this TrainingJob resource. Only populated when
the TrainingJob's state is `PIPELINE_STATE_FAILED` or
`PIPELINE_STATE_CANCELLED`."""
self._sync_gca_resource()
return getattr(self._gca_resource, "error")
@classmethod
def get(
cls,
resource_name: str,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "_TrainingJob":
"""Get Training Job for the given resource_name.
Args:
resource_name (str):
Required. A fully-qualified resource name or ID.
project (str):
Optional project to retrieve training job from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional location to retrieve training job from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Custom credentials to use to upload this model. Overrides
credentials set in aiplatform.init.
Raises:
ValueError: If the retrieved training job's training task definition
doesn't match the custom training task definition.
Returns:
A Vertex AI Training Job
"""
# Create job with dummy parameters
# These parameters won't be used as user can not run the job again.
# If they try, an exception will be raised.
self = cls._empty_constructor(
project=project,
location=location,
credentials=credentials,
resource_name=resource_name,
)
self._gca_resource = self._get_gca_resource(resource_name=resource_name)
if (
self._gca_resource.training_task_definition
not in cls._supported_training_schemas
):
raise ValueError(
f"The retrieved job's training task definition "
f"is {self._gca_resource.training_task_definition}, "
f"which is not compatible with {cls.__name__}."
)
return self
@classmethod
def _get_and_return_subclass(
cls,
resource_name: str,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "_TrainingJob":
"""Retrieve Training Job subclass for the given resource_name without
knowing the training_task_definition.
Example usage:
```
aiplatform.training_jobs._TrainingJob._get_and_return_subclass(
'projects/.../locations/.../trainingPipelines/12345'
)
# Returns: <google.cloud.aiplatform.training_jobs.AutoMLImageTrainingJob>
```
Args:
resource_name (str):
Required. A fully-qualified resource name or ID.
project (str):
Optional project to retrieve dataset from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional location to retrieve dataset from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to upload this model. Overrides
credentials set in aiplatform.init.
Returns:
A Vertex AI Training Job
"""
# Retrieve training pipeline resource before class construction
client = cls._instantiate_client(location=location, credentials=credentials)
gca_training_pipeline = getattr(client, cls._getter_method)(name=resource_name)
schema_uri = gca_training_pipeline.training_task_definition
# Collect all AutoML training job classes and CustomTrainingJob
class_list = [
c for c in cls.__subclasses__() if c.__name__.startswith("AutoML")
] + [CustomTrainingJob]
# Identify correct training job subclass, construct and return object
for c in class_list:
if schema_uri in c._supported_training_schemas:
return c._empty_constructor(
project=project,
location=location,
credentials=credentials,
resource_name=resource_name,
)
@property
@abc.abstractmethod
def _model_upload_fail_string(self) -> str:
"""Helper property for model upload failure."""
pass
@abc.abstractmethod
def run(self) -> Optional[models.Model]:
"""Runs the training job.
Should call _run_job internally
"""
pass
@staticmethod
def _create_input_data_config(
dataset: Optional[datasets._Dataset] = None,
annotation_schema_uri: Optional[str] = None,
training_fraction_split: Optional[float] = None,
validation_fraction_split: Optional[float] = None,
test_fraction_split: Optional[float] = None,
training_filter_split: Optional[str] = None,
validation_filter_split: Optional[str] = None,
test_filter_split: Optional[str] = None,
predefined_split_column_name: Optional[str] = None,
timestamp_split_column_name: Optional[str] = None,
gcs_destination_uri_prefix: Optional[str] = None,
bigquery_destination: Optional[str] = None,
) -> Optional[gca_training_pipeline.InputDataConfig]:
"""Constructs a input data config to pass to the training pipeline.
Args:
dataset (datasets._Dataset):
The dataset within the same Project from which data will be used to train the Model. The
Dataset must use schema compatible with Model being trained,
and what is compatible should be described in the used
TrainingPipeline's [training_task_definition]
[google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
For tabular Datasets, all their data is exported to
training, to pick and choose from.
annotation_schema_uri (str):
Google Cloud Storage URI points to a YAML file describing
annotation schema. The schema is defined as an OpenAPI 3.0.2
[Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) The schema files
that can be used here are found in
gs://google-cloud-aiplatform/schema/dataset/annotation/,
note that the chosen schema must be consistent with
``metadata``
of the Dataset specified by
``dataset_id``.
Only Annotations that both match this schema and belong to
DataItems not ignored by the split method are used in
respectively training, validation or test role, depending on
the role of the DataItem they are on.
When used in conjunction with
``annotations_filter``,
the Annotations used for training are filtered by both
``annotations_filter``
and
``annotation_schema_uri``.
training_fraction_split (float):
Optional. The fraction of the input data that is to be used to train
the Model. This is ignored if Dataset is not provided.
validation_fraction_split (float):
Optional. The fraction of the input data that is to be used to validate
the Model. This is ignored if Dataset is not provided.
test_fraction_split (float):
Optional. The fraction of the input data that is to be used to evaluate
the Model. This is ignored if Dataset is not provided.
training_filter_split (str):
Optional. A filter on DataItems of the Dataset. DataItems that match
this filter are used to train the Model. A filter with same syntax
as the one used in DatasetService.ListDataItems may be used. If a
single DataItem is matched by more than one of the FilterSplit filters,
then it is assigned to the first set that applies to it in the training,
validation, test order. This is ignored if Dataset is not provided.
validation_filter_split (str):
Optional. A filter on DataItems of the Dataset. DataItems that match
this filter are used to validate the Model. A filter with same syntax
as the one used in DatasetService.ListDataItems may be used. If a
single DataItem is matched by more than one of the FilterSplit filters,
then it is assigned to the first set that applies to it in the training,
validation, test order. This is ignored if Dataset is not provided.
test_filter_split (str):
Optional. A filter on DataItems of the Dataset. DataItems that match
this filter are used to test the Model. A filter with same syntax
as the one used in DatasetService.ListDataItems may be used. If a
single DataItem is matched by more than one of the FilterSplit filters,
then it is assigned to the first set that applies to it in the training,
validation, test order. This is ignored if Dataset is not provided.
predefined_split_column_name (str):
Optional. The key is a name of one of the Dataset's data
columns. The value of the key (either the label's value or
value in the column) must be one of {``training``,
``validation``, ``test``}, and it defines to which set the
given piece of data is assigned. If for a piece of data the
key is not present or has an invalid value, that piece is
ignored by the pipeline.
Supported only for tabular and time series Datasets.
timestamp_split_column_name (str):
Optional. The key is a name of one of the Dataset's data
columns. The value of the key values of the key (the values in
the column) must be in RFC 3339 `date-time` format, where
`time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
piece of data the key is not present or has an invalid value,
that piece is ignored by the pipeline.
Supported only for tabular and time series Datasets.
This parameter must be used with training_fraction_split,
validation_fraction_split, and test_fraction_split.
gcs_destination_uri_prefix (str):
Optional. The Google Cloud Storage location.
The Vertex AI environment variables representing Google
Cloud Storage data URIs will always be represented in the
Google Cloud Storage wildcard format to support sharded
data.
- AIP_DATA_FORMAT = "jsonl".
- AIP_TRAINING_DATA_URI = "gcs_destination/training-*"
- AIP_VALIDATION_DATA_URI = "gcs_destination/validation-*"
- AIP_TEST_DATA_URI = "gcs_destination/test-*".
bigquery_destination (str):
The BigQuery project location where the training data is to
be written to. In the given project a new dataset is created
with name
``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
training input data will be written into that dataset. In
the dataset three tables will be created, ``training``,
``validation`` and ``test``.
- AIP_DATA_FORMAT = "bigquery".
- AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
- AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
- AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
Raises:
ValueError: When more than 1 type of split configuration is passed or when
the split configuration passed is incompatible with the dataset schema.
"""
input_data_config = None
if dataset:
# Initialize all possible splits
filter_split = None
predefined_split = None
timestamp_split = None
fraction_split = None
# Create filter split
if any(
[
training_filter_split is not None,
validation_filter_split is not None,
test_filter_split is not None,
]
):
if all(
[
training_filter_split is not None,
validation_filter_split is not None,
test_filter_split is not None,
]
):
filter_split = gca_training_pipeline.FilterSplit(
training_filter=training_filter_split,
validation_filter=validation_filter_split,
test_filter=test_filter_split,
)
else:
raise ValueError(
"All filter splits must be passed together or not at all"
)
# Create predefined split
if predefined_split_column_name:
predefined_split = gca_training_pipeline.PredefinedSplit(
key=predefined_split_column_name
)
# Create timestamp split or fraction split
if timestamp_split_column_name:
timestamp_split = gca_training_pipeline.TimestampSplit(
training_fraction=training_fraction_split,
validation_fraction=validation_fraction_split,
test_fraction=test_fraction_split,
key=timestamp_split_column_name,
)
elif any(
[
training_fraction_split is not None,
validation_fraction_split is not None,
test_fraction_split is not None,
]
):
fraction_split = gca_training_pipeline.FractionSplit(
training_fraction=training_fraction_split,
validation_fraction=validation_fraction_split,
test_fraction=test_fraction_split,
)
splits = [
split
for split in [
filter_split,
predefined_split,
timestamp_split_column_name,
fraction_split,
]
if split is not None
]
# Fallback to fraction split if nothing else is specified
if len(splits) == 0:
_LOGGER.info(
"No dataset split provided. The service will use a default split."
)
elif len(splits) > 1:
raise ValueError(
"""Can only specify one of:
1. training_filter_split, validation_filter_split, test_filter_split
2. predefined_split_column_name
3. timestamp_split_column_name, training_fraction_split, validation_fraction_split, test_fraction_split
4. training_fraction_split, validation_fraction_split, test_fraction_split"""
)
# create GCS destination
gcs_destination = None
if gcs_destination_uri_prefix:
gcs_destination = gca_io.GcsDestination(
output_uri_prefix=gcs_destination_uri_prefix
)
# TODO(b/177416223) validate managed BQ dataset is passed in
bigquery_destination_proto = None
if bigquery_destination:
bigquery_destination_proto = gca_io.BigQueryDestination(
output_uri=bigquery_destination
)
# create input data config
input_data_config = gca_training_pipeline.InputDataConfig(
fraction_split=fraction_split,
filter_split=filter_split,
predefined_split=predefined_split,
timestamp_split=timestamp_split,
dataset_id=dataset.name,
annotation_schema_uri=annotation_schema_uri,
gcs_destination=gcs_destination,
bigquery_destination=bigquery_destination_proto,
)
return input_data_config
def _run_job(
self,
training_task_definition: str,
training_task_inputs: Union[dict, proto.Message],
dataset: Optional[datasets._Dataset],
training_fraction_split: Optional[float] = None,
validation_fraction_split: Optional[float] = None,
test_fraction_split: Optional[float] = None,
training_filter_split: Optional[str] = None,
validation_filter_split: Optional[str] = None,
test_filter_split: Optional[str] = None,
predefined_split_column_name: Optional[str] = None,
timestamp_split_column_name: Optional[str] = None,
annotation_schema_uri: Optional[str] = None,
model: Optional[gca_model.Model] = None,
gcs_destination_uri_prefix: Optional[str] = None,
bigquery_destination: Optional[str] = None,
create_request_timeout: Optional[float] = None,
) -> Optional[models.Model]:
"""Runs the training job.
Args:
training_task_definition (str):
Required. A Google Cloud Storage path to the
YAML file that defines the training task which
is responsible for producing the model artifact,
and may also include additional auxiliary work.
The definition files that can be used here are
found in gs://google-cloud-
aiplatform/schema/trainingjob/definition/. Note:
The URI given on output will be immutable and
probably different, including the URI scheme,
than the one given on input. The output URI will
point to a location where the user only has a
read access.
training_task_inputs (Union[dict, proto.Message]):
Required. The training task's input that corresponds to the training_task_definition parameter.
dataset (datasets._Dataset):
The dataset within the same Project from which data will be used to train the Model. The
Dataset must use schema compatible with Model being trained,
and what is compatible should be described in the used
TrainingPipeline's [training_task_definition]
[google.cloud.aiplatform.v1beta1.TrainingPipeline.training_task_definition].
For tabular Datasets, all their data is exported to
training, to pick and choose from.
annotation_schema_uri (str):
Google Cloud Storage URI points to a YAML file describing
annotation schema. The schema is defined as an OpenAPI 3.0.2
[Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) The schema files
that can be used here are found in
gs://google-cloud-aiplatform/schema/dataset/annotation/,
note that the chosen schema must be consistent with
``metadata``
of the Dataset specified by
``dataset_id``.
Only Annotations that both match this schema and belong to
DataItems not ignored by the split method are used in
respectively training, validation or test role, depending on
the role of the DataItem they are on.
When used in conjunction with
``annotations_filter``,
the Annotations used for training are filtered by both
``annotations_filter``
and
``annotation_schema_uri``.
training_fraction_split (float):
Optional. The fraction of the input data that is to be used to train
the Model. This is ignored if Dataset is not provided.
validation_fraction_split (float):
Optional. The fraction of the input data that is to be used to validate
the Model. This is ignored if Dataset is not provided.
test_fraction_split (float):
Optional. The fraction of the input data that is to be used to evaluate
the Model. This is ignored if Dataset is not provided.
training_filter_split (str):
Optional. A filter on DataItems of the Dataset. DataItems that match
this filter are used to train the Model. A filter with same syntax
as the one used in DatasetService.ListDataItems may be used. If a
single DataItem is matched by more than one of the FilterSplit filters,
then it is assigned to the first set that applies to it in the training,
validation, test order. This is ignored if Dataset is not provided.
validation_filter_split (str):
Optional. A filter on DataItems of the Dataset. DataItems that match
this filter are used to validate the Model. A filter with same syntax
as the one used in DatasetService.ListDataItems may be used. If a
single DataItem is matched by more than one of the FilterSplit filters,
then it is assigned to the first set that applies to it in the training,
validation, test order. This is ignored if Dataset is not provided.
test_filter_split (str):
Optional. A filter on DataItems of the Dataset. DataItems that match
this filter are used to test the Model. A filter with same syntax
as the one used in DatasetService.ListDataItems may be used. If a
single DataItem is matched by more than one of the FilterSplit filters,
then it is assigned to the first set that applies to it in the training,
validation, test order. This is ignored if Dataset is not provided.
predefined_split_column_name (str):
Optional. The key is a name of one of the Dataset's data
columns. The value of the key (either the label's value or
value in the column) must be one of {``training``,
``validation``, ``test``}, and it defines to which set the
given piece of data is assigned. If for a piece of data the
key is not present or has an invalid value, that piece is
ignored by the pipeline.
Supported only for tabular and time series Datasets.
timestamp_split_column_name (str):
Optional. The key is a name of one of the Dataset's data
columns. The value of the key values of the key (the values in
the column) must be in RFC 3339 `date-time` format, where
`time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a
piece of data the key is not present or has an invalid value,
that piece is ignored by the pipeline.
Supported only for tabular and time series Datasets.
This parameter must be used with training_fraction_split,
validation_fraction_split, and test_fraction_split.
model (~.model.Model):
Optional. Describes the Model that may be uploaded (via
[ModelService.UploadMode][]) by this TrainingPipeline. The
TrainingPipeline's
``training_task_definition``
should make clear whether this Model description should be
populated, and if there are any special requirements
regarding how it should be filled. If nothing is mentioned
in the
``training_task_definition``,
then it should be assumed that this field should not be
filled and the training task either uploads the Model
without a need of this information, or that training task
does not support uploading a Model as part of the pipeline.
When the Pipeline's state becomes
``PIPELINE_STATE_SUCCEEDED`` and the trained Model had been
uploaded into Vertex AI, then the model_to_upload's
resource ``name``
is populated. The Model is always uploaded into the Project
and Location in which this pipeline is.
gcs_destination_uri_prefix (str):
Optional. The Google Cloud Storage location.
The Vertex AI environment variables representing Google
Cloud Storage data URIs will always be represented in the
Google Cloud Storage wildcard format to support sharded
data.
- AIP_DATA_FORMAT = "jsonl".
- AIP_TRAINING_DATA_URI = "gcs_destination/training-*"
- AIP_VALIDATION_DATA_URI = "gcs_destination/validation-*"
- AIP_TEST_DATA_URI = "gcs_destination/test-*".
bigquery_destination (str):
The BigQuery project location where the training data is to
be written to. In the given project a new dataset is created
with name
``dataset_<dataset-id>_<annotation-type>_<timestamp-of-training-call>``
where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All
training input data will be written into that dataset. In
the dataset three tables will be created, ``training``,
``validation`` and ``test``.
- AIP_DATA_FORMAT = "bigquery".
- AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training"
- AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation"
- AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test"
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
"""
input_data_config = self._create_input_data_config(
dataset=dataset,
annotation_schema_uri=annotation_schema_uri,
training_fraction_split=training_fraction_split,
validation_fraction_split=validation_fraction_split,
test_fraction_split=test_fraction_split,
training_filter_split=training_filter_split,
validation_filter_split=validation_filter_split,
test_filter_split=test_filter_split,
predefined_split_column_name=predefined_split_column_name,
timestamp_split_column_name=timestamp_split_column_name,
gcs_destination_uri_prefix=gcs_destination_uri_prefix,
bigquery_destination=bigquery_destination,
)
# create training pipeline
training_pipeline = gca_training_pipeline.TrainingPipeline(
display_name=self._display_name,
training_task_definition=training_task_definition,
training_task_inputs=training_task_inputs,
model_to_upload=model,
input_data_config=input_data_config,
labels=self._labels,
encryption_spec=self._training_encryption_spec,
)
training_pipeline = self.api_client.create_training_pipeline(
parent=initializer.global_config.common_location_path(
self.project, self.location
),
training_pipeline=training_pipeline,
timeout=create_request_timeout,
)
self._gca_resource = training_pipeline
_LOGGER.info("View Training:\n%s" % self._dashboard_uri())
model = self._get_model()
if model is None:
_LOGGER.warning(
"Training did not produce a Managed Model returning None. "
+ self._model_upload_fail_string
)
return model
def _is_waiting_to_run(self) -> bool:
"""Returns True if the Job is pending on upstream tasks False
otherwise."""
self._raise_future_exception()
if self._latest_future:
_LOGGER.info(
"Training Job is waiting for upstream SDK tasks to complete before"
" launching."
)
return True
return False
@property
def state(self) -> Optional[gca_pipeline_state.PipelineState]:
"""Current training state."""
if self._assert_has_run():
return
self._sync_gca_resource()
return self._gca_resource.state
def get_model(self, sync=True) -> models.Model:
"""Vertex AI Model produced by this training, if one was produced.
Args:
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
Returns:
model: Vertex AI Model produced by this training
Raises:
RuntimeError: If training failed or if a model was not produced by this training.
"""
self._assert_has_run()
if not self._gca_resource.model_to_upload:
raise RuntimeError(self._model_upload_fail_string)
return self._force_get_model(sync=sync)
@base.optional_sync()
def _force_get_model(self, sync: bool = True) -> models.Model:
"""Vertex AI Model produced by this training, if one was produced.
Args:
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
Returns:
model: Vertex AI Model produced by this training
Raises:
RuntimeError: If training failed or if a model was not produced by this training.
"""
model = self._get_model()
if model is None:
raise RuntimeError(self._model_upload_fail_string)
return model
def _get_model(self) -> Optional[models.Model]:
"""Helper method to get and instantiate the Model to Upload.
Returns:
model: Vertex AI Model if training succeeded and produced a Vertex AI
Model. None otherwise.
Raises:
RuntimeError: If Training failed.
"""
self._block_until_complete()
if self.has_failed:
raise RuntimeError(
f"Training Pipeline {self.resource_name} failed. No model available."
)
if not self._gca_resource.model_to_upload:
return None
if self._gca_resource.model_to_upload.name:
return models.Model(model_name=self._gca_resource.model_to_upload.name)
def _wait_callback(self):
"""Callback performs custom logging during _block_until_complete. Override in subclass."""
pass
def _block_until_complete(self):
"""Helper method to block and check on job until complete."""
log_wait = _LOG_WAIT_TIME
previous_time = time.time()
while self.state not in _PIPELINE_COMPLETE_STATES:
current_time = time.time()
if current_time - previous_time >= log_wait:
_LOGGER.info(
"%s %s current state:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
)
)
log_wait = min(log_wait * _WAIT_TIME_MULTIPLIER, _MAX_WAIT_TIME)
previous_time = current_time
self._wait_callback()
time.sleep(_JOB_WAIT_TIME)
self._raise_failure()
_LOGGER.log_action_completed_against_resource("run", "completed", self)
if self._gca_resource.model_to_upload and not self.has_failed:
_LOGGER.info(
"Model available at %s" % self._gca_resource.model_to_upload.name
)
def _raise_failure(self):
"""Helper method to raise failure if TrainingPipeline fails.
Raises:
RuntimeError: If training failed.
"""
if self._gca_resource.error.code != code_pb2.OK:
raise RuntimeError("Training failed with:\n%s" % self._gca_resource.error)
@property
def has_failed(self) -> bool:
"""Returns True if training has failed.
False otherwise.
"""
self._assert_has_run()
return self.state == gca_pipeline_state.PipelineState.PIPELINE_STATE_FAILED
def _dashboard_uri(self) -> str:
"""Helper method to compose the dashboard uri where training can be
viewed."""
fields = self._parse_resource_name(self.resource_name)
url = f"https://console.cloud.google.com/ai/platform/locations/{fields['location']}/training/{fields['training_pipeline']}?project={fields['project']}"
return url
@property
def _has_run(self) -> bool:
"""Helper property to check if this training job has been run."""
return self._gca_resource is not None
def _assert_has_run(self) -> bool:
"""Helper method to assert that this training has run."""
if not self._has_run:
if self._is_waiting_to_run():
return True
raise RuntimeError(
"TrainingPipeline has not been launched. You must run this"
" TrainingPipeline using TrainingPipeline.run. "
)
return False
@classmethod
def list(
cls,
filter: Optional[str] = None,
order_by: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List["base.VertexAiResourceNoun"]:
"""List all instances of this TrainingJob resource.
Example Usage:
aiplatform.CustomTrainingJob.list(
filter='display_name="experiment_a27"',
order_by='create_time desc'
)
Args:
filter (str):
Optional. An expression for filtering the results of the request.
For field names both snake_case and camelCase are supported.
order_by (str):
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve list from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve list. Overrides
credentials set in aiplatform.init.
Returns:
List[VertexAiResourceNoun] - A list of TrainingJob resource objects
"""
training_job_subclass_filter = (
lambda gapic_obj: gapic_obj.training_task_definition
in cls._supported_training_schemas
)
return cls._list_with_local_order(
cls_filter=training_job_subclass_filter,
filter=filter,
order_by=order_by,
project=project,
location=location,
credentials=credentials,
)