Distill Whisper with a VQ bottleneck

from whisperspeech import wh_transcribe
import IPython
/opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/pyannote/audio/core/io.py:43: UserWarning: torchaudio._backend.set_audio_backend has been deprecated. With dispatcher enabled, this function is no-op. You can remove the function call.
  torchaudio.set_audio_backend("soundfile")
torchvision is not available - cannot save figures

Prepare the dataset

shards = [str(x) for x in Path('/data/whisperspeech-wds/').glob('librilight-*.tar')]
ds = wds.WebDataset(shards, shardshuffle=True)
ds2 = ds.compose(
    wds.decode(wds.torch_audio),
    utils.find_audio,
    merge_in(derived_dataset('/data/whisperspeech-processed-wds/', 'vad')),
    wds.map_dict(**{"vad.npy":wh_transcribe.chunk_merger}),
    wh_transcribe.split_to_chunks,
    merge_in(derived_dataset('/data/whisperspeech-processed-wds/', 'base.en-txt')),
    wds.shuffle(),
    wds.select(lambda x: x['i'] != 0 and x['i'] != x['imax']),
)
vad_shards = [str(x) for x in Path('/data/whisperspeech-processed-wds/').glob('*-large-6454-vad-*.tar.gz')]
ds = wds.WebDataset(vad_shards).decode().map_dict(**{'vad.npy':wh_transcribe.chunk_merger})
chunks = [len(x['vad.npy'][1:-1]) for x in progress_bar(ds, total='noinfer')]
100.00% [3411/3411 00:01<00:00]
sum(chunks)
203078
for x in progress_bar(ds2, total=5):
    IPython.display.display(IPython.display.Markdown(f"## {x['__key__']} from {x['__url__']}\n{x['txt']}"))
    IPython.display.display(IPython.display.Audio(x['samples'], rate=16000))
100.00% [5/5 00:01<00:00]

large/6454/kaffirkangarooklondiketales_1611_librivox_64kb_mp3/kaffirkangaroo_03_leavitt_64kb_006 from /data/whisperspeech-wds/librilight-large-6454-flac-000007.tar

Physically I was incapable of complying with the command, and mentally I had not the slightest intention of departing. In an outhouse devoted to storing melees, sheepskins, and harness, an old man was sitting on the doorstep, compounding a mixture which I recognized as a sheep remedy.

large/6454/kaffirkangarooklondiketales_1611_librivox_64kb_mp3/kaffirkangaroo_03_leavitt_64kb_009 from /data/whisperspeech-wds/librilight-large-6454-flac-000007.tar

The following day I was the most surprised man in South Africa when I learned that my preparation was working a marvelous cure. I was invited to remain with the bore the balance of the season as an honoured guest. Day after day I tramped the hills, returning at night as wise and as rich as when I set out. There were unmistakable indications that gold should be found in the vicinity, but the stubborn fact remained that I could not find it.

large/6454/kaffirkangarooklondiketales_1611_librivox_64kb_mp3/kaffirkangaroo_03_leavitt_64kb_001 from /data/whisperspeech-wds/librilight-large-6454-flac-000007.tar

I was one of the first prospectors in the Transvaal to search for gold and a precious dance it led me. At that time, but few Englishmen had ventured into the Boer country, and such was the jealousy with which they were regarded that it was impossible to secure any information which would assist in the search. Footsoir and weary, I tramped from farm to farm, content

large/6454/kaffirkangarooklondiketales_1611_librivox_64kb_mp3/kaffirkangaroo_03_leavitt_64kb_032 from /data/whisperspeech-wds/librilight-large-6454-flac-000007.tar

Dead, more than twenty years. In fact, before I was married and came to live here, for he was my husband’s father. Did you know him? Yes, but I was only a little girl at the time. Why have the clothes been kept?

large/6454/kaffirkangarooklondiketales_1611_librivox_64kb_mp3/kaffirkangaroo_03_leavitt_64kb_004 from /data/whisperspeech-wds/librilight-large-6454-flac-000007.tar

Fortunately, I had acquired some knowledge of sheep in Australia else I believe that I should have starved. When all else failed, I became a sheep doctor and then did a compound whose virtues would have done credit to the most widely advertised path and medicine nostrum.

ds3 = ds2.compose(
    add_masks,
    tokenize_text,
    wds.to_tuple('samples', 'mask', 'in_ttoks', 'out_ttoks')
)
for x in ds3: break
x
(tensor([0.0043, 0.0102, 0.0163,  ..., 0.0000, 0.0000, 0.0000]),
 tensor([ True,  True,  True,  ..., False, False, False]),
 tensor([50257,  3152,   257, 44823,  3154,  1589,    11,   484,   673,  1144,
           572,   503,   286,  2837,   290,   706,  2063,   281,  1711,   338,
          1057,    11,   262, 39535, 21067,   373,   625,   262,  2318,   290,
           287,  5897, 10150,    13,  1119,  2582, 40424,   510,   262, 27913,
          4608,   284, 47251,   290,  1043,   257,  1588,  1426,   325,   286,
          4684, 13384,  3492,   284, 17655,   511, 15892,    13, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256]),
 tensor([ 3152,   257, 44823,  3154,  1589,    11,   484,   673,  1144,   572,
           503,   286,  2837,   290,   706,  2063,   281,  1711,   338,  1057,
            11,   262, 39535, 21067,   373,   625,   262,  2318,   290,   287,
          5897, 10150,    13,  1119,  2582, 40424,   510,   262, 27913,  4608,
           284, 47251,   290,  1043,   257,  1588,  1426,   325,   286,  4684,
         13384,  3492,   284, 17655,   511, 15892,    13, 50256,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100]))
ds3 = ds2.compose(
    add_masks,
    lambda x: tokenize_text(x, model='medium', language='en'),
    wds.to_tuple('samples', 'mask', 'in_ttoks', 'out_ttoks')
)
for x in ds3: break
x
(tensor([0.0013, 0.0010, 0.0011,  ..., 0.0000, 0.0000, 0.0000]),
 tensor([ True,  True,  True,  ..., False, False, False]),
 tensor([50258, 50259, 50359,    32,  1326,  1270,  3931,   382,   613,    11,
         11672,   293, 37632, 13809,    11,   576,  1319,   264,  1851,   295,
           264,  1002,    11,   293,  1939,   576,   572,   544,  1643,   281,
         18071,   264,  1164,   295,  3687,    11,   420,  1497,   554,  1952,
          6018,    11,   813,   264,  1974,  5010,   295,   721,    11,   689,
           264,  7700,   366,  4054,   293,  7006,   293, 14154,   292,    13,
          2188,  1359, 17431,  2212,   281,  3511,   328,  3780,   311,  3567,
           294,   702,  1536,  6717,  1062,   362, 16424,   796,   666,   257,
          5403, 14763,    13, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257]),
 tensor([50259, 50359,    32,  1326,  1270,  3931,   382,   613,    11, 11672,
           293, 37632, 13809,    11,   576,  1319,   264,  1851,   295,   264,
          1002,    11,   293,  1939,   576,   572,   544,  1643,   281, 18071,
           264,  1164,   295,  3687,    11,   420,  1497,   554,  1952,  6018,
            11,   813,   264,  1974,  5010,   295,   721,    11,   689,   264,
          7700,   366,  4054,   293,  7006,   293, 14154,   292,    13,  2188,
          1359, 17431,  2212,   281,  3511,   328,  3780,   311,  3567,   294,
           702,  1536,  6717,  1062,   362, 16424,   796,   666,   257,  5403,
         14763,    13, 50257,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100]))
