"""
Compute diffusion-limited oxidation and ageing in rigid medium.
"""
from __future__ import print_function, absolute_import
import os
import sys
import logging

logging.basicConfig(level=logging.DEBUG)

DATA_DIR = '/home/jan/Documents/PYTHON/sfepy3'
sys.path.append(DATA_DIR)

import numpy as np

from sfepy.base.base import IndexedStruct
from sfepy.discrete import (FieldVariable, Material, Integral, Function,
                            Equation, Equations, Problem)
from sfepy.discrete.conditions import Conditions, EssentialBC, InitialCondition
from sfepy.discrete.fem import Mesh, FEDomain, Field
from sfepy.homogenization.utils import define_box_regions
from sfepy.solvers.ls import ScipyDirect
from sfepy.solvers.nls import Newton
from sfepy.solvers.ts_solvers import SimpleTimeSteppingSolver
from sfepy.terms import Term

ORDER = 1
FLUX_COEF = 1e-3
DIFFUSIVITY = 1.
DENSITY = 2e3
GAS_CONSTANT = 8.314
TEMPERATURE = 300.
E_EXPONENT = 1.
F_EXPONENT = 1.
ALPHA_COEF = 1.
V_COEF = .1

def get_reaction_coefs(ts, coors, problem, mode=None, **kwargs):
    logging.debug('Entered get_reaction_coefs, mode = %s' % mode)
    out = None
    if mode == 'qp':
        logging.debug('mode = qp')
        if ts.step == 0:
            state = problem.create_state()
            state.apply_ic()
            problem.equations.variables.set_data(state())
        q_vals = problem.evaluate('ev_volume_integrate.1.Omega(q)',
                                  mode='qp', verbose=False)
        k_1 = ALPHA_COEF / DENSITY * (1 - q_vals) \
              * np.exp(F_EXPONENT/GAS_CONSTANT/TEMPERATURE)
        k_2 = V_COEF * (1 - q_vals) \
              * np.exp(E_EXPONENT/GAS_CONSTANT/TEMPERATURE)
        k_1.shape = (k_1.shape[0] * k_1.shape[1], 1, 1)
        k_2.shape = (k_2.shape[0] * k_2.shape[1], 1, 1)
        out = {'k_1' : k_1, 'k_2' : k_2}

    logging.debug('End of get_reaction_coefs')
    return out

def main():
    mesh = Mesh.from_file(os.sep.join([DATA_DIR, 'meshes', '3d', 'block.mesh']))
    domain = FEDomain('domain', mesh)

    omega = domain.create_region('Omega', 'all')
    gamma = domain.create_region('Gamma', 'vertices of surface', 'facet')

    lbn, rtf = domain.get_mesh_bounding_box()
    box_regions = define_box_regions(3, lbn, rtf)
    regions = dict([
        [r, domain.create_region(r, box_regions[r][0], box_regions[r][1])]
        for r in box_regions])
    gamma_n = domain.create_region('Gamma_n', 'r.Left +v r.Bottom +v r.Near', 'facet')

    field = Field.from_args(
        'fu', np.float64, 'scalar', omega, approx_order=ORDER)

    c = FieldVariable('c', 'unknown', field, history=True)
    q = FieldVariable('q', 'unknown', field, history=True)
    s = FieldVariable('s', 'test', field, primary_var_name='c')
    r = FieldVariable('r', 'test', field, primary_var_name='q')

    ### Materials ###
    m = Material('m', diffusivity=DIFFUSIVITY*np.eye(3))

    m_2_fun = Function('m_2_fun', get_reaction_coefs)
    m_2 = Material('m_2', function=m_2_fun)

    ### Integral ###
    integral = Integral('i', order=2*ORDER)

    ### Terms ###
    term_time = Term.new(
        'dw_volume_dot(s, dc/dt)', integral, omega, s=s, c=c)
    term_vol = Term.new(
        'dw_diffusion(m.diffusivity, s, c)', integral, omega, s=s, c=c, m=m)
    term_reaction = Term.new(
        'dw_volume_dot(m_2.k_1, s, c)', integral, omega, s=s, c=c, m_2=m_2)

    term_q_time = Term.new(
        'dw_volume_dot(r, dq/dt)', integral, omega, r=r, q=q)
    term_q_reaction = Term.new(
        'dw_volume_dot(m_2.k_2, r, c)', integral, omega, r=r, c=c, m_2=m_2)

    ### Equations ###
    eq_balance = Equation(
        'balance', term_time+term_vol+term_reaction)
    eq_evolution = Equation(
        'evolution', term_q_time-term_q_reaction)

    equations = Equations([eq_balance, eq_evolution])

    ### Initial condition ###
    ic = InitialCondition('ic', omega, {'c.0' : .1, 'q.0' : .2})

    ### Solvers ###
    ls = ScipyDirect({})
    nls_status = IndexedStruct()
    nls = Newton({'is_linear' : True}, lin_solver=ls, status=nls_status)

    ### Problem definition ###
    pb = Problem(
        'diffusion', equations=equations, nls=nls, ls=ls,
    )
    pb.set_ics(Conditions([ic,]))

    logging.debug('Init tss')
    tss = SimpleTimeSteppingSolver({'t0' : 0., 't1' : 4., 'n_step' : 5},
                                   problem=pb)
    logging.debug('done.')
    logging.debug('tss.init_time()')
    tss.init_time()
    logging.debug('done.')

    ### Solution ###
    for step, time, state in tss():
        pass

if __name__  == '__main__':
    main()
