Skip to content
Snippets Groups Projects
Select Git revision
  • d175df45c8d5308d32896a1ec9b2891b493f4fb4
  • master default protected
2 results

rMtS_multi.py

Blame
  • tiborauer's avatar
    Tibor Auer authored
    5d1a0a89
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    rMtS_multi.py 22.05 KiB
    """
    Repetitive match-to-sample test
    A combination of Alekseichuk et al., 2016 (https://doi.org/10.1016/j.cub.2016.04.035) and Berger et al., 2019 (https://doi.org/10.1038/s41467-019-12057-0)
    """
    from psychopy import visual, data, logging, gui, core, clock, monitors #, parallel
    import os
    from psychopy.constants import (NOT_STARTED, STARTED, PLAYING, PAUSED,
                                    STOPPED, FINISHED, PRESSED, RELEASED, FOREVER)
    
    from numpy import ceil, array, concatenate, linspace, ones, inf, isin, array2string
    from numpy.random import shuffle, permutation, randint, choice
    from itertools import product
    from collections import OrderedDict
    from pyniexp.scannersynch import scanner_synch
    from utils import generate_jitter, generate_sample, get_neighbor
    
    if __name__ == '__main__':
        expName = 'repetitive Match-to-Sample'
    
        expInfo = OrderedDict([
            ('participant',''),
            ('session','1'),
            ('grid size',6),  # number of cells per axis -> number of cells = gridXY**2
            ('sample size',4),  # number of circles (cannot be more than half of the number of cells)
            ('sample retention time','[2, 2]'), # delay after sample
            ('match number',2), # number of matches per sample
            ('scanner mode', False), # Is it inside the scanner
            ('stimulation', False), # Run tES
            ('stimulation intensity [mA]', 1) # desired intensity (mA)
        ])
        dlg = gui.DlgFromDict(dictionary=expInfo, title=expName, sortKeys=False)
        if dlg.OK == False:
            core.quit()  # user pressed cancel
        expInfo['date'] = data.getDateStr()  # add a simple timestamp
    
        ######## CONFIG ########
        # Settings
        Monitor = 'testMonitor'
    
        EMUL = not(expInfo['scanner mode']) # Is it outside the scanner
        doSTIMULATION = expInfo['stimulation'] # Run tES
        # - Parallel port settings
        #parallel.setPortAddress(0x378) 
        #parallel.setData(0) # sets all pins to 0
        #parallel.setPin(2,1) # din 1
        #parallel.setPin(3,1) # din 2
        #parallel.setPin(4,1) # din 4
        #parallel.setPin(5,1) # din 8
        #parallel.setPin(6,1) # din 16
        #parallel.setPin(7,1) # din 32
        #parallel.setPin(8,1) # din 64
        #parallel.setPin(9,1) # din 128
    
        # - timings
        nDummies = 3
        restDuration = 5 # duration of rest between blocks
        nBlock = 1 # number of blocks
    
        # Schedule of trial: sampleJitter -> sample -> matchJitter + filler -> nMatch x [ match (same duration as sample) -> responseJitter -> response ]
        #     Actual: 0.5 + 0.5 + 1 + 2 * ( 0.5 + 0 + 1 ) = 5
    
        nSample = 60 # number of samples per block
        sampleJitterRange = [0.25, 0.75] # jitter before sample
        sampleDuration = 0.5 # duration of the sample
    
        nMatch = expInfo['match number'] # number of matches per sample
        matchJitterRange = eval(expInfo['sample retention time']) # delay after sample
        fillerDuration = 2
    
        responseJitterRange = [0, 0] # jitter between events
        responseDuration = 1 # maximum duration of response
    
        # - brain stimulation
        defWave = {
            'amplitude': expInfo['stimulation intensity [mA]'], # desired intensity (mA)
            'frequency': 10,
            'phase': 0,
            'duration': 
                nSample*(
                    array(sampleJitterRange).mean()+
                    sampleDuration+
                    array(matchJitterRange).mean()+
                    nMatch*(
                        sampleDuration + 
                        array(responseJitterRange).mean()+
                        responseDuration
                    )
                ), # 5s is the lenght if a trial (see above)
            'rampUp': 3,
            'rampDown': 3,
            'samplingRate': 1000,
        }
        phaseDiff = 0 # phase of thes second channel
        frequencies = [0,5,10,20,60] # x5 frequencies
    
        gridSize = 600 # width and height of the grid
        gridXY = expInfo['grid size'] # number of cells per axis -> number of cells = gridXY**2
        colour = array([-0.5,-0.5,-0.5])
        sampleNum = expInfo['sample size'] # number of circles (cannot be more than half of the number of cells)
        sampleSize = 0.9 # circle size relative to cell size
    
        tolJitter = 1e-3
    
        # Jitters
        sampleJitter = generate_jitter(sampleJitterRange,nSample,tolJitter)
        matchJitter = generate_jitter(matchJitterRange,nSample,tolJitter)
        responseJitter = generate_jitter(responseJitterRange,nSample*nMatch,tolJitter)
        frequencies = concatenate([permutation(frequencies) for i in range(int(ceil(nBlock/len(frequencies))))])
    
        # Grid
        sampleSize = gridSize / gridXY * sampleSize
        cellCoordinates = gridSize/gridXY*(gridXY-1)/2
        cellCoordinates = array(list(product(linspace(-cellCoordinates,cellCoordinates,gridXY),repeat=2)))
        gridCoordinates = array([(l,0) for l in  linspace(-gridSize/2,gridSize/2,gridXY+1)] + [(0,l) for l in  linspace(-gridSize/2,gridSize/2,gridXY+1)])
        gridAngles = [90]*(gridXY+1) + [0]*(gridXY+1)
    
        # Buttons
        # - button 6 and 7 (zero-indexed) corresponding to NATA right index and middle finger
        bYes = 6 # button for 'Yes'
        bNo = 7 # button for 'No'
    
        ######## LOGGING ########
        _thisDir = os.path.dirname(os.path.abspath(__file__))
        os.chdir(_thisDir)
    
        filename = _thisDir + os.sep + u'data/%s_%s_%s' % (expInfo['participant'], expName, expInfo['date'])
    
        thisExp = data.ExperimentHandler(name=expName, version='',
            extraInfo=expInfo, runtimeInfo=None,
            originPath='D:\\Private\\University of Surrey\\Violante, Ines Dr (Psychology) - NeuroModulationLab\\BayesianOptimisation\\Tasks\\rMTS.py',
            savePickle=True, saveWideText=True,
            dataFileName=filename)
        logFile = logging.LogFile(filename+'.log', level=logging.EXP)
        logging.console.setLevel(logging.WARNING)  # this outputs to the screen, not a file
    
        ######## PREPARE ########
        blockTrials = []
        for b in range(nBlock):
            blockTrials += [OrderedDict([])]
            sampleTrials = []
            shuffle(sampleJitter); shuffle(matchJitter); shuffle(responseJitter)
            for s in range(nSample):
                [sampleSelection, allSelection] = generate_sample(gridXY,sampleNum,1)
                sampleTrials += [OrderedDict([
                    ('onsetSample',sampleJitter[s]),
                    ('sample',sampleSelection),
                    ('onsetMatch',matchJitter[s])
                    ])]
                matchTrials = []
                for m in range(nMatch):
                    match = sampleSelection.copy()
                    indReplace = randint(0,sampleNum)
                    poolReplace = get_neighbor(gridXY,match[indReplace])
                    poolReplace = [i for i in poolReplace if i not in match] # non-inclusive neighbors
                    poolReplace += [match[indReplace]]*len(poolReplace) # pool with equal number of original value and its neighbors (i.e. p=0.5)
                    match[indReplace] = choice(poolReplace,1)
    
                    matchTrials += [OrderedDict([     
                        ('match', match),
                        ('onsetResponse',responseJitter[s*nMatch+m])
                        ])]
                sampleTrials[-1]['matchTrials'] = matchTrials
            blockTrials[-1]['sampleTrials'] = sampleTrials
            blockTrials[-1]['frequency'] = frequencies[b]
    
        loopBlock = data.TrialHandler(nReps=1, method='sequential', trialList=blockTrials, autoLog=False)
        thisExp.addLoop(loopBlock)
    
        # Scanner and buttons
        SSO = scanner_synch(config='config_scanner.json',emul_synch=EMUL,emul_buttons=EMUL)
        SSO.set_synch_readout_time(0.5)
        SSO.TR = 1.8
    
        SSO.set_buttonbox_readout_time(0.5)
        SSO.buttonbox_timeout = 1 # wait for response for 2 sec
        # buttons: no - yes
        if not(SSO.emul_buttons): SSO.add_buttonbox('Nata')
        else: 
            SSO.buttons = ['0']*(max([bNo, bYes])+1)
            SSO.buttons[bNo] = '2'
            SSO.buttons[bYes] = '1'
        SSO.start_process()
    
        # Stimulator
        # if doSTIMULATION:
        #     from pyniexp.stimulation import Waveform, Stimulator
        #     BSO = Stimulator(configFile='config_stimulation.json')
        #     defWave['frequency'] = frequencies[0]
        #     wave1 = Waveform(**defWave)
        #     defWave['phase'] = phaseDiff
        #     wave2 = Waveform(**defWave)
        #     BSO.loadWaveform([wave1, wave2])
    
        # Visual
        mon = monitors.Monitor(Monitor)
        win = visual.Window([1280,1024],winType='pyglet',screen=1,monitor=Monitor,units='pix',fullscr = True, autoLog=False, gammaErrorPolicy='ignore')
        win.mouseVisible = False
        gridForm = visual.ElementArrayStim(win=win, name='gridForm', nElements=(gridXY+1)*2, sizes=[gridSize,2], xys = gridCoordinates, oris=gridAngles, units='pix', 
                elementTex=ones([16,16]), elementMask=ones([16,16]), colors=colour, colorSpace='rgb', autoLog=False)
        restStim = visual.TextStim(win=win, name='Rest',
            text="+", height=gridSize*0.1, wrapWidth=win.size[0], autoLog=False)
        fullForm = visual.ElementArrayStim(win, nElements=gridXY**2, sizes=sampleSize, xys = cellCoordinates, units='pix', 
            elementTex=None, elementMask="circle", colors=-colour, colorSpace='rgb', autoLog=False)
        responseStim = visual.TextStim(win=win, name='Response',
            text="?", height=gridSize*0.9, wrapWidth=win.size[0], autoLog=False)
    
        # Timers
        expClock = core.Clock()
        trialClock = core.Clock()
        interimClock = core.Clock()
        logging.setDefaultClock(expClock)
    
        ######## WAIT FOR SYNCH ########
        msg = visual.TextStim(win, text="Press a button to start...", height=gridSize*0.1, wrapWidth=win.size[0], autoLog=False)
        msg.draw()
        win.flip()
        SSO.wait_for_button(timeout=inf)
    
        ######## WAIT FOR SYNCH ########
        msg = visual.TextStim(win, text="Wait for scanner...", height=gridSize*0.1, wrapWidth=win.size[0], autoLog=False)
        msg.draw()
        win.flip()
        SSO.wait_for_synch()
        #parallel.setPin(2, 1)
        #
        #parallel.setData(0)
    
        while nDummies:
            msg.text="{}".format(nDummies)
            msg.draw()
            win.flip()
            SSO.wait_for_synch()
            logging.log(level=logging.DATA, msg='Pulse - {:.3f} - {}'.format(SSO.time_of_last_pulse,SSO.synch_count))
            nDummies -= 1
    
        SSO.reset_clock()
        expClock.reset()
        frameN = -1
    
        ######## RUN ########
        # Rest
        trialClock.reset() 
        restStim.status = NOT_STARTED
        #parallel.setPin(2, 0)
        #parallel.setPin(9, 1)
    
        while trialClock.getTime() < restDuration:
            # get current time
            t = trialClock.getTime()
            frameN = frameN + 1  # number of completed frames (so 0 is the first frame)
    
            # *sampleForm* updates
            if restStim.status == NOT_STARTED:
                # keep track of start time/frame for later
                restStim.tStart = t
                restStim.frameNStart = frameN  # exact frame index
                restStim.status = STARTED
                win.logOnFlip(level=logging.EXP, msg='Rest - STARTED')
            
            if restStim.status == STARTED:
                restStim.draw()
    
            win.flip()
        interimClock.reset(restDuration-trialClock.getTime())
            
        #parallel.setData(0)
            
        if restStim.status == STARTED:
            restStim.status = STOPPED
            win.logOnFlip(level=logging.EXP, msg='Rest - STOPPED')
            #parallel.setPin(9, 0)
            #parallel.setPin(9, 1)
            #parallel.setData(0)
    
        # Main loop
        for thisBlock in loopBlock:
    
            loopSample = data.TrialHandler(nReps=1, method='sequential', trialList=thisBlock['sampleTrials'], autoLog=False)
            thisExp.addLoop(loopSample)
    
            # if doSTIMULATION:
            #    if thisBlock['frequency']: 
            #         BSO.stimulate()
            #     else:
            #         BSO.initialize()
            #     logging.log(level=logging.DATA, msg='Stimulation - {:.3f} - Frequency: {}'.format(SSO.clock,BSO.waves[0].frequency))
    
            for thisSample in loopSample:
                # Sample
                trialClock.reset(-interimClock.getTime()) 
                jitter = thisSample['onsetSample']
    
                sampleCoordinates = cellCoordinates[thisSample['sample'],:]
                sampleForm = visual.ElementArrayStim(win, nElements=sampleCoordinates.shape[0], sizes=sampleSize, xys = sampleCoordinates, units='pix', 
                    elementTex=None, elementMask="circle", colors=colour, colorSpace='rgb', autoLog=False)
                sampleForm.status = NOT_STARTED   
    
                while trialClock.getTime() < (jitter + sampleDuration):
                    # get current time
                    t = trialClock.getTime()
                    frameN = frameN + 1  # number of completed frames (so 0 is the first frame)
    
                    # update/draw components on each frame
                    gridForm.draw()
                    
                    # *sampleForm* updates
                    if t >= jitter and sampleForm.status == NOT_STARTED:
                        # keep track of start time/frame for later
                        sampleForm.tStart = t
                        sampleForm.frameNStart = frameN  # exact frame index
                        sampleForm.status = STARTED
                        win.logOnFlip(level=logging.EXP, msg='Sample - STARTED - ' + array2string(thisSample['sample']))
                        #parallel.setPin(9, 0)
                        #parallel.setPin(4, 1)
                    
                    if sampleForm.status == STARTED:
                        sampleForm.draw()
    
                    win.flip()
                interimClock.reset((jitter + sampleDuration) - trialClock.getTime()) 
    
                #parallel.setData(0)
    
                if sampleForm.status == STARTED:
                    sampleForm.status = STOPPED
                    win.logOnFlip(level=logging.EXP, msg='Sample - STOPPED')
                
                # Match
                loopMatch = data.TrialHandler(nReps=1, method='sequential', trialList=thisSample['matchTrials'], autoLog=False)
                thisExp.addLoop(loopMatch)
                for thisMatch in loopMatch:
                    # Recall
                    trialClock.reset(-interimClock.getTime()) 
                    if loopMatch.thisTrialN == 0: 
                        jitter = thisSample['onsetMatch']
                        fullForm.status = NOT_STARTED
                    else: jitter = 0
    
                    sampleCoordinates = cellCoordinates[thisMatch['match'],:]
                    recallForm = visual.ElementArrayStim(win, nElements=sampleCoordinates.shape[0], sizes=sampleSize, xys = sampleCoordinates, units='pix', 
                        elementTex=None, elementMask="circle", colors=colour, colorSpace='rgb', autoLog=False)
                    recallForm.status == NOT_STARTED
    
                    while trialClock.getTime() < (jitter + sampleDuration):
                        # get current time
                        t = trialClock.getTime()
                        frameN = frameN + 1  # number of completed frames (so 0 is the first frame)
    
                        # update/draw components on each frame
                        gridForm.draw()
    
                        # fullForm
                        if jitter:
                            if t < fillerDuration and fullForm.status == NOT_STARTED:
                                fullForm.tStart = t
                                fullForm.frameNStart = frameN  # exact frame index
                                fullForm.status = STARTED
                                win.logOnFlip(level=logging.EXP, msg='Mask - STARTED - ' + array2string(thisMatch['match']))
                            
                            
                            if t >= fillerDuration and fullForm.status == STARTED:
                                fullForm.status = STOPPED
                                win.logOnFlip(level=logging.EXP, msg='Mask - STOPPED')
    
                        # *recallForm* updates
                        if t >= jitter and recallForm.status == NOT_STARTED:
                            recallForm.tStart = t
                            recallForm.frameNStart = frameN  # exact frame index
                            recallForm.status = STARTED
                            win.logOnFlip(level=logging.EXP, msg='Match - STARTED - ' + array2string(thisMatch['match']))
                            #parallel.setPin(4, 0)
                            #parallel.setPin(5, 1)
                        if fullForm.status == STARTED:
                            fullForm.draw()
    
                        if recallForm.status == STARTED:
                            recallForm.draw()
    
                        win.flip()
                    interimClock.reset((jitter + sampleDuration) - trialClock.getTime()) 
    
                    if recallForm.status == STARTED:
                        recallForm.status = STOPPED
                        win.logOnFlip(level=logging.EXP, msg='Match - STOPPED')
                    
                    # Response 
                    trialClock.reset(-interimClock.getTime()) 
                    jitter = thisMatch['onsetResponse']
    
                    responseStim.status = NOT_STARTED
                    SSO.reset_buttons()
                    
                    while trialClock.getTime() < (jitter + responseDuration):
                        # get current time
                        t = trialClock.getTime()
                        frameN = frameN + 1  # number of completed frames (so 0 is the first frame)
    
                        # update/draw components on each frame
                        gridForm.draw()
                        
                        # *sampleForm* updates
                        if t >= jitter and responseStim.status == NOT_STARTED:
                            # keep track of start time/frame for later
                            responseStim.tStart = t
                            responseStim.frameNStart = frameN  # exact frame index
                            responseStim.setAutoDraw(True)
                            win.logOnFlip(level=logging.EXP, msg='Response - STARTED')
                            #parallel.setPin(6, 0)
                            #parallel.setPin(5, 0)
                            #parallel.setPin(7, 0)
                            #parallel.setPin(6, 1)
                            SSO.wait_for_button(no_block=True) 
    
                        win.flip()
                        
                    interimClock.reset((jitter + responseDuration)-trialClock.getTime())
    
                    if responseStim.status == STARTED:
                        responseStim.status = STOPPED
                        responseStim.setAutoDraw(False)
                        win.logOnFlip(level=logging.EXP, msg='Response - STOPPED')
                    
                    if len(SSO.buttonpresses): # no - SSO.buttonpresses[-1][0] = bNo; yes - SSO.buttonpresses[-1][0] = bYes
                        #parallel.setPin(6, 0)
                        #parallel.setPin(7, 1)
                        logging.log(level=logging.EXP, msg='Button - {:.3f} - {}'.format(SSO.buttonpresses[-1][1],SSO.buttonpresses[-1][0]))
                        thisExp.addData('resp.key',SSO.buttonpresses[-1][0])
                        thisExp.addData('resp.rt',SSO.buttonpresses[-1][1]-(SSO.clock-trialClock.getTime())-responseStim.tStart)
                        if all(thisMatch['match'] == thisSample['sample']) and (SSO.buttonpresses[-1][0] == bYes): 
                            thisExp.addData('resp.code','hit')
                        elif any(thisMatch['match'] != thisSample['sample']) and (SSO.buttonpresses[-1][0] == bNo): 
                            thisExp.addData('resp.code','cr')
                        else: thisExp.addData('resp.code','false')
                    else: thisExp.addData('resp.code','miss')
                    thisExp.nextEntry()
        
            # Rest
            trialClock.reset(-interimClock.getTime()) 
            restStim.status = NOT_STARTED
    
            while trialClock.getTime() < restDuration:
                # get current time
                t = trialClock.getTime()
                frameN = frameN + 1  # number of completed frames (so 0 is the first frame)
    
                # *sampleForm* updates
                if restStim.status == NOT_STARTED:
                    # keep track of start time/frame for later
                    restStim.tStart = t
                    restStim.frameNStart = frameN  # exact frame index
                    restStim.status = STARTED
                    win.logOnFlip(level=logging.EXP, msg='Rest - STARTED')
                    #parallel.setPin(7, 0)
                    #parallel.setPin(9, 0)
                    #parallel.setPin(9, 1)
                
                if restStim.status == STARTED:
                    restStim.draw()
    
                win.flip()
            interimClock.reset(restDuration-trialClock.getTime())
    
            if restStim.status == STARTED:
                restStim.status = STOPPED
                win.logOnFlip(level=logging.EXP, msg='Rest - STOPPED')
            # if doSTIMULATION & (loopBlock.nRemaining>0): 
            #     for ch in range(BSO.nChannels):
            #         BSO.waves[ch].frequency = loopBlock.getFutureTrial()['frequency']
            #     BSO.loadWaveform()
    
        win.flip()
        SSO = None
        if doSTIMULATION: BSO = None
    
        ######## EVENTS (BIDS) ########
        logging.flush()
    
        from numpy import fromstring
        import csv
    
        eventFile = filename + '_events.tsv'
        fOut = open(eventFile,'w',newline='')
        ev = csv.writer(fOut, delimiter='\t')
        ev.writerow(['onset','duration','trial_type','response_time','value'])
    
        fIn = open(filename+'.log')
        log = csv.reader(fIn, delimiter='\t')
        sample = []
        nMatch = 0
        match = []
        itemToWrite = [None]*5 # 5 columns
        button = []
        for item in log:
            if len(item) < 2 or item[1].find('EXP') == -1: continue
            if any(item[2].find(evs) >= 0 for evs in ['Sample', 'Match', 'Response']):
                if itemToWrite[0] is None:
                    itemToWrite[0:3] = [round(float(item[0]),4), None, item[2].split(' - ')[0]]
                    if item[2].find('Sample') >= 0:
                        sample = fromstring(item[2].split(' - ')[2][1:-1],sep=' ')
                        nMatch = 0
                    elif item[2].find('Match') >= 0: 
                        match = fromstring(item[2].split(' - ')[2][1:-1],sep=' ')
                        nMatch += 1
                        itemToWrite[2] += str(nMatch)
                else:
                    itemToWrite[1] = round(float(item[0]) - itemToWrite[0],4)
                    if item[2].find('Response') >= 0:
                        itemToWrite[2] += str(nMatch)
                        if len(button):
                            button[0] = round(button[0]-itemToWrite[0],4)
                        else:
                            button = ['n/a','miss']
                        itemToWrite[3:5] = button
                        button = []
                    ev.writerow(itemToWrite)
                    itemToWrite = [None]*5 # 5 columns
            elif item[2].find('Button') >= 0:
                button += [float(item[2].split(' - ')[1])]
                if not(isin(match,sample) ^ (item[2].split(' - ')[2] == str(bYes))): 
                    button += ['hit']
                else: button += ['false']
        fIn.close()
        fOut.close()