# for gradient checkpointing def create_custom_forward(module, **kwargs): def custom_forward(*inputs): return module(*inputs, **kwargs) return custom_forward