Module truthnet.util

Expand source code
import argparse
from .truthnet import *
from .truthfinder import *
import gzip
from concurrent.futures import ThreadPoolExecutor, as_completed
import pylab as plt
import seaborn as sns
from sklearn import metrics
from zedstat import zedstat

def calculate_response_parallel(responsedictpath,
                                Veritas_model_path,
                                outfilepath,
                                verbose=False,
                                numworkers=10):

    with gzip.open(responsedictpath, 'rb') as filepath:
        responsedata = pickle.load(filepath)

    vmodel=load_veritas_model(Veritas_model_path)
    model=vmodel['model']
    model_neg=vmodel['model_neg']
    perturb=0
    H={}

    def process_item(i):
        subjectid = i['subject_id']
        resp = i['responses']
        s = pd.concat([pd.DataFrame(columns=model.feature_names),
                       pd.DataFrame(resp,index=['response'])])\
                .fillna('').values[0].astype(str)

        veritas = dissonance_distr_median(s, model)
        score = funcw(s, model, model_neg)
        lowert = funcm(s, model)
        return subjectid, (lowert, veritas, score)

    patients_responses = make_str_format(responsedata)
    list_response_dict = extract_ptsd_items(patients_responses)

    with ThreadPoolExecutor(max_workers=numworkers) as executor:
        future_to_item = {executor.submit(process_item, i): i for i in list_response_dict}

        for future in tqdm(as_completed(future_to_item), total=len(future_to_item)):
            item = future_to_item[future]
            try:
                subjectid, result = future.result()
                H[subjectid] = result
            except Exception as exc:
                print(f'Item {item} generated an exception: {exc}')

    hf=pd.DataFrame(H)
    hf.to_csv(outfilepath)
    if verbose:
        print(vmodel) 


def get_malinger_func(C0,C1,C2,score=True):
    '''
    -1 is malingering
    0 is no dx
    1 is true dx
    '''
    if score:
        def malinger(row):
            if (row.lower_threshold < C0) and (row.score > C2):
                return -1
            if (row.veritas > C1) and (row.score > C2):
                return -1    
            if (row.score < C2):
                return 0
            else:
                return 1

    else:
        def malinger(row):
            if (row.lower_threshold < C0):
                return -1
            if (row.veritas > C1) and (row.score > C2):
                return -1    
            if (row.score < C2):
                return 0
            else:
                return 1
            
    return malinger


