Skip to content
Merged
16 changes: 16 additions & 0 deletions RAT/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,28 @@ def __repr__(self):

def __setitem__(self, index: int, set_dict: dict[str, Any]) -> None:
"""Assign the values of an existing object's attributes using a dictionary."""
self._setitem(index, set_dict)

def _setitem(self, index: int, set_dict: dict[str, Any]) -> None:
"""Auxiliary routine of "__setitem__" used to enable wrapping."""
self._validate_name_field(set_dict)
for key, value in set_dict.items():
setattr(self.data[index], key, value)

def __delitem__(self, index: int) -> None:
"""Delete an object from the list by index."""
self._delitem(index)

def _delitem(self, index: int) -> None:
"""Auxiliary routine of "__delitem__" used to enable wrapping."""
del self.data[index]

def __iadd__(self, other: Sequence[object]) -> 'ClassList':
"""Define in-place addition using the "+=" operator."""
return self._iadd(other)

def _iadd(self, other: Sequence[object]) -> 'ClassList':
"""Auxiliary routine of "__iadd__" used to enable wrapping."""
if not hasattr(self, '_class_handle'):
self._class_handle = type(other[0])
self._check_classes(self + other)
Expand Down
85 changes: 80 additions & 5 deletions RAT/project.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""The project module. Defines and stores all the input data required for reflectivity calculations in RAT."""

import copy
import functools
import numpy as np
from pydantic import BaseModel, FieldValidationInfo, field_validator, model_validator
from typing import Any
from pydantic import BaseModel, FieldValidationInfo, field_validator, model_validator, ValidationError
from typing import Any, Callable

from RAT.classlist import ClassList
import RAT.models
from RAT.utils.custom_errors import formatted_pydantic_error

try:
from enum import StrEnum
Expand Down Expand Up @@ -46,6 +49,27 @@ class Geometries(StrEnum):
'contrasts': 'Contrast'
}

values_defined_in = {'backgrounds.value_1': 'background_parameters',
'backgrounds.value_2': 'background_parameters',
'backgrounds.value_3': 'background_parameters',
'backgrounds.value_4': 'background_parameters',
'backgrounds.value_5': 'background_parameters',
'resolutions.value_1': 'resolution_parameters',
'resolutions.value_2': 'resolution_parameters',
'resolutions.value_3': 'resolution_parameters',
'resolutions.value_4': 'resolution_parameters',
'resolutions.value_5': 'resolution_parameters',
'layers.thickness': 'parameters',
'layers.SLD': 'parameters',
'layers.roughness': 'parameters',
'contrasts.data': 'data',
'contrasts.background': 'backgrounds',
'contrasts.nba': 'bulk_in',
'contrasts.nbs': 'bulk_out',
'contrasts.scalefactor': 'scalefactors',
'contrasts.resolution': 'resolutions',
}


class Project(BaseModel, validate_assignment=True, extra='forbid', arbitrary_types_allowed=True):
"""Defines the input data for a reflectivity calculation in RAT.
Expand Down Expand Up @@ -109,7 +133,9 @@ def check_class(cls, value: ClassList, info: FieldValidationInfo) -> ClassList:
return value

def model_post_init(self, __context: Any) -> None:
"""Initialises the class in the ClassLists for empty data fields, and sets protected parameters."""
"""Initialises the class in the ClassLists for empty data fields, sets protected parameters, and wraps
ClassList routines to control revalidation.
"""
for field_name, model in model_in_classlist.items():
field = getattr(self, field_name)
if not hasattr(field, "_class_handle"):
Expand All @@ -119,6 +145,17 @@ def model_post_init(self, __context: Any) -> None:
fit=True, prior_type=RAT.models.Priors.Uniform, mu=0,
sigma=np.inf))

# Wrap ClassList routines - when any of these routines are called, the wrapper will force revalidation of the
# model, handle errors and reset previous values if necessary.
class_lists = ['parameters', 'bulk_in', 'bulk_out', 'qz_shifts', 'scalefactors', 'background_parameters',
'backgrounds', 'resolution_parameters', 'resolutions', 'custom_files', 'data', 'layers',
'contrasts']
methods_to_wrap = ['_setitem', '_delitem', '_iadd', 'append', 'insert', 'pop', 'remove', 'clear', 'extend']
for class_list in class_lists:
attribute = getattr(self, class_list)
for methodName in methods_to_wrap:
setattr(attribute, methodName, self._classlist_wrapper(attribute, getattr(attribute, methodName)))

