Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plot_HG_and_stats individual electrode Z-score trace plotting is broken #14

Open
jimzhang629 opened this issue Aug 16, 2024 · 0 comments

Comments

@jimzhang629
Copy link
Contributor

def plot_channels_on_grid_time_perm_cluster(evoke_data, std_err_data, channels_subset, mat, sample_rate=2048, dec_factor=8, plot_x_dim=6, plot_y_dim=6):
    """
    Plots evoked EEG/MEG data for a subset of channels on a grid, overlaying significance markers for specified time windows.

    Parameters:
    - evoke_data: mne.Evoked object
        The evoked data to be plotted. This object contains the averaged EEG/MEG data over epochs.
    - std_err_data: 
        The standard error of the evoked data to be plotted
    - channels_subset: list of str
        A list of channel names to be plotted. Each channel name must correspond to a channel in `evoke_data`.
    - mat: numpy.array
        A binary matrix (same shape as evoke_data) indicating significant data points (1 for significant, 0 for non-significant).
    - sample_rate: float
        The sampling rate of the data, in Hz. Used to convert sample indices in `time_windows` to time in seconds.
    - dec_factor: int
        the decimation factor by which to downsample the sampling rate.
    - plot_x_dim: int, optional (default=6)
        The number of columns in the grid layout for plotting the channels.
    - plot_y_dim: int, optional (default=6)
        The number of rows in the grid layout for plotting the channels.

    Returns:
    - fig: matplotlib.figure.Figure object
        The figure object containing the grid of plots. Each plot shows the evoked data for a channel, with significance
        markers overlaid for the specified time windows.
    """
    fig, axes = plt.subplots(plot_x_dim, plot_y_dim, figsize=(20, 12))
    fig.suptitle("Channels with Significance Overlay")
    axes_flat = axes.flatten()

    for channel, ax in zip(channels_subset, axes_flat):
        stderr = std_err_data.data[channel_to_index[channel], :]
        time_in_seconds = np.arange(0, len(mat[channel_to_index[channel]])) / (sample_rate / dec_factor)  # Should be 2048 Hz sample rate
        sig_data_in_seconds = np.array(mat[channel_to_index[channel]])
        ax.plot(evoke_data.times, evoke_data.data[channel_to_index[channel], :])
         # Add the standard error shading
        ax.fill_between(evoke_data.times, evoke_data.data[channel_to_index[channel], :] - stderr, evoke_data.data[channel_to_index[channel], :] + stderr, alpha=0.2)

        # Find the maximum y-value for the current channel
        max_y_value = np.max(evoke_data.data[channel_to_index[channel], :])

        # Overlay significance as a horizontal line at the max y-value
        significant_points = np.where(sig_data_in_seconds == 1)[0]
        for point in significant_points:
            ax.hlines(y=max_y_value, xmin=time_in_seconds[point]-1, xmax=time_in_seconds[point] + 0.005 - 1, color='red', linewidth=1) # subtract 1 cuz the sig time is from 0 to 2.5, while the high gamma time is from -1 to 1.5

        ax.set_title(channel)

    plt.tight_layout()
    plt.subplots_adjust(top=0.95)
    return fig

The above code plots the same thing for raw as z-score.

And, the below code plots completely wrong indexed channels for the z-score:

def plot_channels_across_subjects(electrode_dict, data_dict, std_err_dict, mat_dict, channel_to_index_dict, plot_x_dim=6, plot_y_dim=6, sample_rate=2048, dec_factor=8, y_label="Amplitude"):
    """
    Plots evoked EEG/MEG data across multiple subjects for a set of electrodes, organized into subplots.

    Parameters:
    - electrode_dict: dict
        Dictionary where keys are subjects and values are lists of electrodes to plot for each subject.
    - data_dict: dict
        Dictionary where each key is a subject and each value is the evoked data for that subject.
    - std_err_dict: dict
        Dictionary where each key is a subject and each value is the standard error data for that subject.
    - mat_dict: dict
        Dictionary where each key is a subject and each value is the significance matrix for that subject.
    - channel_to_index_dict: dict
        Dictionary where each key is a subject and each value is a dictionary mapping channel names to their indices for that subject.
    - plot_x_dim: int, optional
        Number of columns in the grid layout for plotting the channels.
    - plot_y_dim: int, optional
        Number of rows in the grid layout for plotting the channels.
    - sample_rate: float
        Sampling rate of the data in Hz.
    - dec_factor: int
        Decimation factor by which to downsample the sampling rate.
    - y_label: str, optional
        Label for the y-axis.

    Returns:
    - fig: matplotlib.figure.Figure object
        The figure object containing the grid of plots.
    """
    channels_per_fig = plot_x_dim * plot_y_dim
    plot_index = 0
    fig_num = 1

    fig, axes = plt.subplots(plot_y_dim, plot_x_dim, figsize=(20, 12))
    fig.suptitle("Channels Across Subjects with Significance Overlay")
    axes_flat = axes.flatten()

    for subject, electrodes in electrode_dict.items():
        for electrode in electrodes:
            if electrode in channel_to_index_dict[subject]:
                if plot_index >= channels_per_fig:
                    plt.tight_layout()
                    plt.subplots_adjust(top=0.95)
                    yield fig, fig_num

                    # Start a new figure if the previous one is full
                    fig, axes = plt.subplots(plot_y_dim, plot_x_dim, figsize=(20, 12))
                    fig.suptitle("Channels Across Subjects with Significance Overlay")
                    axes_flat = axes.flatten()
                    plot_index = 0
                    fig_num += 1

                ax = axes_flat[plot_index]
                ch_idx = channel_to_index_dict[subject][electrode]
                stderr = std_err_dict[subject].data[ch_idx, :]
                time_in_seconds = np.arange(0, len(mat_dict[subject][ch_idx])) / (sample_rate / dec_factor)
                sig_data_in_seconds = np.array(mat_dict[subject][ch_idx])

                ax.plot(data_dict[subject].times, data_dict[subject].data[ch_idx, :])
                # Add the standard error shading
                ax.fill_between(data_dict[subject].times, data_dict[subject].data[ch_idx, :] - stderr, data_dict[subject].data[ch_idx, :] + stderr, alpha=0.2)

                # Find the maximum y-value for the current channel
                max_y_value = np.max(data_dict[subject].data[ch_idx, :])

                # Overlay significance as a horizontal line at the max y-value
                significant_points = np.where(sig_data_in_seconds == 1)[0]
                for point in significant_points:
                    ax.hlines(y=max_y_value, xmin=time_in_seconds[point]-1, xmax=time_in_seconds[point] + 0.005 - 1, color='red', linewidth=1)

                ax.set_title(f"{subject}: {electrode}")
                ax.set_ylabel(y_label)

                plot_index += 1

    plt.tight_layout()
    plt.subplots_adjust(top=0.95)
    yield fig, fig_num

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant