2 Copyright (C) 2024, 2025 University College London
3 Copyright (C) 2020, 2022, University of Pennsylvania
4 Copyright (C) 2025, Commonwealth Scientific and Industrial Research Organisation
5 This file is part of STIR.
7 SPDX-License-Identifier: Apache-2.0 AND License-ref-PARAPET-license
9 See STIR/LICENSE.txt for details
14 \ingroup distributable
16 \brief Implementation of stir::LM_distributable_computation() and related functions
18 \author Nikos Efthimiou
19 \author Kris Thielemans
20 \author Ashley Gillman
22#include "stir/shared_ptr.h"
23#include "stir/recon_buildblock/distributable.h"
24#include "stir/DiscretisedDensity.h"
25#include "stir/CPUTimer.h"
26#include "stir/HighResWallClockTimer.h"
27#include "stir/is_null_ptr.h"
29#include "stir/format.h"
31#include "stir/recon_buildblock/ProjMatrixByBin.h"
32#include "stir/recon_buildblock/ProjMatrixElemsForOneBin.h"
35#include "stir/num_threads.h"
39template <typename CallBackT>
41LM_distributable_computation(const shared_ptr<ProjMatrixByBin> PM_sptr,
42 const shared_ptr<ProjDataInfo>& proj_data_info_sptr,
43 DiscretisedDensity<3, float>* output_image_ptr,
44 const DiscretisedDensity<3, float>* input_image_ptr,
45 const std::vector<BinAndCorr>& record_ptr,
47 const int num_subsets,
49 const bool accumulate,
50 double* double_out_ptr,
51 CallBackT&& call_back)
56 HighResWallClockTimer wall_clock_timer;
57 wall_clock_timer.start();
59 assert(!record_ptr.empty());
61 if (output_image_ptr != NULL && !accumulate)
62 output_image_ptr->fill(0.F);
64 std::vector<shared_ptr<DiscretisedDensity<3, float>>> local_output_image_sptrs;
65 std::vector<double> local_double_outs;
66 std::vector<double*> local_double_out_ptrs;
67 std::vector<int> local_counts, local_count2s;
68 std::vector<ProjMatrixElemsForOneBin> local_row;
70# pragma omp parallel shared(local_output_image_sptrs, local_row, local_double_outs, local_counts, local_count2s)
72 // start of threaded section if openmp
77 // allocate "local" vectors
80 const auto num_threads = omp_get_num_threads();
82 const int num_threads = 1;
84 info("Listmode gradient calculation: starting loop with " + std::to_string(num_threads) + " threads", 2);
85 local_output_image_sptrs.resize(get_max_num_threads(), shared_ptr<DiscretisedDensity<3, float>>());
86 local_double_out_ptrs.resize(get_max_num_threads(), 0);
89 local_double_outs.resize(get_max_num_threads(), 0.);
90 for (int t = 0; t < get_max_num_threads(); ++t)
91 local_double_out_ptrs[t] = &local_double_outs[t];
93 local_counts.resize(get_max_num_threads(), 0);
94 local_count2s.resize(get_max_num_threads(), 0);
95 local_row.resize(get_max_num_threads(), ProjMatrixElemsForOneBin());
98# pragma omp for schedule(dynamic)
100 // note: VC uses OpenMP 2.0, so need signed integer for loop
101 for (long int ievent = 0; ievent < static_cast<long>(record_ptr.size()); ++ievent)
103 auto& record = record_ptr.at(ievent);
104 if (record.my_bin.get_bin_value() == 0.0f) // shouldn't happen really, but a check probably doesn't hurt
108 const int thread_num = omp_get_thread_num();
110 const int thread_num = 0;
113 if (output_image_ptr != NULL)
115 if (is_null_ptr(local_output_image_sptrs[thread_num]))
116 local_output_image_sptrs[thread_num].reset(output_image_ptr->get_empty_copy());
119 const Bin& measured_bin = record.my_bin;
123 Bin basic_bin = measured_bin;
124 if (!PM_sptr->get_symmetries_ptr()->is_basic(measured_bin))
125 PM_sptr->get_symmetries_ptr()->find_basic_bin(basic_bin);
127 if (subset_num != static_cast<int>(basic_bin.view_num() % num_subsets))
133 PM_sptr->get_proj_matrix_elems_for_one_bin(local_row[thread_num], measured_bin);
134 call_back(*local_output_image_sptrs[thread_num],
135 local_row[thread_num],
136 has_add ? record.my_corr : 0.F,
139 local_double_out_ptrs[thread_num]);
142 // flatten data constructed by threads (or collapse unitary dim if no threading)
144 if (double_out_ptr != NULL)
146 for (int i = 0; i < static_cast<int>(local_double_outs.size()); ++i)
147 *double_out_ptr += local_double_outs[i]; // accumulate all (as they were initialised to zero)
149 // count += std::accumulate(local_counts.begin(), local_counts.end(), 0);
150 // count2 += std::accumulate(local_count2s.begin(), local_count2s.end(), 0);
152 if (output_image_ptr != NULL)
154 for (int i = 0; i < static_cast<int>(local_output_image_sptrs.size()); ++i)
155 if (!is_null_ptr(local_output_image_sptrs[i])) // only accumulate if a thread filled something in
156 *output_image_ptr += *(local_output_image_sptrs[i]);
160 wall_clock_timer.stop();
162 "Computation times for distributable_computation, CPU {}s, wall-clock {}s", CPU_timer.value(), wall_clock_timer.value()));