Skip to content
55 changes: 40 additions & 15 deletions RAT/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""The models module. Contains the pydantic models used by RAT to store project parameters."""

import numpy as np
from pydantic import BaseModel, Field, FieldValidationInfo, field_validator, model_validator
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator

try:
from enum import StrEnum
Expand Down Expand Up @@ -54,7 +54,7 @@ class Types(StrEnum):

class Background(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines the Backgrounds in RAT."""
name: str = Field(default_factory=lambda: 'New Background ' + next(background_number))
name: str = Field(default_factory=lambda: 'New Background ' + next(background_number), min_length=1)
type: Types = Types.Constant
value_1: str = ''
value_2: str = ''
Expand All @@ -65,28 +65,42 @@ class Background(BaseModel, validate_assignment=True, extra='forbid'):

class Contrast(BaseModel, validate_assignment=True, extra='forbid'):
"""Groups together all of the components of the model."""
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number))
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number), min_length=1)
data: str = ''
background: str = ''
nba: str = ''
nbs: str = ''
scalefactor: str = ''
resolution: str = ''
resample: bool = False
model: list[str] = [] # But how many strings? How to deal with this?
model: list[str] = []


class ContrastWithRatio(BaseModel, validate_assignment=True, extra='forbid'):
"""Groups together all of the components of the model including domain terms."""
name: str = Field(default_factory=lambda: 'New Contrast ' + next(contrast_number), min_length=1)
data: str = ''
background: str = ''
nba: str = ''
nbs: str = ''
scalefactor: str = ''
resolution: str = ''
resample: bool = False
domain_ratio: str = ''
model: list[str] = []


class CustomFile(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines the files containing functions to run when using custom models."""
name: str = Field(default_factory=lambda: 'New Custom File ' + next(custom_file_number))
name: str = Field(default_factory=lambda: 'New Custom File ' + next(custom_file_number), min_length=1)
filename: str = ''
language: Languages = Languages.Python
path: str = 'pwd' # Should later expand to find current file path


class Data(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
"""Defines the dataset required for each contrast."""
name: str = Field(default_factory=lambda: 'New Data ' + next(data_number))
name: str = Field(default_factory=lambda: 'New Data ' + next(data_number), min_length=1)
data: np.ndarray[float] = np.empty([0, 3])
data_range: list[float] = []
simulation_range: list[float] = [0.005, 0.7]
Expand All @@ -106,7 +120,7 @@ def check_data_dimension(cls, data: np.ndarray[float]) -> np.ndarray[float]:

@field_validator('data_range', 'simulation_range')
@classmethod
def check_list_elements(cls, limits: list[float], info: FieldValidationInfo) -> list[float]:
def check_list_elements(cls, limits: list[float], info: ValidationInfo) -> list[float]:
"""The data range and simulation range must contain exactly two parameters."""
if len(limits) != 2:
raise ValueError(f'{info.field_name} must contain exactly two values')
Expand All @@ -117,23 +131,34 @@ def check_list_elements(cls, limits: list[float], info: FieldValidationInfo) ->

class DomainContrast(BaseModel, validate_assignment=True, extra='forbid'):
"""Groups together the layers required for each domain."""
name: str = Field(default_factory=lambda: 'New Domain Contrast ' + next(domain_contrast_number))
name: str = Field(default_factory=lambda: 'New Domain Contrast ' + next(domain_contrast_number), min_length=1)
model: list[str] = []


class Layer(BaseModel, validate_assignment=True, extra='forbid'):
class Layer(BaseModel, validate_assignment=True, extra='forbid', populate_by_name=True):
"""Combines parameters into defined layers."""
name: str = Field(default_factory=lambda: 'New Layer ' + next(layer_number))
name: str = Field(default_factory=lambda: 'New Layer ' + next(layer_number), min_length=1)
thickness: str = ''
SLD: str = Field('', validation_alias='SLD_real')
roughness: str = ''
hydration: str = ''
hydrate_with: Hydration = Hydration.BulkOut


class AbsorptionLayer(BaseModel, validate_assignment=True, extra='forbid', populate_by_name=True):
"""Combines parameters into defined layers including absorption terms."""
name: str = Field(default_factory=lambda: 'New Layer ' + next(layer_number), min_length=1)
thickness: str = ''
SLD: str = ''
SLD_real: str = Field('', validation_alias='SLD')
SLD_imaginary: str = ''
roughness: str = ''
hydration: str = ''
hydrate_with: Hydration = Hydration.BulkOut


class Parameter(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines parameters needed to specify the model"""
name: str = Field(default_factory=lambda: 'New Parameter ' + next(parameter_number))
"""Defines parameters needed to specify the model."""
name: str = Field(default_factory=lambda: 'New Parameter ' + next(parameter_number), min_length=1)
min: float = 0.0
value: float = 0.0
max: float = 0.0
Expand All @@ -152,12 +177,12 @@ def check_value_in_range(self) -> 'Parameter':

class ProtectedParameter(Parameter, validate_assignment=True, extra='forbid'):
"""A Parameter with a fixed name."""
name: str = Field(frozen=True)
name: str = Field(frozen=True, min_length=1)


class Resolution(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines Resolutions in RAT."""
name: str = Field(default_factory=lambda: 'New Resolution ' + next(resolution_number))
name: str = Field(default_factory=lambda: 'New Resolution ' + next(resolution_number), min_length=1)
type: Types = Types.Constant
value_1: str = ''
value_2: str = ''
Expand Down
Loading