STIR 6.4.0
GibbsRelativeDifferencePenalty.h
Go to the documentation of this file.
1//
2//
3/*
4 Copyright (C) 2019 -2025, University College London
5 Copyright (C) 2025, University of Milano-Bicocca
6 This file is part of STIR.
7
8 SPDX-License-Identifier: Apache-2.0
9
10 See STIR/LICENSE.txt for details
11*/
24
25#ifndef __stir_recon_buildblock_GibbsRelativeDifferencePenalty_H__
26#define __stir_recon_buildblock_GibbsRelativeDifferencePenalty_H__
27
30#include <cmath>
31#include "stir/cuda_utilities.h"
32
33#ifdef STIR_WITH_CUDA
35#endif
36
37START_NAMESPACE_STIR
38
61template <typename elemT>
63{
64public:
65 float gamma;
66 float epsilon;
68 __host__ __device__ inline double value(const elemT val_center, const elemT val_neigh, int z, int y, int x) const
69 {
70 // Implemented formula:
71 // return 0.5 * (val_center -val_neigh)**2 / (val_center +val_neigh + gamma * |val_center - val_neigh| + epsilon);
72 const elemT diff = val_center - val_neigh;
73 const elemT add = val_center + val_neigh;
74 const elemT NUM = 0.5 * (diff * diff);
75 const elemT DEN = 1.0 / (add + gamma * fabs(diff) + epsilon);
76
77 return NUM * DEN;
78 }
79
81 __host__ __device__ inline double derivative_10(const elemT val_center, const elemT val_neigh, int z, int y, int x) const
82 {
83 // Implemented formula:
84 // return 0.5 * (val_center-val_neigh) * (val_center + 3*val_neigh + gamma * |val_center - val_neigh| + 2 * epsilon ) /
85 // ((val_center +val_neigh)+ gamma * |val_center - val_neigh| + epsilon)**2;
86 const elemT diff = val_center - val_neigh;
87 const elemT factor = val_center + val_neigh + gamma * fabs(diff) + epsilon;
88 const elemT NUM = 0.5 * diff * (factor + 2 * val_neigh + epsilon);
89 const elemT DEN = 1.0 / (factor * factor);
90
91 return NUM * DEN;
92 }
93
95 __host__ __device__ inline double derivative_20(const elemT val_center, const elemT val_neigh, int z, int y, int x) const
96 {
97 // Implemented formula:
98 // return (2*val_center + epsilon)**2 / (val_center + val_neigh + gamma * |val_center - val_neigh| + epsilon)**3;
99 const elemT NUM = 2 * val_neigh + epsilon;
100 const elemT DEN = 1.0 / (val_center + val_neigh + gamma * fabs(val_center - val_neigh) + epsilon);
101
102 return NUM * NUM * DEN * DEN * DEN;
103 }
104
105 __host__ __device__ inline double derivative_11(const elemT val_center, const elemT val_neigh, int z, int y, int x) const
106 {
107 // Implemented formula:
108 // return -(2*val_center + epsilon)*(2*val_neigh + epsilon)/ (val_center + val_neigh + gamma * |val_center - val_neigh| +
109 // epsilon)**3;
110 const elemT NUM = -(2 * val_center + epsilon) * (2 * val_neigh + epsilon);
111 const elemT DEN = 1.0 / (val_center + val_neigh + gamma * fabs(val_center - val_neigh) + epsilon);
112
113 return NUM * DEN * DEN * DEN;
114 }
115
117 static inline bool is_convex() { return true; }
118
121 {
122 parser.add_key("gamma value", &this->gamma);
123 parser.add_key("epsilon value", &this->epsilon);
124 }
125
128 {
129 this->gamma = 2;
130 this->epsilon = 1e-7;
131 }
132};
133
149template <typename elemT>
150class GibbsRelativeDifferencePenalty : public RegisteredParsingObject<GibbsRelativeDifferencePenalty<elemT>,
151 GeneralisedPrior<DiscretisedDensity<3, elemT>>,
152 GibbsPenalty<elemT, RelativeDifferencePotential<elemT>>>
153{
154private:
158 base_type;
159
160public:
161 GibbsRelativeDifferencePenalty() { this->set_defaults(); }
162 GibbsRelativeDifferencePenalty(const bool only_2D, float penalisation_factor, float gamma_v, float epsilon_v)
163 : base_type(only_2D, penalisation_factor)
164 {
165 this->potential.gamma = gamma_v;
166 this->potential.epsilon = epsilon_v;
167 }
168
169 static constexpr const char* registered_name = "Gibbs Relative Difference";
170 float get_gamma() const { return this->potential.gamma; }
171 float get_epsilon() const { return this->potential.epsilon; }
172 void set_gamma(float gamma_v) { this->potential.gamma = gamma_v; }
173 void set_epsilon(float epsilon_v) { this->potential.epsilon = epsilon_v; }
174};
175
176#ifdef STIR_WITH_CUDA
191template <typename elemT>
192class CudaGibbsRelativeDifferencePenalty
193 : public RegisteredParsingObject<CudaGibbsRelativeDifferencePenalty<elemT>,
194 GeneralisedPrior<DiscretisedDensity<3, elemT>>,
195 CudaGibbsPenalty<elemT, RelativeDifferencePotential<elemT>>>
196{
197private:
201 base_type;
202
203public:
204 CudaGibbsRelativeDifferencePenalty() { this->set_defaults(); }
205 CudaGibbsRelativeDifferencePenalty(const bool only_2D, float penalisation_factor, float gamma_v, float epsilon_v)
206 : base_type(only_2D, penalisation_factor)
207 {
208 this->potential.gamma = gamma_v;
209 this->potential.epsilon = epsilon_v;
210 }
211
212 static constexpr const char* registered_name = "Cuda Gibbs Relative Difference";
213 float get_gamma() const { return this->potential.gamma; }
214 float get_epsilon() const { return this->potential.epsilon; }
215 void set_gamma(float gamma_v) { this->potential.gamma = gamma_v; }
216 void set_epsilon(float epsilon_v) { this->potential.epsilon = epsilon_v; }
217};
218#endif
219
220END_NAMESPACE_STIR
221
222#endif
Declaration of the stir::CudaGibbsPenalty class.
Declaration of the stir::GibbsPenalty class.
Declaration of class stir::RegisteredParsingObject.
A base class with CUDA-accelerated implementation of the GibbsPenalty class.
Definition CudaGibbsPenalty.h:52
A base class for 'generalised' priors, i.e. priors for which at least a 'gradient' is defined.
Definition GeneralisedPrior.h:44
A base class for Gibbs type penalties in the GeneralisedPrior hierarchy.
Definition GibbsPenalty.h:99
potentialT potential
Gibbs Potential Function.
Definition GibbsPenalty.h:186
void set_defaults() override
sets value for penalisation factor
Definition GibbsPenalty.inl:158
bool only_2D
can be set during parsing to restrict the weights to the 2D case
Definition GibbsPenalty.h:173
A class to parse Interfile headers.
Definition KeyParser.h:162
void add_key(const std::string &keyword, float *variable_ptr)
add a keyword. When parsing, parse its value as a float and put it in *variable_ptr
Definition KeyParser.cxx:343
Parent class for all leaves in a RegisteredObject hierarchy that do parsing of parameter files.
Definition RegisteredParsingObject.h:78
Implementation of the Relative Difference penalty potential.
Definition GibbsRelativeDifferencePenalty.h:63
static bool is_convex()
method to indicate whether the the prior defined by this potential is convex
Definition GibbsRelativeDifferencePenalty.h:117
__host__ __device__ double value(const elemT val_center, const elemT val_neigh, int z, int y, int x) const
Method for computing the potential value.
Definition GibbsRelativeDifferencePenalty.h:68
__host__ __device__ double derivative_20(const elemT val_center, const elemT val_neigh, int z, int y, int x) const
Method for computing the second derivative with respect to val_center.
Definition GibbsRelativeDifferencePenalty.h:95
__host__ __device__ double derivative_11(const elemT val_center, const elemT val_neigh, int z, int y, int x) const
Method for computing the mixed second derivative.
Definition GibbsRelativeDifferencePenalty.h:105
void set_defaults()
Set default values for potential-specific parameters.
Definition GibbsRelativeDifferencePenalty.h:127
__host__ __device__ double derivative_10(const elemT val_center, const elemT val_neigh, int z, int y, int x) const
Method for computing the first derivative with respect to val_center.
Definition GibbsRelativeDifferencePenalty.h:81
void initialise_keymap(KeyParser &parser)
Method for setting up parsing additional parameters.
Definition GibbsRelativeDifferencePenalty.h:120
some utilities for STIR and CUDA