Source code for amstrax.plugins.pulse_processing_alt_baseline

import numba
import numpy as np

import strax

export, __all__ = strax.exporter()

# Number of TPC PMTs. Hardcoded for now...
n_tpc = 8


[docs]@export @strax.takes_config( strax.Option( 'save_outside_hits', default=(3, 3), help='Save (left, right) samples besides hits; cut the rest'), strax.Option('trigger_threshold', default=50), ) class PulseProcessing(strax.Plugin): """ 1. Split raw_records into: - tpc_records - diagnostic_records - aqmon_records Perhaps this should be done by DAQreader in the future For TPC records, apply basic processing: 2. Apply software HE veto after high-energy peaks. 3. Find hits, apply linear filter, and zero outside hits. """ __version__ = '0.0.3' parallel = 'process' rechunk_on_save = False compressor = 'zstd' depends_on = 'raw_records' provides = ('records_alt_bl', 'pulse_counts') data_kind = {k: k for k in provides}
[docs] def infer_dtype(self): # Get record_length from the plugin making raw_records rr_dtype = self.deps['raw_records'].dtype_for('raw_records') record_length = len(np.zeros(1, rr_dtype)[0]['data']) dtype = dict() for p in self.provides: if p.endswith('records_alt_bl'): dtype[p] = record_dtype(record_length) dtype['pulse_counts'] = pulse_count_dtype(n_tpc) return dtype
[docs] def compute(self, raw_records): # Do not trust in DAQ + strax.baseline to leave the # out-of-bounds samples to zero. records = np.zeros(len(raw_records), dtype=self.dtype['records_alt_bl']) for name in raw_records.dtype.names: records[name] = raw_records[name] baseline_std(records) strax.zero_out_of_bounds(records) ## # Split off non-TPC records and count TPC pulses # (perhaps we should migrate this to DAQRreader in the future) ## r, other = channel_split(records, n_tpc) pulse_counts = count_pulses(r, n_tpc) # Find hits # -- before filtering,since this messes with the with the S/N hits = find_hits(r, threshold=self.config['trigger_threshold']) le, re = self.config['save_outside_hits'] r = strax.cut_outside_hits(r, hits, left_extension=le, right_extension=re) # Probably overkill, but just to be sure... strax.zero_out_of_bounds(r) return dict(records_alt_bl=r, pulse_counts=pulse_counts, )
[docs]@export @strax.takes_config( strax.Option('peak_gap_threshold', default=3000, help="No hits for this many ns triggers a new peak"), strax.Option('peak_left_extension', default=80, help="Include this many ns left of hits in peaks"), strax.Option('peak_right_extension', default=80, help="Include this many ns right of hits in peaks"), strax.Option('peak_min_area', default=1, help="Minimum contributing PMTs needed to define a peak"), strax.Option('peak_min_pmts', default=1, help="Minimum contributing PMTs needed to define a peak"), strax.Option('single_channel_peaks', default=False, help='Whether single-channel peaks should be reported'), strax.Option('peak_split_min_height', default=25, help="Minimum height in PE above a local sum waveform" "minimum, on either side, to trigger a split"), strax.Option('peak_split_min_ratio', default=4, help="Minimum ratio between local sum waveform" "minimum and maxima on either side, to trigger a split"), strax.Option('diagnose_sorting', track=False, default=False, help="Enable runtime checks for sorting and disjointness"), strax.Option('pmt_channel', default=0, help="PMT channel for splitting pmt and sipms"), strax.Option('trigger_threshold', default=50), ) class PeaksAltBl(strax.Plugin): depends_on = 'records_alt_bl' data_kind = dict(peaks_top_alt_bl='peaks', peaks_bottom_alt_bl='peaks') parallel = 'process' provides = ('peaks_top_alt_bl', 'peaks_bottom_alt_bl') rechunk_on_save = True __version__ = '0.1.12' dtype = dict(peaks_top_alt_bl=strax.peak_dtype(n_channels=8) + [(('Maximum height of the peak', 'peak_max'), np.int16)], peaks_bottom_alt_bl=strax.peak_dtype(n_channels=8) + [(('Maximum height of the peak', 'peak_max'), np.int16)] )
[docs] def compute(self, records_alt_bl): r = records_alt_bl self.to_pe = np.ones(16) hits = find_hits(r, threshold=self.config['trigger_threshold']) hits = strax.sort_by_time(hits) hits_bottom, hits_top = hits[hits['channel'] == self.config['pmt_channel']], hits[ hits['channel'] != self.config['pmt_channel']] r_bottom, r_top = r[r['channel'] == self.config['pmt_channel']], r[ r['channel'] != self.config['pmt_channel']] peaks_bottom = strax.find_peaks( hits_bottom, self.to_pe, gap_threshold=self.config['peak_gap_threshold'], left_extension=self.config['peak_left_extension'], right_extension=self.config['peak_right_extension'], min_channels=1, result_dtype=self.dtype['peaks_bottom_alt_bl']) strax.sum_waveform(peaks_bottom, r_bottom, self.to_pe) # peaks_bottom = strax.split_peaks( # peaks_bottom, r_bottom, self.to_pe, # min_height=self.config['peak_split_min_height'], # min_ratio=self.config['peak_split_min_ratio']) strax.compute_widths(peaks_bottom) peaks_top = strax.find_peaks( hits_top, self.to_pe, gap_threshold=self.config['peak_gap_threshold'], left_extension=self.config['peak_left_extension'], right_extension=self.config['peak_right_extension'], min_area=self.config['peak_min_area'], min_channels=self.config['peak_min_pmts'], result_dtype=self.dtype['peaks_top_alt_bl']) strax.sum_waveform(peaks_top, r_top, self.to_pe) # peaks_top = strax.split_peaks( # peaks_top, r_top, self.to_pe, # min_height=self.config['peak_split_min_height'], # min_ratio=self.config['peak_split_min_ratio']) strax.compute_widths(peaks_top) peaks_top['peak_max'] = np.max(peaks_top['data'], axis=1) peaks_bottom['peak_max'] = np.max(peaks_bottom['data'], axis=1) return dict(peaks_top_alt_bl=peaks_top, peaks_bottom_alt_bl=peaks_bottom, )
[docs]@export class PeakBasicsTopAltBl(strax.Plugin): provides = 'peak_basics_top_alt_bl' depends_on = 'peaks_top_alt_bl' data_kind = 'peaks' parallel = 'False' rechunk_on_save = True __version__ = '0.1.0' dtype = [ (('Start time of the peak (ns since unix epoch)', 'time'), np.int64), (('End time of the peak (ns since unix epoch)', 'endtime'), np.int64), (('Peak integral in PE', 'area'), np.float32), (('Number of PMTs contributing to the peak', 'n_channels'), np.int16), (('PMT number which contributes the most PE', 'max_pmt'), np.int16), (('Area of signal in the largest-contributing PMT (PE)', 'max_pmt_area'), np.int32), (('Width (in ns) of the central 50% area of the peak', 'range_50p_area'), np.float32), (('Fraction of area seen by the top array', 'area_fraction_top'), np.float32), (('Length of the peak waveform in samples', 'length'), np.int32), (('Time resolution of the peak waveform in ns', 'dt'), np.int16), ]
[docs] def compute(self, peaks): p = peaks p = strax.sort_by_time(p) r = np.zeros(len(p), self.dtype) for q in 'time length dt area'.split(): r[q] = p[q] r['endtime'] = p['time'] + p['dt'] * p['length'] r['n_channels'] = (p['area_per_channel'] > 0).sum(axis=1) r['range_50p_area'] = p['width'][:, 5] r['max_pmt'] = np.argmax(p['area_per_channel'], axis=1) r['max_pmt_area'] = np.max(p['area_per_channel'], axis=1) area_top = p['area_per_channel'][:, :8].sum(axis=1) # Negative-area peaks get 0 AFT - TODO why not NaN? m = p['area'] > 0 r['area_fraction_top'][m] = area_top[m] / p['area'][m] return r
[docs]@export class PeakBasicsBottomAltBl(strax.Plugin): provides = 'peak_basics_bottom_alt_bl' depends_on = 'peaks_bottom_alt_bl' data_kind = 'peaks' parallel = 'False' rechunk_on_save = True __version__ = '0.1.0' dtype = [ (('Start time of the peak (ns since unix epoch)', 'time'), np.int64), (('End time of the peak (ns since unix epoch)', 'endtime'), np.int64), (('Peak integral in PE', 'area'), np.float32), (('Number of PMTs contributing to the peak', 'n_channels'), np.int16), (('PMT number which contributes the most PE', 'max_pmt'), np.int16), (('Area of signal in the largest-contributing PMT (PE)', 'max_pmt_area'), np.int32), (('Width (in ns) of the central 50% area of the peak', 'range_50p_area'), np.float32), (('Fraction of area seen by the top array', 'area_fraction_top'), np.float32), (('Length of the peak waveform in samples', 'length'), np.int32), (('Time resolution of the peak waveform in ns', 'dt'), np.int16), ]
[docs] def compute(self, peaks): p = peaks p = strax.sort_by_time(p) r = np.zeros(len(p), self.dtype) for q in 'time length dt area'.split(): r[q] = p[q] r['endtime'] = p['time'] + p['dt'] * p['length'] r['n_channels'] = (p['area_per_channel'] > 0).sum(axis=1) r['range_50p_area'] = p['width'][:, 5] r['max_pmt'] = np.argmax(p['area_per_channel'], axis=1) r['max_pmt_area'] = np.max(p['area_per_channel'], axis=1) area_top = p['area_per_channel'][:, :8].sum(axis=1) # Negative-area peaks get 0 AFT - TODO why not NaN? m = p['area'] > 0 r['area_fraction_top'][m] = area_top[m] / p['area'][m] return r
# Base dtype for interval-like objects (pulse, peak, hit) interval_dtype = [ (('Channel/PMT number', 'channel'), np.int16), (('Time resolution in ns', 'dt'), np.int16), (('Start time of the interval (ns since unix epoch)', 'time'), np.int64), # Don't try to make O(second) long intervals! (('Length of the interval in samples', 'length'), np.int32), # Sub-dtypes MUST contain an area field # However, the type varies: float for sum waveforms (area in PE) # and int32 for per-channel waveforms (area in ADC x samples) ] def record_dtype(samples_per_record=strax.DEFAULT_RECORD_LENGTH): """Data type for a waveform record. Length can be shorter than the number of samples in data, this indicates a record with zero-padding at the end. """ return interval_dtype + [ (("Integral in ADC x samples", 'area'), np.int32), # np.int16 is not enough for some PMT flashes... (('Length of pulse to which the record belongs (without zero-padding)', 'pulse_length'), np.int32), (('Fragment number in the pulse', 'record_i'), np.int16), (('Baseline in ADC counts. data = int(baseline) - data_orig', 'baseline'), np.float32), (('Level of data reduction applied (strax.ReductionLevel enum)', 'reduction_level'), np.uint8), # Note this is defined as a SIGNED integer, so we can # still represent negative values after subtracting baselines (('Waveform data in ADC counts above baseline', 'data'), np.int16, samples_per_record), (('Baseline standard deviation', 'baseline_std'), np.float32), ] @numba.jit(nopython=True, nogil=True, cache=True) def baseline_std(records, baseline_samples=40): """Subtract pulses from int(baseline), store baseline in baseline field :param baseline_samples: number of samples at start of pulse to average Assumes records are sorted in time (or at least by channel, then time) Assumes record_i information is accurate (so don't cut pulses before baselining them!) """ if len(records) == 0: return records # Array for looking up last baseline seen in channel # We only care about the channels in this set of records; a single .max() # is worth avoiding the hassle of passing n_channels around last_bl_std_in = np.zeros(records['channel'].max() + 1, dtype=np.float32) for d_i, d in enumerate(records): # Compute the baseline if we're the first record of the pulse, # otherwise take the last baseline we've seen in the channel if d.record_i == 0: bl_std = last_bl_std_in[d.channel] = d.data[:baseline_samples].std() else: bl_std = last_bl_std_in[d.channel] d.baseline_std = bl_std @strax.growing_result(strax.hit_dtype, chunk_size=int(1e4)) @numba.jit(nopython=True, nogil=True, cache=True) def find_hits(records, threshold=70, _result_buffer=None): """Return hits (intervals above threshold) found in records. Hits that straddle record boundaries are split (TODO: fix this?) NB: returned hits are NOT sorted yet! """ buffer = _result_buffer if len(records) == 0: return samples_per_record = len(records[0]['data']) offset = 0 for record_i, r in enumerate(records): # print("Starting record ', record_i) in_interval = False hit_start = -1 area = 0 for i in range(samples_per_record): # We can't use enumerate over r['data'], # numba gives errors if we do. x = r['data'][i] above_threshold = x > threshold # print(r['data'][i], above_threshold, in_interval, hit_start) if not in_interval and above_threshold: # Start of a hit in_interval = True hit_start = i if in_interval: if not above_threshold: # Hit ends at the start of this sample hit_end = i in_interval = False elif i == samples_per_record - 1: # Hit ends at the *end* of this sample # (because the record ends) hit_end = i + 1 area += x in_interval = False else: area += x if not in_interval: # print('saving hit') # Hit is done, add it to the result if hit_end == hit_start: print(r['time'], r['channel'], hit_start) raise ValueError( "Caught attempt to save zero-length hit!") res = buffer[offset] res['left'] = hit_start res['right'] = hit_end res['time'] = r['time'] + hit_start * r['dt'] # Note right bound is exclusive, no + 1 here: res['length'] = hit_end - hit_start res['dt'] = r['dt'] res['channel'] = r['channel'] res['record_i'] = record_i area += int(round( res['length'] * (r['baseline'] % 1))) res['area'] = area area = 0 # Yield buffer to caller if needed offset += 1 if offset == len(buffer): yield offset offset = 0 # Clear stuff, just for easier debugging # hit_start = 0 # hit_end = 0 yield offset @numba.njit def rough_sum(regions, records, to_pe, n, dt): """Compute ultra-rough sum waveforms for regions, assuming: - every record is a single peak at its first sample - all regions have the same length and dt and probably not carying too much about boundaries """ if len(regions) == 0 or len(records) == 0: return # dt and n are passed explicitly to avoid overflows/wraparounds # related to the small dt integer type peak_i = 0 r_i = 0 while (peak_i <= len(regions) - 1) and (r_i <= len(records) - 1): p = regions[peak_i] l = p['time'] r = l + n * dt while True: if r_i > len(records) - 1: # Scan ahead until records contribute break t = records[r_i]['time'] if t >= r: break if t >= l: index = int((t - l) // dt) regions[peak_i]['data'][index] += ( records[r_i]['area'] * to_pe[records[r_i]['channel']]) r_i += 1 peak_i += 1 ## # Pulse counting ##
[docs]@export def pulse_count_dtype(n_channels): # NB: don't use the dt/length interval dtype, integer types are too small # to contain these huge chunk-wide intervals return [ (('Lowest start time observed in the chunk', 'time'), np.int64), (('Highest endt ime observed in the chunk', 'endtime'), np.int64), (('Number of pulses', 'pulse_count'), (np.int64, n_channels)), (('Number of lone pulses', 'lone_pulse_count'), (np.int64, n_channels)), (('Integral of all pulses in ADC_count x samples', 'pulse_area'), (np.int64, n_channels)), (('Integral of lone pulses in ADC_count x samples', 'lone_pulse_area'), (np.int64, n_channels)), ]
def count_pulses(records, n_channels): """Return array with one element, with pulse count info from records""" result = np.zeros(1, dtype=pulse_count_dtype(n_channels)) _count_pulses(records, n_channels, result) return result @numba.njit def _count_pulses(records, n_channels, result): count = np.zeros(n_channels, dtype=np.int64) lone_count = np.zeros(n_channels, dtype=np.int64) area = np.zeros(n_channels, dtype=np.int64) lone_area = np.zeros(n_channels, dtype=np.int64) last_end_seen = 0 next_start = 0 for r_i, r in enumerate(records): if r_i != len(records) - 1: next_start = records[r_i + 1]['time'] ch = r['channel'] if ch >= n_channels: print(ch) raise RuntimeError("Out of bounds channel in get_counts!") if r['record_i'] == 0: count[ch] += 1 area[ch] += r['area'] if (r['time'] > last_end_seen and r['time'] + r['pulse_length'] < next_start): lone_count[ch] += 1 lone_area[ch] += r['area'] last_end_seen = max(last_end_seen, r['time'] + r['pulse_length']) res = result[0] res['pulse_count'][:] = count[:] res['lone_pulse_count'][:] = lone_count[:] res['pulse_area'][:] = area[:] res['lone_pulse_area'][:] = lone_area[:] res['time'] = records[0]['time'] res['endtime'] = last_end_seen ## # Misc ## @numba.njit def _mask_and_not(x, mask): return x[mask], x[~mask]
[docs]@export @numba.njit def channel_split(rr, first_other_ch): """Return """ return _mask_and_not(rr, rr['channel'] < first_other_ch)
@numba.jit(nopython=True, nogil=True, cache=True) def get_record_index(raw_records, channel, direction): if direction == -1: for i in range(-1, -len(raw_records), -1): if raw_records[i]['channel'] == channel: return i return 0 if direction == +1: for i in range(1, len(raw_records), 1): if raw_records[i]['channel'] == channel: return i return len(raw_records)
[docs]@export @strax.growing_result(strax.record_dtype(), chunk_size=int(1e6)) @numba.jit(nopython=False, nogil=False, cache=True) def fill_records(raw_records, hits, trigger_window, _result_buffer=None): samples_per_record = strax.DEFAULT_RECORD_LENGTH tw = trigger_window buffer = _result_buffer offset = 0 skipper = 0 for ch in np.unique(hits['channel']): hit_ch = hits[(hits['channel'] == ch) & (hits['length'] > 1)] for h in hit_ch: if skipper != 0: skipper -= 1 continue dt = h['dt'] hit = [] hit.append(h) h_c = hit_ch[(hit_ch['time'] > h['time'])] max_t = h['time'] for h_ in h_c: if h_['time'] > max_t and h_['time'] < max_t + dt * tw: hit.append(h_) max_t = h_['time'] elif h_['time'] > max_t + dt * tw: break hit_buffer = np.zeros(len(hit), dtype=strax.hit_dtype) for i in np.arange(len(hit)): # print(hit[i]) hit_buffer[i] = hit[i] dt = hit_buffer[0]['dt'] start = hit_buffer[0]['time'] - dt * tw end = hit_buffer[-1]['time'] + dt * tw p_length = int((end - start) / dt) # records_needed = int(np.ceil(p_length / (samples_per_record))) p_offset = hit[0]['left'] - tw p_end = hit[-1]['right'] + tw input_record_index = [np.unique(hit_buffer['record_i']).tolist()][0] assert input_record_index != [0] if p_offset < 0: previous = hit_buffer[0]['record_i'] + get_record_index( raw_records[:hit_buffer[0]['record_i']], hit_buffer[0]['channel'], -1) if previous < 0 or previous == hit_buffer[0]['record_i']: p_length += p_offset p_offset = 0 if previous > 0: p_offset += samples_per_record input_record_index.append(previous) if p_end > samples_per_record: next_idx = hit_buffer[-1]['record_i'] + get_record_index( raw_records[hit_buffer[-1]['record_i']:], hit_buffer[-1]['channel'], +1) if next_idx < len(raw_records): input_record_index.append(next_idx) if next_idx > len(raw_records): p_length -= (p_end % samples_per_record) input_record_index.sort() record_buffer = [] for i in input_record_index: if i > len(raw_records) - 1: p_length -= tw * dt continue record_buffer.extend(list(raw_records[i]['data'])) records_needed = int(np.ceil(p_length / samples_per_record)) n_store = 0 for rec_i in range(records_needed): r_ = buffer[offset + rec_i] r_['dt'] = dt r_['channel'] = hit[0]['channel'] r_['pulse_length'] = p_length r_['record_i'] = rec_i r_['time'] = start + rec_i * samples_per_record * dt r_['baseline'] = raw_records[input_record_index[0]]['baseline'] p_offset += n_store if rec_i != records_needed - 1: n_store = samples_per_record else: n_store = p_length - samples_per_record * rec_i # print(p_length , records_needed, rec_i, n_store, p_offset , len(record_buffer)) r_['data'][:n_store] = record_buffer[p_offset:p_offset + n_store] r_['length'] = n_store skipper = len(hit) - 1 offset += records_needed if offset >= 750: yield offset offset = 0 print('Almost done!') yield offset