Gradient and Optimization with JAX and Optax

Gradient and Optimization with JAX and Optax#

import jax
import optax

import jax.numpy as jnp

import meent
from meent.on_jax.optimizer.loss import LossDeflector
backend = 1  # JAX

# common
pol = 0  # 0: TE, 1: TM

n_top = 1  # n_topncidence
n_bot = 1  # n_transmission

theta = 0 * jnp.pi / 180  # angle of incidence
phi = 0 * jnp.pi / 180  # angle of rotation

wavelength = 900

thickness = [500., 1000.]  # thickness of each layer, from top to bottom.
period = [1000.]  # length of the unit cell. Here it's 1D.

fto = [10]

type_complex = jnp.complex128
ucell_1d_m = jnp.array([
    [[0, 0, 0, 1, 1, 1, 1, 0, 0, 0, ]],
    [[1, 1, 1, 1, 0, 1, 1, 1, 1, 1, ]],
    ]) * 4. + 1.  # refractive index

2.1 Gradient#

Gradient can be calculated with the help of jax.value_and_grad function. Read this for further information: Automatic Differentiation

Optax is used for optimization. Like PyTorch, Optax also provides various loss functions and optimizers so users easily can utilize well-established implementations. Refer this tutorial: Learn Optax

Gradient can be utilized to solve optimization problems. Here are examples that show couple of ways to get gradient or optimized values with or without predefined functions of meent.

2.1.1 Examples#

mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi, fto=fto, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex)

pois = ['ucell', 'thickness']
forward = mee.conv_solve
loss_fn = LossDeflector(x_order=1, y_order=0)

# case 1: Gradient
grad = mee.grad(pois, forward, loss_fn)

print('ucell gradient:')
print(grad['ucell'])
print('thickness gradient:')
print(grad['thickness'])
ucell gradient:
[[[-0.05114874 -0.02533636 -0.00729883  0.07873582 -0.01841166
    0.09447967  0.08779338 -0.0012304  -0.03640632 -0.04779842]]

 [[-0.17959986 -0.08614187 -0.22233491 -0.19389416  0.08978906
    0.05564021 -0.04575985 -0.13595162 -0.29835993  0.12867445]]]
thickness gradient:
[ 0.00222043 -0.00671415]

2.2 Optimization#

2.2.1 Examples#

Example 1

mee = meent.call_mee(backend=backend, pol=pol, n_top=n_top, n_bot=n_bot, theta=theta, phi=phi, fto=fto, wavelength=wavelength, period=period, ucell=ucell_1d_m, thickness=thickness, type_complex=type_complex)

pois = ['ucell', 'thickness']
forward = mee.conv_solve
loss_fn = LossDeflector(x_order=1, y_order=0)

# case 2: SGD
optimizer = optax.sgd(learning_rate=1e-2, momentum=0.9)
res = mee.fit(pois, forward, loss_fn, optimizer, iteration=3)

print('ucell final:')
print(res['ucell'])
print('thickness final:')
print(res['thickness'])
100%|██████████| 3/3 [00:05<00:00,  1.88s/it]
ucell final:
[[[1.00286423 1.00145549 1.00050169 4.99666797 5.00175318 4.99580863
   4.99617526 1.00015109 1.00214635 1.00275083]]

 [[5.0054235  4.99990456 5.00824621 5.0065062  0.99325253 4.99254125
   4.99835018 5.00367578 5.01333396 4.9885967 ]]]
thickness final:
[ 499.99989253 1000.00039487]