STIR  6.2.0
Public Types | Public Member Functions | Protected Member Functions | Protected Attributes | List of all members
stir::GeneralisedPriorTests Class Reference

Test class for QuadraticPrior, RelativeDifferencePrior, CudaRelativeDifferencePrior and LogcoshPrior. More...

Inheritance diagram for stir::GeneralisedPriorTests:
Inheritance graph
[legend]

Public Types

typedef DiscretisedDensity< 3, float > target_type
 
- Public Types inherited from stir::ObjectiveFunctionTests< GeneralisedPrior< DiscretisedDensity< 3, float > >, DiscretisedDensity< 3, float > >
typedef GeneralisedPrior< DiscretisedDensity< 3, float > > objective_function_type
 
typedef DiscretisedDensity< 3, float > target_type
 

Public Member Functions

 GeneralisedPriorTests (char const *density_filename=nullptr)
 Constructor that can take some input data to run the test with. More...
 
void construct_input_data (shared_ptr< target_type > &density_sptr, shared_ptr< target_type > &kappa_sptr)
 
void configure_prior_tests (bool gradient, bool Hessian_convexity, bool Hessian_numerical)
 Set methods that control which tests are run.
 
- Public Member Functions inherited from stir::ObjectiveFunctionTests< GeneralisedPrior< DiscretisedDensity< 3, float > >, DiscretisedDensity< 3, float > >
virtual Succeeded test_gradient (const std::string &test_name, GeneralisedPrior< DiscretisedDensity< 3, float > > &objective_function, DiscretisedDensity< 3, float > &target, const float eps, const bool full_gradient=true)
 Test the gradient of the objective function by comparing to the numerical gradient via perturbation. More...
 
virtual Succeeded test_Hessian (const std::string &test_name, GeneralisedPrior< DiscretisedDensity< 3, float > > &objective_function, const DiscretisedDensity< 3, float > &target, const float eps)
 Test the accumulate_Hessian_times_input of the objective function by comparing to the numerical result via perturbation. More...
 
virtual Succeeded test_Hessian_concavity (const std::string &test_name, GeneralisedPrior< DiscretisedDensity< 3, float > > &objective_function, const DiscretisedDensity< 3, float > &target, const float mult_factor=1.F)
 Test the Hessian of the objective function by testing the (mult_factor * x^T Hx > 0) condition. More...
 
- Public Member Functions inherited from stir::RunTests
 RunTests (const double tolerance=1E-4)
 Default constructor.
 
virtual ~RunTests ()
 Destructor, outputs a diagnostic message.
 
virtual void run_tests ()=0
 Function (to be overloaded) which does the actual tests. More...
 
bool is_everything_ok () const
 Returns if all checks were fine upto now.
 
int main_return_value () const
 Handy return value for a main() function. More...
 
void set_tolerance (const double tolerance)
 Set value used in floating point comparisons (see check_* functions)
 
double get_tolerance () const
 Get value used in floating point comparisons (see check_* functions)
 
bool check (const bool, const std::string &str="")
 Tests if true, str can be used to tell what you are testing. More...
 
template<class T1 , class T2 >
bool check_if_less (T1 a, T2 b, const std::string &str="")
 check if a<b
 
bool check_if_equal (const std::string &a, const std::string &b, const std::string &str="")
 
bool check_if_equal (const double a, const double b, const std::string &str="")
 
bool check_if_equal (const short a, const short b, const std::string &str="")
 
bool check_if_equal (const unsigned short a, const unsigned short b, const std::string &str="")
 
bool check_if_equal (const int a, const int b, const std::string &str="")
 
bool check_if_equal (const unsigned int a, const unsigned int b, const std::string &str="")
 
bool check_if_equal (const long a, const long b, const std::string &str="")
 
bool check_if_equal (const unsigned long a, const unsigned long b, const std::string &str="")
 
bool check_if_equal (const Bin &a, const Bin &b, const std::string &str="")
 
