import numpy as np
import pandas as pd
import matplotlib.pylab as plt
from matplotlib.pylab import rcParams
rcParams['figure.figsize'] = (10, 8)
from statsmodels.tsa.stattools import adfuller
from statsmodels.tsa.stattools import acf, pacf
from statsmodels.tsa.arima_model import ARIMA


def load_dataset():
    dataset = pd.read_csv('data/air-passengers.csv')
    # parse strings to datetime type
    dataset['Month'] = pd.to_datetime(dataset['Month'], infer_datetime_format=True)
    indexed_data = dataset.set_index(['Month'])
    
    return indexed_data

def plot_acf_dcf():
    # ACF and PACF plots

    lag_acf = acf(dataset_log_first_order_diff, nlags=20)
    lag_pacf = pacf(dataset_log_first_order_diff, nlags=20, method='ols')

    # Plot ACF
    plt.subplot(121)
    plt.plot(lag_acf)
    plt.axhline(y=0, linestyle='--', color='gray')
    plt.axhline(y=-1.96/np.sqrt(len(dataset_log_first_order_diff)), linestyle='--', color='gray')
    plt.axhline(y=1.96/np.sqrt(len(dataset_log_first_order_diff)), linestyle='--', color='gray')
    plt.title('ACF')

    # Plot PACF
    plt.subplot(122)
    plt.plot(lag_pacf)
    plt.axhline(y=0, linestyle='--', color='gray')
    plt.axhline(y=-1.96/np.sqrt(len(dataset_log_first_order_diff)), linestyle='--', color='gray')
    plt.axhline(y=1.96/np.sqrt(len(dataset_log_first_order_diff)), linestyle='--', color='gray')
    plt.title('PACF')
    plt.tight_layout()
    plt.show()

def model_ARIMA(indexed_data_log_scale, dataset_log_first_order_diff):
    #ARIMA Model
    model = ARIMA(indexed_data_log_scale, order=(2, 1, 2))
    results_ARIMA = model.fit(disp=-1)
    plt.plot(dataset_log_first_order_diff)
    plt.plot(results_ARIMA.fittedvalues, color='red')
    plt.title('RSS: %.4f'% sum((results_ARIMA.fittedvalues-dataset_log_first_order_diff['#Passengers'])**2))
    plt.show()    
    return results_ARIMA

def test_stationarity(timeseries):
    
    # Determinig rolling statistics
    rolling_mean = timeseries.rolling(window=12).mean()
    rolling_std = timeseries.rolling(window=12).std()
    
    # Plot rolling statistics
    orig = plt.plot(timeseries, color='blue', label='Original')
    mean = plt.plot(rolling_mean, color='red', label='Rolling Mean')
    std = plt.plot(rolling_std, color='black', label='Rolling Std')
    plt.legend(loc='best')
    plt.title('Rolling Mean & Standard Deviation')
    plt.show()
    
    print 'Result of Dicky=Fuller Test'
    dftest = adfuller(timeseries['#Passengers'], autolag='AIC')
    dfoutput = pd.Series(dftest[0:4], index=['Test Statistic', 'p-value', '#Lags Used', '#Observations Used'])
    for key, value in dftest[4].items():
        dfoutput['Critical Value (%s)'%key] = value
    print dfoutput


indexed_data = load_dataset()
indexed_data_log_scale = np.log(indexed_data)
dataset_log_first_order_diff = indexed_data_log_scale - indexed_data_log_scale.shift()
dataset_log_first_order_diff.dropna(inplace=True)

test_stationarity(dataset_log_first_order_diff)

plot_acf_dcf()

results_ARIMA = model_ARIMA(indexed_data_log_scale, dataset_log_first_order_diff)
results_ARIMA.plot_predict(1,204)
plt.show()
future = results_ARIMA.forecast(steps=60)
print np.exp(future[1])
