package dbcon

import (
	"context"
	"fmt"
	"reflect"
	"strings"
	"time"

	"gorm.io/driver/postgres"
	"gorm.io/gorm"
	"gorm.io/gorm/clause"
	"gorm.io/gorm/logger"

	log "gitlab.ursabyte.com/faizal.aziz/ulfssar-go/logger"
)

var Keys = map[string]string{
	"Primary Key": "pk",
	"Foregin Key": "fk",
	"Unique":      "unique",
}

type (
	postgresqldb struct {
		db  *gorm.DB
		err error
	}

	PostgreSqlOption struct {
		ConnectionString                     string
		MaxLifeTimeConnection                time.Duration
		MaxIdleConnection, MaxOpenConnection int
		Logger                               log.Logger
	}
)

func (d *postgresqldb) Error() error {
	return d.err
}

func (d *postgresqldb) Close() error {
	sql, err := d.db.DB()

	if err != nil {
		return err
	}

	if err := sql.Close(); err != nil {
		return err
	}
	return nil
}

func (d *postgresqldb) Begin() ORM {
	var (
		db  = d.db.Begin()
		err = db.Error
	)
	return &postgresqldb{db, err}
}

func (d *postgresqldb) Commit() error {
	return d.db.Commit().Error
}

func (d *postgresqldb) Rollback() error {
	return d.db.Rollback().Error
}

func (d *postgresqldb) Offset(offset int64) ORM {
	var (
		db  = d.db.Offset(int(offset))
		err = d.db.Error
	)
	return &postgresqldb{db, err}
}

func (d *postgresqldb) Limit(limit int64) ORM {
	var (
		db  = d.db.Limit(int(limit))
		err = d.db.Error
	)
	return &postgresqldb{db, err}
}

func (d *postgresqldb) First(object interface{}) error {
	var (
		res = d.db.First(object)
	)

	if res.Error != nil {
		return res.Error
	}

	return nil
}

func (d *postgresqldb) Last(object interface{}) error {
	var (
		res = d.db.Last(object)
	)

	if res.Error != nil {
		return res.Error
	}

	return nil
}

func (d *postgresqldb) Find(object interface{}) error {
	var (
		res = d.db.Find(object)
	)

	if res.Error != nil {
		return res.Error
	}

	return nil
}

func (d *postgresqldb) Model(value interface{}) ORM {
	var (
		db  = d.db.Model(value)
		err = db.Error
	)

	return &postgresqldb{db, err}
}

func (d *postgresqldb) OmitAssoc() ORM {
	var (
		db  = d.db.Omit(clause.Associations)
		err = db.Error
	)

	return &postgresqldb{db, err}
}

func (d *postgresqldb) Select(query interface{}, args ...interface{}) ORM {
	var (
		db  = d.db.Select(query, args...)
		err = db.Error
	)

	return &postgresqldb{db, err}
}

func (d *postgresqldb) Table(name string, args ...interface{}) ORM {
	var (
		db  = d.db.Table(name, args...)
		err = db.Error
	)
	return &postgresqldb{db, err}
}

func (d *postgresqldb) Where(query interface{}, args ...interface{}) ORM {
	var (
		db  = d.db.Where(query, args...)
		err = db.Error
	)
	return &postgresqldb{db, err}
}

func (d *postgresqldb) Order(value interface{}) ORM {
	var (
		db  = d.db.Order(value)
		err = d.db.Error
	)

	return &postgresqldb{db, err}
}

func (d *postgresqldb) Create(args interface{}) error {
	return d.db.Create(args).Error
}

func (d *postgresqldb) CreateTable(tableName string, structs interface{}) error {
	return d.db.Exec(generateSQLFromStorage(tableName, structs)).Error
}

func (d *postgresqldb) Update(args interface{}) error {
	return d.db.Updates(args).Error
}

func (d *postgresqldb) UpdateColumns(args interface{}) error {
	return d.db.UpdateColumns(args).Error
}

func (d *postgresqldb) Delete(model interface{}, args ...interface{}) error {
	return d.db.Delete(model, args...).Error
}

func (d *postgresqldb) WithContext(ctx context.Context) ORM {
	var (
		db = d.db.WithContext(ctx)
	)

	return &postgresqldb{db: db, err: nil}
}