@model_validator(mode='after')
def cross_check_model_values(self) -> 'Project':
"""Certain model fields should contain values defined elsewhere in the project."""
Expand Down Expand Up @@ -170,5 +207,43 @@ def check_allowed_values(self, attribute: str, field_list: list[str], allowed_va
for field in field_list:
value = getattr(model, field)
if value and value not in allowed_values:
setattr(model, field, '')
raise ValueError(f'The parameter "{value}" has not been defined in the list of allowed values.')
raise ValueError(f'The value "{value}" in the "{field}" field of "{attribute}" must be defined in '
f'"{values_defined_in[attribute + "." + field]}".')

def _classlist_wrapper(self, class_list: 'ClassList', func: Callable):
"""Defines the function used to wrap around ClassList routines to force revalidation.

Parameters
----------
class_list : ClassList
The ClassList defined in the "Project" model that is being modified.
func : Callable
The routine being wrapped.

Returns
-------
wrapped_func : Callable
The wrapped routine.
"""
@functools.wraps(func)
def wrapped_func(*args, **kwargs):
"""Run the given function and then revalidate the "Project" model. If any exception is raised, restore
the previous state of the given ClassList and report details of the exception.
"""
previous_state = copy.deepcopy(getattr(class_list, 'data'))
return_value = None
try:
return_value = func(*args, **kwargs)
Project.model_validate(self)
except ValidationError as e:
setattr(class_list, 'data', previous_state)
error_string = formatted_pydantic_error(e)
# Use ANSI escape sequences to print error text in red
print('\033[31m' + error_string + '\033[0m')
except (TypeError, ValueError):
setattr(class_list, 'data', previous_state)
raise
finally:
del previous_state
return return_value
return wrapped_func
26 changes: 26 additions & 0 deletions RAT/utils/custom_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Defines routines for custom error handling in RAT."""

from pydantic import ValidationError


def formatted_pydantic_error(error: ValidationError) -> str:
"""Write a custom string format for pydantic validation errors.

Parameters
----------
error : pydantic.ValidationError
A ValidationError produced by a pydantic model

Returns
-------
error_str : str
A string giving details of the ValidationError in a custom format.
"""
num_errors = error.error_count()
error_str = f'{num_errors} validation error{"s"[:num_errors!=1]} for {error.title}'
for this_error in error.errors():
error_str += '\n'
if this_error['loc']:
error_str += ' '.join(this_error['loc']) + '\n'
error_str += ' ' + this_error['msg']
return error_str
14 changes: 14 additions & 0 deletions tests/test_classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,20 @@ def test_setitem_same_name_field(two_name_class_list: 'ClassList', new_values: d
two_name_class_list[0] = new_values


def test_delitem(two_name_class_list: 'ClassList', one_name_class_list: 'ClassList') -> None:
"""We should be able to delete elements from a ClassList with the del operator."""
class_list = two_name_class_list
del class_list[1]
assert class_list == one_name_class_list


def test_delitem_not_present(two_name_class_list: 'ClassList') -> None:
"""If we use the del operator to delete an index out of range, we should raise an IndexError."""
class_list = two_name_class_list
with pytest.raises(IndexError, match=re.escape("list assignment index out of range")):
del class_list[2]


@pytest.mark.parametrize("added_list", [
(ClassList(InputAttributes(name='Eve'))),
([InputAttributes(name='Eve')]),
Expand Down
20 changes: 20 additions & 0 deletions tests/test_custom_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Test the utils.custom_errors module."""

from pydantic import create_model, ValidationError
import pytest

import RAT.utils.custom_errors


def test_formatted_pydantic_error() -> None:
"""When a pytest ValidationError is raised we should be able to take it and construct a formatted string."""

# Create a custom pydantic model for the test
TestModel = create_model('TestModel', int_field=(int, 1), str_field=(str, 'a'))

with pytest.raises(ValidationError) as exc_info:
TestModel(int_field='string', str_field=5)

error_str = RAT.utils.custom_errors.formatted_pydantic_error(exc_info.value)
assert error_str == ('2 validation errors for TestModel\nint_field\n Input should be a valid integer, unable to '
'parse string as an integer\nstr_field\n Input should be a valid string')
Loading