template<class T >
bool check_if_equal (const DetectionPosition< T > &a, const DetectionPosition< T > &b, const std::string &str="")
 
template<class T >
bool check_if_equal (const std::complex< T > a, const std::complex< T > b, const std::string &str="")
 check equality by calling check_if_equal on real and imaginary parts
 
template<class T >
bool check_if_equal (const VectorWithOffset< T > &t1, const VectorWithOffset< T > &t2, const std::string &str="")
 check equality by comparing ranges and calling check_if_equal on all elements
 
template<class T >
bool check_if_equal (const std::vector< T > &t1, const std::vector< T > &t2, const std::string &str="")
 check equality by comparing size and calling check_if_equal on all elements
 
bool check_if_equal (const ProjDataInMemory &t1, const ProjDataInMemory &t2, const std::string &str="")
 
template<int n>
bool check_if_equal (const IndexRange< n > &t1, const IndexRange< n > &t2, const std::string &str="")
 
template<int num_dimensions, class coordT >
bool check_if_equal (const BasicCoordinate< num_dimensions, coordT > &a, const BasicCoordinate< num_dimensions, coordT > &b, const std::string &str="")
 check equality by comparing norm(a-b) with tolerance
 
bool check_if_zero (const double a, const std::string &str="")
 
bool check_if_zero (const short a, const std::string &str="")
 
bool check_if_zero (const unsigned short a, const std::string &str="")
 
bool check_if_zero (const int a, const std::string &str="")
 
bool check_if_zero (const unsigned int a, const std::string &str="")
 
bool check_if_zero (const long a, const std::string &str="")
 
bool check_if_zero (const unsigned long a, const std::string &str="")
 
template<class T >
bool check_if_zero (const VectorWithOffset< T > &t, const std::string &str="")
 use check_if_zero on all elements
 
template<int num_dimensions, class coordT >
bool check_if_zero (const BasicCoordinate< num_dimensions, coordT > &a, const std::string &str="")
 compare norm with tolerance
 

Protected Member Functions

virtual void run_tests_for_objective_function (const std::string &test_name, GeneralisedPrior< target_type > &objective_function, const shared_ptr< target_type > &target_sptr)
 run the test More...
 
virtual void test_Hessian_convexity (const std::string &test_name, GeneralisedPrior< GeneralisedPriorTests::target_type > &objective_function, const shared_ptr< GeneralisedPriorTests::target_type > &target_sptr)
 Test various configurations of the Hessian of the prior via accumulate_Hessian_times_input() for convexity. More...
 
virtual void test_Hessian_against_numerical (const std::string &test_name, GeneralisedPrior< GeneralisedPriorTests::target_type > &objective_function, const shared_ptr< GeneralisedPriorTests::target_type > &target_sptr)
 Tests the compute_Hessian method implemented into convex priors. More...
 
virtual bool test_Hessian_convexity_configuration (const std::string &test_name, GeneralisedPrior< GeneralisedPriorTests::target_type > &objective_function, const shared_ptr< GeneralisedPriorTests::target_type > &target_sptr, float beta, float input_multiplication, float input_addition, float current_image_multiplication, float current_image_addition)
 Hessian test for a particular configuration of the Hessian concave condition. More...
 
- Protected Member Functions inherited from stir::ObjectiveFunctionTests< GeneralisedPrior< DiscretisedDensity< 3, float > >, DiscretisedDensity< 3, float > >
virtual shared_ptr< const DiscretisedDensity< 3, float > > construct_increment (const DiscretisedDensity< 3, float > &target, const float eps) const
 Construct small increment for target. More...
 
- Protected Member Functions inherited from stir::RunTests
template<class T >
bool check_if_equal_generic (const T &a, const T &b, const std::string &str)
 function that is called by some check_if_equal implementations. It just uses operator!=
 
template<class T >
bool check_if_zero_generic (T a, const std::string &str)
 function that is called by some check_if_zero implementations. It just uses operator!=
 

Protected Attributes

char const * density_filename
 
shared_ptr< GeneralisedPrior< target_type > > objective_function_sptr
 
bool do_test_gradient = false
 Variables to control which tests are run, see the set methods.
 
bool do_test_Hessian_convexity = false
 
bool do_test_Hessian_against_numerical = false
 
- Protected Attributes inherited from stir::RunTests
double tolerance
 tolerance for comparisons with real values
 
bool everything_ok
 variable storing current status
 

Detailed Description

Test class for QuadraticPrior, RelativeDifferencePrior, CudaRelativeDifferencePrior and LogcoshPrior.

This test compares the result of GeneralisedPrior::compute_gradient() with a numerical gradient computed by using the GeneralisedPrior::compute_value() function. Additionally, the Hessian's convexity is tested, via GeneralisedPrior::accumulate_Hessian_times_input(), by evaluating the x^T Hx > 0 constraint.

Constructor & Destructor Documentation

◆ GeneralisedPriorTests()

stir::GeneralisedPriorTests::GeneralisedPriorTests ( char const *  density_filename = nullptr)
explicit

Constructor that can take some input data to run the test with.

This makes it possible to run the test with your own data. However, beware that it is very easy to set up a very long computation.

Todo:
it would be better to parse an objective function. That would allow us to set all parameters from the command line.

Member Function Documentation

◆ run_tests_for_objective_function()

void stir::GeneralisedPriorTests::run_tests_for_objective_function ( const std::string &  test_name,
GeneralisedPrior< target_type > &  objective_function,
const shared_ptr< target_type > &  target_sptr 
)
protectedvirtual

◆ test_Hessian_convexity()

void stir::GeneralisedPriorTests::test_Hessian_convexity ( const std::string &  test_name,
GeneralisedPrior< GeneralisedPriorTests::target_type > &  objective_function,
const shared_ptr< GeneralisedPriorTests::target_type > &  target_sptr 
)
protectedvirtual

Test various configurations of the Hessian of the prior via accumulate_Hessian_times_input() for convexity.

Tests the convexity condition:

\[ x^T \cdot H_{\lambda}x >= 0 \]

for all non-negative x and non-zero (Relative Difference Prior conditions). This function constructs an array of configurations to test this condition and calls test_Hessian_convexity_configuration().

Construct configurations

Reset beta to original value

References stir::GeneralisedPrior< DataT >::is_convex(), stir::GeneralisedPrior< DataT >::set_penalisation_factor(), and test_Hessian_convexity_configuration().

Referenced by run_tests_for_objective_function().

◆ test_Hessian_against_numerical()

void stir::GeneralisedPriorTests::test_Hessian_against_numerical ( const std::string &  test_name,
GeneralisedPrior< GeneralisedPriorTests::target_type > &  objective_function,
const shared_ptr< GeneralisedPriorTests::target_type > &  target_sptr 
)
protectedvirtual

Tests the compute_Hessian method implemented into convex priors.

Performs a perturbation response using compute_gradient to determine if the compute_Hessian (for a single densel) is within tolerance.

Setup

References stir::RunTests::check_if_less(), stir::GeneralisedPrior< DataT >::compute_gradient(), stir::GeneralisedPrior< DataT >::compute_Hessian(), stir::RunTests::everything_ok, stir::info(), and stir::GeneralisedPrior< DataT >::is_convex().

Referenced by run_tests_for_objective_function().

◆ test_Hessian_convexity_configuration()

bool stir::GeneralisedPriorTests::test_Hessian_convexity_configuration ( const std::string &  test_name,
GeneralisedPrior< GeneralisedPriorTests::target_type > &  objective_function,
const shared_ptr< GeneralisedPriorTests::target_type > &  target_sptr,
float  beta,
float  input_multiplication,
float  input_addition,
float  current_image_multiplication,
float  current_image_addition 
)
protectedvirtual

The documentation for this class was generated from the following file: