#importare tutti i pacchetti necessari per costruire un classificatore o un regressore
#che erediti i metodi dei classificatori o regressori di sklearn
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from Helper_SVM_students import *



class SVM_custom(ClassifierMixin, BaseEstimator):
    #Initialization of the hyper-parameters
    def __init__(self, C=1.0, opt_strat = 'primal', kernel = 'linear', degree=2, coef0=0.0):
        self.C = C
        self.opt_strat = opt_strat
        self.kernel = kernel
        self.degree = degree
        self.coef0 = coef0

    def fit(self, X, y):
        #Training phase of the model. Implement all the steps that yield the attributes needed for the prediction phase
        if self.opt_strat == 'primal':
            w, b, xi = svm_primal_optimization(X, y, self.C)

            self.X_ = X
            self.y_ = y
            self.w = w
            self.b = b
            self.is_fitted_ = True
            self.classes_ = np.unique(y)
        else:
            alpha = svm_dual_optimization(X, y, self.C, self.kernel, self.degree, self.coef0)
            w=sum([y[i]*alpha[i]*X[i,:] for i in range(X.shape[0])])

            for i in range(X.shape[0]):
                if alpha[i] != 0:
                    sup_ind = i
                    break

            b = y[sup_ind] - sum([alpha[j]*y[j]*X[sup_ind,:]@X[j,:] for j in range(X.shape[0])])
            self.X_=X
            self.y_=y
            self.alpha=alpha
            self.w =w
            self.b =b
            self.is_fitted_ = True
            self.classes_ = np.unique(y)

        return self

    def decision_function(self, X):
        #Whenever necessary, implement a function that associate to every element x a score that determines the label assigned to x
        score=[]
        return score

    def predict(self, X):
        X_tr=self.X_
        #Prediction phase of the model. Assign to every x in X a label saved in y_pred
        y_pred=np.zeros(X.shape[0])
        for i in range(X.shape[0]):
            if self.kernel=='linear':
                if self.w @ X[i,:] + self.b >=0:
                    y_pred[i]=1
                else:
                    y_pred[i]=-1
        return y_pred