Implement Automatic Mixed Precision with GradScaler to Address NaN Loss Issues#13
Implement Automatic Mixed Precision with GradScaler to Address NaN Loss Issues#13Gsunshine merged 1 commit intolocuslab:mainfrom
Conversation
Gsunshine
left a comment
There was a problem hiding this comment.
Merge AMP via Gradscalar into ECT.
| @click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True) | ||
| @click.option('--tf32', help='Enable tf32 for A100/H100 training speed', metavar='BOOL', type=bool, default=False, show_default=True) | ||
| @click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) | ||
| @click.option('--enable_gradscaler', help='Enable torch.cuda.amp.GradScaler, NOTE overwritting loss_scale set by --ls', metavar='BOOL', type=bool, default=False, show_default=True) |
There was a problem hiding this comment.
Hi Zixiang @aiihn ,
Thanks for your neat PR!
Would it be better to use a short abbreviation like amp as the option name? AMP already stands for Automatic Mixed Precision.
| if enable_gradscaler: | ||
| if 'gradscaler_state' in data: | ||
| dist.print0(f'Loading GradScaler state from "{resume_state_dump}"...') | ||
| # Although not loading the state_dict of the GradScaler works well, loading it can improve reproducibility. |
There was a problem hiding this comment.
Gotcha. Thanks for the comments!
| scaler.step(optimizer) | ||
| scaler.update() | ||
| else: | ||
| # Update weights. |
There was a problem hiding this comment.
TODO is also unclear to me either. It seems still useful and compatible per Claude.
It's fine to remove my commented code for lr rampup.
|
Hi @aiihn , Thank you again for your PR! I had another AMP implementation that could also be helpful for ECT. I’ll check it out later and test Links for reference: Cheers, |
Description
This pull request addresses the issue of NaN losses occurring during mixed-precision training with
--fp16enabled (#12).Key Changes
torch.cuda.amp.GradScalerto dynamically adjust loss scaling.GradScalerwill override theloss_scaleset manually by--ls.Usage
Use
--fp16=Truealong with--enable_gradscaler=True. For example, below is the mixed-training command modified from run_ecm_1hour.sh.The FID records obtained using the above command are shown in the following images:

