STIR  6.2.0
Public Member Functions | Protected Member Functions | Protected Attributes | List of all members

A base class for 'generalised' priors, i.e. priors for which at least a 'gradient' is defined. More...

#include "stir/recon_buildblock/GeneralisedPrior.h"

Inheritance diagram for stir::GeneralisedPrior< DataT >:
Inheritance graph
[legend]

Public Member Functions

virtual double compute_value (const DataT &current_estimate)=0
 compute the value of the function More...
 
virtual void compute_gradient (DataT &prior_gradient, const DataT &current_estimate)=0
 This should compute the gradient of the log of the prior function at the current_estimate. More...
 
virtual void compute_Hessian (DataT &prior_Hessian_for_single_densel, const BasicCoordinate< 3, int > &coords, const DataT &current_image_estimate) const
 This computes a single row of the Hessian. More...
 
virtual void add_multiplication_with_approximate_Hessian (DataT &output, const DataT &input) const
 This should compute the multiplication of the Hessian with a vector and add it to output. More...
 
virtual void accumulate_Hessian_times_input (DataT &output, const DataT &current_estimate, const DataT &input) const
 This should compute the multiplication of the Hessian with a vector and add it to output. More...
 
float get_penalisation_factor () const
 
void set_penalisation_factor (float new_penalisation_factor)
 
virtual Succeeded set_up (shared_ptr< const DataT > const &target_sptr)
 Has to be called before using this object.
 
virtual bool is_convex () const =0
 Indicates if the prior is a smooth convex function. More...
 
- Public Member Functions inherited from stir::RegisteredObjectBase
virtual std::string get_registered_name () const =0
 Returns the name of the type of the object. More...
 
- Public Member Functions inherited from stir::ParsingObject
 ParsingObject (const ParsingObject &)
 
ParsingObjectoperator= (const ParsingObject &)
 
void ask_parameters ()
 
virtual std::string parameter_info ()
 
bool parse (std::istream &f)
 
bool parse (const char *const filename)
 

Protected Member Functions

void set_defaults () override
 sets value for penalisation factor More...
 
void initialise_keymap () override
 sets key for penalisation factor More...
 
virtual void check (DataT const &current_estimate) const
 Check that the prior is ready to be used.
 
- Protected Member Functions inherited from stir::ParsingObject
virtual bool post_processing ()
 This will be called at the end of the parsing. More...
 
virtual void set_key_values ()
 This will be called before parsing or parameter_info is called. More...
 

Protected Attributes

float penalisation_factor
 
bool _already_set_up
 
- Protected Attributes inherited from stir::ParsingObject
KeyParser parser
 

Additional Inherited Members

- Static Public Member Functions inherited from stir::RegisteredObject< GeneralisedPrior< DataT > >
static GeneralisedPrior< DataT > * read_registered_object (std::istream *in, const std::string &registered_name)
 Construct a new object (of a type derived from Root, its actual type determined by the registered_name parameter) by parsing the istream. More...
 
static GeneralisedPrior< DataT > * ask_type_and_parameters ()
 ask the user for the type, and then calls read_registered_object(0, type) More...
 
static void list_registered_names (std::ostream &stream)
 List all possible registered names to the stream. More...
 
- Protected Types inherited from stir::RegisteredObject< GeneralisedPrior< DataT > >
typedef GeneralisedPrior< DataT > *(* RootFactory) (std::istream *)
 The type of a root factory is a function, taking an istream* as argument, and returning a Root*.
 
typedef FactoryRegistry< std::string, RootFactory, interfile_lessRegistryType
 The type of the registry.
 
- Static Protected Member Functions inherited from stir::RegisteredObject< GeneralisedPrior< DataT > >
static RegistryTyperegistry ()
 Static function returning the registry. More...
 

Detailed Description

template<typename DataT>
class stir::GeneralisedPrior< DataT >

A base class for 'generalised' priors, i.e. priors for which at least a 'gradient' is defined.

This class exists to accomodate FilterRootPrior. Otherwise we could just live with Prior as a base class.

Member Function Documentation

◆ compute_value()

template<typename DataT>
virtual double stir::GeneralisedPrior< DataT >::compute_value ( const DataT &  current_estimate)
pure virtual