def validate(response_dataframe,C0,C1,C2,
             DX=True,score=True, plots=True,verbose=True,
             validation_type='withdx',outfile='report.png'):
    '''
    response dataframe should look like:
           lower_threshold veritas score
    sub1         x            x       x
    sub2         x            x       x
    DX=True implies a dx column is present in input
    score=False implies the score is faked or absent (no dx subgroup 
    was available in training
    '''
    malinger=get_malinger_func(C0,C1,C2,score=score)
    if not DX:
        response_dataframe['dx'] = [x>C2 for x in response_dataframe.score.values]

    response_dataframe['mg']=response_dataframe.apply(malinger,axis=1)


    if DX:
        fpr, tpr, thresholds = metrics.roc_curve(response_dataframe.dx.values.astype(int),
                                                 response_dataframe.score.values.astype(float),
                                                 pos_label=1)
        ff=pd.DataFrame(tpr,fpr,columns=['tpr']).assign(threshold=thresholds)
        ff.index.name='fpr'
        zt=zedstat.processRoc(df=ff.reset_index(),
                              order=3, 
                              total_samples=304,
                              positive_samples=86,
                              alpha=0.01,
                              prevalence=0.5)
        zt.smooth(STEP=0.001)
        zt.allmeasures(interpolate=True)
        zt.usample(precision=3)
        Z=zt.get()
        if verbose:
            print(Z[Z.ppv>.875].tail(10))


    if validation_type == "withdx":
        mratio=(response_dataframe[(response_dataframe.mg==-1)
                                   & (response_dataframe.dx==1)].index.size)/response_dataframe.dx.sum()
        fullauc=zt.auc()
        
        if plots:
            #plt.style.use('seaborn-dark-palette')

            plt.figure(figsize=[20,12])
            plt.subplot(231)
            sns.scatterplot(data=response_dataframe,x='lower_threshold',y='veritas',hue='mg',size='dx')
            plt.plot([.5,2.5],[.76,.76],'-r')
            plt.plot([C0,C0],[.5,.95],'-r')

            plt.subplot(232)
            ax=sns.scatterplot(data=response_dataframe,x='score',y='veritas',hue='dx')
            plt.plot([.2,2.5],[C1,C1],'-r')
            plt.plot([C2,C2],[.5,.95],'-r')

            plt.subplot(233)
            sns.scatterplot(data=response_dataframe,x='score',y='lower_threshold',hue='dx')
            plt.subplots_adjust(wspace=0.23)  # Adjust this value as needed

            cf=response_dataframe.corr()
            plt.subplot(234)
            sns.heatmap(cf,cmap='jet',alpha=.5)


            plt.subplot(235)

            plt.plot(fpr,tpr,'g',lw=2)
            plt.gca().legend(['R20'])
            zt.get().tpr.plot(style='-b',lw=2)

            ax = plt.subplot(236)
            ax.text(0.5, 0.6, f'malinger prevalenec in DX: {mratio:.2f}', fontsize=16, ha='center')
            ax.text(0.5, 0.4, f'AUC: {fullauc[0]:.2f} $\pm$ {fullauc[1]-fullauc[0]:.2f}', fontsize=16, ha='center')
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_frame_on(False)
            
        return {'auc':fullauc,'mratio':mratio}, response_dataframe, zt

    if validation_type == "fnrexpt":
        fnr=response_dataframe[(response_dataframe.mg==1)].index.size/response_dataframe.index.size
        if plots:
            #plt.style.use('seaborn-dark-palette')

            plt.figure(figsize=[20,12])
            plt.subplot(231)
            sns.scatterplot(data=response_dataframe,x='lower_threshold',y='veritas',hue='mg',size='dx')
            plt.plot([.5,2.5],[.76,.76],'-r')
            plt.plot([C0,C0],[.5,.95],'-r')

            plt.subplot(232)
            ax=sns.scatterplot(data=response_dataframe,x='score',y='veritas',hue='dx')
            plt.plot([.2,2.5],[C1,C1],'-r')
            plt.plot([C2,C2],[.5,.95],'-r')

            plt.subplot(233)
            sns.scatterplot(data=response_dataframe,x='score',y='lower_threshold',hue='dx')
            plt.subplots_adjust(wspace=0.23)  # Adjust this value as needed

            cf=response_dataframe.corr()
            plt.subplot(234)
            sns.heatmap(cf,cmap='jet',alpha=.5)

            ax = plt.subplot(236)
            ax.text(0.5, 0.6, f'FNR in EXPT: {fnr:.2f}', fontsize=16, ha='center')
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_frame_on(False)

        return {'fnr':fnr}, response_dataframe


    if validation_type == "noscore":
        mrate=response_dataframe[response_dataframe.mg==-1].index.size/response_dataframe.index.size
        if plots:
            #plt.style.use('seaborn-dark-palette')

            plt.figure(figsize=[8,8])
            plt.subplot(111)
            sns.scatterplot(data=response_dataframe,x='lower_threshold',y='veritas',hue='mg')
            plt.plot([.1,2.5],[C1,C1],'-r')
            plt.plot([C0,C0],[.1,.95],'-r')
            plt.plot([C2,C2],[.1,.95],'-r')

            ax = plt.gca()

            ax.text(0.65, 0.8, f'mrate: {mrate:.2f}', fontsize=16, ha='center')

        return {'mrate':mrate}, response_dataframe

    plt.savefig(outfile,dpi=300,bbox_inches='tight',transparent=True)   




def drop_empty_string_keys(input_dict):
    # Create a new dictionary, excluding keys with empty string values
    cleaned_dict = {key: value for key, value in input_dict.items() if value != ''}
    return cleaned_dict

Functions

