Skip to content

Commit

Permalink
Merge branch 'master' into carlosg/zarr
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs committed Jul 18, 2023
2 parents b140734 + e9cd804 commit c53d4c4
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 31 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ dependencies = [
"diskcache>=5.5",
"fsspec>=2023.4.0",
"s3fs>=2023.4.0",
"pydantic>=1.10.0",
"pydantic>=2.0.0",
"PyYAML>=6.0",
"pydantic-yaml>=0.11",
"pydantic-yaml>=1.0",
"zarr>=2.14.2",
]

Expand Down
53 changes: 30 additions & 23 deletions src/noisepy/seis/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import List, Optional

import numpy as np
import obspy
from pydantic import Field, root_validator
from pydantic_yaml import YamlModel
from pydantic import BaseModel, ConfigDict, Field
from pydantic.functional_validators import model_validator
from pydantic_yaml import parse_yaml_raw_as, to_yaml_str

INVALID_COORD = -sys.float_info.max

Expand Down Expand Up @@ -99,22 +101,24 @@ class StackMethod(Enum):
ALL = "all"


class ConfigParameters(YamlModel):
class ConfigParameters(BaseModel):
model_config = ConfigDict(validate_default=True)

client_url_key: str = "SCEDC"
start_date: datetime = Field(default=datetime(2019, 1, 1))
end_date: datetime = Field(default=datetime(2019, 1, 2))
samp_freq: float = Field(default=20) # TODO: change this samp_freq for the obspy "sampling_rate"
samp_freq: float = Field(default=20.0) # TODO: change this samp_freq for the obspy "sampling_rate"
cc_len: float = Field(default=1800.0, description="basic unit of data length for fft (sec)")
# download params.
# Targeted region/station information: only needed when down_list is False
lamin: float = Field(default=31, description="Download: minimum latitude")
lamax: float = Field(default=36, description="Download: maximum latitude")
lomin: float = Field(default=-122, description="Download: minimum longitude")
lomax: float = Field(default=-115, description="Download: maximum longitude")
down_list = Field(default=False, description="download stations from a pre-compiled list or not")
net_list = Field(default=["CI"], description="network list")
stations = Field(default=["*"], description="station list")
channels = Field(default=["BHE", "BHN", "BHZ"], description="channel list")
lamin: float = Field(default=31.0, description="Download: minimum latitude")
lamax: float = Field(default=36.0, description="Download: maximum latitude")
lomin: float = Field(default=-122.0, description="Download: minimum longitude")
lomax: float = Field(default=-115.0, description="Download: maximum longitude")
down_list: bool = Field(default=False, description="download stations from a pre-compiled list or not")
net_list: List[str] = Field(default=["CI"], description="network list")
stations: List[str] = Field(default=["*"], description="station list")
channels: List[str] = Field(default=["BHE", "BHN", "BHZ"], description="channel list")
# pre-processing parameters
step: float = Field(default=450.0, description="overlapping between each cc_len (sec)")
freqmin: float = Field(default=0.05)
Expand Down Expand Up @@ -158,7 +162,7 @@ class ConfigParameters(YamlModel):
)
rm_resp: str = Field(default="no", description="select 'no' to not remove response and use 'inv','spectrum',")
rm_resp_out: str = Field(default="VEL", description="output location from response removal")
respdir: str = Field(default=None, description="response directory")
respdir: Optional[str] = Field(default=None, description="response directory")
# some control parameters
acorr_only: bool = Field(default=False, description="only perform auto-correlation")
xcorr_only: bool = Field(default=True, description="only perform cross-correlation or not")
Expand All @@ -168,20 +172,18 @@ class ConfigParameters(YamlModel):
# new rotation para
rotation: bool = Field(default=True, description="rotation from E-N-Z to R-T-Z")
correction: bool = Field(default=False, description="angle correction due to mis-orientation")
correction_csv: str = Field(default=None, description="Path to e.g. meso_angles.csv")
correction_csv: Optional[str] = Field(default=None, description="Path to e.g. meso_angles.csv")
# 'RESP', or 'polozeros' to remove response

class Config:
use_enum_values = True

@property
def dt(self) -> float:
return 1.0 / self.samp_freq

@root_validator
def validate(cld, values) -> dict:
assert values.get("substack_len") % values.get("cc_len") == 0
return values
@model_validator(mode="after")
def validate(cls, m: ConfigParameters) -> ConfigParameters:
if m.substack_len % m.cc_len != 0:
raise ValueError(f"substack_len ({m.substack_len}) must be a multiple of cc_len ({m.cc_len})")
return m

# TODO: Remove once all uses of ConfigParameters have been converted to use strongly typed access
def __getitem__(self, key):
Expand All @@ -191,11 +193,16 @@ def __getitem__(self, key):
return self.__dict__[key]

def save_yaml(self, filename: str):
# yaml_str = yaml.dump(self.__dict__)
yaml_str = self.yaml()
yaml_str = to_yaml_str(self)
with open(filename, "w") as f:
f.write(yaml_str)

def load_yaml(filename: str) -> ConfigParameters:
with open(filename, "r") as f:
yaml_str = f.read()
config = parse_yaml_raw_as(ConfigParameters, yaml_str)
return config


@dataclass
class Channel:
Expand Down
10 changes: 5 additions & 5 deletions src/noisepy/seis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def parse_bool(bstr: str) -> bool:


def get_arg_type(arg_type):
if arg_type == list:
if arg_type == List[str]:
return list_str
if arg_type == datetime:
return dateutil.parser.isoparse
Expand All @@ -72,14 +72,14 @@ def get_arg_type(arg_type):

def add_model(parser: argparse.ArgumentParser, model: ConfigParameters):
# Add config model to the parser
fields = model.__fields__
fields = model.model_fields
for name, field in fields.items():
parser.add_argument(
f"--{name}",
dest=name,
type=get_arg_type(field.type_),
type=get_arg_type(field.annotation),
default=argparse.SUPPRESS,
help=field.field_info.description,
help=field.description,
)


Expand All @@ -97,7 +97,7 @@ def initialize_params(args, data_dir: str) -> ConfigParameters:
config_path = fs_join(data_dir, CONFIG_FILE)
if config_path is not None and os.path.isfile(config_path):
logger.info(f"Loading parameters from {config_path}")
params = ConfigParameters.parse_file(config_path)
params = ConfigParameters.load_yaml(config_path)
else:
logger.warning(f"Config file {config_path if config_path else ''} not found. Using default parameters.")
params = ConfigParameters()
Expand Down
3 changes: 2 additions & 1 deletion tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ def test_orientation(ch, orien):
def test_config_yaml(tmp_path: Path):
file = str(tmp_path.joinpath("config.yaml"))
c1 = ConfigParameters()
ConfigParameters.validate(c1)
# change a couple of properties
c1.step = 800
c1.stack_method = StackMethod.ROBUST
c1.save_yaml(file)
c2 = ConfigParameters.parse_file(file)
c2 = ConfigParameters.load_yaml(file)
assert c1 == c2


Expand Down

0 comments on commit c53d4c4

Please sign in to comment.