Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matlab edits #130

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
49 changes: 45 additions & 4 deletions MATLAB/decomposition/esnr.m
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
% Syntax: evokeSnr = esnr(chanSignal, chanNoise)
%
% Inputs:
% chanSignal - Matrix of channel signal samples (rows represent observations, columns represent features)
% chanNoise - Matrix of channel noise samples (rows represent observations, columns represent features)
% chanSignal - Vector of channel signal samples (rows represent observations, columns represent features)
% chanNoise - Vector of channel noise samples (rows represent observations, columns represent features)
%
% Output:
% evokeSnr - Scalar value representing the eSNR in decibels (dB)
%
% Example:
% chanSignal = [1 2 3; 4 5 6; 7 8 9];
% chanNoise = [0.5 1 1.5; 2 2.5 3; 3.5 4 4.5];
% chanSignal = [1 2 3 4 5 6 7 8 9];
% chanNoise = [0.5 1 1.5 2 2.5 3 3.5 4 4.5];
% evokeSnr = esnr(chanSignal, chanNoise);
%

Expand All @@ -31,3 +31,44 @@
% Calculate the eSNR in decibels (dB)
evokeSnr = 10 .* log10(mean(mDistSignal.^2 / varNoise));
end



function dist = mahalUpdate(noise, signal, sigma)
% mahalUpdate - Calculate Mahalanobis distance between signal and noise.
%
% Syntax: dist = mahalUpdate(noise, signal, sigma)
%
% Inputs:
% noise - Matrix of noise samples (rows represent trials, columns represent features)
% signal - Matrix of signal samples (rows represent trials, columns represent features)
% sigma - Covariance matrix (assumed to be positive definite)
%
% Output:
% dist - Row vector of Mahalanobis distances for each trial
%
% Example:
% noise = [1 2; 3 4; 5 6];
% signal = [2 3; 4 5; 6 7];
% sigma = [1 0.5; 0.5 1];
% dist = mahalUpdate(noise, signal, sigma);
%

% Calculate the mean of the noise samples
meanN = mean(noise, 1)';

% Calculate the number of trials
nTrials = size(noise, 1);

% Initialize the array to store Mahalanobis distances
dist = zeros(1, nTrials);

% Calculate the Mahalanobis distance for each trial
for tr = 1:nTrials
% Calculate the difference between the signal and mean of noise
diff = signal(tr, :)' - meanN;

% Calculate the Mahalanobis distance
dist(tr) = sqrt(diff' / sigma * diff);
end
end
80 changes: 80 additions & 0 deletions MATLAB/finger flexion decoding/featureExtract.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
% Place the following folders under Software tool in path
% a)"EMD",
% b)"wavelet-master"
% load day2run2_Isometric_Left.mat in UM-01-wednesday
% load day2run2_Isometric_Left.mat in UM-01-Thursday
% load z_iso_right_updated.mat in UM-03
%z = z_right; % only in UM-03 - brutal code

fsTrial = 1000; % sampling frequency of the sensor
fsECoG = 10000; % sampling frequency of ECoG
fsECoGd = fsECoG; % downsampled sampling frequency
tw = [-0.25 3.5]; % time-window across finger movement onset in seconds
% Place the following folders under Software tool in path
% a)"EMD",
% b)"wavelet-master"
% load day2run2_Isometric_Left.mat in UM-01-wednesday
% load day2run2_Isometric_Left.mat in UM-01-Thursday
% load z_iso_right_updated.mat in UM-03
z = z_right; % only in UM-03 - brutal code
% should work in all z struct files

% UM -01 channels
% c_channel1=[56:65]; % cable 1
% c_channel2=[66:75]; % cable 2
% UM -03 channels
c_channel1=[1:35]; % cable 1
c_channel2=[36:41]; % cable 2

trialInt = find(ismember([z.TaskNumber],[1 2 5])); % extracting finger trials only
zInt = z(trialInt);
trialId = [zInt.TaskNumber]; % finger ids
trialOn = [zInt.MoveOnset]; % flexion onsets
%trialOff = [zInt.MoveOffset]; % flexion offset

ieegSplit = []; % channels x trials x timepoints
tc = 1;
emptyTrials = [];
% Processing each trial information
% 1. Downsampling to 2 kHz
% 2. Filtering 60 Hz and its harmonics
% 3. Creating ieeg trial structure
for iTrial = 1:size(zInt,2)
iTrial

ieegTemp = zInt(iTrial).ECoG;
if(isempty(ieegTemp))
emptyTrials = [emptyTrials iTrial];
continue;
end
ieegTemp = cell2mat(squeeze(struct2cell(ieegTemp)));

ieegTemp = resample(ieegTemp',fsECoGd,fsECoG)'; % downsampling
ieegTempCar1 = ieegTemp(c_channel1,:)-mean(ieegTemp(c_channel1,:),1);
ieegTempCar2 = ieegTemp(c_channel2,:)-mean(ieegTemp(c_channel2,:),1);
ieegCar = [ieegTempCar1; ieegTempCar2];
ieegFilt = ieegCar; % no 60 Hz filtering
timeTrial = linspace(0,(size(ieegCar,2)-1)./fsECoGd,size(ieegCar,2));
timeId = find(timeTrial>=trialOn(iTrial)/fsTrial,1);
ieegTempArrange = ieegFilt(:,timeId+tw(1)*fsECoGd:timeId+tw(2)*fsECoGd);
ieegSplit(:,tc,:) = ieegTempArrange;
tc = tc+1;
end
trialId(emptyTrials) = [];
sigChan = [7,8,12,13]; %UM - 01
sigChan = [1,9]; % UM - 03;
ieegSplitSig = ieegSplit(sigChan,:,:);
sigPowerFFT = getFFTPower(ieegSplitSig,fsECoGd,tw,[0 1],[66 114]);
sigPowerEMD = getEmdPower(ieegSplitSig,tw,[0 1],9,3);
sigPowerWav = getWaveletPower(ieegSplitSig,fsECoGd,tw,[0 1],[66 114]);
sigPowerPsd = getPsd(ieegSplitSig,fsECoGd,tw,[0 1],[2 3]);
%% Naive - Bayes Classification

ypredFFT = nbClassify(sigPowerFFT',double(trialId),1);
accFFT = sum(ypredFFT==trialId)/length(trialId)
ypredEMD = nbClassify(sigPowerEMD',double(trialId),1);
accEMD = sum(ypredEMD==trialId)/length(trialId)
ypredWav = nbClassify(sigPowerWav',double(trialId),1);
accWav = sum(ypredWav==trialId)/length(trialId)
ypredPca = nbClassifyPSCA(sigPowerPsd(:,1:length(trialId),:),fsECoGd,double(trialId),1,1);
accPca = sum(ypredPca==trialId)/length(trialId)
28 changes: 23 additions & 5 deletions MATLAB/ieegClassDefinition/decoderClass.m
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,16 @@
% Initialize variables
CmatAll = zeros(length(decodeUnitUnique), length(decodeUnitUnique));
ytestall = [];

aucAll = zeros(1,length(decodeUnitUnique));
% Performclassification for nIter iterations
for iTer = 1:obj.nIter
if(trainTestDiff == 0)
% Call pcaLinearDecoderWrap function for classification
[~, ytest, ypred, optimVarAll, ~, modelWeightsAll] = pcaLinearDecoderWrap(ieegInput, decoderUnit, ieegStruct.tw, d_time_window, obj.varExplained, obj.numFold, isAuc);
[~, ytest, ypred, optimVarAll, aucMod, modelWeightsAll] = pcaLinearDecoderWrap(ieegInput, decoderUnit, ieegStruct.tw, d_time_window, obj.varExplained, obj.numFold, isAuc);
%[~, ytest, ypred] = stmfDecodeWrap(ieegInput, decoderUnit, ieegStruct.tw, d_time_window, obj.numFold, isauc);
else
% Call pcaLinearDecoderWrapTrainTest function for classification with separate train and test time windows
[~, ytest, ypred, optimVarAll, ~, modelWeightsAll] = pcaLinearDecoderWrapTrainTest(ieegInput, decoderUnit, ieegStruct.tw, d_time_window(1,:), d_time_window(2,:), obj.varExplained, obj.numFold, isAuc);
[~, ytest, ypred, optimVarAll, aucMod, modelWeightsAll] = pcaLinearDecoderWrapTrainTest(ieegInput, decoderUnit, ieegStruct.tw, d_time_window(1,:), d_time_window(2,:), obj.varExplained, obj.numFold, isAuc);
end

% Accumulate test labels and predictions
Expand All @@ -115,8 +115,10 @@
% Compute confusion matrix and accumulate
Cmat = confusionmat(ytest, ypred);
CmatAll = CmatAll + Cmat;
aucAll = aucAll + mean(aucMod);

end

aucAll = aucAll./obj.nIter;
% Compute normalized confusion matrix
CmatCatNorm = CmatAll ./ sum(CmatAll, 2);

Expand All @@ -129,6 +131,7 @@
decodeResultStruct.p = StatThInv(ytestall, decodeResultStruct.accPhoneme * 100);
decodeResultStruct.modelWeights = modelWeightsAll;
decodeResultStruct.optimVarAll = optimVarAll;
decodeResultStruct.aucAll = aucAll;
end

function decodeResultStruct = baseRegress(obj, ieegStruct, decoderUnit, d_time_window, selectChannel, selectTrial)
Expand Down Expand Up @@ -198,7 +201,7 @@
ieegStruct {mustBeA(ieegStruct, 'ieegStructClass')} % ieeg class object
decoderUnit double {mustBeVector} % Decoder labels
options.timeRes double = 0.02; % Decoder time resolution; Defaults to 0.02
options.timeWin double; % Time window size for analysis
options.timeWin double = 0.25; % Time window size for analysis
options.selectChannels double = 1:size(ieegStruct.data,1); % Select number of electrodes for analysis; Defaults to all
options.selectTrials double = 1:size(ieegStruct.data,2); % Select number of trials for analysis; Defaults to all
options.isModelWeight logical = 1 % Extract model weights if true;
Expand Down Expand Up @@ -384,6 +387,21 @@
decodeTimeStruct.pValTime = pValTime;
decodeTimeStruct.timeRange = timeRange;
end


function decoderChanStruct = indChanClassify(obj,ieegStruct,decoderUnit,options)
arguments
obj {mustBeA(obj, 'decoderClass')} % Decoder class object
ieegStruct {mustBeA(ieegStruct, 'ieegStructClass')} % ieeg class object
decoderUnit double {mustBeVector} % Decoder labels
options.d_time_window double = ieegStruct.tw; % Decoder time window; Defaults to epoch time-window
end

parfor iChan = 1:size(ieegStruct.data,1)
iChan
decoderChanStruct{iChan} = baseClassify(obj,ieegStruct,decoderUnit,d_time_window=options.d_time_window,selectChannel=iChan,isAuc=1);
end
end

end
end
74 changes: 42 additions & 32 deletions MATLAB/ieegClassDefinition/extractHGDataWithROI.m
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,15 @@
options.normFactor = [];
options.normType = 1; % 1 - z-score, 2 - mean normalization
options.fDown double = 200;
options.baseTimeRange = [-0.5 0];
options.baseTimeRange = [-0.6 -0.1];
options.baseName = 'Start'
options.respTimeThresh = -1;
options.respTimeThresh = 0.1;
options.respDurThresh = 2;
options.subsetElec cell = '' % subset of electrodes to select from stats
options.remNoiseTrials logical = true; % true to remove all noisy trials
options.remNoResponseTrials logical = true; % true to remove all no-response trials
options.remWMchannels logical = true;
options.remNoiseThreshold double = 10;
end

% Extract normalization type and time padding
Expand All @@ -73,22 +75,21 @@
disp('No normalization factors provided');

% Selecting all trials with no noise for baseline
ieegBaseStruct = extractRawDataWithROI(Subject, Epoch = options.baseName, ...
Time = [options.baseTimeRange(1)-timePad options.baseTimeRange(2)+timePad], ...
roi = options.roi, remFastResponseTimeTrials = -1, ...
remNoiseTrials = options.remNoiseTrials, remNoResponseTrials = false, ...
subsetElec = options.subsetElec, remWMchannels = options.remWMchannels);
ieegBaseStruct = extractRawDataWithROI(Subject, 'Epoch', options.baseName, ...
'Time', [options.baseTimeRange(1)-timePad options.baseTimeRange(2)+timePad], ...
'roi', options.roi, 'remFastResponseTimeTrials', -1, ...
'remNoiseTrials', false, 'remNoResponseTrials', false, ...
'subsetElec', options.subsetElec, 'remWMchannels', options.remWMchannels);

% Extracting normalization parameters for each subject
normFactorSubject = cell(length(Subject), 1);
parfor iSubject = 1:length(Subject)
if(isempty(ieegBaseStruct(iSubject).ieegStruct))
normFactorSubject{iSubject} = [];
continue;
if(~isempty(ieegBaseStruct(iSubject).ieegStruct))
ieegBaseHG = extractHiGamma(ieegBaseStruct(iSubject).ieegStruct, ...
options.fDown, options.baseTimeRange);
normFactorBase = extractHGnormFactor(ieegBaseHG);
normFactorSubject{iSubject} = normFactorBase;
end
ieegBaseHG = extractHiGamma(ieegBaseStruct(iSubject).ieegStruct, ...
options.fDown, options.baseTimeRange);
normFactorBase = extractHGnormFactor(ieegBaseHG);
normFactorSubject{iSubject} = normFactorBase;
end
else
normFactorSubject = options.normFactor;
Expand All @@ -97,29 +98,38 @@
clear ieegBaseStruct;

% Extracting field epochs for the fixed parameters
ieegFieldStruct = extractRawDataWithROI(Subject, Epoch = options.Epoch, ...
Time = [options.Time(1)-timePad options.Time(2)+timePad], ...
roi = options.roi, remFastResponseTimeTrials = options.respTimeThresh, ...
remNoiseTrials = options.remNoiseTrials, remNoResponseTrials = options.remNoResponseTrials, ...
subsetElec = options.subsetElec, remWMchannels = options.remWMchannels);
ieegFieldStruct = extractRawDataWithROI(Subject, 'Epoch', options.Epoch, ...
'Time', [options.Time(1)-timePad options.Time(2)+timePad], ...
'roi', options.roi, 'remFastResponseTimeTrials', options.respTimeThresh, ...
'remNoiseTrials', options.remNoiseTrials, 'remNoResponseTrials', options.remNoResponseTrials, ...
'subsetElec', options.subsetElec, 'remWMchannels', options.remWMchannels,'remlongDurationTrials',options.respDurThresh);

ieegHGAll = [];
ieegHGAll = repmat(struct('ieegHGNorm', [], 'channelName', [], 'normFactor', [], 'trialInfo', []), length(Subject), 1);

% Filtering signal in the high-gamma band for each subject
parfor iSubject = 1:length(Subject)
if(isempty(ieegFieldStruct(iSubject).ieegStruct))
ieegHGAll(iSubject).ieegHGNorm = [];
ieegHGAll(iSubject).channelName = [];
ieegHGAll(iSubject).normFactor = [];
ieegHGAll(iSubject).trialInfo = [];
continue;
if(~isempty(ieegFieldStruct(iSubject).ieegStruct))
ieegFieldHG = extractHiGamma(ieegFieldStruct(iSubject).ieegStruct, ...
options.fDown, options.Time, normFactorSubject{iSubject}, normType);
% Removing noisy trials
% if options.remNoiseThreshold > 0
% [~, goodtrialIds] = remove_bad_trials(ieegFieldHG.data, threshold = options.remNoiseThreshold, method=2);
% %goodTrialsCommon = extractCommonTrials(goodtrials);
% else
% goodtrialIds = ones(size(ieegEpoch,1),size(ieegEpoch,2));
% end
%
% for iChan = 1:size(ieegFieldHG.data,1)
% % Assigning noisy trials to values of 0
% ieegFieldHG.data(iChan,~goodtrialIds(iChan,:),:) = nan;
% end
ieegHGAll(iSubject).ieegHGNorm = ieegFieldHG;
ieegHGAll(iSubject).channelName = ieegFieldStruct(iSubject).channelName;
ieegHGAll(iSubject).trialInfo = ieegFieldStruct(iSubject).trialInfo;
ieegHGAll(iSubject).normFactor = normFactorSubject{iSubject};
ieegHGAll(iSubject).responseTime = ieegFieldStruct(iSubject).responseTime;
ieegHGAll(iSubject).responseDuration = ieegFieldStruct(iSubject).responseDuration;
end
ieegFieldHG = extractHiGamma(ieegFieldStruct(iSubject).ieegStruct, ...
options.fDown, options.Time, normFactorSubject{iSubject}, normType);
ieegHGAll(iSubject).ieegHGNorm = ieegFieldHG;
ieegHGAll(iSubject).channelName = ieegFieldStruct(iSubject).channelName;
ieegHGAll(iSubject).trialInfo = ieegFieldStruct(iSubject).trialInfo;
ieegHGAll(iSubject).normFactor = normFactorSubject{iSubject};
end

end
Loading
Loading