def calculate_response_parallel(responsedictpath, Veritas_model_path, outfilepath, verbose=False, numworkers=10)
Expand source code
def calculate_response_parallel(responsedictpath,
                                Veritas_model_path,
                                outfilepath,
                                verbose=False,
                                numworkers=10):

    with gzip.open(responsedictpath, 'rb') as filepath:
        responsedata = pickle.load(filepath)

    vmodel=load_veritas_model(Veritas_model_path)
    model=vmodel['model']
    model_neg=vmodel['model_neg']
    perturb=0
    H={}

    def process_item(i):
        subjectid = i['subject_id']
        resp = i['responses']
        s = pd.concat([pd.DataFrame(columns=model.feature_names),
                       pd.DataFrame(resp,index=['response'])])\
                .fillna('').values[0].astype(str)

        veritas = dissonance_distr_median(s, model)
        score = funcw(s, model, model_neg)
        lowert = funcm(s, model)
        return subjectid, (lowert, veritas, score)

    patients_responses = make_str_format(responsedata)
    list_response_dict = extract_ptsd_items(patients_responses)

    with ThreadPoolExecutor(max_workers=numworkers) as executor:
        future_to_item = {executor.submit(process_item, i): i for i in list_response_dict}

        for future in tqdm(as_completed(future_to_item), total=len(future_to_item)):
            item = future_to_item[future]
            try:
                subjectid, result = future.result()
                H[subjectid] = result
            except Exception as exc:
                print(f'Item {item} generated an exception: {exc}')

    hf=pd.DataFrame(H)
    hf.to_csv(outfilepath)
    if verbose:
        print(vmodel) 
def drop_empty_string_keys(input_dict)
Expand source code
def drop_empty_string_keys(input_dict):
    # Create a new dictionary, excluding keys with empty string values
    cleaned_dict = {key: value for key, value in input_dict.items() if value != ''}
    return cleaned_dict
def get_malinger_func(C0, C1, C2, score=True)

-1 is malingering 0 is no dx 1 is true dx

Expand source code
def get_malinger_func(C0,C1,C2,score=True):
    '''
    -1 is malingering
    0 is no dx
    1 is true dx
    '''
    if score:
        def malinger(row):
            if (row.lower_threshold < C0) and (row.score > C2):
                return -1
            if (row.veritas > C1) and (row.score > C2):
                return -1    
            if (row.score < C2):
                return 0
            else:
                return 1

    else:
        def malinger(row):
            if (row.lower_threshold < C0):
                return -1
            if (row.veritas > C1) and (row.score > C2):
                return -1    
            if (row.score < C2):
                return 0
            else:
                return 1
            
    return malinger
def validate(response_dataframe, C0, C1, C2, DX=True, score=True, plots=True, verbose=True, validation_type='withdx', outfile='report.png')

