From 76cff6f6f0bc39115b5ea692643652454019865c Mon Sep 17 00:00:00 2001
From: Marcus Herrmann <marcus.herrmann@sed.ethz.ch>
Date: Thu, 10 Nov 2016 18:36:19 +0100
Subject: [PATCH] Improve amp-mag regression

- using statsmodels for regression
     --> better statistical analysis
     --> enable Weighted LS (if desired)
            (plotted by scaling scatter plot size)
- plotting confidence + prediction intervals
- plot a line with slope = 1
- improve residual plot
---
 TM/eventops.py |  41 +++++++-----
 TM/plotting.py | 178 ++++++++++++++++++++++++++++++++++---------------
 version        |   2 +-
 3 files changed, 151 insertions(+), 70 deletions(-)

diff --git a/TM/eventops.py b/TM/eventops.py
index 4ad165a..61e5d1b 100644
--- a/TM/eventops.py
+++ b/TM/eventops.py
@@ -24,6 +24,7 @@ from itertools import repeat
 import math
 import scipy
 import numpy as np
+import statsmodels.api as sm
 
 # ObsPy
 from obspy import UTCDateTime
@@ -316,7 +317,7 @@ def correct_amplitude_attenuation(events, eventType):
 #    return logLike
 
 
-def regressMagnitude(foundEvents, catalog, magRange):
+def regressMagnitude(foundEvents, catalog, magRange, w_base=1):
     """
     Determines the magnitde (ML) for the detections (foundEvents) based on a
     regression b/w the catalogs' amplitudes and magnitudes. This relation is
@@ -344,7 +345,6 @@ def regressMagnitude(foundEvents, catalog, magRange):
     if len(catalog[0]) > 4:
         raise Warning("Catalog has to much values. It should only have ID, time, Mag, amp.")
 
-
     # Deepcopy catalog + events to not touch original
     catalog = copy.deepcopy(catalog)
     foundEvents = copy.deepcopy(foundEvents)
@@ -362,21 +362,28 @@ def regressMagnitude(foundEvents, catalog, magRange):
     catalog_ML = np.asarray([x[2] for x in catalog_minMag])
 
     # Do regression using Linear Least Squares of log10-amp
-    b, a, r_value, p_value, std_err = scipy.stats.linregress(np.log10(catalog_A), catalog_ML)
-    print("\nLS-regression: ML = %.2f * log(A) + %.2f" % (b, a))
-
-#    # Do regression using maximum likelihood
-#    nll = lambda *args: -loglikelihoodRegress(*args)
-#    initParams = [b, a, std_err]
-#    initParams = [1, 1, 1]
-##    catalog_XY = np.sort(np.column_stack((catalog_logA, catalog_ML)), axis=0)
-#    result = scipy.optimize.minimize(nll, initParams, args=(catalog_XY[:,0], catalog_XY[:,1]),
-#                                     method='nelder-mead')
-#    print(result)
-#    b_maxL, a_maxL, sd_maxL = result["x"]
-
-    plotting.plot_amp_regression(catalog_ML, catalog_A,
-                                 regLineInfos=[b, a, r_value, p_value, std_err, ''])
+#    b, a, r2, _, std_err = scipy.stats.linregress(np.log10(catalog_A), catalog_ML)
+
+    # Regression using statsmodels - weighting possible
+    # 1. Define the weights
+    # w_pot = catalog_A
+    w_pot = catalog_ML
+    weights = np.ones(len(catalog_A)) * w_base**w_pot
+    weights /= max(weights)
+    # 2. Fit a linear line
+    model = sm.WLS(catalog_ML, sm.add_constant(np.log10(catalog_A)), weights=weights)
+    results = model.fit()
+    # Get results & statistics
+    a, b = results.params
+    std_errs = results.bse
+    r2 = results.rsquared
+
+    LStype = 'OLS' if w_base == 1 else 'WLS'
+    print("\n%s-regression: ML = %.2f * log(A) + %.2f" % (LStype, b, a))
+
+    # Plot the regression line
+    regLineInfos = [b, a, r2, std_errs, w_base, '']
+    plotting.plot_amp_regression(catalog_ML, catalog_A, regLineInfos, weights)
 
     # Update foundEvents with REGRESSION ML
     [x.insert(2, float("{0:.2f}".format(b*math.log10(x[1]) + a))) for x in foundEvents]
diff --git a/TM/plotting.py b/TM/plotting.py
index 6fc7330..dae965a 100644
--- a/TM/plotting.py
+++ b/TM/plotting.py
@@ -723,7 +723,8 @@ def plotCHlocationOverview(directory):
     subprocess.call([command] + strArgs, env=environ, stderr=subprocess.STDOUT)
 
 
-def plot_amp_regression(magSet, ampSet, regLineInfos=None, legendPos='upper left'):
+def plot_amp_regression(magSet, ampSet, regLineInfos=None,
+                        weights=None, legendPos='upper left'):
     """
     Plots the regression for given magnitude <-> amplitude pairs.
 
