Jiale/test2_ort/ROC_calculate.py

79 lines
2.5 KiB
Python

# -*- coding: utf-8 -*-
# ref: https://github.com/ilyajob05/ROC_calculation
import numpy as np
import matplotlib.pyplot as plt
import sys
class FaceROCTooler:
def __init__(self, sims, ids):
self.sims = sims
self.ids = ids
# index 0 - TPR, 1 - FPR, 2 - threshold 3 - accuracy
self.curveROC = np.empty([0,3])
self.real_posi = self.ids == 1
self.real_nega = self.ids == 0
self.P = np.count_nonzero(self.real_posi)
self.N = np.count_nonzero(self.real_nega)
assert self.P + self.N == self.ids.shape[0]
def getTPRFPR_acc(self, threshold):
pre_posi = self.sims > threshold
pre_nega = self.sims <= threshold
TP = (pre_posi & self.real_posi).sum()
FP = (pre_posi & self.real_nega).sum()
FN = (pre_nega & self.real_posi).sum()
TN = (pre_nega & self.real_nega).sum()
acc = (TP + TN) / (self.P+self.N)
return threshold, TP/self.P, FP/self.N, acc
def getCurve(self, step):
curveROC = []
substep = int(1.0/step)
for i in range(substep):
curthreshold = step * i
curveROC.append(self.getTPRFPR_acc(curthreshold))
curveROC.append(self.getTPRFPR_acc(1.0))
self.curveROC = np.array(curveROC)
def findFPR(self, fpr):
assert self.curveROC.shape[0] > 0
idx = (np.abs(self.curveROC[:,2]-fpr))
return self.curveROC[idx]
def getbalance(self):
assert self.curveROC.shape[0] > 0
gaps = self.curveROC[:,1] - self.curveROC[:,2] # TPR - FPR
maxidx = np.argmax(gaps)
return maxidx, self.curveROC[maxidx]
def getMaxAcc(self):
assert self.curveROC.shape[0] > 0
accs = self.curveROC[:,3]
maxidx = np.argmax(accs)
return maxidx, self.curveROC[maxidx]
if __name__ =='__main__':
sims = np.loadtxt(sys.argv[1], dtype = np.float32)
ids = np.loadtxt(sys.argc[2], dtype = int)
# load data
f = np.load(sys.argv[1]) #simfile
id = np.load(sys.argv[2]) #intfile
roc = FaceROCTooler(sims,ids)
roc.getCurve(1e-3)
item = roc.findFPR(0.01)
print(f"threshold:{item[0]} , TPR: {item[1]} , FPR: {item[2]}, acc: {item[3]}")
_, ACC = roc.getMaxAcc()
print(f"threshold:{ACC[0]} , TPR: {ACC[1]} , FPR: {ACC[2]}, acc: {ACC[3]}")
_, balance = roc.getbalance()
print(f"threshold:{balance[0]} , TPR: {balance[1]} , FPR: {balance[2]}, acc: {balance[3]}")