Source code for whisker_serial_order.models

#!/usr/bin/env python
# whisker_serial_order/models.py

"""
===============================================================================

    Copyright © 2016-2018 Rudolf Cardinal (rudolf@pobox.com).

    Licensed under the Apache License, Version 2.0 (the "License");
    you may not use this file except in compliance with the License.
    You may obtain a copy of the License at

        http://www.apache.org/licenses/LICENSE-2.0

    Unless required by applicable law or agreed to in writing, software
    distributed under the License is distributed on an "AS IS" BASIS,
    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    See the License for the specific language governing permissions and
    limitations under the License.

===============================================================================

SQLAlchemy models and other data storage classes for the serial order task.

"""

from argparse import ArgumentTypeError
import logging
from typing import Any, List, Iterable, Optional, Set, Tuple

import arrow
from cardinal_pythonlib.sqlalchemy.alembic_func import (
    ALEMBIC_NAMING_CONVENTION,
)
from cardinal_pythonlib.sqlalchemy.arrow_types import ArrowMicrosecondType
from cardinal_pythonlib.sqlalchemy.orm_inspect import (
    deepcopy_sqla_object,
    SqlAlchemyAttrDictMixin,
)
from sqlalchemy import (
    BigInteger,
    Boolean,
    Column,
    Float,
    ForeignKey,
    Integer,
    MetaData,
    String,  # variable length in PostgreSQL; specify length for MySQL
    Text,  # variable length
)
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, Session
from sqlalchemy.sql.type_api import TypeDecorator
from sqlalchemy_utils import ScalarListType

from whisker_serial_order.constants import (
    DATETIME_FORMAT_PRETTY,
    MAX_EVENT_LENGTH,
    MAX_HOLE_NUMBER,
    MIN_HOLE_NUMBER,
    MIN_SERIAL_ORDER_POSITION,
    MAX_SERIAL_ORDER_POSITION,
)
from whisker_serial_order.extra import latency_s
from whisker_serial_order.version import (
    MAX_VERSION_LENGTH,
    SERIAL_ORDER_VERSION,
)

log = logging.getLogger(__name__)


# =============================================================================
# Constants
# =============================================================================

MAX_GENERIC_STRING_LENGTH = 255
MAX_HOLE_OR_SERIALPOS_PAIR_DEFINITION_STRING_LENGTH = 255  # more than enough!
N_HOLES_FOR_CHOICE = 2


# =============================================================================
# SQLAlchemy base.
# =============================================================================
# Derived classes will share the specified metadata.

MASTER_META = MetaData(naming_convention=ALEMBIC_NAMING_CONVENTION)
Base = declarative_base(metadata=MASTER_META)


# =============================================================================
# Helper functions/classes
# =============================================================================