func (d *postgresqldb) Raw(query string, args ...interface{}) ORM {
	var (
		db  = d.db.Raw(query, args...)
		err = db.Error
	)

	return &postgresqldb{db, err}
}

func (d *postgresqldb) Exec(query string, args ...interface{}) ORM {
	var (
		db  = d.db.Exec(query, args...)
		err = db.Error
	)

	return &postgresqldb{db, err}
}

func (d *postgresqldb) Scan(object interface{}) error {
	var (
		db = d.db.Scan(object)
	)

	return db.Error
}

func (d *postgresqldb) Preload(assoc string, args ...interface{}) ORM {
	var (
		db  = d.db.Preload(assoc, args)
		err = db.Error
	)

	return &postgresqldb{db, err}
}

func (d *postgresqldb) Joins(assoc string) ORM {
	var (
		db  = d.db.Joins(assoc)
		err = db.Error
	)

	return &postgresqldb{db, err}
}

func (d *postgresqldb) GetGormInstance() *gorm.DB {
	return d.db
}

func (d *postgresqldb) Count(count *int64) error {
	var (
		res = d.db.Count(count)
	)

	if res.Error != nil {
		return res.Error
	}

	return nil
}

func (d *postgresqldb) Association(column string) ORMAssociation {
	return d.db.Association(column)
}

func (d *postgresqldb) Or(query interface{}, args ...interface{}) ORM {
	var (
		db  = d.db.Or(query, args...)
		err = db.Error
	)

	return &postgresqldb{db, err}
}

func (d *postgresqldb) Save(data interface{}) error {
	var (
		db  = d.db.Save(data)
		err = db.Error
	)

	return err
}

func NewPostgreSql(option *PostgreSqlOption) (ORM, error) {
	var (
		opts = &gorm.Config{
			QueryFields: true,
		}
	)

	if option.Logger != nil {
		opts.Logger = logger.New(option.Logger, logger.Config{
			SlowThreshold:             time.Second,
			LogLevel:                  logger.Info,
			Colorful:                  false,
			IgnoreRecordNotFoundError: false,
		})
	}

	db, err := gorm.Open(postgres.Open(option.ConnectionString), opts)

	if err != nil {
		return nil, err
	}

	sql, err := db.DB()

	if err != nil {
		return nil, err
	}

	sql.SetConnMaxLifetime(option.MaxLifeTimeConnection)
	sql.SetMaxOpenConns(option.MaxOpenConnection)
	sql.SetMaxIdleConns(option.MaxIdleConnection)

	return &postgresqldb{db: db}, nil
}

func generateSQLFromStorage(tableName string, columnAndValues interface{}) string {
	st := reflect.TypeOf(columnAndValues)
	var columnNames string

	for i := 0; i < st.NumField(); i++ {
		field := st.Field(i)
		columnName := field.Tag.Get("sql-column")
		if columnName == "" {
			columnName = field.Name
		}
		typeName := field.Tag.Get("sql-type")
		if typeName == "" {
			typeName = "text"
		}
		sqlDefault := field.Tag.Get("sql-default")
		if sqlDefault != "" {
			def := "default " + sqlDefault
			sqlDefault = def
		}

		nullable := field.Tag.Get("sql-nullable")
		if nullable == "true" {
			nullable = "not null"
		} else {
			nullable = ""
		}

		argName := field.Tag.Get("sql-constraint")
		if argName != "" {
			argName = "constraint " + argName
		}

		keysName := field.Tag.Get("sql-keys")
		if keysName == "" {
			keysName = ""
		} else {
			if keysName != "" && argName == "" {
				for key, val := range Keys {
					if strings.Contains(keysName, key) || strings.Contains(keysName, strings.ToUpper(key)) || strings.Contains(keysName, strings.ToLower(key)) {
						argName = val
					}
				}
				argName = fmt.Sprintf("constraint %s_%s", strings.Replace(keysName, " ", "_", 1), argName)
			}
		}

		columnNames += fmt.Sprintf("%s %s %s %s %s %s ,", columnName, typeName, sqlDefault, nullable, argName, keysName)
	}

	columnNamesStr := strings.Replace(strings.TrimRight(columnNames, ","), "DROP DATABASE", "", 1)

	return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s);", tableName, columnNamesStr)
}
