Utils

The utils module provide a couple of helper function for training and evaluation.

ddpm_schedule

Function DMLP.utils.ddpm_schedule.ddpm_schedule(beta1:float, beta2: float, T:int) -> Dict[str, torch.Tensor] Generate parameters for diffusion/denosing

Args
beta1: hyperparameter for ddpm scheduler
beta2: hyperparameter for ddpm scheduler
T: number of steps for diffusion process

Return
A dictionary of variables/parameters for diffusion/denoising process

random_init

Function DMLP.utils.random_init.random_init(model)
randomly initialize model weights

Args
model: model we want to initialize weights

Return
None

sample_sequence_conditional

Function DMLP.utils.sample.sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1,top_k=0, top_p=0.0,device='cpu', decoder_tokenizer=None, eos_id=50259, loss=False)
Generate text given a past token. If past is none, the function generate new sentence. If past is not none, the function reconstruct token based on past.

Args
model: model use to generate text
length: maximum sentence length
context: context token, usually bos token
past: latent representation of input text. If past is none, we are generating sentence with ddpm.
num_samples: number of sentence generate
temperature: temperature to normalized conditional probabilities
device: device for computation
decoder_tokenizer: tokenizer for selected decoder
eos_id: end of sentence id
loss: whether to calculate reconstruction loss

Return
Generated text tokens and reconstruction loss if specified in the argument.

save_checkpoint

Function DMLP.utils.save_checkpoint.save_checkpoint(model_vae, optimizer, global_step, parameter_name, output_dir, logger, ppl=False, ddpm=None, use_philly=False)
Save Model checkpoint to output directory based on best bleu score.

Args model_vae: Variational Autoencoder
optimizer: optimizer used to train model
global_step: number of iteration trained
parameter_name: name of parameters to save
output_dir: directory to save checkpoints
logger: train logger
ddpm: DDPM

Return
None