1818__all__ = [
1919 "SwinTransformer" ,
2020 "Swin_T_Weights" ,
21+ "Swin_S_Weights" ,
22+ "Swin_B_Weights" ,
2123 "swin_t" ,
24+ "swin_s" ,
25+ "swin_b" ,
2226]
2327
2428
@@ -408,9 +412,9 @@ def _swin_transformer(
408412
409413class Swin_T_Weights (WeightsEnum ):
410414 IMAGENET1K_V1 = Weights (
411- url = "https://download.pytorch.org/models/swin_t-81486767 .pth" ,
415+ url = "https://download.pytorch.org/models/swin_t-4c37bd06 .pth" ,
412416 transforms = partial (
413- ImageClassification , crop_size = 224 , resize_size = 238 , interpolation = InterpolationMode .BICUBIC
417+ ImageClassification , crop_size = 224 , resize_size = 232 , interpolation = InterpolationMode .BICUBIC
414418 ),
415419 meta = {
416420 ** _COMMON_META ,
@@ -419,11 +423,57 @@ class Swin_T_Weights(WeightsEnum):
419423 "recipe" : "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer" ,
420424 "_metrics" : {
421425 "ImageNet-1K" : {
422- "acc@1" : 81.358 ,
423- "acc@5" : 95.526 ,
426+ "acc@1" : 81.474 ,
427+ "acc@5" : 95.776 ,
428+ }
429+ },
430+ "_docs" : """These weights reproduce closely the results of the paper using a similar training recipe.""" ,
431+ },
432+ )
433+ DEFAULT = IMAGENET1K_V1
434+
435+
436+ class Swin_S_Weights (WeightsEnum ):
437+ IMAGENET1K_V1 = Weights (
438+ url = "https://download.pytorch.org/models/swin_s-30134662.pth" ,
439+ transforms = partial (
440+ ImageClassification , crop_size = 224 , resize_size = 246 , interpolation = InterpolationMode .BICUBIC
441+ ),
442+ meta = {
443+ ** _COMMON_META ,
444+ "num_params" : 49606258 ,
445+ "min_size" : (224 , 224 ),
446+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer" ,
447+ "_metrics" : {
448+ "ImageNet-1K" : {
449+ "acc@1" : 83.196 ,
450+ "acc@5" : 96.360 ,
451+ }
452+ },
453+ "_docs" : """These weights reproduce closely the results of the paper using a similar training recipe.""" ,
454+ },
455+ )
456+ DEFAULT = IMAGENET1K_V1
457+
458+
459+ class Swin_B_Weights (WeightsEnum ):
460+ IMAGENET1K_V1 = Weights (
461+ url = "https://download.pytorch.org/models/swin_b-1f1feb5c.pth" ,
462+ transforms = partial (
463+ ImageClassification , crop_size = 224 , resize_size = 238 , interpolation = InterpolationMode .BICUBIC
464+ ),
465+ meta = {
466+ ** _COMMON_META ,
467+ "num_params" : 87768224 ,
468+ "min_size" : (224 , 224 ),
469+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer" ,
470+ "_metrics" : {
471+ "ImageNet-1K" : {
472+ "acc@1" : 83.582 ,
473+ "acc@5" : 96.640 ,
424474 }
425475 },
426- "_docs" : """These weights reproduce closely the results of the paper using its training recipe.""" ,
476+ "_docs" : """These weights reproduce closely the results of the paper using a similar training recipe.""" ,
427477 },
428478 )
429479 DEFAULT = IMAGENET1K_V1
@@ -463,3 +513,75 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, *
463513 progress = progress ,
464514 ** kwargs ,
465515 )
516+
517+
518+ def swin_s (* , weights : Optional [Swin_S_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> SwinTransformer :
519+ """
520+ Constructs a swin_small architecture from
521+ `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/pdf/2103.14030>`_.
522+
523+ Args:
524+ weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The
525+ pretrained weights to use. See
526+ :class:`~torchvision.models.Swin_S_Weights` below for
527+ more details, and possible values. By default, no pre-trained
528+ weights are used.
529+ progress (bool, optional): If True, displays a progress bar of the
530+ download to stderr. Default is True.
531+ **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
532+ base class. Please refer to the `source code
533+ <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
534+ for more details about this class.
535+
536+ .. autoclass:: torchvision.models.Swin_S_Weights
537+ :members:
538+ """
539+ weights = Swin_S_Weights .verify (weights )
540+
541+ return _swin_transformer (
542+ patch_size = 4 ,
543+ embed_dim = 96 ,
544+ depths = [2 , 2 , 18 , 2 ],
545+ num_heads = [3 , 6 , 12 , 24 ],
546+ window_size = 7 ,
547+ stochastic_depth_prob = 0.3 ,
548+ weights = weights ,
549+ progress = progress ,
550+ ** kwargs ,
551+ )
552+
553+
554+ def swin_b (* , weights : Optional [Swin_B_Weights ] = None , progress : bool = True , ** kwargs : Any ) -> SwinTransformer :
555+ """
556+ Constructs a swin_base architecture from
557+ `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/pdf/2103.14030>`_.
558+
559+ Args:
560+ weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The
561+ pretrained weights to use. See
562+ :class:`~torchvision.models.Swin_B_Weights` below for
563+ more details, and possible values. By default, no pre-trained
564+ weights are used.
565+ progress (bool, optional): If True, displays a progress bar of the
566+ download to stderr. Default is True.
567+ **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
568+ base class. Please refer to the `source code
569+ <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
570+ for more details about this class.
571+
572+ .. autoclass:: torchvision.models.Swin_B_Weights
573+ :members:
574+ """
575+ weights = Swin_B_Weights .verify (weights )
576+
577+ return _swin_transformer (
578+ patch_size = 4 ,
579+ embed_dim = 128 ,
580+ depths = [2 , 2 , 18 , 2 ],
581+ num_heads = [4 , 8 , 16 , 32 ],
582+ window_size = 7 ,
583+ stochastic_depth_prob = 0.5 ,
584+ weights = weights ,
585+ progress = progress ,
586+ ** kwargs ,
587+ )
0 commit comments