Train
Train
The training pipe line is writted in the function train_vae_ddpm
.
The function can be imported by from DMLP.train.train_function import train_vae_ddpm
.
The training pipeline will output a tensorboard logging with all the evaluation results.
train_vae_ddpm
__Args__
local_rank: GPU device. This will be passed to the function automatically by CUDA_VISIBLE_DEVICE
world_size: the number of GPUs used
model: vae_ddpm model
optimizer: optimizer for training
train_dataloader: train_dataloader
output_dir: directory to save checkpoints, example outputs, and logs
batch_size: batch_size for logging; This will NOT affect dataloader
condition_f=lambda x: False: A function selecting model parameters
logging_steps = -1: logging step for logging and evaluation
train_epoch = 20: Number of training epoch
gradient_accumulation_steps = 1: gradient accumulation during training
device = 'cpu': model device
fp16=False: useless now; keep it false
fp16_opt_level=None: useless now; leave it None
learning_rate=9e-5: learning rate
adam_epsilon=1e-5: eps for optimizer
lr_end_multiplier= 0.01: An argument for `transformers.get_polynomial_decay_schedule_with_warmup`; please refer to their documentation
power=3.0: An argument for `transformers.get_polynomial_decay_schedule_with_warmup`; please refer to their documentation
warmup_steps=0: An argument for `transformers.get_polynomial_decay_schedule_with_warmup`; please refer to their documentation
disable_bar=True: turn on or off tqdm bar
max_grad_norm=1: paramter for `torch.nn.utils.clip_grad_norm_` to save sapce
save=True: save checkpoint or not, True only if `evaluate_during_training=True`
evaluate_during_training=False: evaluate model if True; False if no evaluation data
eval_dataloader=None: evaluation dataloader
sent_length=32: sentence length for generation
model_id='gpt2': model to evaluate sentence generation perplexity; gpt by default
ppl_eval=True: evaluate perplexity or not
__Return__
global_step: global iteration
tr_loss / global_step: average loss
optimizer: final optimizer
Evaluation_Functions
The evaluation process contains two parts: reconstruction by calc_rec_lgy
and generation by calc_ppl_lgy_ddpm
.
calc_rec_lgy
Reconstruction evaluation: from DMLP.train.reconstruction import calc_rec_lgy
__Args__
model_vae: VAE module of model
encoder_tokenizer: encoder tokenizer
decoder_tokenizer: decoder tokenizer
eval_dataloader: evaluation data
device: device for model
disable_bar: diplay tqdm bar or not
__Return__
BLEU: bleu score
calc_ppl_lgy_ddpm
Generation evaluation: from DMLP.train.generation import calc_ppl_lgy_ddpm
__Args__
model_vae: VAE module
decoder_tokenizer: decoder tokenizer
ns=1: number of iterations
sent_length=32: sentence length
ddpm=None: DDPM module
device='cpu': device
output_dir = "output.txt": output directory to save example output
disable_bar=True: display tqdm bar or not
fp16=False: useless now
model_id='gpt2': model id to calculate perplexity
ppl_eval=True: calculate perplexity or not
__Return__
A dictionary for different evaluations:
'ppl': perplexity (calculated only if ppl_eval=True)
'sbleu': bleu score refering other sentences (better low)
'length': mean sentence length
'norm_z': mean of normalized latent z
'ppl_sbleu': perplexity + bleu score (calculated only if ppl_eval=True)
evaluation
Reconstruction + Generation all together from DMLP.train.evaluation import evaluation
__Args__
model: VAE_DDPM model
eval_dataloader: evaluation data
device: cpu or gpu
disable_bar: display tqdm bar or not
ns=1: number of iterations in generation
sent_length=32: sentence length for generation
output_dir="output.txt": output directory to save example output
fp16=False: useless now
model_id='gpt2': model id to calculate perplexity
ppl_eval=True: calcualte perplexity or not
__Return__
A dictionary for different evaluations:
'bleu': bleu score for generation
'ppl': perplexity (calculated only if ppl_eval=True)
'sbleu': bleu score refering other sentences (better low)
'length': mean sentence length
'norm_z': mean of normalized latent z
'ppl_sbleu': perplexity + bleu score (calculated only if ppl_eval=True)