53 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			53 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import Union
 | |
| 
 | |
| import torch
 | |
| from torch.optim.lr_scheduler import _LRScheduler
 | |
| 
 | |
| from typeguard import check_argument_types
 | |
| 
 | |
| 
 | |
| class WarmupLR(_LRScheduler):
 | |
|     """The WarmupLR scheduler
 | |
| 
 | |
|     This scheduler is almost same as NoamLR Scheduler except for following
 | |
|     difference:
 | |
| 
 | |
|     NoamLR:
 | |
|         lr = optimizer.lr * model_size ** -0.5
 | |
|              * min(step ** -0.5, step * warmup_step ** -1.5)
 | |
|     WarmupLR:
 | |
|         lr = optimizer.lr * warmup_step ** 0.5
 | |
|              * min(step ** -0.5, step * warmup_step ** -1.5)
 | |
| 
 | |
|     Note that the maximum lr equals to optimizer.lr in this scheduler.
 | |
| 
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         optimizer: torch.optim.Optimizer,
 | |
|         warmup_steps: Union[int, float] = 25000,
 | |
|         last_epoch: int = -1,
 | |
|     ):
 | |
|         assert check_argument_types()
 | |
|         self.warmup_steps = warmup_steps
 | |
| 
 | |
|         # __init__() must be invoked before setting field
 | |
|         # because step() is also invoked in __init__()
 | |
|         super().__init__(optimizer, last_epoch)
 | |
| 
 | |
|     def __repr__(self):
 | |
|         return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
 | |
| 
 | |
|     def get_lr(self):
 | |
|         step_num = self.last_epoch + 1
 | |
|         return [
 | |
|             lr
 | |
|             * self.warmup_steps ** 0.5
 | |
|             * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)
 | |
|             for lr in self.base_lrs
 | |
|         ]
 | |
| 
 | |
|     def set_step(self, step: int):
 | |
|         self.last_epoch = step
 | 
