| |
| |
| |
| |
| |
|
|
| import errno |
| import functools |
| import hashlib |
| import inspect |
| import io |
| import os |
| import random |
| import socket |
| import tempfile |
| import warnings |
| import zlib |
| from contextlib import contextmanager |
|
|
| from diffq import UniformQuantizer, DiffQuantizer |
| import torch as th |
| import tqdm |
| from torch import distributed |
| from torch.nn import functional as F |
|
|
|
|
| def center_trim(tensor, reference): |
| """ |
| Center trim `tensor` with respect to `reference`, along the last dimension. |
| `reference` can also be a number, representing the length to trim to. |
| If the size difference != 0 mod 2, the extra sample is removed on the right side. |
| """ |
| if hasattr(reference, "size"): |
| reference = reference.size(-1) |
| delta = tensor.size(-1) - reference |
| if delta < 0: |
| raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") |
| if delta: |
| tensor = tensor[..., delta // 2:-(delta - delta // 2)] |
| return tensor |
|
|
|
|
| def average_metric(metric, count=1.): |
| """ |
| Average `metric` which should be a float across all hosts. `count` should be |
| the weight for this particular host (i.e. number of examples). |
| """ |
| metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda') |
| distributed.all_reduce(metric, op=distributed.ReduceOp.SUM) |
| return metric[1].item() / metric[0].item() |
|
|
|
|
| def free_port(host='', low=20000, high=40000): |
| """ |
| Return a port number that is most likely free. |
| This could suffer from a race condition although |
| it should be quite rare. |
| """ |
| sock = socket.socket() |
| while True: |
| port = random.randint(low, high) |
| try: |
| sock.bind((host, port)) |
| except OSError as error: |
| if error.errno == errno.EADDRINUSE: |
| continue |
| raise |
| return port |
|
|
|
|
| def sizeof_fmt(num, suffix='B'): |
| """ |
| Given `num` bytes, return human readable size. |
| Taken from https://stackoverflow.com/a/1094933 |
| """ |
| for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: |
| if abs(num) < 1024.0: |
| return "%3.1f%s%s" % (num, unit, suffix) |
| num /= 1024.0 |
| return "%.1f%s%s" % (num, 'Yi', suffix) |
|
|
|
|
| def human_seconds(seconds, display='.2f'): |
| """ |
| Given `seconds` seconds, return human readable duration. |
| """ |
| value = seconds * 1e6 |
| ratios = [1e3, 1e3, 60, 60, 24] |
| names = ['us', 'ms', 's', 'min', 'hrs', 'days'] |
| last = names.pop(0) |
| for name, ratio in zip(names, ratios): |
| if value / ratio < 0.3: |
| break |
| value /= ratio |
| last = name |
| return f"{format(value, display)} {last}" |
|
|
|
|
| class TensorChunk: |
| def __init__(self, tensor, offset=0, length=None): |
| total_length = tensor.shape[-1] |
| assert offset >= 0 |
| assert offset < total_length |
|
|
| if length is None: |
| length = total_length - offset |
| else: |
| length = min(total_length - offset, length) |
|
|
| self.tensor = tensor |
| self.offset = offset |
| self.length = length |
| self.device = tensor.device |
|
|
| @property |
| def shape(self): |
| shape = list(self.tensor.shape) |
| shape[-1] = self.length |
| return shape |
|
|
| def padded(self, target_length): |
| delta = target_length - self.length |
| total_length = self.tensor.shape[-1] |
| assert delta >= 0 |
|
|
| start = self.offset - delta // 2 |
| end = start + target_length |
|
|
| correct_start = max(0, start) |
| correct_end = min(total_length, end) |
|
|
| pad_left = correct_start - start |
| pad_right = end - correct_end |
|
|
| out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) |
| assert out.shape[-1] == target_length |
| return out |
|
|
|
|
| def tensor_chunk(tensor_or_chunk): |
| if isinstance(tensor_or_chunk, TensorChunk): |
| return tensor_or_chunk |
| else: |
| assert isinstance(tensor_or_chunk, th.Tensor) |
| return TensorChunk(tensor_or_chunk) |
|
|
|
|
| def apply_model(model, mix, shifts=None, split=False, |
| overlap=0.25, transition_power=1., progress=False): |
| """ |
| Apply model to a given mixture. |
| |
| Args: |
| shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec |
| and apply the oppositve shift to the output. This is repeated `shifts` time and |
| all predictions are averaged. This effectively makes the model time equivariant |
| and improves SDR by up to 0.2 points. |
| split (bool): if True, the input will be broken down in 8 seconds extracts |
| and predictions will be performed individually on each and concatenated. |
| Useful for model with large memory footprint like Tasnet. |
| progress (bool): if True, show a progress bar (requires split=True) |
| """ |
| assert transition_power >= 1, "transition_power < 1 leads to weird behavior." |
| device = mix.device |
| channels, length = mix.shape |
| if split: |
| out = th.zeros(len(model.sources), channels, length, device=device) |
| sum_weight = th.zeros(length, device=device) |
| segment = model.segment_length |
| stride = int((1 - overlap) * segment) |
| offsets = range(0, length, stride) |
| scale = stride / model.samplerate |
| if progress: |
| offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') |
| |
| |
| |
| weight = th.cat([th.arange(1, segment // 2 + 1), |
| th.arange(segment - segment // 2, 0, -1)]).to(device) |
| assert len(weight) == segment |
| |
| |
| weight = (weight / weight.max())**transition_power |
| for offset in offsets: |
| chunk = TensorChunk(mix, offset, segment) |
| chunk_out = apply_model(model, chunk, shifts=shifts) |
| chunk_length = chunk_out.shape[-1] |
| out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out |
| sum_weight[offset:offset + segment] += weight[:chunk_length] |
| offset += segment |
| assert sum_weight.min() > 0 |
| out /= sum_weight |
| return out |
| elif shifts: |
| max_shift = int(0.5 * model.samplerate) |
| mix = tensor_chunk(mix) |
| padded_mix = mix.padded(length + 2 * max_shift) |
| out = 0 |
| for _ in range(shifts): |
| offset = random.randint(0, max_shift) |
| shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) |
| shifted_out = apply_model(model, shifted) |
| out += shifted_out[..., max_shift - offset:] |
| out /= shifts |
| return out |
| else: |
| valid_length = model.valid_length(length) |
| mix = tensor_chunk(mix) |
| padded_mix = mix.padded(valid_length) |
| with th.no_grad(): |
| out = model(padded_mix.unsqueeze(0))[0] |
| return center_trim(out, length) |
|
|
|
|
| @contextmanager |
| def temp_filenames(count, delete=True): |
| names = [] |
| try: |
| for _ in range(count): |
| names.append(tempfile.NamedTemporaryFile(delete=False).name) |
| yield names |
| finally: |
| if delete: |
| for name in names: |
| os.unlink(name) |
|
|
|
|
| def get_quantizer(model, args, optimizer=None): |
| quantizer = None |
| if args.diffq: |
| quantizer = DiffQuantizer( |
| model, min_size=args.q_min_size, group_size=8) |
| if optimizer is not None: |
| quantizer.setup_optimizer(optimizer) |
| elif args.qat: |
| quantizer = UniformQuantizer( |
| model, bits=args.qat, min_size=args.q_min_size) |
| return quantizer |
|
|
|
|
| def load_model(path, strict=False): |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| load_from = path |
| package = th.load(load_from, 'cpu') |
|
|
| klass = package["klass"] |
| args = package["args"] |
| kwargs = package["kwargs"] |
|
|
| if strict: |
| model = klass(*args, **kwargs) |
| else: |
| sig = inspect.signature(klass) |
| for key in list(kwargs): |
| if key not in sig.parameters: |
| warnings.warn("Dropping inexistant parameter " + key) |
| del kwargs[key] |
| model = klass(*args, **kwargs) |
|
|
| state = package["state"] |
| training_args = package["training_args"] |
| quantizer = get_quantizer(model, training_args) |
|
|
| set_state(model, quantizer, state) |
| return model |
|
|
|
|
| def get_state(model, quantizer): |
| if quantizer is None: |
| state = {k: p.data.to('cpu') for k, p in model.state_dict().items()} |
| else: |
| state = quantizer.get_quantized_state() |
| buf = io.BytesIO() |
| th.save(state, buf) |
| state = {'compressed': zlib.compress(buf.getvalue())} |
| return state |
|
|
|
|
| def set_state(model, quantizer, state): |
| if quantizer is None: |
| model.load_state_dict(state) |
| else: |
| buf = io.BytesIO(zlib.decompress(state["compressed"])) |
| state = th.load(buf, "cpu") |
| quantizer.restore_quantized_state(state) |
|
|
| return state |
|
|
|
|
| def save_state(state, path): |
| buf = io.BytesIO() |
| th.save(state, buf) |
| sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] |
|
|
| path = path.parent / (path.stem + "-" + sig + path.suffix) |
| path.write_bytes(buf.getvalue()) |
|
|
|
|
| def save_model(model, quantizer, training_args, path): |
| args, kwargs = model._init_args_kwargs |
| klass = model.__class__ |
|
|
| state = get_state(model, quantizer) |
|
|
| save_to = path |
| package = { |
| 'klass': klass, |
| 'args': args, |
| 'kwargs': kwargs, |
| 'state': state, |
| 'training_args': training_args, |
| } |
| th.save(package, save_to) |
|
|
|
|
| def capture_init(init): |
| @functools.wraps(init) |
| def __init__(self, *args, **kwargs): |
| self._init_args_kwargs = (args, kwargs) |
| init(self, *args, **kwargs) |
|
|
| return __init__ |
|
|