-
Notifications
You must be signed in to change notification settings - Fork 4
Add Models and Project for RAT API #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
c277ed0
Adds "models.py" with initial draft of pydantic models for API classes
DrPaulSharp 98851d5
Adds validators for pydantic models
DrPaulSharp 5ffb29b
Fixes parameter names in background model
DrPaulSharp 75be6f6
Adds new model ProtectedParameter
DrPaulSharp 4e66327
Adds "project.py" with initial draft of the high level "Project" class
DrPaulSharp b1f4574
Adds "model_post_init" routine for the "project" model
DrPaulSharp a05b35a
Adds "__repr__" routine for the "project" model
DrPaulSharp 2ede5a1
Adds code to work with updated ClassList
DrPaulSharp 760e686
Moves validators to cross-check project fields from "models.py" to "p…
DrPaulSharp d216c8c
Replaces annotated validators with single field validator in "project…
DrPaulSharp 5e98ba0
Add contrasts to cross-checking model validator in "project.py"
DrPaulSharp 035511c
Changes data model to accept numpy array
DrPaulSharp e39a252
Adds docs and modifies the Project class's "model_post_init" to ensur…
DrPaulSharp a5aeb48
Adds tests "test_models.py"
DrPaulSharp 19051e6
Removes unused routine "get_all_names" from "project.py"
DrPaulSharp 2ca3aa3
Adds tests "test_project.py"
DrPaulSharp 4b69089
Tidies up model and project classes and tests
DrPaulSharp 2257bbb
Adds code to fix enums for all python versions
DrPaulSharp c6ec305
Adds code to stop "test_repr" in "test_project.py" writing to console
DrPaulSharp 55e07a1
Addresses review comments
DrPaulSharp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| """The models module. Contains the pydantic models used by RAT to store project parameters.""" | ||
|
|
||
| import numpy as np | ||
| from pydantic import BaseModel, Field, FieldValidationInfo, field_validator, model_validator | ||
|
|
||
| try: | ||
| from enum import StrEnum | ||
| except ImportError: | ||
| from strenum import StrEnum | ||
|
|
||
|
|
||
| def int_sequence(): | ||
| """Iterate through integers for use as model counters.""" | ||
| num = 1 | ||
| while True: | ||
| yield str(num) | ||
| num += 1 | ||
|
|
||
|
|
||
| # Create a counter for each model | ||
| background_number = int_sequence() | ||
| contrast_number = int_sequence() | ||
| custom_file_number = int_sequence() | ||
| data_number = int_sequence() | ||
| domain_contrast_number = int_sequence() | ||
| layer_number = int_sequence() | ||
| parameter_number = int_sequence() | ||
| resolution_number = int_sequence() | ||
|
|
||
|
|
||
| class Hydration(StrEnum): | ||
| None_ = 'none' | ||
| BulkIn = 'bulk in' | ||
| BulkOut = 'bulk out' | ||
| Oil = 'oil' | ||
|
|
||
|
|
||
| class Languages(StrEnum): | ||
| Python = 'python' | ||
| Matlab = 'matlab' | ||
|
|
||
|
|
||
| class Priors(StrEnum): | ||
| Uniform = 'uniform' | ||
| Gaussian = 'gaussian' | ||
| Jeffreys = 'jeffreys' | ||
|
|
||
|
|
||
| class Types(StrEnum): | ||
| Constant = 'constant' | ||
| Data = 'data' | ||
| Function = 'function' | ||
|
|
||
|
|
||
| class Background(BaseModel, validate_assignment=True, extra='forbid'): | ||
| """Defines the Backgrounds in RAT.""" | ||
| name: str = Field(default_factory=lambda: 'New Background ' + next(background_number)) | ||
| type: Types = Types.Constant | ||
| value_1: str = '' | ||
| value_2: str = '' | ||
| value_3: str = '' | ||
| value_4: str = '' | ||
| value_5: str = '' | ||
|
|
||
|
|
||
| class Contrast(BaseModel, validate_assignment=True, extra='forbid'): | ||
| """Groups together all of the components of the model.""" | ||
| name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number)) | ||
| data: str = '' | ||
| background: str = '' | ||
| nba: str = '' | ||
| nbs: str = '' | ||
| scalefactor: str = '' | ||
| resolution: str = '' | ||
| resample: bool = False | ||
| model: list[str] = [] # But how many strings? How to deal with this? | ||
|
|
||
|
|
||
| class CustomFile(BaseModel, validate_assignment=True, extra='forbid'): | ||
| """Defines the files containing functions to run when using custom models.""" | ||
| name: str = Field(default_factory=lambda: 'New Custom File ' + next(custom_file_number)) | ||
| filename: str = '' | ||
| language: Languages = Languages.Python | ||
| path: str = 'pwd' # Should later expand to find current file path | ||
|
|
||
|
|
||
| class Data(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True): | ||
| """Defines the dataset required for each contrast.""" | ||
| name: str = Field(default_factory=lambda: 'New Data ' + next(data_number)) | ||
| data: np.ndarray[float] = np.empty([0, 3]) | ||
| data_range: list[float] = [] | ||
| simulation_range: list[float] = [0.005, 0.7] | ||
|
|
||
| @field_validator('data') | ||
| @classmethod | ||
| def check_data_dimension(cls, data: np.ndarray[float]) -> np.ndarray[float]: | ||
| """The data must be a two-dimensional array containing at least three columns.""" | ||
| try: | ||
| data.shape[1] | ||
| except IndexError: | ||
| raise ValueError('"data" must have at least two dimensions') | ||
| else: | ||
| if data.shape[1] < 3: | ||
| raise ValueError('"data" must have at least three columns') | ||
| return data | ||
DrPaulSharp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @field_validator('data_range', 'simulation_range') | ||
| @classmethod | ||
| def check_list_elements(cls, limits: list[float], info: FieldValidationInfo) -> list[float]: | ||
| """The data range and simulation range must contain exactly two parameters.""" | ||
| if len(limits) != 2: | ||
| raise ValueError(f'{info.field_name} must contain exactly two values') | ||
| return limits | ||
|
|
||
| # Also need model validators for data range compared to data etc -- need more details. | ||
|
|
||
|
|
||
| class DomainContrast(BaseModel, validate_assignment=True, extra='forbid'): | ||
| """Groups together the layers required for each domain.""" | ||
| name: str = Field(default_factory=lambda: 'New Domain Contrast ' + next(domain_contrast_number)) | ||
| model: list[str] = [] | ||
|
|
||
|
|
||
| class Layer(BaseModel, validate_assignment=True, extra='forbid'): | ||
| """Combines parameters into defined layers.""" | ||
| name: str = Field(default_factory=lambda: 'New Layer ' + next(layer_number)) | ||
| thickness: str = '' | ||
| SLD: str = '' | ||
| roughness: str = '' | ||
| hydration: str = '' | ||
| hydrate_with: Hydration = Hydration.BulkOut | ||
|
|
||
|
|
||
| class Parameter(BaseModel, validate_assignment=True, extra='forbid'): | ||
| """Defines parameters needed to specify the model""" | ||
| name: str = Field(default_factory=lambda: 'New Parameter ' + next(parameter_number)) | ||
| min: float = 0.0 | ||
| value: float = 0.0 | ||
| max: float = 0.0 | ||
| fit: bool = False | ||
| prior_type: Priors = Priors.Uniform | ||
| mu: float = 0.0 | ||
| sigma: float = np.inf | ||
|
|
||
| @model_validator(mode='after') | ||
| def check_value_in_range(self) -> 'Parameter': | ||
| """The value of a parameter must lie within its defined bounds.""" | ||
| if self.value < self.min or self.value > self.max: | ||
| raise ValueError(f'value {self.value} is not within the defined range: {self.min} <= value <= {self.max}') | ||
| return self | ||
|
|
||
|
|
||
| class ProtectedParameter(Parameter, validate_assignment=True, extra='forbid'): | ||
| """A Parameter with a fixed name.""" | ||
| name: str = Field(frozen=True) | ||
|
|
||
|
|
||
| class Resolution(BaseModel, validate_assignment=True, extra='forbid'): | ||
| """Defines Resolutions in RAT.""" | ||
| name: str = Field(default_factory=lambda: 'New Resolution ' + next(resolution_number)) | ||
| type: Types = Types.Constant | ||
| value_1: str = '' | ||
| value_2: str = '' | ||
| value_3: str = '' | ||
| value_4: str = '' | ||
| value_5: str = '' | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,174 @@ | ||
| """The project module. Defines and stores all the input data required for reflectivity calculations in RAT.""" | ||
|
|
||
| import numpy as np | ||
| from pydantic import BaseModel, FieldValidationInfo, field_validator, model_validator | ||
| from typing import Any | ||
|
|
||
| from RAT.classlist import ClassList | ||
| import RAT.models | ||
|
|
||
| try: | ||
| from enum import StrEnum | ||
| except ImportError: | ||
| from strenum import StrEnum | ||
|
|
||
|
|
||
| class CalcTypes(StrEnum): | ||
| NonPolarised = 'non polarised' | ||
| Domains = 'domains' | ||
| OilWater = 'oil water' | ||
|
|
||
|
|
||
| class ModelTypes(StrEnum): | ||
| CustomLayers = 'custom layers' | ||
| CustomXY = 'custom xy' | ||
| StandardLayers = 'standard layers' | ||
|
|
||
|
|
||
| class Geometries(StrEnum): | ||
| AirSubstrate = 'air/substrate' | ||
| SubstrateLiquid = 'substrate/liquid' | ||
|
|
||
|
|
||
| # Map project fields to pydantic models | ||
| model_in_classlist = {'parameters': 'Parameter', | ||
| 'bulk_in': 'Parameter', | ||
| 'bulk_out': 'Parameter', | ||
| 'qz_shifts': 'Parameter', | ||
| 'scalefactors': 'Parameter', | ||
| 'background_parameters': 'Parameter', | ||
| 'resolution_parameters': 'Parameter', | ||
| 'backgrounds': 'Background', | ||
| 'resolutions': 'Resolution', | ||
| 'custom_files': 'CustomFile', | ||
| 'data': 'Data', | ||
| 'layers': 'Layer', | ||
| 'contrasts': 'Contrast' | ||
| } | ||
|
|
||
|
|
||
| class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True): | ||
| """Defines the input data for a reflectivity calculation in RAT. | ||
|
|
||
| This class combines the data defined in each of the pydantic models included in "models.py" into the full set of | ||
| inputs required for a reflectivity calculation. | ||
| """ | ||
| name: str = '' | ||
| calc_type: CalcTypes = CalcTypes.NonPolarised | ||
| model: ModelTypes = ModelTypes.StandardLayers | ||
| geometry: Geometries = Geometries.AirSubstrate | ||
| absorption: bool = False | ||
|
|
||
| parameters: ClassList = ClassList() | ||
|
|
||
| bulk_in: ClassList = ClassList(RAT.models.Parameter(name='SLD Air', min=0, value=0, max=0, fit=False, | ||
| prior_type=RAT.models.Priors.Uniform, mu=0, sigma=np.inf)) | ||
|
|
||
| bulk_out: ClassList = ClassList(RAT.models.Parameter(name='SLD D2O', min=6.2e-6, value=6.35e-6, max=6.35e-6, | ||
| fit=False, prior_type=RAT.models.Priors.Uniform, mu=0, | ||
| sigma=np.inf)) | ||
|
|
||
| qz_shifts: ClassList = ClassList(RAT.models.Parameter(name='Qz shift 1', min=-1e-4, value=0, max=1e-4, fit=False, | ||
| prior_type=RAT.models.Priors.Uniform, mu=0, sigma=np.inf)) | ||
|
|
||
| scalefactors: ClassList = ClassList(RAT.models.Parameter(name='Scalefactor 1', min=0.02, value=0.23, max=0.25, | ||
| fit=False, prior_type=RAT.models.Priors.Uniform, mu=0, | ||
| sigma=np.inf)) | ||
|
|
||
| background_parameters: ClassList = ClassList(RAT.models.Parameter(name='Background Param 1', min=1e-7, value=1e-6, | ||
| max=1e-5, fit=False, | ||
| prior_type=RAT.models.Priors.Uniform, mu=0, | ||
| sigma=np.inf)) | ||
|
|
||
| backgrounds: ClassList = ClassList(RAT.models.Background(name='Background 1', type=RAT.models.Types.Constant.value, | ||
| value_1='Background Param 1')) | ||
|
|
||
| resolution_parameters: ClassList = ClassList(RAT.models.Parameter(name='Resolution Param 1', min=0.01, value=0.03, | ||
| max=0.05, fit=False, | ||
| prior_type=RAT.models.Priors.Uniform, mu=0, | ||
| sigma=np.inf)) | ||
|
|
||
| resolutions: ClassList = ClassList(RAT.models.Resolution(name='Resolution 1', type=RAT.models.Types.Constant.value, | ||
| value_1='Resolution Param 1')) | ||
|
|
||
| custom_files: ClassList = ClassList() | ||
| data: ClassList = ClassList(RAT.models.Data(name='Simulation')) | ||
| layers: ClassList = ClassList() | ||
| contrasts: ClassList = ClassList() | ||
|
|
||
| @field_validator('parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters', | ||
| 'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers', | ||
| 'contrasts') | ||
| @classmethod | ||
| def check_class(cls, value: ClassList, info: FieldValidationInfo) -> ClassList: | ||
| """Each of the data fields should be a ClassList of the appropriate model.""" | ||
| model_name = model_in_classlist[info.field_name] | ||
| model = getattr(RAT.models, model_name) | ||
| assert all(isinstance(element, model) for element in value), \ | ||
| f'"{info.field_name}" ClassList contains objects other than "{model_name}"' | ||
| return value | ||
|
|
||
| def model_post_init(self, __context: Any) -> None: | ||
| """Initialises the class in the ClassLists for empty data fields, and sets protected parameters.""" | ||
| for field_name, model in model_in_classlist.items(): | ||
| field = getattr(self, field_name) | ||
| if not hasattr(field, "_class_handle"): | ||
| setattr(field, "_class_handle", getattr(RAT.models, model)) | ||
|
|
||
| self.parameters.insert(0, RAT.models.ProtectedParameter(name='Substrate Roughness', min=1, value=3, max=5, | ||
| fit=True, prior_type=RAT.models.Priors.Uniform, mu=0, | ||
| sigma=np.inf)) | ||
|
|
||
| @model_validator(mode='after') | ||
| def cross_check_model_values(self) -> 'Project': | ||
| """Certain model fields should contain values defined elsewhere in the project.""" | ||
| value_fields = ['value_1', 'value_2', 'value_3', 'value_4', 'value_5'] | ||
| self.check_allowed_values('backgrounds', value_fields, self.background_parameters.get_names()) | ||
| self.check_allowed_values('resolutions', value_fields, self.resolution_parameters.get_names()) | ||
| self.check_allowed_values('layers', ['thickness', 'SLD', 'roughness'], self.parameters.get_names()) | ||
|
|
||
| self.check_allowed_values('contrasts', ['data'], self.data.get_names()) | ||
| self.check_allowed_values('contrasts', ['background'], self.backgrounds.get_names()) | ||
| self.check_allowed_values('contrasts', ['nba'], self.bulk_in.get_names()) | ||
| self.check_allowed_values('contrasts', ['nbs'], self.bulk_out.get_names()) | ||
| self.check_allowed_values('contrasts', ['scalefactor'], self.scalefactors.get_names()) | ||
| self.check_allowed_values('contrasts', ['resolution'], self.resolutions.get_names()) | ||
| return self | ||
|
|
||
| def __repr__(self): | ||
| output = '' | ||
| for key, value in self.__dict__.items(): | ||
| if value: | ||
| output += f'{key.replace("_", " ").title() + ": " :-<100}\n\n' | ||
| try: | ||
| value.value # For enums | ||
| except AttributeError: | ||
| output += repr(value) + '\n\n' | ||
| else: | ||
| output += value.value + '\n\n' | ||
| return output | ||
|
|
||
| def check_allowed_values(self, attribute: str, field_list: list[str], allowed_values: list[str]) -> None: | ||
| """Check the values of the given fields in the given model are in the supplied list of allowed values. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| attribute : str | ||
| The attribute of Project being validated. | ||
| field_list : list [str] | ||
| The fields of the attribute to be checked for valid values. | ||
| allowed_values : list [str] | ||
| The list of allowed values for the fields given in field_list. | ||
|
|
||
| Raises | ||
| ------ | ||
| ValueError | ||
| Raised if any field in field_list has a value not specified in allowed_values. | ||
| """ | ||
| class_list = getattr(self, attribute) | ||
| for model in class_list: | ||
| for field in field_list: | ||
| value = getattr(model, field) | ||
| if value and value not in allowed_values: | ||
| setattr(model, field, '') | ||
| raise ValueError(f'The parameter "{value}" has not been defined in the list of allowed values.') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| numpy >= 1.20 | ||
| pydantic >= 2.0.3 | ||
| pytest >= 7.4.0 | ||
| pytest-cov >= 4.1.0 | ||
| StrEnum >= 0.4.15 | ||
| tabulate >= 0.9.0 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.