@@ -734,7 +735,7 @@ def plot_amp_regression(magSet, ampSet, regLineInfos=None, legendPos='upper left
     if isinstance(magSet, list) or isinstance(ampSet, list):
         numSets = len(magSet)
         if len(ampSet) != numSets:
-            raise Warning("Number of magnitude and amplitde sets doesn't agree.")
+            raise Warning("Number of magnitude and amplitude sets doesn't agree.")
     else:
         numSets = 1
         # Pack the arrays into a list (with one entry)
@@ -764,15 +765,28 @@ def plot_amp_regression(magSet, ampSet, regLineInfos=None, legendPos='upper left
         plt.ylabel("ML (SED)")
 
         # Variant 1: scatter with density coloring (only useful for ONE set)
-        if numSets == 1 and len(mags) > 100:
-            # Calculate the point density
-            xy = np.vstack([np.log10(amps), mags])
-            density = scipy.stats.gaussian_kde(xy)(xy)
-            # Sort points by density; densest points are plotted last
-            idx = density.argsort()
+        if numSets == 1:
+
+            if len(mags) > 100:
+                # Calculate the point density
+                xy = np.vstack([np.log10(amps), mags])
+                density = scipy.stats.gaussian_kde(xy)(xy)
+                # Sort points by density; densest points are plotted last
+                idx = density.argsort()
+            else:
+                density = np.ones(len(amps))
+                idx = np.arange(0, len(amps))
+
+            if weights is None:
+                weights = np.ones(len(amps))  # (no weigths)
+
+            weights /= np.max(weights)  # Make sure that it ranges b/w [0...1]
+            weights *= 25/np.mean(weights)  # Scale for plot
+
             # Plot
-            plt.scatter(amps[idx], mags[idx], s=25, alpha=0.5,
-                        c=density[idx], cmap=plt.get_cmap('gist_heat'), edgecolor='')
+            plt.scatter(amps[idx], mags[idx], s=weights[idx],
+                        alpha=0.5, c=density[idx], cmap=plt.get_cmap('gist_heat'),
+                        edgecolor='')
             plt.gca().set_xscale('log')
 
         # otherwise Variant 2: simple one-colored scatter (per set)
@@ -780,12 +794,14 @@ def plot_amp_regression(magSet, ampSet, regLineInfos=None, legendPos='upper left
             plt.semilogx(amps, mags, '.', c=colors[iSet])
 
     # min/max among all sets
-    minAmp_log10 = math.log10(min([min(x) for x in ampSet]))
-    maxAmp_log10 = math.log10(max([max(x) for x in ampSet]))
-#    minAmp_log10 = math.log10(1e4)
-#    maxAmp_log10 = math.log10(2e7)
+    minAmp_log10 = np.log10(min([min(x) for x in ampSet]))
+    maxAmp_log10 = np.log10(max([max(x) for x in ampSet]))
+#    minAmp_log10 = np.log10(5e3)
+#    maxAmp_log10 = np.log10(5e6)
     minMag = min([min(x) for x in magSet])
     maxMag = max([max(x) for x in magSet])
+#    minMag = 0.7
+#    maxMag = 3.3
 
     # Plot (possible) regression line(s) + infos
     if regLineInfos:
@@ -793,66 +809,124 @@ def plot_amp_regression(magSet, ampSet, regLineInfos=None, legendPos='upper left
         if not isinstance(regLineInfos[0], list):
             regLineInfos = [regLineInfos]  # Pack into a list
 
+        # b=1 line
+        x_b1 = np.linspace(minAmp_log10*0.9, maxAmp_log10*1.5, 3)
+        a_n_b = np.mean(np.array([x[:2] for x in regLineInfos]), 0)
+        x_mean = np.mean([np.mean(np.log10(x)) for x in ampSet])
+        plt.plot(10**x_b1, a_n_b[1] + (a_n_b[0]-1)*x_mean + 1*x_b1, ':', c='k', lw=1,
+                 label='slope = 1')
+
         for idx, regLine in enumerate(regLineInfos):
             # Unpack values
-            b, a, r_value, p_value, std_err, label = regLine
+            b, a, r2, std_errs, w_base, label = regLine
+            x = np.log10(ampSet[idx])
+
             # Create x vector
-            if len(regLineInfos) == numSets:
-                x_plot = np.array([math.log10(min(ampSet[idx]))-0.1,
-                                   math.log10(max(ampSet[idx]))+0.1])
-            else:
-                x_plot = np.array([minAmp_log10-0.1, maxAmp_log10+0.1])
-            # Plot
-            plt.plot(10**x_plot, a + b*x_plot, '-', c=colors[idx], lw=1,
+            x_plot = np.linspace(min(x), max(x), 100)
+
+            y_plot = a + b*x_plot
+
+            # Plot the regression line
+            w_label = '\nevent_weight: %d$^{ML}$' % w_base if w_base != 1 else ''
+            plt.plot(10**x_plot, y_plot, '-', c=colors[idx], lw=1,
                      label=('%s\n' % label +
                             'ML = %.2f * log(A) + %.2f' % (b, a) +
                             '\n$R^2$: %.4f    $\sigma_{slope}$: %.4f' %
-                            (r_value**2, std_err)))
-            label += '_'
-        #    plt.plot(10**x_plot, a_maxL + b_maxL*x_plot, 'r-', lw=1)
-
-    #    plt.axis('scaled')  # Doesn't work with semilog axis
-    plt.axis([10**(minAmp_log10-0.2), 10**(maxAmp_log10+0.2),
-              minMag-0.2, maxMag+0.2])
-#    plt.axis([10**(4.9893-0.2), 10**(7.0858+0.2), 0.7-0.2, 3.4+0.2])  # SED events for OTER2
-#    plt.axis([10**(3.8516-0.2), 10**(6.7256+0.2), 0.7-0.2, 3.4+0.2])  # SED events for MATTE
-
-    # Legend
-    l = plt.legend(loc=legendPos, fontsize=legFontSize, fancybox=True)
-    # Color even the text of each entry with the appropriate color
-    for idx, legTexts in enumerate(l.get_texts()):
-        legTexts.set_color(colors[idx])
-
-    plt.savefig(config.directory + "/regression_mag-amp_%s%devents.png" %
-                (label, sum(len(x) for x in magSet)), dpi=150, bbox_inches='tight')
-#    plt.savefig(config.directory + "/regression_mag-amp_%devents.pdf" % len(catalog_ML),
-#                dpi=150, bbox_inches='tight')
-    # Determine current time as ID... to not overwrite possible previous file
-#    crtTime = time.strftime("%Y-%m-%dT%H%M%S", time.localtime())
-#    plt.savefig("plots/regression_" + crtTime + ".png", dpi=150)
-    plt.close()
+                            (r2, std_errs[1]) + w_label))
+            label += '_' if label != '' else ''
+
+            # Plot confidence + prediction interval
+            resids = magSet[idx] - (a + b*x)
+            ssr = np.sum(resids**2)
+            x_mean = np.log10(ampSet[idx]).mean()
+            n = len(ampSet[idx])
+            dof = n - 2
+            s = np.sqrt(ssr/dof)
+            t = scipy.stats.t.ppf(1-0.05/2, df=dof)  # Using t-student's distribution
+
+            conf = t * s * np.sqrt(1.0/n +
+                                   (x_plot-x_mean)**2 / (np.sum((x_plot-x_mean)**2)))
+            upper_confI = y_plot + abs(conf)
+            lower_confI = y_plot - abs(conf)
+
+            pred = t * s * np.sqrt(1 + 1.0/n +
+                                   (x_plot-x_mean)**2 / (sum(x**2) - 1/n * sum(x**2)))
+
+            upper_predI = y_plot + abs(pred)
+            lower_predI = y_plot - abs(pred)
+
+            if idx == len(regLineInfos)-1:
+                label_CI = '95% confidence interval'
+                label_PI = '95% prediction interval'
+            else:
+                label_CI = ''
+                label_PI = ''
+            plt.fill_between(10**x_plot, lower_confI, upper_confI,
+                             color=colors[idx], alpha=0.3, label=label_CI)
+            plt.fill_between(10**x_plot, lower_predI, upper_predI,
+                             color=colors[idx], alpha=0.1, label=label_PI)
+
+        #    plt.axis('scaled')  # Doesn't work with semilog axis
+        plt.axis([10**minAmp_log10*0.8, 10**maxAmp_log10*1.5,
+                  minMag-0.2, maxMag+0.2])
+
+        # Legend
+        # Put the b=1 line at the end
+        handles, labels = plt.gca().get_legend_handles_labels()
+        handles.append(handles.pop(0))
+        labels.append(labels.pop(0))
+
+        l = plt.legend(handles, labels, loc=legendPos,
+                       fontsize=legFontSize, fancybox=True)
+        # Color even the text of each entry with the appropriate color
+        for idx, legTexts in enumerate(l.get_texts()):
+            if idx < len(regLineInfos):  # To allow plotting label of CI & PI
+                legTexts.set_color(colors[idx])
+
+        if numSets > 1:
+            label = 'ALLclusters_'
+
+        plt.savefig(config.directory + "/regression_mag-amp_%s%devents_w%d.png" %
+                    (label, sum([len(x) for x in magSet]), w_base),
+                    dpi=150, bbox_inches='tight')
+#        plt.savefig(config.directory + "/regression_mag-amp_%devents.pdf" % len(catalog_ML),
+#                    dpi=150, bbox_inches='tight')
+        # Determine current time as ID... to not overwrite possible previous file
+#        crtTime = time.strftime("%Y-%m-%dT%H%M%S", time.localtime())
+#        plt.savefig("plots/regression_" + crtTime + ".png", dpi=150)
+        plt.close()
 
     # Plot residuals
     if numSets == 1:
         amps = ampSet[0]
         mags = magSet[0]
 
-        resis = mags - a - b*np.log10(amps)
-        plt.plot(amps, resis, 'k.')
+        resids = mags - (a + b*np.log10(amps))
+        resid_ssr = np.sum(resids**2)
+        resids_MSE = np.sum(resids**2)/len(amps)
+
+        plt.scatter(amps, resids, s=weights, alpha=0.5, color='gray')
         plt.axhline(color='k', linestyle='--')
+        plt.fill_between(10**x_plot, abs(pred), -abs(pred),
+                         color=colors[0], alpha=0.1, label=label_PI)
+
         plt.gca().set_xscale('log')
 
-        plt.annotate('sum(residuals²): %.2f' % sum(resis**2),
+        plt.annotate('sum(residuals²): %.2f' % resid_ssr,
                      xy=(0, 1), xycoords='axes fraction', fontsize=12,
                      xytext=(4, 4), textcoords='offset points',
                      ha='left', va='bottom', bbox=dict(boxstyle="square", fc="w"))
+        plt.annotate('MSE: %.3f' % resids_MSE,
+                     xy=(1, 1), xycoords='axes fraction', fontsize=12,
+                     xytext=(-4, 4), textcoords='offset points',
+                     ha='right', va='bottom', bbox=dict(boxstyle="square", fc="w"))
 
         plt.xlabel("max. amplitude @%s" % (config.station))
         plt.ylabel("residuals")
-        plt.axis([10**(minAmp_log10-0.2), 10**(maxAmp_log10+0.2), -1, 1])
+        plt.axis([10**minAmp_log10*0.8, 10**maxAmp_log10*1.5, -1, 1])
 
-        plt.savefig(config.directory + "/regression_mag-amp_resid_%s%devents.png" %
-                    (label, len(mags)), dpi=150, bbox_inches='tight')
+        plt.savefig(config.directory + "/regression_mag-amp_resid_%s%devents_w%d.png" %
+                    (label, len(mags), w_base), dpi=150, bbox_inches='tight')
         plt.close()
 
 
diff --git a/version b/version
index 17e51c3..d917d3e 100644
--- a/version
+++ b/version
@@ -1 +1 @@
-0.1.1
+0.1.2
-- 
GitLab