train_ds = load_dataset('librilight-wds/librilight-small-flac-000000-s0*.tar', 'librilight-preproc-wds/', samples=2500 * 32)
val_ds = load_dataset('librilight-wds/librilight-small-flac-000000-s11.tar', 'librilight-preproc-wds/', samples=500)
for x in progress_bar(wds.WebLoader(train_ds, num_workers=16, batch_size=None), total='noinfer'): pass
[245/? 00:09<?]
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
 in <module>:1                                                                                    
                                                                                                  
 1 for x in progress_bar(wds.WebLoader(train_ds, num_workers=16, batch_size=None), total='n     
   2                                                                                              
                                                                                                  
 /opt/conda/lib/python3.10/site-packages/fastprogress/fastprogress.py:41 in __iter__              
                                                                                                  
    38 def __iter__(self):                                                                    
    39 │   │   if self.total != 0: self.update(0)                                                 
    40 │   │   try:                                                                               
  41 │   │   │   for i,o in enumerate(self.gen):                                                
    42 │   │   │   │   if self.total and i >= self.total: break                                   
    43 │   │   │   │   yield o                                                                    
    44 │   │   │   │   self.update(i+1)                                                           
                                                                                                  
 /root/workspace/webdataset/webdataset/pipeline.py:64 in iterator                                 
                                                                                                  
    61 def iterator(self):                                                                    
    62 │   │   """Create an iterator through the entire dataset, using the given number of repe   
    63 │   │   for i in range(self.repetitions):                                                  
  64 │   │   │   for sample in self.iterator1():                                                
    65 │   │   │   │   yield sample                                                               
    66                                                                                        
    67 def __iter__(self):                                                                    
                                                                                                  
 /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:633 in __next__           
                                                                                                  
    630 │   │   │   if self._sampler_iter is None:                                                
    631 │   │   │   │   # TODO(https://github.com/pytorch/pytorch/issues/76750)                   
    632 │   │   │   │   self._reset()  # type: ignore[call-arg]                                   
  633 │   │   │   data = self._next_data()                                                      
    634 │   │   │   self._num_yielded += 1                                                        
    635 │   │   │   if self._dataset_kind == _DatasetKind.Iterable and \                          
    636 │   │   │   │   │   self._IterableDataset_len_called is not None and \                    
                                                                                                  
 /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1328 in _next_data        
                                                                                                  
   1325 │   │   │   │   return self._process_data(data)                                           
   1326 │   │   │                                                                                 
   1327 │   │   │   assert not self._shutdown and self._tasks_outstanding > 0                     
 1328 │   │   │   idx, data = self._get_data()                                                  
   1329 │   │   │   self._tasks_outstanding -= 1                                                  
   1330 │   │   │   if self._dataset_kind == _DatasetKind.Iterable:                               
   1331 │   │   │   │   # Check for _IterableDatasetStopIteration                                 
                                                                                                  
 /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1294 in _get_data         
                                                                                                  
   1291 │   │   │   # need to call `.task_done()` because we don't use `.join()`.                 
   1292 │   │   else:                                                                             
   1293 │   │   │   while True:                                                                   
 1294 │   │   │   │   success, data = self._try_get_data()                                      
   1295 │   │   │   │   if success:                                                               
   1296 │   │   │   │   │   return data                                                           
   1297                                                                                           
                                                                                                  
 /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1132 in _try_get_data     
                                                                                                  
   1129 │   │   # Returns a 2-tuple:                                                              
   1130 │   │   #   (bool: whether successfully get data, any: data if successful else None)      
   1131 │   │   try:                                                                              
 1132 │   │   │   data = self._data_queue.get(timeout=timeout)                                  
   1133 │   │   │   return (True, data)                                                           
   1134 │   │   except Exception as e:                                                            
   1135 │   │   │   # At timeout and error, we manually check whether any worker has              
                                                                                                  
 /opt/conda/lib/python3.10/multiprocessing/queues.py:113 in get                                   
                                                                                                  
   110 │   │   │   try:                                                                           
   111 │   │   │   │   if block:                                                                  
   112 │   │   │   │   │   timeout = deadline - time.monotonic()                                  
 113 │   │   │   │   │   if not self._poll(timeout):                                            
   114 │   │   │   │   │   │   raise Empty                                                        
   115 │   │   │   │   elif not self._poll():                                                     
   116 │   │   │   │   │   raise Empty                                                            
                                                                                                  
 /opt/conda/lib/python3.10/multiprocessing/connection.py:257 in poll                              
                                                                                                  
   254 │   │   """Whether there is any input available to be read"""                              
   255 │   │   self._check_closed()                                                               
   256 │   │   self._check_readable()                                                             
 257 │   │   return self._poll(timeout)                                                         
   258                                                                                        
   259 def __enter__(self):                                                                   
   260 │   │   return self                                                                        
                                                                                                  
 /opt/conda/lib/python3.10/multiprocessing/connection.py:424 in _poll                             
                                                                                                  
   421 │   │   return self._recv(size)                                                            
   422                                                                                        
   423 def _poll(self, timeout):                                                              
 424 │   │   r = wait([self], timeout)                                                          
   425 │   │   return bool(r)                                                                     
   426                                                                                            
   427                                                                                            
                                                                                                  
 /opt/conda/lib/python3.10/multiprocessing/connection.py:931 in wait                              
                                                                                                  
   928 │   │   │   │   deadline = time.monotonic() + timeout                                      
   929 │   │   │                                                                                  
   930 │   │   │   while True:                                                                    
 931 │   │   │   │   ready = selector.select(timeout)                                           
   932 │   │   │   │   if ready:                                                                  
   933 │   │   │   │   │   return [key.fileobj for (key, events) in ready]                        
   934 │   │   │   │   else:                                                                      
                                                                                                  
 /opt/conda/lib/python3.10/selectors.py:416 in select                                             
                                                                                                  
   413 │   │   │   timeout = math.ceil(timeout * 1e3)                                             
   414 │   │   ready = []                                                                         
   415 │   │   try:                                                                               
 416 │   │   │   fd_event_list = self._selector.poll(timeout)                                   
   417 │   │   except InterruptedError:                                                           
   418 │   │   │   return ready                                                                   
   419 │   │   for fd, event in fd_event_list:                                                    
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyboardInterrupt
for x in train_ds:
    print(x[3])
    break
tensor([[  464,  7664,   286,  ...,  -100,  -100,  -100],
        [ 2953,   717,   612,  ...,  -100,  -100,  -100],
        [25383,   339,   587,  ...,  -100,  -100,  -100],
        ...,
        [  392,   340,   880,  ...,  -100,  -100,  -100],
        [  464, 31526, 11416,  ...,  -100,  -100,  -100],
        [ 2202,   262, 16720,  ...,  -100,  -100,  -100]])

Training code


source

RQBottleneckTransformer

 RQBottleneckTransformer (vq_codes=512, q_depth=12, depth=1, n_head=2,
                          head_width=64, ffn_mult=4, codebook_dim=2,
                          threshold_ema_dead_code=2, use_cosine_sim=False,
                          kl_loss_mul=1, downsample=1,
                          whisper_model_name='tiny.en',
                          tunables=Tunables(init_std=1.5,
                          embeddings_std=0.045, embeddings_lr_scale=1,
                          output_mult=1, query_mult=2, rope=True,
                          mask_embs=True, downsample_conv=False,
                          downsample_mean=True, codebook_dim=32,
                          codebook_decay=0.9, lr0=0.0009,
                          clip_gradient_norm=2, weight_decay=0.001,
                          warmup_steps=850, random=False))

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

import pylab as plt
import fastprogress
import IPython
import numpy as np

class RQVisual:
    def __init__ (self, model, masterbar, total_steps):
        self.model = model
        self.masterbar = masterbar
        self.total_steps = total_steps
        self.epochs = total_steps // masterbar.main_bar.total
        
        gs = plt.GridSpec(3, 1, height_ratios=[2,2,1])
        graph_fig = plt.figure(figsize=(10,6))
        self.graph_fig = graph_fig
        self.loss_p = graph_fig.add_subplot(gs[0])
        self.acc_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p)
        self.acc_p.tick_params('x', labelbottom=False)
        self.lr_p = graph_fig.add_subplot(gs[2], sharex=self.loss_p)
        self.lr_p.tick_params('x', labelbottom=False)
        self.graph_out = None
        
        self.its = []
        self.train_losses = []
        self.val_losses = []
        self.lr_history = []
        self.entropy = np.nan
        self.entropy_history = []
            
    def show(self):
        self.start_t = time.time()
        self.masterbar.write(["samples", "train", "val", "codebook entropy", "time"], table=True)
        self.graph_out = display(self.graph_fig, display_id=True)
        self.entropy_out = display(IPython.display.HTML(''), display_id=True)
    
    def hide(self):
        if self.graph_out is not None:
            self.graph_out.update(IPython.display.HTML(''))
    
    def plot(self):
        loss_p, acc_p, lr_p = self.loss_p, self.acc_p, self.lr_p
        loss_p.clear()
        loss_p.plot(self.its, self.train_losses)
        loss_p.plot(self.its, self.val_losses)
        loss_p.set_xlim(10000, self.total_steps)
        loss_p.set_xscale('log')
        loss_p.set_yscale('log')
        acc_p.clear()
        acc_p.plot(self.its, np.stack(self.entropy_history), ':')
        lr_p.clear()
        lrs = np.array(self.lr_history)
        lr_p.plot(self.its, lrs)
        self.graph_out.update(self.graph_fig)
    
    def add_data(self, it, lr, train_loss, val_los):
        self.its.append(it)
        self.train_losses.append(train_loss)
        self.val_losses.append(val_los)
        self.lr_history.append(lr)
        with torch.no_grad():
            cls = vqmodel.rq.layers[0]._codebook.cluster_size
            pdf = cls / cls.sum()
            entropy = -torch.nansum(pdf * pdf.log2())
        self.entropy_history.append(entropy.cpu().numpy())
        self.entropy_out.update(f"Entropy: {self.entropy_history[-1]:.2f}")
        self.model.reset_stats()
        self.plot()

    def add_table_row(self, it, avg_train_loss, val_loss):
        elapsed_t = time.time() - self.start_t
        self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", f"{self.entropy_history[-1]:.2f}", fastprogress.core.format_time(elapsed_t)], table=True)
    
    def on_iter(self, bar, it, avg_train_loss, val_loss):
        epoch = math.ceil(it / self.total_steps * self.epochs)
        bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}"

source

make_model

 make_model (size:str, tunables:__main__.Tunables=Tunables(init_std=1.5,
             embeddings_std=0.045, embeddings_lr_scale=1, output_mult=1,
             query_mult=2, rope=True, mask_embs=True,
             downsample_conv=False, downsample_mean=True, codebook_dim=32,
             codebook_decay=0.9, lr0=0.0009, clip_gradient_norm=2,
             weight_decay=0.001, warmup_steps=850, random=False),
             dataset:torch.utils.data.dataset.Dataset=None)
# convert the final checkpoint
model = make_model('base.en-2d-512c-dim64').load_checkpoint('vq_stoks-epoch=3-step=28582-val_loss=11.42.ckpt')
model.save_model(f'vqmodel-512c-dim64-4e-hyptuned-32gpu.model')
tunables: Tunables(init_std=1.5, embeddings_std=0.045, embeddings_lr_scale=1, output_mult=1, query_mult=2, rope=True, mask_embs=True, downsample_conv=False, downsample_mean=True, codebook_dim=32, codebook_decay=0.9, lr0=0.0009, clip_gradient_norm=2, weight_decay=0.001, warmup_steps=850, random=False)
Tunables(init_std=1.5, embeddings_std=0.045, embeddings_lr_scale=1, output_mult=1, query_mult=2, rope=True, mask_embs=True, downsample_conv=False, downsample_mean=True, codebook_dim=32, codebook_decay=0.9, lr0=0.0009, clip_gradient_norm=2, weight_decay=0.001, warmup_steps=850, random=False)
# convert the final checkpoint
model = make_model('medium-2d-512c-dim64').load_checkpoint('../vq_stoks-epoch=0-step=9776-val_loss=0.00.ckpt')
model.save_model(f'vqmodel-medium-en+pl-512c-dim64.model')
# convert the final checkpoint
model = make_model('base-2d-512c-dim64').load_checkpoint('../vq_stoks--2-24696-acc=0.91.ckpt')
model.save_model(f'vqmodel-base-en+pl-512c-dim64.model')
# convert the final checkpoint
model = make_model('medium-2d-1024c-dim64').load_checkpoint('../vq_stoks-chad_gold-3-32813-acc=0.96.ckpt')
model.save_model(f'vqmodel-medium-v2-en+pl-1024c-dim64.model')

Architectural experiments

# with learned positional embeddings, no out_blocks
vqmodel = RQBottleneckTransformer(codebook_dim=16, vq_codes=512, q_depth=1, n_head=6, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True).cuda()
train("svq", vqmodel, train_ds, val_ds, bs=32, epochs=1, lr=3e-3, warmup_steps=1000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
OneCycle: 6290 1
'Entropy: 8.71'
0.00% [0/1 00:00<?]
samples train val codebook entropy time
50016 107.56952 157.32113 8.71 05:24
100000 85.44750 101.79171 8.70 10:37
126688 81.44776 104.25017 8.71 13:27

62.94% [3959/6290 13:26<07:54 #126688/201280 loss: 81.448 / 104.250]
/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:645: UserWarning: Length of IterableDataset Dataset: 10 batches x 32 samples, 1.9 hours) was reported to be 10 (when accessing len(dataloader)), but 11 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.
  warnings.warn(warn_msg)
/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:645: UserWarning: Length of IterableDataset Dataset: 10 batches x 32 samples, 1.9 hours) was reported to be 10 (when accessing len(dataloader)), but 12 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.
  warnings.warn(warn_msg)

# with learned positional embeddings, out_blocks before positional
vqmodel = RQBottleneckTransformer(codebook_dim=16, vq_codes=512, q_depth=1, n_head=6, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True).cuda()
train("svq", vqmodel, train_ds, val_ds, bs=32, epochs=1, lr=3e-3, warmup_steps=1000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
OneCycle: 6290 1
'Entropy: 8.70'
100.00% [1/1 22:57<00:00]
samples train val codebook entropy time
50016 23.45991 42.24113 8.80 05:48
100000 16.19686 23.67809 8.78 11:27
150016 11.99028 17.22306 8.74 17:07
200000 11.68037 16.67605 8.70 22:46
201280 11.92631 16.65236 8.70 22:57

100.00% [6290/6290 22:57<00:00 #201280/201280 loss: 11.926 / 16.652]
/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:645: UserWarning: Length of IterableDataset Dataset: 6290 batches x 32 samples, 1307.9 hours) was reported to be 6290 (when accessing len(dataloader)), but 6291 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.
  warnings.warn(warn_msg)

# with learned positional embeddings, out_blocks before positional, mlp before vq
vqmodel = RQBottleneckTransformer(codebook_dim=16, vq_codes=512, q_depth=1, n_head=6, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True).cuda()
train("svq", vqmodel, train_ds, val_ds, bs=32, epochs=1, lr=3e-3, warmup_steps=1000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
OneCycle: 6290 1
'Entropy: 8.57'
100.00% [1/1 23:09<00:00]
samples train val codebook entropy time
50016 24.63220 44.67238 8.74 05:53
100000 14.69983 19.67298 8.67 11:35
150016 11.50774 17.75203 8.58 17:16
200000 11.33895 15.66892 8.55 22:58
201280 10.87422 15.81362 8.57 23:09

100.00% [6290/6290 23:08<00:00 #201280/201280 loss: 10.874 / 15.814]

# with learned positional embeddings, out_blocks after positional, mlp before vq
vqmodel = RQBottleneckTransformer(codebook_dim=16, vq_codes=512, q_depth=1, n_head=6, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True).cuda()
train("svq", vqmodel, train_ds, val_ds, bs=32, epochs=1, lr=3e-3, warmup_steps=1000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
OneCycle: 6290 1
'Entropy: 8.54'
100.00% [1/1 23:11<00:00]
samples train val codebook entropy time
50016 18.37899 27.54997 8.65 05:53
100000 13.13329 17.32240 8.60 11:35
150016 10.83435 13.55371 8.56 17:18
200000 9.69492 12.35855 8.51 23:00
201280 10.54271 12.43994 8.54 23:11

100.00% [6290/6290 23:11<00:00 #201280/201280 loss: 10.543 / 12.440]

# with learned positional embeddings, out_blocks after positional, mlp before vq
vqmodel = RQBottleneckTransformer(codebook_dim=16, vq_codes=512, q_depth=1, n_head=6, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True).cuda()
train("svq", vqmodel, train_ds, val_ds, bs=32, epochs=5, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
vqmodel.save_model('vq-2d-512c-cosine-padfix-premlp-learnpos-5e.model')
OneCycle: 6290 5
'Entropy: 8.40'
100.00% [5/5 1:55:58<00:00]
samples train val codebook entropy time
50016 24.24790 47.61960 8.62 05:53
100000 14.35983 18.50102 8.55 11:35
150016 12.35634 16.84217 8.56 17:18
200000 11.74603 16.10603 8.52 23:00
250016 10.85323 14.83014 8.49 28:56
300000 10.78046 14.04290 8.47 34:38
350016 10.05354 12.98133 8.40 40:21
400000 9.59631 13.78049 8.50 46:03
450016 9.22316 12.76403 8.40 51:57
500000 9.38958 11.96084 8.46 57:40
550016 8.36034 12.59843 8.35 1:03:22
600000 9.39242 11.55411 8.43 1:09:05
650016 8.30749 10.80241 8.42 1:15:02
700000 8.20436 10.39852 8.48 1:20:45
750016 8.21392 10.36367 8.41 1:26:27
800000 7.73189 11.21438 8.48 1:32:10
850016 7.64852 10.93893 8.47 1:38:06
900000 7.72010 10.49391 8.39 1:43:48
950016 7.58901 9.85925 8.42 1:49:31
1000000 7.14871 10.67987 8.40 1:55:14
1006400 6.73056 10.67323 8.40 1:55:58

100.00% [6290/6290 23:12<00:00 #201280/201280 loss: 6.731 / 10.673]
/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:645: UserWarning: Length of IterableDataset Dataset: 10 batches x 32 samples, 1.9 hours) was reported to be 10 (when accessing len(dataloader)), but 11 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.
  warnings.warn(warn_msg)
/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:645: UserWarning: Length of IterableDataset Dataset: 10 batches x 32 samples, 1.9 hours) was reported to be 10 (when accessing len(dataloader)), but 12 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.
  warnings.warn(warn_msg)
/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:645: UserWarning: Length of IterableDataset Dataset: 6290 batches x 32 samples, 1307.9 hours) was reported to be 6290 (when accessing len(dataloader)), but 6291 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.
  warnings.warn(warn_msg)

# with learned positional embeddings, out_blocks after positional, mlp before vq
vqmodel = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=6, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True).cuda()
train("svq", vqmodel, train_ds, val_ds, bs=32, epochs=5, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
vqmodel.save_model('vq-2d-4096c-cosine32-padfix-premlp-learnpos-5e.model')
OneCycle: 6290 5
'Entropy: 11.07'
100.00% [5/5 1:57:58<00:00]
samples train val codebook entropy time
50016 15.49718 26.42581 11.23 06:00
100000 11.36006 14.78076 11.25 11:48
150016 10.29752 13.68974 11.19 17:36
200000 9.22019 12.14817 11.26 23:24
250016 9.09067 13.16928 11.17 29:26
300000 8.56113 12.38342 11.13 35:14
350016 8.30965 12.02589 11.15 41:02
400000 7.76135 10.97900 11.14 46:50
450016 7.34585 10.10667 11.11 52:53
500000 7.65255 11.02440 11.10 58:41
550016 7.47726 10.73619 11.10 1:04:29
600000 6.96974 9.63206 11.14 1:10:17
650016 6.93395 9.97940 11.08 1:16:19
700000 6.64507 8.91945 11.13 1:22:07
750016 6.53036 9.27800 11.01 1:27:55
800000 6.50427 8.30845 11.07 1:33:44
850016 6.51113 9.09502 11.12 1:39:48
900000 6.05660 8.44461 10.99 1:45:36
950016 6.20974 8.88156 11.06 1:51:25
1000000 5.95045 8.69922 11.08 1:57:13
1006400 6.18939 8.88604 11.07 1:57:58

100.00% [6290/6290 23:37<00:00 #201280/201280 loss: 6.189 / 8.886]

# base.en Whisper with learned positional embeddings, out_blocks after positional, mlp before vq
vqmodel = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=8, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
train("svq", vqmodel, train_ds, val_ds, bs=14, epochs=5, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
vqmodel.save_model('vq-base.en-2d-4096c-cosine32-padfix-premlp-learnpos-5e.model')
OneCycle: 6280 5
'Entropy: 10.86'
100.00% [5/5 3:05:51<00:00]
samples train val codebook entropy time
50016 18.17899 27.83681 11.11 09:23
100000 13.50658 17.32206 11.06 18:34
150016 12.10491 15.49411 11.08 27:47
200000 11.84169 15.30570 10.95 36:58
250016 11.19514 14.05272 10.99 46:23
300000 10.98578 13.69234 10.86 55:34
350016 10.58517 13.25610 10.99 1:04:46
400000 9.87159 12.88844 10.91 1:13:57
450016 9.76353 12.50161 10.92 1:23:22
500000 10.08099 12.71940 10.94 1:32:33
550016 9.85388 12.70232 10.89 1:41:45
600000 10.50843 11.94505 10.93 1:50:57
650016 9.29321 12.16166 10.96 2:00:20
700000 9.24717 11.35387 10.93 2:09:32
750016 8.80798 11.78821 10.95 2:18:43
800000 9.14499 10.97496 10.93 2:27:55
850016 8.75328 11.08632 10.96 2:37:21
900000 8.40084 10.79851 10.88 2:46:33
950016 8.73481 11.27116 10.96 2:55:45
1000000 8.55846 11.28967 10.86 3:04:57
1004800 8.09170 11.12924 10.86 3:05:51

100.00% [6280/6280 37:12<00:00 #200960/200960 loss: 8.092 / 11.129]
/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:645: UserWarning: Length of IterableDataset Dataset: 10 batches x 32 samples, 1.9 hours) was reported to be 10 (when accessing len(dataloader)), but 11 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.
  warnings.warn(warn_msg)
/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:645: UserWarning: Length of IterableDataset Dataset: 10 batches x 32 samples, 1.9 hours) was reported to be 10 (when accessing len(dataloader)), but 12 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.
  warnings.warn(warn_msg)
/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:645: UserWarning: Length of IterableDataset Dataset: 6280 batches x 32 samples, 1306.1 hours) was reported to be 6280 (when accessing len(dataloader)), but 6281 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.
  warnings.warn(warn_msg)

# base.en whisper with learned positional embeddings, out_blocks after positional, mlp before vq
# cleaned dataset (removed 1st and last segments)
vqmodel = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=8, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
train("svq", vqmodel, train_ds, val_ds, bs=14, epochs=5, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
vqmodel.save_model('vq-base.en-2d-4096c-cosine32-padfix-premlp-learnpos-5e-cleaned.model')
OneCycle: 6132 5
'Entropy: 10.79'
100.00% [5/5 3:09:42<00:00]
samples train val codebook entropy time
50016 19.44056 22.67257 11.13 09:46
100000 13.55178 14.58443 11.26 19:23
150016 11.96837 13.18968 11.09 29:00
200000 11.43871 12.44640 11.05 38:49
250016 11.28360 11.70081 11.10 48:26
300000 10.83751 11.31110 11.09 58:03
350016 10.69315 11.17086 11.12 1:07:40
400000 9.98770 10.92539 11.05 1:17:30
450016 9.83174 10.69181 11.05 1:27:07
500000 9.77236 10.48352 11.14 1:36:44
550016 9.66632 10.36597 11.09 1:46:21
600000 9.40930 10.08656 11.02 1:56:09
650016 9.44357 9.92484 11.04 2:05:46
700000 8.96556 9.79054 11.06 2:15:23
750016 8.83601 9.65099 11.01 2:25:00
800000 8.66107 9.39148 11.12 2:34:48
850016 8.44581 9.40969 11.00 2:44:26
900000 8.56439 9.22455 11.05 2:54:03
950016 8.52489 9.30351 11.03 3:03:40
981120 8.84632 9.33108 10.79 3:09:42

100.00% [6132/6132 37:57<00:00 #196224/196224 loss: 8.846 / 9.331]
/tmp/ipykernel_90303/1747892456.py:43: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.clear()
/tmp/ipykernel_90303/1747892456.py:46: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.set_xlim(10000, self.total_steps)
/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:645: UserWarning: Length of IterableDataset Dataset: 6132 batches x 32 samples, 1277.7 hours) was reported to be 6132 (when accessing len(dataloader)), but 6133 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.
  warnings.warn(warn_msg)

# base.en whisper with learned positional embeddings, out_blocks after positional, mlp before vq
# cleaned dataset
vqmodel = RQBottleneckTransformer(codebook_dim=32, vq_codes=1024, q_depth=1, n_head=8, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
train("svq", vqmodel, train_ds, val_ds, bs=12, epochs=5, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
vqmodel.save_model('vq-base.en-2d-1024c-cosine32-padfix-premlp-learnpos-5e-cleaned.model')
OneCycle: 6132 5
'Entropy: 9.36'
100.00% [5/5 3:08:14<00:00]
samples train val codebook entropy time
50016 21.66206 27.27091 9.59 09:41
100000 15.25066 16.20915 9.53 19:13
150016 13.21848 14.25581 9.54 28:45
200000 11.82871 13.98582 9.49 38:30
250016 11.85884 13.12596 9.42 48:02
300000 11.54107 12.60187 9.43 57:34
350016 11.45310 12.29700 9.46 1:07:07
400000 11.08207 11.98462 9.38 1:16:51
450016 10.65160 11.61482 9.44 1:26:24
500000 10.69448 11.57619 9.34 1:35:56
550016 10.25768 11.15084 9.38 1:45:29
600000 9.86860 10.86430 9.48 1:55:14
650016 9.90988 10.71315 9.44 2:04:47
700000 9.53233 10.52028 9.42 2:14:19
750016 9.89578 10.26827 9.36 2:23:52
800000 9.15078 10.15152 9.42 2:33:36
850016 9.16481 9.96554 9.34 2:43:09
900000 9.14512 9.90501 9.40 2:52:42
950016 9.18524 9.92719 9.36 3:02:15
981120 8.97033 9.95517 9.36 3:08:14

100.00% [6132/6132 37:41<00:00 #196224/196224 loss: 8.970 / 9.955]
/tmp/ipykernel_90303/1747892456.py:43: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.clear()
/tmp/ipykernel_90303/1747892456.py:46: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.set_xlim(10000, self.total_steps)

# base.en whisper with learned positional embeddings, out_blocks after positional, mlp before vq
# cleaned dataset
vqmodel = RQBottleneckTransformer(codebook_dim=32, vq_codes=64, q_depth=1, n_head=8, depth=1,
                                  downsample=1, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
train("svq", vqmodel, train_ds, val_ds, bs=14, epochs=5, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
vqmodel.save_model('vq-base.en-64c-cosine32-padfix-premlp-learnpos-5e-cleaned.model')
OneCycle: 6132 5
'Entropy: 5.64'
100.00% [5/5 3:09:51<00:00]
samples train val codebook entropy time
50016 76.17780 192.67165 5.82 09:48
100000 27.85803 31.11143 5.71 19:25
150016 19.38920 22.02595 5.75 29:02
200000 16.75521 18.75611 5.68 38:51
250016 16.22832 17.68415 5.60 48:29
300000 15.28871 16.20028 5.68 58:06
350016 14.91663 16.24565 5.63 1:07:43
400000 14.08824 15.30097 5.64 1:17:32
450016 13.53690 15.08575 5.61 1:27:10
500000 13.62558 14.45319 5.65 1:36:47
550016 12.45450 13.74045 5.66 1:46:25
600000 12.25172 14.05763 5.68 1:56:14
650016 12.76195 13.71730 5.69 2:05:51
700000 12.19483 13.02070 5.61 2:15:28
750016 11.83110 12.79714 5.62 2:25:06
800000 12.23673 12.70706 5.73 2:34:56
850016 11.69901 12.50606 5.64 2:44:34
900000 12.03180 12.29434 5.71 2:54:11
950016 12.06521 12.22985 5.67 3:03:49
981120 13.17802 12.70389 5.64 3:09:51

100.00% [6132/6132 38:00<00:00 #196224/196224 loss: 13.178 / 12.704]
/tmp/ipykernel_94907/1747892456.py:43: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.clear()
/tmp/ipykernel_94907/1747892456.py:46: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.set_xlim(10000, self.total_steps)
/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:645: UserWarning: Length of IterableDataset Dataset: 6132 batches x 32 samples, 1277.7 hours) was reported to be 6132 (when accessing len(dataloader)), but 6133 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.
  warnings.warn(warn_msg)

# base.en! with learned positional embeddings, out_blocks after positional, mlp before vq
# cleaned dataset
vqmodel = RQBottleneckTransformer(codebook_dim=32, vq_codes=512, q_depth=1, n_head=8, depth=1,
                                  downsample=1, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
train("svq", vqmodel, train_ds, val_ds, bs=12, epochs=5, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
vqmodel.save_model('vq-base.en-512c-cosine32-padfix-premlp-learnpos-5e-cleaned.model')
OneCycle: 6132 5
'Entropy: 8.44'
100.00% [5/5 3:10:13<00:00]
samples train val codebook entropy time
50016 21.94018 27.54010 8.70 09:48
100000 15.30265 16.38729 8.72 19:26
150016 13.55491 14.22489 8.67 29:04
200000 12.27958 13.59388 8.53 38:54
250016 11.48394 12.79483 8.59 48:33
300000 11.45791 12.34518 8.52 58:11
350016 11.51288 11.73254 8.54 1:07:49
400000 11.04880 11.61340 8.44 1:17:41
450016 10.74074 11.15114 8.51 1:27:20
500000 10.22759 11.11760 8.52 1:36:59
550016 10.23485 10.82111 8.45 1:46:38
600000 9.62602 10.52901 8.48 1:56:30
650016 9.54247 10.39591 8.40 2:06:08
700000 9.27610 10.17579 8.41 2:15:47
750016 9.39848 10.03072 8.46 2:25:25
800000 8.95939 9.87603 8.49 2:35:15
850016 9.08446 9.74571 8.47 2:44:54
900000 8.76172 9.79162 8.43 2:54:32
950016 9.12931 9.58630 8.47 3:04:10
981120 9.33700 9.72177 8.44 3:10:13

100.00% [6132/6132 38:02<00:00 #196224/196224 loss: 9.337 / 9.722]
/tmp/ipykernel_94907/1747892456.py:43: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.clear()
/tmp/ipykernel_94907/1747892456.py:46: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.set_xlim(10000, self.total_steps)

# base.en! with learned positional embeddings, out_blocks after positional, mlp before vq
# cleaned dataset
vqmodel = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=8, depth=1,
                                  downsample=1, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
train("svq", vqmodel, train_ds, val_ds, bs=14, epochs=1, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
OneCycle: 6132 1
'Entropy: 8.55'
100.00% [1/1 38:00<00:00]
samples train val codebook entropy time
50016 24.54137 31.36435 8.57 09:47
100000 15.90889 17.09020 8.58 19:26
150016 13.30405 13.95759 8.51 29:05
196224 14.19891 12.88708 8.55 38:00

100.00% [6132/6132 38:00<00:00 #196224/196224 loss: 14.199 / 12.887]
/tmp/ipykernel_94907/1747892456.py:43: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.clear()
/tmp/ipykernel_94907/1747892456.py:46: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.set_xlim(10000, self.total_steps)

# base.en! with learned positional embeddings, out_blocks after positional, mlp before vq
# cleaned dataset
vqmodel = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=8, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
train("svq", vqmodel, train_ds, val_ds, bs=14, epochs=1, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
OneCycle: 6132 1
'Entropy: 11.28'
100.00% [1/1 37:54<00:00]
samples train val codebook entropy time
50016 17.26417 22.29299 11.24 09:45
100000 12.41381 14.22859 11.25 19:22
150016 11.16801 11.97096 11.21 29:00
196224 10.49819 10.57301 11.28 37:54

100.00% [6132/6132 37:54<00:00 #196224/196224 loss: 10.498 / 10.573]
/tmp/ipykernel_94907/1747892456.py:43: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.clear()
/tmp/ipykernel_94907/1747892456.py:46: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.set_xlim(10000, self.total_steps)

# base.en! with learned positional embeddings, out_blocks after positional, mlp before vq
# cleaned dataset
vqmodel = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=8, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
train("svq", vqmodel, train_ds, val_ds, bs=14, epochs=5, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
vqmodel.save_model('vq-base.en-2d-4096c-cosine32-padfix-premlp-preconv-learnpos-5e-cleaned.model')
OneCycle: 6132 5
'Entropy: 10.75'
100.00% [5/5 3:11:21<00:00]
samples train val codebook entropy time
50016 18.85334 22.89696 10.80 09:51
100000 13.86454 16.37101 10.73 19:33
150016 12.85605 13.55042 10.70 29:15
200000 11.59676 12.87997 10.70 39:09
250016 11.12804 12.39809 10.76 48:52
300000 11.10460 11.67927 10.78 58:33
350016 11.11719 11.55583 10.77 1:08:16
400000 10.57183 11.07552 10.69 1:18:09
450016 10.49243 10.82820 10.79 1:27:51
500000 10.20853 10.77793 10.81 1:37:33
550016 10.11812 10.54805 10.73 1:47:15
600000 9.56493 10.22062 10.77 1:57:10
650016 9.40594 10.19217 10.68 2:06:52
700000 9.17259 9.85726 10.74 2:16:34
750016 9.18224 9.74915 10.68 2:26:17
800000 8.92105 9.47104 10.70 2:36:09
850016 8.61280 9.39290 10.71 2:45:51
900000 8.43418 9.33166 10.72 2:55:33
950016 8.57911 9.33823 10.71 3:05:16
981120 8.63924 9.37749 10.75 3:11:21

100.00% [6132/6132 38:16<00:00 #196224/196224 loss: 8.639 / 9.377]
/tmp/ipykernel_100642/1747892456.py:43: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.clear()
/tmp/ipykernel_100642/1747892456.py:46: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.set_xlim(10000, self.total_steps)
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

# base.en! with learned positional embeddings, out_blocks after positional, mlp before vq
# cleaned dataset, mean downsampling
vqmodel = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=8, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
train("svq", vqmodel, train_ds, val_ds, bs=14, epochs=5, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
vqmodel.save_model('vq-base.en-2d-4096c-cosine32-padfix-premlp-premean-learnpos-5e-cleaned.model')
OneCycle: 6132 5
'Entropy: 10.87'
100.00% [5/5 3:09:50<00:00]
samples train val codebook entropy time
50016 17.48580 22.87051 10.93 09:49
100000 13.30088 14.67394 11.07 19:26
150016 12.26683 12.99752 10.98 29:04
200000 11.53840 12.33599 10.96 38:53
250016 10.86994 12.00824 11.01 48:30
300000 10.59976 11.63654 11.01 58:08
350016 10.76181 11.29659 10.93 1:07:45
400000 9.99428 10.90412 10.98 1:17:35
450016 9.78972 10.65274 10.92 1:27:13
500000 9.70262 10.54080 10.93 1:36:50
550016 9.86663 10.32896 10.96 1:46:28
600000 9.41082 10.16734 10.97 1:56:16
650016 9.54473 9.94173 10.96 2:05:53
700000 9.06406 9.71947 10.93 2:15:30
750016 9.10101 9.46919 10.93 2:25:08
800000 8.60536 9.40041 10.94 2:34:56
850016 8.50216 9.23997 10.89 2:44:34
900000 8.29970 9.23626 10.90 2:54:11
950016 8.52151 9.20892 10.93 3:03:48
981120 8.69804 9.14721 10.87 3:09:50

100.00% [6132/6132 37:58<00:00 #196224/196224 loss: 8.698 / 9.147]
/tmp/ipykernel_129075/774804256.py:43: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.clear()
/tmp/ipykernel_129075/774804256.py:46: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.set_xlim(10000, self.total_steps)
/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:645: UserWarning: Length of IterableDataset Dataset: 6132 batches x 32 samples, 1277.7 hours) was reported to be 6132 (when accessing len(dataloader)), but 6133 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.
  warnings.warn(warn_msg)

# base.en! with learned positional embeddings, out_blocks after positional, mlp before vq
# cleaned dataset, mean downsampling
vqmodel = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=8, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
vqmodel.ensure_whisper()
train("svq", vqmodel, train_ds, val_ds, bs=14, epochs=5, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=16, visual_class=RQVisual)
'Entropy: 10.91'
0.00% [0/5 00:00<?]
samples train val codebook entropy time
50008 15.93577 18.26651 10.88 31:51
71736 14.07252 15.22314 10.91 57:51

35.23% [5124/14546 57:50<1:46:20 #14348/203648 loss: 14.073 / 15.223]

# base.en! with learned positional embeddings, out_blocks after positional, mlp before vq
# cleaned dataset, mean downsampling
vqmodel = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=8, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
vqmodel.ensure_whisper()
train("svq", vqmodel, train_ds, val_ds, bs=14, epochs=5, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=8, visual_class=RQVisual)
#vqmodel.save_model('vq-base.en-2d-4096c-cosine32-padfix-premlp-premean-learnpos-5e-cleaned.model')
'Entropy: 10.75'
20.00% [1/5 30:53<2:03:32]
samples train val codebook entropy time
50008 17.99252 21.13446 10.86 07:13
100002 14.73851 15.26074 10.74 14:30
150010 12.67679 13.50757 10.61 22:25
200004 11.98636 12.63929 10.72 30:13
248374 12.14378 12.26164 10.75 37:45

22.25% [3236/14546 06:51<23:57 #49675/203648 loss: 12.144 / 12.262]

# base.en! with learned positional embeddings, out_blocks after positional, mlp before vq
# cleaned dataset, mean downsampling, eqvad
vqmodel = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=8, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
train("svq", vqmodel, train_ds, val_ds, bs=14, epochs=5, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
vqmodel.save_model('vq-base.en-2d-4096c-cosine32-padfix-premlp-premean-learnpos-5e-cleaned-eqvad.model')
OneCycle: 9933 5
'Entropy: 9.83'
100.00% [5/5 5:07:42<00:00]
samples train val codebook entropy time
50016 18.06458 19.45549 10.27 09:48
100000 13.27705 13.06077 10.36 19:27
150016 11.91958 12.15395 10.17 29:05
200000 11.59404 11.67862 10.28 38:44
250016 11.44242 11.32514 10.16 48:22
300000 10.80200 11.16721 10.17 58:01
350016 10.78535 10.94168 10.32 1:07:53
400000 10.66275 10.93297 10.21 1:17:32
450016 11.32866 10.82697 10.23 1:27:11
500000 10.40007 10.87806 10.05 1:36:50
550016 10.74838 10.63030 10.02 1:46:30
600000 10.57567 10.58560 9.97 1:56:08
650016 10.26159 10.44148 10.19 2:06:01
700000 10.08803 10.51371 10.12 2:15:40
750016 10.02600 10.39278 9.97 2:25:19
800000 10.27624 10.39350 10.06 2:34:58
850016 10.19159 10.25763 9.81 2:44:37
900000 10.08171 10.23527 10.00 2:54:16
950016 9.88339 10.25396 9.92 3:03:55
1000000 9.62146 10.11803 10.06 3:13:46
1050016 9.46334 10.04561 9.84 3:23:25
1100000 9.51465 10.11484 9.79 3:33:04
1150016 9.50131 9.95828 9.79 3:42:43
1200000 9.53149 9.94314 9.89 3:52:22
1250016 9.33688 9.85693 9.80 4:02:01
1300000 9.26627 9.81014 9.75 4:11:53
1350016 9.37144 9.76661 9.77 4:21:32
1400000 9.06240 9.80434 9.76 4:31:11
1450016 9.10573 9.80284 9.77 4:40:50
1500000 9.01136 9.71748 9.74 4:50:29
1550016 9.15775 9.71512 9.85 5:00:08
1589280 9.26362 9.71802 9.83 5:07:42

100.00% [9933/9933 1:01:29<00:00 #317856/317856 loss: 9.264 / 9.718]
/tmp/ipykernel_133489/774804256.py:43: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.clear()
/tmp/ipykernel_133489/774804256.py:46: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.set_xlim(10000, self.total_steps)
/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:645: UserWarning: Length of IterableDataset Dataset: 9933 batches x 32 samples, 1275.0 hours) was reported to be 9933 (when accessing len(dataloader)), but 9934 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.
  warnings.warn(warn_msg)

# base.en! with learned positional embeddings, out_blocks after positional, mlp before vq
# cleaned dataset
# downsample conv
vqmodel = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=8, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
train("svq", vqmodel, train_ds, val_ds, bs=14, epochs=1, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
#vqmodel.save_model('vq-base.en-512c-cosine32-padfix-premlp-learnpos-5e-cleaned.model')
OneCycle: 6132 1
'Entropy: 10.70'
100.00% [1/1 38:13<00:00]
samples train val codebook entropy time
50016 18.56527 21.86226 10.70 09:50
100000 14.16297 14.83381 10.66 19:32
150016 11.57994 12.28649 10.68 29:14
196224 10.27239 10.96855 10.70 38:13

100.00% [6132/6132 38:13<00:00 #196224/196224 loss: 10.272 / 10.969]
/tmp/ipykernel_100642/1747892456.py:43: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.clear()
/tmp/ipykernel_100642/1747892456.py:46: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.set_xlim(10000, self.total_steps)
/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:645: UserWarning: Length of IterableDataset Dataset: 6132 batches x 32 samples, 1277.7 hours) was reported to be 6132 (when accessing len(dataloader)), but 6133 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.
  warnings.warn(warn_msg)

# base.en! with learned positional embeddings, out_blocks after positional, mlp before vq
# cleaned dataset
vqmodel = RQBottleneckTransformer(codebook_dim=64, vq_codes=4096, q_depth=1, n_head=8, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
train("svq", vqmodel, train_ds, val_ds, bs=14, epochs=1, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
#vqmodel.save_model('vq-base.en-512c-cosine32-padfix-premlp-learnpos-5e-cleaned.model')
OneCycle: 6132 1
'Entropy: 10.14'
0.00% [0/1 00:00<?]
samples train val codebook entropy time
50016 19.88679 26.18120 10.21 09:49
100000 14.04911 15.88962 10.19 19:26
107520 13.98125 15.41472 10.14 20:55

54.79% [3360/6132 20:54<17:14 #107520/196224 loss: 13.981 / 15.415]
/tmp/ipykernel_94907/1747892456.py:43: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.clear()
/tmp/ipykernel_94907/1747892456.py:46: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.set_xlim(10000, self.total_steps)

# base.en! with learned positional embeddings, out_blocks after positional, mlp before vq
# cleaned dataset
vqmodel = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=8, depth=2,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
train("svq", vqmodel, train_ds, val_ds, bs=14, epochs=1, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
#vqmodel.save_model('vq-base.en-512c-cosine32-padfix-premlp-learnpos-5e-cleaned.model')
OneCycle: 6132 1
'Entropy: 11.10'
100.00% [1/1 40:03<00:00]
samples train val codebook entropy time
50016 18.68695 25.23358 11.06 10:18
100000 13.17344 14.20349 11.11 20:28
150016 10.66736 11.51643 11.02 30:39
196224 9.68099 10.36363 11.10 40:03

100.00% [6132/6132 40:03<00:00 #196224/196224 loss: 9.681 / 10.364]
/tmp/ipykernel_94907/1747892456.py:43: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.clear()
/tmp/ipykernel_94907/1747892456.py:46: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.set_xlim(10000, self.total_steps)

# base.en! with learned positional embeddings, out_blocks after positional, mlp before vq
# cleaned dataset
vqmodel = RQBottleneckTransformer(codebook_dim=32, vq_codes=64, q_depth=2, n_head=8, depth=1,
                                  downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, whisper_model_name="base.en").cuda()
train("svq", vqmodel, train_ds, val_ds, bs=14, epochs=1, lr=3e-3, warmup_steps=2000,
      run_valid_every_iters=10000, table_row_every_iters=50000, dl_workers=4, visual_class=RQVisual)
#vqmodel.save_model('vq-base.en-512c-cosine32-padfix-premlp-learnpos-5e-cleaned.model')
OneCycle: 6132 1
'Entropy: 5.65'
100.00% [1/1 37:35<00:00]
samples train val codebook entropy time
50016 82.99027 173.42301 5.91 09:42
100000 31.85972 36.78515 5.81 19:14
150016 23.16688 25.48340 5.76 28:46
196224 20.68511 23.00216 5.65 37:36

100.00% [6132/6132 37:35<00:00 #196224/196224 loss: 20.685 / 23.002]
/tmp/ipykernel_94907/1747892456.py:43: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.clear()
/tmp/ipykernel_94907/1747892456.py:46: UserWarning: Attempt to set non-positive xlim on a log-scaled axis will be ignored.
  loss_p.set_xlim(10000, self.total_steps)