| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						import logging | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						import sys | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from torch import nn | 
					
					
						
						| 
							 | 
						import torch.distributed as dist | 
					
					
						
						| 
							 | 
						import torch.nn.functional as F | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from .norm import SimpleRMSNorm as SimpleRMSNormTorch | 
					
					
						
						| 
							 | 
						from .srmsnorm_triton import SimpleRMSNorm as SimpleRMSNormTriton | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						use_triton = eval(os.environ.get("use_triton", default="True")) | 
					
					
						
						| 
							 | 
						debug = eval(os.environ.get("debug", default="False")) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if use_triton: | 
					
					
						
						| 
							 | 
						    SimpleRMSNorm = SimpleRMSNormTriton | 
					
					
						
						| 
							 | 
						else: | 
					
					
						
						| 
							 | 
						    SimpleRMSNorm = SimpleRMSNormTorch | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						logging.basicConfig( | 
					
					
						
						| 
							 | 
						    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | 
					
					
						
						| 
							 | 
						    datefmt="%Y-%m-%d %H:%M:%S", | 
					
					
						
						| 
							 | 
						    level=os.environ.get("LOGLEVEL", "INFO").upper(), | 
					
					
						
						| 
							 | 
						    stream=sys.stdout, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						logger = logging.getLogger("print_config") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						BASE_DIM = 256 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def is_dist_avail_and_initialized(): | 
					
					
						
						| 
							 | 
						    if not dist.is_available(): | 
					
					
						
						| 
							 | 
						        return False | 
					
					
						
						| 
							 | 
						    if not dist.is_initialized(): | 
					
					
						
						| 
							 | 
						        return False | 
					
					
						
						| 
							 | 
						    return True | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_world_size(): | 
					
					
						
						| 
							 | 
						    if not is_dist_avail_and_initialized(): | 
					
					
						
						| 
							 | 
						        return 1 | 
					
					
						
						| 
							 | 
						    return dist.get_world_size() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_rank(): | 
					
					
						
						| 
							 | 
						    if not is_dist_avail_and_initialized(): | 
					
					
						
						| 
							 | 
						        return 0 | 
					
					
						
						| 
							 | 
						    return dist.get_rank() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def is_main_process(): | 
					
					
						
						| 
							 | 
						    return get_rank() == 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def logging_info(string): | 
					
					
						
						| 
							 | 
						    if is_main_process(): | 
					
					
						
						| 
							 | 
						        logger.info(string) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def print_params(**kwargs): | 
					
					
						
						| 
							 | 
						    if is_main_process(): | 
					
					
						
						| 
							 | 
						        logger.info(f"start print config of {kwargs['__class__']}") | 
					
					
						
						| 
							 | 
						        for key in kwargs: | 
					
					
						
						| 
							 | 
						            if key in ["__class__", "self"]: | 
					
					
						
						| 
							 | 
						                continue | 
					
					
						
						| 
							 | 
						            logger.info(f"{key}: {kwargs[key]}") | 
					
					
						
						| 
							 | 
						        logger.info(f"end print config of {kwargs['__class__']}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def print_config(config): | 
					
					
						
						| 
							 | 
						    if is_main_process(): | 
					
					
						
						| 
							 | 
						        logger.info(f"start print config of {config['__class__']}") | 
					
					
						
						| 
							 | 
						        for key in config: | 
					
					
						
						| 
							 | 
						            if key in ["__class__", "self"]: | 
					
					
						
						| 
							 | 
						                continue | 
					
					
						
						| 
							 | 
						            logger.info(f"{key}: {config[key]}") | 
					
					
						
						| 
							 | 
						        logger.info(f"end print config of {config['__class__']}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def print_module(module): | 
					
					
						
						| 
							 | 
						    named_modules = set() | 
					
					
						
						| 
							 | 
						    for p in module.named_modules(): | 
					
					
						
						| 
							 | 
						        named_modules.update([p[0]]) | 
					
					
						
						| 
							 | 
						    named_modules = list(named_modules) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    string_repr = "" | 
					
					
						
						| 
							 | 
						    for p in module.named_parameters(): | 
					
					
						
						| 
							 | 
						        name = p[0].split(".")[0] | 
					
					
						
						| 
							 | 
						        if name not in named_modules: | 
					
					
						
						| 
							 | 
						            string_repr = (string_repr + "(" + name + "): " + "Tensor(" + | 
					
					
						
						| 
							 | 
						                           str(tuple(p[1].shape)) + ", requires_grad=" + | 
					
					
						
						| 
							 | 
						                           str(p[1].requires_grad) + ")\n") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return string_repr.rstrip("\n") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_activation_fn(activation): | 
					
					
						
						| 
							 | 
						    if debug: | 
					
					
						
						| 
							 | 
						        logger.info(f"activation: {activation}") | 
					
					
						
						| 
							 | 
						    if activation == "gelu": | 
					
					
						
						| 
							 | 
						        return F.gelu | 
					
					
						
						| 
							 | 
						    elif activation == "relu": | 
					
					
						
						| 
							 | 
						        return F.relu | 
					
					
						
						| 
							 | 
						    elif activation == "elu": | 
					
					
						
						| 
							 | 
						        return F.elu | 
					
					
						
						| 
							 | 
						    elif activation == "sigmoid": | 
					
					
						
						| 
							 | 
						        return F.sigmoid | 
					
					
						
						| 
							 | 
						    elif activation == "exp": | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        def f(x): | 
					
					
						
						| 
							 | 
						            with torch.no_grad(): | 
					
					
						
						| 
							 | 
						                x_max = torch.max(x, dim=-1, keepdims=True).values | 
					
					
						
						| 
							 | 
						            y = torch.exp(x - x_max) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            return y | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return f | 
					
					
						
						| 
							 | 
						    elif activation == "leak": | 
					
					
						
						| 
							 | 
						        return F.leaky_relu | 
					
					
						
						| 
							 | 
						    elif activation == "1+elu": | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        def f(x): | 
					
					
						
						| 
							 | 
						            return 1 + F.elu(x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return f | 
					
					
						
						| 
							 | 
						    elif activation == "2+elu": | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        def f(x): | 
					
					
						
						| 
							 | 
						            return 2 + F.elu(x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return f | 
					
					
						
						| 
							 | 
						    elif activation == "silu" or activation == "swish": | 
					
					
						
						| 
							 | 
						        return F.silu | 
					
					
						
						| 
							 | 
						    elif activation == "sine": | 
					
					
						
						| 
							 | 
						        return torch.sin | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        logger.info( | 
					
					
						
						| 
							 | 
						            f"activation: does not support {activation}, use Identity!!!") | 
					
					
						
						| 
							 | 
						        return lambda x: x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_norm_fn(norm_type): | 
					
					
						
						| 
							 | 
						    if norm_type == "simplermsnorm": | 
					
					
						
						| 
							 | 
						        return SimpleRMSNorm | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        return nn.LayerNorm | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def convert_to_multiple_of_base(x): | 
					
					
						
						| 
							 | 
						    return BASE_DIM * ((x + BASE_DIM - 1) // BASE_DIM) | 
					
					
						
						| 
							 | 
						
 |