◆ compute_gradient()

template<typename DataT>
virtual void stir::GeneralisedPrior< DataT >::compute_gradient ( DataT &  prior_gradient,
const DataT &  current_estimate 
)
pure virtual

This should compute the gradient of the log of the prior function at the current_estimate.

The gradient is already multiplied with the penalisation_factor.

Warning
The derived class should overwrite any data in prior_gradient.

Implemented in stir::PLSPrior< elemT >, stir::RelativeDifferencePrior< elemT >, stir::LogcoshPrior< elemT >, stir::ParametricQuadraticPrior< TargetT >, stir::QuadraticPrior< elemT >, stir::QuadraticPrior< float >, stir::FilterRootPrior< DataT >, and stir::CudaRelativeDifferencePrior< elemT >.

Referenced by stir::GeneralisedPriorTests::test_Hessian_against_numerical().

◆ compute_Hessian()

template<typename TargetT>
void stir::GeneralisedPrior< TargetT >::compute_Hessian ( TargetT &  prior_Hessian_for_single_densel,
const BasicCoordinate< 3, int > &  coords,
const TargetT &  current_image_estimate 
) const
virtual

This computes a single row of the Hessian.

Default implementation just call error(). This function needs to be overridden by the derived class.

The method computes a row (i.e. at a densel/voxel, indicated by coords) of the Hessian at current_estimate. Note that a row corresponds to an object of DataT. The method (as implemented in derived classes) should store the result in prior_Hessian_for_single_densel.

Referenced by stir::GeneralisedPriorTests::test_Hessian_against_numerical().

◆ add_multiplication_with_approximate_Hessian()

template<typename TargetT>
void stir::GeneralisedPrior< TargetT >::add_multiplication_with_approximate_Hessian ( TargetT &  output,
const TargetT &  input 
) const
virtual

This should compute the multiplication of the Hessian with a vector and add it to output.

Default implementation just call error(). This function needs to be overridden by the derived class. This method assumes that the hessian of the prior is 1 and hence the function quadratic. Instead, accumulate_Hessian_times_input() should be used. This method remains for backwards comparability.

Warning
The derived class should accumulate in output.

Reimplemented in stir::ParametricQuadraticPrior< TargetT >.

◆ accumulate_Hessian_times_input()

template<typename TargetT>
void stir::GeneralisedPrior< TargetT >::accumulate_Hessian_times_input ( TargetT &  output,
const TargetT &  current_estimate,
const TargetT &  input 
) const
virtual

This should compute the multiplication of the Hessian with a vector and add it to output.

Default implementation just call error(). This function needs to be overridden by the derived class.

Warning
The derived class should accumulate in output.

Referenced by stir::GeneralisedPriorTests::test_Hessian_convexity_configuration().

◆ set_penalisation_factor()

template<typename elemT >
void stir::GeneralisedPrior< elemT >::set_penalisation_factor ( float  new_penalisation_factor)
inline
Warning
Currently we allow the penalisation factor to be set after calling set_up().

Referenced by stir::GeneralisedPriorTests::test_Hessian_convexity(), and stir::GeneralisedPriorTests::test_Hessian_convexity_configuration().

◆ is_convex()

template<typename DataT>
virtual bool stir::GeneralisedPrior< DataT >::is_convex ( ) const
pure virtual

◆ set_defaults()

template<typename TargetT >
void stir::GeneralisedPrior< TargetT >::set_defaults ( )
overrideprotectedvirtual

sets value for penalisation factor

Has to be called by set_defaults in the leaf-class

Reimplemented from stir::ParsingObject.

Reimplemented in stir::QuadraticPrior< float >.

Referenced by stir::FilterRootPrior< DataT >::compute_gradient().

◆ initialise_keymap()

template<typename TargetT >
void stir::GeneralisedPrior< TargetT >::initialise_keymap ( )
overrideprotectedvirtual

sets key for penalisation factor

Has to be called by initialise_keymap in the leaf-class

Reimplemented from stir::ParsingObject.

Reimplemented in stir::QuadraticPrior< float >.

Referenced by stir::FilterRootPrior< DataT >::compute_gradient().


The documentation for this class was generated from the following files: