Source code for amstrax.common

from collections import defaultdict
from platform import python_version
from importlib import import_module
import socket

from collections import OrderedDict, deque
from importlib import import_module
from git import Repo, InvalidGitRepositoryError
from configparser import NoSectionError
import pandas as pd
import inspect
import io
import os
import tarfile
import sys
import numpy as np
import strax


export, __all__ = strax.exporter()
__all__ += [
    "amstrax_dir",
    "to_pe",
    "n_tpc_pmts",
    "n_xamsl_channel",
    "tpc_r",
    "tpc_z",
]

# Current values
n_tpc_pmts = 8
n_xamsl_channel = 4
to_pe = 1

tpc_r = 6  # TODO check this value
tpc_z = 10  # TODO check this value

# Current values
n_tpc_pmts = 8
n_xamsl_channel = 4
to_pe = 1

_is_jupyter = any("jupyter" in arg for arg in sys.argv)

amstrax_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))


def open_test_data(file_name):
    """Downloads amstrax test data to strax_test_data in the current directory"""
    with open(file_name, mode="rb") as f:
        result = f.read()
    f = io.BytesIO(result)
    tf = tarfile.open(fileobj=f)
    tf.extractall()


first_sr1_run = "1"

cache_dict = dict()


# Placeholder for resource management system in the future?
# @export
# def get_resource(x, fmt='text'):
#     return straxen.get_resource(x, fmt=fmt)


[docs]@export def get_elife(run_id): """Return electron lifetime for run_id in ns""" return 642e3
[docs]@export def select_channels(arr, channel_list): """Select only the values in arr that have arr['channel'] in channel_list""" sel = np.sum([arr["channel"] == channel for channel in channel_list], axis=0) > 0 return arr[sel]
[docs]@export def get_config_defaults(st, exclude=["raw_records", "records"], include=[]): configs = [] for plugin in st._plugin_class_registry.values(): if len(include) > 0: skip = True for t in include: if t in plugin.provides: skip = False if skip: continue # if raw_records or records in plugin.provides: skip = False for t in exclude: if t in plugin.provides: print(f"{plugin.__name__} {t} skipped") skip = True if not skip: takes_config = dict(plugin.takes_config) for name, config in plugin.takes_config.items(): configs.append((name, config.default)) configs = set(configs) df = pd.DataFrame(configs, columns=["name", "default"]) return df
def _version_info_for_module(module_name, include_git): try: mod = import_module(module_name) except (ModuleNotFoundError, ImportError): print(f"{module_name} is not installed") return git = None version = mod.__dict__.get("__version__", None) module_path = mod.__dict__.get("__path__", [None])[0] if include_git: try: repo = Repo(module_path, search_parent_directories=True) except InvalidGitRepositoryError: # not a git repo pass else: try: branch = repo.active_branch except TypeError: branch = "unknown" try: commit_hash = repo.head.object.hexsha except TypeError: commit_hash = "unknown" git = f"branch:{branch} | {commit_hash[:7]}" return version, module_path, git
[docs]@export def run_ids_from_file(filename): run_ids = [] with open(filename, "r") as f: for line in f: run_ids.append(f"{int(line):06d}") print(f"Found {len(run_ids)} runs in {filename}") return run_ids