response dataframe should look like: lower_threshold veritas score sub1 x x x sub2 x x x DX=True implies a dx column is present in input score=False implies the score is faked or absent (no dx subgroup was available in training

Expand source code
def validate(response_dataframe,C0,C1,C2,
             DX=True,score=True, plots=True,verbose=True,
             validation_type='withdx',outfile='report.png'):
    '''
    response dataframe should look like:
           lower_threshold veritas score
    sub1         x            x       x
    sub2         x            x       x
    DX=True implies a dx column is present in input
    score=False implies the score is faked or absent (no dx subgroup 
    was available in training
    '''
    malinger=get_malinger_func(C0,C1,C2,score=score)
    if not DX:
        response_dataframe['dx'] = [x>C2 for x in response_dataframe.score.values]

    response_dataframe['mg']=response_dataframe.apply(malinger,axis=1)


    if DX:
        fpr, tpr, thresholds = metrics.roc_curve(response_dataframe.dx.values.astype(int),
                                                 response_dataframe.score.values.astype(float),
                                                 pos_label=1)
        ff=pd.DataFrame(tpr,fpr,columns=['tpr']).assign(threshold=thresholds)
        ff.index.name='fpr'
        zt=zedstat.processRoc(df=ff.reset_index(),
                              order=3, 
                              total_samples=304,
                              positive_samples=86,
                              alpha=0.01,
                              prevalence=0.5)
        zt.smooth(STEP=0.001)
        zt.allmeasures(interpolate=True)
        zt.usample(precision=3)
        Z=zt.get()
        if verbose:
            print(Z[Z.ppv>.875].tail(10))


    if validation_type == "withdx":
        mratio=(response_dataframe[(response_dataframe.mg==-1)
                                   & (response_dataframe.dx==1)].index.size)/response_dataframe.dx.sum()
        fullauc=zt.auc()
        
        if plots:
            #plt.style.use('seaborn-dark-palette')

            plt.figure(figsize=[20,12])
            plt.subplot(231)
            sns.scatterplot(data=response_dataframe,x='lower_threshold',y='veritas',hue='mg',size='dx')
            plt.plot([.5,2.5],[.76,.76],'-r')
            plt.plot([C0,C0],[.5,.95],'-r')

            plt.subplot(232)
            ax=sns.scatterplot(data=response_dataframe,x='score',y='veritas',hue='dx')
            plt.plot([.2,2.5],[C1,C1],'-r')
            plt.plot([C2,C2],[.5,.95],'-r')

            plt.subplot(233)
            sns.scatterplot(data=response_dataframe,x='score',y='lower_threshold',hue='dx')
            plt.subplots_adjust(wspace=0.23)  # Adjust this value as needed

            cf=response_dataframe.corr()
            plt.subplot(234)
            sns.heatmap(cf,cmap='jet',alpha=.5)


            plt.subplot(235)

            plt.plot(fpr,tpr,'g',lw=2)
            plt.gca().legend(['R20'])
            zt.get().tpr.plot(style='-b',lw=2)

            ax = plt.subplot(236)
            ax.text(0.5, 0.6, f'malinger prevalenec in DX: {mratio:.2f}', fontsize=16, ha='center')
            ax.text(0.5, 0.4, f'AUC: {fullauc[0]:.2f} $\pm$ {fullauc[1]-fullauc[0]:.2f}', fontsize=16, ha='center')
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_frame_on(False)
            
        return {'auc':fullauc,'mratio':mratio}, response_dataframe, zt

    if validation_type == "fnrexpt":
        fnr=response_dataframe[(response_dataframe.mg==1)].index.size/response_dataframe.index.size
        if plots:
            #plt.style.use('seaborn-dark-palette')

            plt.figure(figsize=[20,12])
            plt.subplot(231)
            sns.scatterplot(data=response_dataframe,x='lower_threshold',y='veritas',hue='mg',size='dx')
            plt.plot([.5,2.5],[.76,.76],'-r')
            plt.plot([C0,C0],[.5,.95],'-r')

            plt.subplot(232)
            ax=sns.scatterplot(data=response_dataframe,x='score',y='veritas',hue='dx')
            plt.plot([.2,2.5],[C1,C1],'-r')
            plt.plot([C2,C2],[.5,.95],'-r')

            plt.subplot(233)
            sns.scatterplot(data=response_dataframe,x='score',y='lower_threshold',hue='dx')
            plt.subplots_adjust(wspace=0.23)  # Adjust this value as needed

            cf=response_dataframe.corr()
            plt.subplot(234)
            sns.heatmap(cf,cmap='jet',alpha=.5)

            ax = plt.subplot(236)
            ax.text(0.5, 0.6, f'FNR in EXPT: {fnr:.2f}', fontsize=16, ha='center')
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_frame_on(False)

        return {'fnr':fnr}, response_dataframe


    if validation_type == "noscore":
        mrate=response_dataframe[response_dataframe.mg==-1].index.size/response_dataframe.index.size
        if plots:
            #plt.style.use('seaborn-dark-palette')

            plt.figure(figsize=[8,8])
            plt.subplot(111)
            sns.scatterplot(data=response_dataframe,x='lower_threshold',y='veritas',hue='mg')
            plt.plot([.1,2.5],[C1,C1],'-r')
            plt.plot([C0,C0],[.1,.95],'-r')
            plt.plot([C2,C2],[.1,.95],'-r')

            ax = plt.gca()

            ax.text(0.65, 0.8, f'mrate: {mrate:.2f}', fontsize=16, ha='center')

        return {'mrate':mrate}, response_dataframe

    plt.savefig(outfile,dpi=300,bbox_inches='tight',transparent=True)