[docs]def spatial_to_serial_order(hole_sequence: List[int], holes: List[int]) -> List[int]: """ Converts a temporal sequence of spatial holes into a list of serial order positions. Converts the list of spatial holes in use (``hole_sequence``) and the temporal sequence of hole indexes (``holes``) into a sequence of spatial hole numbers. Args: hole_sequence: ordered list of spatial hole numbers to be presented in the first phase of the task, e.g. [3, 1, 4]. holes: spatial hole numbers to be enquired about: "what was the temporal order of these holes in the first phase?"; e.g. [4, 3]. Returns: list of serial order positions (in this example: [3, 1]). """ return [hole_sequence.index(h) + 1 for h in holes]
[docs]def serial_order_to_spatial(hole_sequence: List[int], seq_positions: List[int]) -> List[int]: """ Converts a first-phase hole sequence and a list of serial order positions (at the choice phase) into a list of spatial holes at test. Args: hole_sequence: ordered list of spatial hole numbers to be presented in the first phase of the task, e.g. [3, 1, 4]. seq_positions: list of serial orders, e.g [1, 3] for the first and third in the sequence. Returns: list of spatial hole numbers (e.g. [3, 4] in this example). """ return [hole_sequence[i - 1] for i in seq_positions]
[docs]class ChoiceHoleRestriction(object): """ Class to describe choice hole restrictions. :ivar permissible_combinations: variable of type ``Set[Tuple[int]]``, where the tuples are sorted sequences of hole numbers. If the set is not empty, then only such combinations are allowed. """ DEFAULT_HOLE_SEPARATOR = "," DEFAULT_GROUP_SEPARATOR = ";" # NB ";" trickier from Bash command line def __init__( self, # String-based init: description: str = "", hole_separator: str = DEFAULT_HOLE_SEPARATOR, group_separator: str = DEFAULT_GROUP_SEPARATOR, # Hole-based init: permissible_combinations: List[List[int]] = None) -> None: """ Args: description: textual description like "1,3; 2,4" to restrict to the combinations of "hole 1 versus hole 3" and "hole 2 versus hole 4". hole_separator: string used to separate holes in a group (usually ","). group_separator: string used to separate groups (usually ";"). permissible_combinations: list of lists of spatial hole numbers, as an alternative to using ``description``. Use one or the other. Raises: argparse.ArgumentTypeError: if its arguments are invalid. """ def assert_hole_ok(hole_: int) -> None: if not (MIN_HOLE_NUMBER <= hole_ <= MAX_HOLE_NUMBER): raise ArgumentTypeError( "Bad hole number {} (must be in range {}-{})".format( hole_, MIN_HOLE_NUMBER, MAX_HOLE_NUMBER)) if description and permissible_combinations: raise ArgumentTypeError( "Specify description or permissible_combinations, " "but not both" ) permissible_combinations = permissible_combinations or [] # type: List[List[int]] # noqa self.permissible_combinations = set() # type: Set[Tuple[int]] # NOTE: can't add lists to a set (TypeError: unhashable type: 'list') if description: # Initialize from string for group_string in description.split(group_separator): holes = [] # type: List[int] for hole_string in group_string.split(hole_separator): try: hole = int(hole_string.strip()) except (ValueError, TypeError): raise ArgumentTypeError("Not an integer: {!r}".format( hole_string)) assert_hole_ok(hole) holes.append(hole) if len(holes) != N_HOLES_FOR_CHOICE: raise ArgumentTypeError( "In description {!r}, hole group {!r} must be of " "length {}, but isn't".format( description, group_string, N_HOLES_FOR_CHOICE) ) holes.sort() self.permissible_combinations.add(tuple(holes)) elif permissible_combinations: # Initialize from list of lists of holes for group in permissible_combinations: for hole in group: if not isinstance(hole, int): raise ArgumentTypeError( "Not an integer: {!r}".format(hole)) assert_hole_ok(hole) holes = sorted(group) self.permissible_combinations.add(tuple(holes)) # Check values are sensible: for holes in self.permissible_combinations: if len(holes) != len(set(holes)): raise ArgumentTypeError("No duplicates permitted; problem was " "{!r}".format(holes))
[docs] def description(self) -> str: """ Returns the description that can be used to recreate this object. """ groupsep = self.DEFAULT_GROUP_SEPARATOR + " " holesep = self.DEFAULT_HOLE_SEPARATOR if self.permissible_combinations: return groupsep.join( holesep.join(str(h) for h in holes) for holes in sorted(self.permissible_combinations) ) return ""
def __str__(self) -> str: return "ChoiceHoleRestriction({!r})".format(self.description())
[docs] def permissible(self, choice_holes: Iterable[int]) -> bool: """ Is the supplied list of choice holes compatible with the restrictions? Args: choice_holes: list of spatial holes. """ if not self.permissible_combinations: # No restrictions; OK return True sorted_holes = tuple(sorted(choice_holes)) return sorted_holes in self.permissible_combinations
[docs]class ChoiceHoleRestrictionType(TypeDecorator): """ SQLAlchemy data type to store :class:`.ChoiceHoleRestriction` in a database. See http://docs.sqlalchemy.org/en/latest/core/custom_types.html. """ impl = String(length=MAX_HOLE_OR_SERIALPOS_PAIR_DEFINITION_STRING_LENGTH)
[docs] def process_bind_param( self, value: Any, dialect: DefaultDialect) -> Optional[str]: """ Converts a bound Python parameter to the database value. Args: value: should be a :class:`.ChoiceHoleRestriction` or None dialect: SQLAlchemy database dialect. Returns: string (outbound to database) """ if not value: return value if not isinstance(value, ChoiceHoleRestriction): raise ValueError("Bad object arriving at " "ChoiceHoleRestrictionType.process_bind_param: " "{!r}".format(value)) return value.description()
[docs] def process_result_value( self, value: Any, dialect: DefaultDialect) -> Optional[ChoiceHoleRestriction]: """ Receive a result-row column value to be converted. Args: value: data fetched from the database (will be a string). dialect: SQLAlchemy database dialect. Returns: a :class:`.ChoiceHoleRestriction` object if the string is valid """ if not value: return None try: return ChoiceHoleRestriction(description=value) except ArgumentTypeError: log.debug("Bad value received from database to " "ChoiceHoleRestrictionType.process_result_value: " "{!r}".format(value)) return None
[docs] def process_literal_param(self, value: Any, dialect: DefaultDialect) -> str: """ Receive a literal parameter value to be rendered inline within a statement. (An abstract method of ``TypeDecorator``, so we should implement it.) Args: value: a Python value dialect: SQLAlchemy database dialect. Returns: a string to be baked into some SQL """ return str(value)
@property def python_type(self) -> type: """ Returns the Python type object expected to be returned by instances of this type, if known. It's :class:`.ChoiceHoleRestriction`. """ return ChoiceHoleRestriction
[docs]class SerialPosRestriction(object): """ Class to describe restrictions on the serial order positions offered at the choice phase. :ivar permissible_combinations: variable of type ``Set[Tuple[int]]``, where the tuples are sorted sequences of serial order position numbers (1 being the first). If the set is not empty, then only such combinations are allowed. """ DEFAULT_POS_SEPARATOR = "," DEFAULT_GROUP_SEPARATOR = ";" # NB ";" trickier from Bash command line def __init__( self, # String-based init: description: str = "", position_separator: str = DEFAULT_POS_SEPARATOR, group_separator: str = DEFAULT_GROUP_SEPARATOR, # Hole-based init: permissible_combinations: List[List[int]] = None) -> None: """ Args: description: textual description like "1,3; 2,3" to restrict to the combinations of "serial position 1 versus 3" and "serial position 2 versus 3". position_separator: string used to separate positions (usually ","). group_separator: string used to separate groups (usually ";"). permissible_combinations: list of lists of serial order positions, as an alternative to using ``description``. Use one or the other. Raises: argparse.ArgumentTypeError: if its arguments are invalid. """ def assert_position_ok(pos_: int) -> None: if not (MIN_SERIAL_ORDER_POSITION <= pos_ <= MAX_SERIAL_ORDER_POSITION): raise ArgumentTypeError( "Bad serial order position {} (must be in range " "{}-{})".format(pos_, MIN_SERIAL_ORDER_POSITION, MAX_SERIAL_ORDER_POSITION)) if description and permissible_combinations: raise ArgumentTypeError( "Specify description or permissible_combinations, " "but not both" ) permissible_combinations = permissible_combinations or [] # type: List[List[int]] # noqa self.permissible_combinations = set() # type: Set[Tuple[int]] # NOTE: can't add lists to a set (TypeError: unhashable type: 'list') if description: # Initialize from string for group_string in description.split(group_separator): positions = [] # type: List[int] for pos_string in group_string.split(position_separator): try: pos = int(pos_string.strip()) except (ValueError, TypeError): raise ArgumentTypeError("Not an integer: {!r}".format( pos_string)) assert_position_ok(pos) positions.append(pos) if len(positions) != N_HOLES_FOR_CHOICE: raise ArgumentTypeError( "In description {!r}, position group {!r} must be of " "length {}, but isn't".format( description, group_string, N_HOLES_FOR_CHOICE) ) positions.sort() self.permissible_combinations.add(tuple(positions)) elif permissible_combinations: # Initialize from list of lists of holes for group in permissible_combinations: for pos in group: if not isinstance(pos, int): raise ArgumentTypeError( "Not an integer: {!r}".format(pos)) assert_position_ok(pos) positions = sorted(group) self.permissible_combinations.add(tuple(positions)) # Check values are sensible: for positions in self.permissible_combinations: if len(positions) != len(set(positions)): raise ArgumentTypeError("No duplicates permitted; problem was " "{!r}".format(positions))
[docs] def description(self) -> str: """ Returns the description that can be used to recreate this object. """ groupsep = self.DEFAULT_GROUP_SEPARATOR + " " pos_sep = self.DEFAULT_POS_SEPARATOR if self.permissible_combinations: return groupsep.join( pos_sep.join(str(h) for h in holes) for holes in sorted(self.permissible_combinations) ) return ""
def __str__(self) -> str: return "SerialPosRestriction({!r})".format(self.description())
[docs] def permissible(self, serial_positions: Iterable[int]) -> bool: """ Is the supplied list of serial order position to be tested compatible with the restrictions? Args: serial_positions: the serial order positions to be presented in the choice phase """ if not self.permissible_combinations: # No restrictions; OK return True sorted_positions = tuple(sorted(serial_positions)) return sorted_positions in self.permissible_combinations
[docs]class SerialPosRestrictionType(TypeDecorator): """ SQLAlchemy data type to store :class:`.SerialPosRestriction` in a database. See http://docs.sqlalchemy.org/en/latest/core/custom_types.html. """ impl = String(length=MAX_HOLE_OR_SERIALPOS_PAIR_DEFINITION_STRING_LENGTH)
[docs] def process_bind_param( self, value: Any, dialect: DefaultDialect) -> Optional[str]: """ Converts a bound Python parameter to the database value. Args: value: should be a :class:`SerialPosRestriction` or None dialect: SQLAlchemy database dialect. Returns: string (outbound to database) """ if not value: return value if not isinstance(value, SerialPosRestriction): raise ValueError( "Bad object arriving at " "SerialPosRestrictionType.process_bind_param: " "{!r}".format(value)) return value.description()
[docs] def process_result_value( self, value: Any, dialect: DefaultDialect) \ -> Optional[SerialPosRestriction]: """ Receive a result-row column value to be converted. Args: value: data fetched from the database (will be a string). dialect: SQLAlchemy database dialect. Returns: a :class:`.SerialPosRestriction` object if the string is valid """ if not value: return None try: return SerialPosRestriction(description=value) except ArgumentTypeError: log.debug( "Bad value received from database to " "SerialPosRestrictionType.process_result_value: " "{!r}".format(value)) return None
[docs] def process_literal_param(self, value: Any, dialect: DefaultDialect) -> str: """ Receive a literal parameter value to be rendered inline within a statement. (An abstract method of ``TypeDecorator``, so we should implement it.) Args: value: a Python value dialect: SQLAlchemy database dialect. Returns: a string to be baked into some SQL """ return str(value)
@property def python_type(self) -> type: """ Returns the Python type object expected to be returned by instances of this type, if known. It's :class:`.SerialPosRestriction`. """ return SerialPosRestriction
[docs]class TrialPlan(object): """ Describes the planned sequence of holes to be offered, and then holes to be tested, for a single trial. :ivar sequence: sequence of 1-based hole numbers to be offered :ivar serial_order_choice: serial positions within the offered sequence to offer as choices (will be SORTED as they are offered simultaneously) :ivar hole_choice: hole positions to offer for the choice (will be SORTED as they are offered simultaneously) What's independent? - ``sequence`` and ``serial_order_choice`` are independent - ``sequence`` and ``correct_is_on_right`` are not independent (they are mediated by ``serial_order_choice``) - ``serial_order_choice`` and ``correct_is_on_right`` are not independent (they are mediated by ``sequence``) """ def __init__(self, sequence: List[int], serial_order_choice: List[int]) -> None: """ Args: sequence: the sequence of hole numbers to be offered (e.g. [3, 4, 1] to present hole 3, hole 4, and hole 1 in that order). serial_order_choice: the serial order positions to be tested (e.g. [1, 3] for the first and third). """ self.sequence = sequence # type: List[int] self.serial_order_choice = sorted(serial_order_choice) self.hole_choice = sorted( serial_order_to_spatial(self.sequence, self.serial_order_choice)) @property # for debugging def correct_incorrect_holes(self) -> Tuple[int, int]: """ Returns: tuple: ``(correct_hole, incorrect_hole)`` for the test phase. """ serial_order_of_choice_holes = spatial_to_serial_order( self.sequence, self.hole_choice) if serial_order_of_choice_holes[0] < serial_order_of_choice_holes[1]: return self.hole_choice[0], self.hole_choice[1] else: return self.hole_choice[1], self.hole_choice[0] @property # for debugging def correct_hole(self) -> int: """ Returns: The correct hole number, from the test phase. """ correct, incorrect = self.correct_incorrect_holes return correct @property # for debugging def incorrect_hole(self) -> int: """ Returns: The incorrect hole number, for the test phase. """ correct, incorrect = self.correct_incorrect_holes return incorrect @property # for debugging def correct_is_on_right(self) -> bool: """ Returns: Is the correct hole on the right? """ correct, incorrect = self.correct_incorrect_holes return correct > incorrect @property # for debugging def sequence_length(self) -> int: """ Returns: The length of the sequence presented. """ return len(self.sequence) def __repr__(self) -> str: return ( "TrialPlan(sequence={}, serial_order_choice={}, " "hole_choice={}; correct_hole={}, correct_is_on_right={})".format( self.sequence, self.serial_order_choice, self.hole_choice, self.correct_hole, self.correct_is_on_right) ) # @property # def hole_serial_order_combo(self) -> List[int]: # return self.serial_order_choice + self.hole_choice
[docs] def meets_restrictions( self, choice_hole_restriction: ChoiceHoleRestriction = None, serial_pos_restriction: SerialPosRestriction = None) \ -> bool: """ Does the trial plan meet the specified restrictions? """ if choice_hole_restriction: if not choice_hole_restriction.permissible(self.hole_choice): return False if serial_pos_restriction: if not serial_pos_restriction.permissible( self.serial_order_choice): return False return True
# ============================================================================= # Program configuration # =============================================================================
[docs]class Config(SqlAlchemyAttrDictMixin, Base): """ SQLAlchemy model for the ``config`` table. """ __tablename__ = 'config' config_id = Column(Integer, primary_key=True) modified_at = Column(ArrowMicrosecondType, default=arrow.now, onupdate=arrow.now) read_only = Column(Boolean) # used for a live task, therefore can't edit stages = relationship("ConfigStage", order_by="ConfigStage.stagenum", cascade="save-update, merge, delete") # No explicit relationship to Session. # This means that deepcopy() won't copy any non-config stuff, which is # helpful, but means that we have to use the session as the starting point # for the write-to-disk walk. # If we wanted to improve this, the other way would be to extend the # deepcopy() function to limit the classes it will traverse. # Whisker server = Column(String(MAX_GENERIC_STRING_LENGTH)) port = Column(Integer) devicegroup = Column(String(MAX_GENERIC_STRING_LENGTH)) # Subject subject = Column(String(MAX_GENERIC_STRING_LENGTH)) # Reinforcement reinf_n_pellets = Column(Integer) reinf_pellet_pulse_ms = Column(Integer) reinf_interpellet_gap_ms = Column(Integer) # ITI iti_duration_ms = Column(Integer) # Failed trials repeat_incomplete_trials = Column(Boolean) # Overall limits session_time_limit_min = Column(Float) def __init__(self, **kwargs) -> None: """ Must be clonable by deepcopy_sqla_object(), so must accept empty kwargs. """ self.read_only = kwargs.pop('read_only', False) self.server = kwargs.pop('server', 'localhost') self.port = kwargs.pop('port', 3233) self.devicegroup = kwargs.pop('devicegroup', 'box0') self.subject = kwargs.pop('subject', '') self.reinf_n_pellets = kwargs.pop('reinf_n_pellets', 2) self.reinf_pellet_pulse_ms = kwargs.pop('reinf_pellet_pulse_ms', 45) self.reinf_interpellet_gap_ms = kwargs.pop('reinf_interpellet_gap_ms', 250) self.iti_duration_ms = kwargs.pop('iti_duration_ms', 2000) self.session_time_limit_min = kwargs.pop('session_time_limit_min', 60) super().__init__(**kwargs) def __str__(self) -> str: return ( "Config {config_id}: subject = {subject}, server = {server}, " "devicegroup = {devicegroup}".format( config_id=self.config_id, subject=self.subject, server=self.server, devicegroup=self.devicegroup, ) )
[docs] def get_modified_at_pretty(self) -> Optional[str]: """ Gets the ``modified_at`` time as a human-readable string. """ if self.modified_at is None: return None return self.modified_at.strftime(DATETIME_FORMAT_PRETTY)
[docs] def clone(self, session: Session, read_only: bool = False) -> 'Config': """ Makes a copy of itself and adds it to the specified SQLAlchemy session. Args: session: the SQLAlchemy session into which to insert the copy. read_only: sets the ``read_only`` property of the copy. Returns: the copy. """ newconfig = deepcopy_sqla_object(self, session, flush=False) # type: Config # ... will add to session newconfig.read_only = read_only session.flush() # but not necessarily commit return newconfig
[docs] def get_n_stages(self) -> int: """ Returns the number of stages. """ return len(self.stages)
[docs] def has_stages(self) -> bool: """ Does the config have at least one stage? """ return self.get_n_stages() > 0
[docs]class ConfigStage(SqlAlchemyAttrDictMixin, Base): """ SQLAlchemy model for the ``config_stage`` table. """ __tablename__ = 'config_stage' config_stage_id = Column(Integer, primary_key=True) modified_at = Column(ArrowMicrosecondType, default=arrow.now, onupdate=arrow.now) config_id = Column(Integer, ForeignKey('config.config_id'), nullable=False) stagenum = Column(Integer, nullable=False) # consecutive, 1-based # Sequence sequence_length = Column(Integer) choice_hole_restriction = Column(ChoiceHoleRestrictionType, nullable=True) serial_pos_restriction = Column(SerialPosRestrictionType, nullable=True) side_dwor_multiplier = Column(Integer) # Limited hold limited_hold_s = Column(Float) # Progress to next stage when X of last Y correct, or total trials complete progression_criterion_x = Column(Integer) progression_criterion_y = Column(Integer) stop_after_n_trials = Column(Integer) def __init__(self, **kwargs) -> None: """ Must be clonable by deepcopy_sqla_object(), so must accept empty kwargs. """ self.config_id = kwargs.pop('config_id', None) # type: int self.stagenum = kwargs.pop('stagenum', None) # type: int self.choice_hole_restriction = kwargs.pop( 'choice_hole_restriction', None) # type: ChoiceHoleRestriction self.side_dwor_multiplier = kwargs.pop('side_dwor_multiplier', 1) self.sequence_length = kwargs.pop('sequence_length', None) # type: int self.limited_hold_s = kwargs.pop('limited_hold_s', 10) # type: float self.progression_criterion_x = kwargs.pop('progression_criterion_x', 10) # type: int self.progression_criterion_y = kwargs.pop('progression_criterion_y', 12) # type: int # In R: use binom.test(x, y) to get the p value for these. # Here, the defaults are such that progression requires p = 0.03857. self.stop_after_n_trials = kwargs.pop('stop_after_n_trials', 100) # type: int # noqa super().__init__(**kwargs) @property def choice_hole_restriction_desc(self) -> str: """ Returns the description of any choice_hole_restriction. """ if not self.choice_hole_restriction: return "" return self.choice_hole_restriction.description() @property def serial_pos_restriction_desc(self) -> str: """ Returns the description of any serial_pos_restriction. """ if not self.serial_pos_restriction: return "" return self.serial_pos_restriction.description()
# ============================================================================= # Session summary details # =============================================================================
[docs]class TaskSession(SqlAlchemyAttrDictMixin, Base): """ SQLAlchemy model for the ``session`` table (renamed from ``Session`` to ``TaskSession`` to avoid confusion with SQLAlchemy ``Session``). """ __tablename__ = 'session' session_id = Column(Integer, primary_key=True) config_id = Column(Integer, ForeignKey('config.config_id'), nullable=False) config = relationship("Config") events = relationship("Event") trials = relationship("Trial") started_at = Column(ArrowMicrosecondType, nullable=False) software_version = Column(String(MAX_VERSION_LENGTH)) filename = Column(Text) trials_responded = Column(Integer, nullable=False, default=0) trials_correct = Column(Integer, nullable=False, default=0) def __init__(self, **kwargs) -> None: self.config_id = kwargs.pop('config_id') # type: int self.started_at = kwargs.pop('started_at') # type: arrow.Arrow self.trials_responded = 0 self.trials_correct = 0 self.software_version = SERIAL_ORDER_VERSION super().__init__(**kwargs)
# ============================================================================= # Trial details # =============================================================================
[docs]class Trial(SqlAlchemyAttrDictMixin, Base): """ SQLAlchemy model for the ``trial`` table. """ __tablename__ = 'trial' trial_id = Column(Integer, primary_key=True) session_id = Column(Integer, ForeignKey('session.session_id'), nullable=False) events = relationship("Event") sequence_timings = relationship("SequenceTiming") trialnum = Column(Integer, nullable=False) config_stage_id = Column(Integer, ForeignKey('config_stage.config_stage_id'), nullable=False) stagenum = Column(Integer, nullable=False) started_at = Column(ArrowMicrosecondType) initiated_at = Column(ArrowMicrosecondType) initiation_latency_s = Column(Float) sequence_holes = Column(ScalarListType(int)) # in order of presentation sequence_length = Column(Integer) # for convenience # Various ways of reporting the holes offered, for convenience: choice_holes = Column(ScalarListType(int)) # in order of sequence choice_seq_positions = Column(ScalarListType(int)) # in order of sequence choice_hole_left = Column(Integer) # hole number, leftmost offered choice_hole_right = Column(Integer) # hole number, rightmost offered choice_hole_earliest = Column(Integer) # hole number, earlist in sequence choice_hole_latest = Column(Integer) # hole number, latest in sequence choice_seqpos_earliest = Column(Integer) # earliest sequence pos offered (1-based) # noqa choice_seqpos_latest = Column(Integer) # latest sequence pos offered (1-based) # noqa sequence_n_offered = Column(Integer, nullable=False, default=0) choice_offered = Column(Boolean, nullable=False, default=False) choice_offered_at = Column(ArrowMicrosecondType) responded = Column(Boolean, nullable=False, default=False) responded_at = Column(ArrowMicrosecondType) responded_hole = Column(Integer) # which hole was chosen? response_correct = Column(Boolean) response_latency_s = Column(Float) reinforced_at = Column(ArrowMicrosecondType) reinf_collected_at = Column(ArrowMicrosecondType) reinf_collect_latency_s = Column(Float) n_premature = Column(Integer, nullable=False, default=0) iti_started_at = Column(ArrowMicrosecondType) def __init__(self, **kwargs) -> None: self.session_id = kwargs.pop('session_id', None) # may be set later self.trialnum = kwargs.pop('trialnum') self.started_at = kwargs.pop('started_at') self.config_stage_id = kwargs.pop('config_stage_id') self.stagenum = kwargs.pop('stagenum') self.n_premature = 0 self.sequence_n_offered = 0 self.sequence_info = None # current sequence info super().__init__(**kwargs)
[docs] def set_sequence(self, sequence_holes: List[int]) -> None: """ Sets the sequence for the first phase of the trial. Args: sequence_holes: ordered list of hole numbers. """ self.sequence_holes = list(sequence_holes) # make a copy self.sequence_length = len(sequence_holes)
[docs] def set_choice(self, choice_holes: List[int]) -> None: """ Sets the choice holes offered in the second phase of the trial. Args: choice_holes: a list, of length 2, of the hole numbers. """ assert len(choice_holes) == 2 assert all(x in self.sequence_holes for x in choice_holes) # Order choice_holes by sequence_holes: self.choice_holes = sorted(choice_holes, key=lambda x: self.sequence_holes.index(x)) self.choice_seq_positions = spatial_to_serial_order( self.sequence_holes, self.choice_holes) self.choice_hole_left = min(self.choice_holes) self.choice_hole_right = max(self.choice_holes) self.choice_hole_earliest = self.choice_holes[0] self.choice_hole_latest = self.choice_holes[-1] self.choice_seqpos_earliest = self.sequence_holes.index( self.choice_hole_earliest) + 1 # 1-based self.choice_seqpos_latest = self.sequence_holes.index( self.choice_hole_latest) + 1 # 1-based
[docs] def get_sequence_holes_as_str(self) -> str: """ Returns a CSV string of the sequence holes. """ return ",".join(str(x) for x in self.sequence_holes)
[docs] def get_choice_holes_as_str(self) -> str: """ Returns a CSV string of the choice holes. """ return ",".join(str(x) for x in self.choice_holes)
[docs] def record_initiation(self, timestamp: arrow.Arrow) -> None: """ Records the time of trial initiation. """ self.initiated_at = timestamp self.initiation_latency_s = latency_s(self.started_at, self.initiated_at)
[docs] def record_sequence_hole_lit(self, timestamp: arrow.Arrow, holenum: int) -> None: """ Records the time and hole number that a sequence hole was illuminated. """ self.sequence_n_offered += 1 self.sequence_info = SequenceTiming( trial_id=self.trial_id, seq_pos=self.sequence_n_offered, hole_num=holenum, ) self.sequence_info.record_hole_lit(timestamp) self.sequence_timings.append(self.sequence_info)
[docs] def record_sequence_hole_response(self, timestamp: arrow.Arrow) -> None: """ Records a response to a sequence hole. """ if self.sequence_info is None: return self.sequence_info.record_hole_response(timestamp)
[docs] def record_sequence_mag_lit(self, timestamp: arrow.Arrow) -> None: """ Records illumination of the food magazine during the initial sequence. """ if self.sequence_info is None: return self.sequence_info.record_mag_lit(timestamp)
[docs] def record_sequence_mag_response(self, timestamp: arrow.Arrow) -> None: """ Records a response to the food magazine during the initial sequence. """ if self.sequence_info is None: return self.sequence_info.record_mag_response(timestamp)
[docs] def record_choice_offered(self, timestamp: arrow.Arrow) -> None: """ Records the time that the choice was offered. """ self.choice_offered = True self.choice_offered_at = timestamp
[docs] def record_response(self, response_hole: int, timestamp: arrow.Arrow) -> bool: """ Records the response during the choice phase. IMPLEMENTS THE KEY TASK RULE: "Which came first?" Args: response_hole: the hole that the subject responded to timestamp: when the response occurred was the response correct? """ self.responded = True self.responded_at = timestamp self.responded_hole = response_hole self.response_latency_s = latency_s(self.choice_offered_at, self.responded_at) self.response_correct = response_hole == self.choice_hole_earliest return self.response_correct
# noinspection PyUnusedLocal
[docs] def record_premature(self, timestamp: arrow.Arrow) -> None: """ Records a premature response. """ self.n_premature += 1
[docs] def record_reinforcement(self, timestamp: arrow.Arrow) -> None: """ Records the delivery of reinforcement. """ self.reinforced_at = timestamp
[docs] def record_reinf_collection(self, timestamp: arrow.Arrow) -> None: """ Records when the subject collected reinforcement. """ if self.was_reinf_collected(): return self.reinf_collected_at = timestamp self.reinf_collect_latency_s = latency_s(self.responded_at, self.reinf_collected_at)
[docs] def was_reinforced(self) -> bool: """ Was the trial reinforced? """ return self.reinforced_at is not None
[docs] def was_reinf_collected(self) -> bool: """ Was reinforcement collected? """ return self.reinf_collected_at is not None
[docs] def record_iti_start(self, timestamp: arrow.Arrow) -> None: """ Records the time that the intertrial interval started. """ self.iti_started_at = timestamp # And this one's done... self.sequence_info = None
# ============================================================================= # Event details # =============================================================================
[docs]class Event(SqlAlchemyAttrDictMixin, Base): """ SQLAlchemy model for the ``event`` table. """ __tablename__ = 'event' event_id = Column(Integer, primary_key=True) session_id = Column(Integer, ForeignKey('session.session_id'), nullable=False) eventnum_in_session = Column(Integer, nullable=False, index=True) trial_id = Column(Integer, ForeignKey('trial.trial_id')) # may be NULL trialnum = Column(Integer) # violates DRY for convenience eventnum_in_trial = Column(Integer) event = Column(String(MAX_EVENT_LENGTH), nullable=False) timestamp = Column(ArrowMicrosecondType, nullable=False) whisker_timestamp_ms = Column(BigInteger) from_server = Column(Boolean) def __init__(self, **kwargs) -> None: self.session_id = kwargs.pop('session_id', None) # may be set later self.eventnum_in_session = kwargs.pop('eventnum_in_session') self.trial_id = kwargs.pop('trial_id', None) self.trialnum = kwargs.pop('trialnum', None) self.eventnum_in_trial = kwargs.pop('eventnum_in_trial', None) self.event = kwargs.pop('event') self.timestamp = kwargs.pop('timestamp') self.whisker_timestamp_ms = kwargs.pop('whisker_timestamp_ms', None) self.from_server = kwargs.pop('from_server', False) super().__init__(**kwargs)
# ============================================================================= # Info/timings of the sequences, including response latencies # =============================================================================
[docs]class SequenceTiming(SqlAlchemyAttrDictMixin, Base): """ SQLAlchemy model for the ``sequence_timing`` table. """ __tablename__ = 'sequence_timing' sequence_timing_id = Column(Integer, primary_key=True) trial_id = Column(Integer, ForeignKey('trial.trial_id'), nullable=False) seq_pos = Column(Integer, nullable=False) hole_num = Column(Integer, nullable=False) hole_lit_at = Column(ArrowMicrosecondType) hole_response_at = Column(ArrowMicrosecondType) hole_response_latency_s = Column(Float) mag_lit_at = Column(ArrowMicrosecondType) mag_response_at = Column(ArrowMicrosecondType) mag_response_latency_s = Column(Float) def __init__(self, **kwargs) -> None: self.trial_id = kwargs.pop('trial_id') self.seq_pos = kwargs.pop('seq_pos') self.hole_num = kwargs.pop('hole_num') super().__init__(**kwargs)
[docs] def record_hole_lit(self, timestamp: arrow.Arrow) -> None: """ Records that the hole has been illuminated. """ self.hole_lit_at = timestamp
[docs] def record_hole_response(self, timestamp: arrow.Arrow) -> None: """ Records that the hole has been responded to. """ self.hole_response_at = timestamp self.hole_response_latency_s = latency_s(self.hole_lit_at, self.hole_response_at)
[docs] def record_mag_lit(self, timestamp: arrow.Arrow) -> None: """ Records that the food magazine has been illuminated. """ self.mag_lit_at = timestamp
[docs] def record_mag_response(self, timestamp: arrow.Arrow) -> None: """ Records that the food magazine has been responded to. """ self.mag_response_at = timestamp self.mag_response_latency_s = latency_s(self.mag_lit_at, self.mag_response_at)