tptt package
Submodules
tptt.configuration_tptt module
Author : Fabien FURFARO
- class tptt.configuration_tptt.TpttConfig(base_model_config: dict | PretrainedConfig | None = None, base_model_name: str = 'google/gemma-3-270m', base_model_subfolder: str | None = None, name_or_path: str | None = None, model_task: str = 'causal_lm', target_modules_names: List[str] | None = None, operator_mode: str | None = None, order: int = 1, alpha_gate: str = '1', beta_gate: str = 'k', linear: bool = True, trick: str = 'derivative', use_linear_checkpoint: bool | None = None, max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', lora_config: dict | None = None, padding_side: str | None = None, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None, **kwargs)
Bases:
PretrainedConfig
Configuration class for the TPTT model. This class merges the backbone config (e.g., Llama) with custom TPTT parameters,
- RECURRENT_MODES = {'delta_product': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'dt'}, 'delta_product_c': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'rdt'}, 'delta_product_r': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'rot'}, 'delta_rule': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 1, 'trick': 'dt'}, 'delta_rule_gelu': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': False, 'order': 1, 'trick': 'dt'}, 'delta_rule_kv': {'alpha_gate': 'c', 'beta_gate': 'kv', 'linear': True, 'order': 1, 'trick': 'dt'}, 'delta_rule_v': {'alpha_gate': 'c', 'beta_gate': 'v', 'linear': True, 'order': 1, 'trick': 'dt'}, 'gated_delta_product': {'alpha_gate': 'k', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'dt'}, 'gated_delta_rule': {'alpha_gate': 'k', 'beta_gate': 'k', 'linear': True, 'order': 1, 'trick': 'dt'}}
- architectures = ['TpttModel']
- auto_map = {'AutoConfig': 'configuration_tptt.TpttConfig', 'AutoModelForCausalLM': 'modeling_tptt.TpttModel'}
- model_type: str = 'tptt'
- tptt.configuration_tptt.convert_sets_to_lists(obj)
Convert sets to list for LoRA serialized config
- tptt.configuration_tptt.generate_model_card(output_path: str, config: dict | object, template: str | None, extra_variables: Dict | None = None)
Generate a README.md file from a Jinja2 template and a configuration.
- template can be either:
a full path to a template file
a short name (e.g., “model_card”) -> will be looked up inside default_templates_dir
- tptt.configuration_tptt.get_model_name(lora: bool = True, cross_gate: bool = False, bidirectional: bool = False, order: int = 1, alpha_gate: str = 'c', beta_gate: str = 'k', linear: bool = True, trick: str = 'dt', prefix: str = 'liza', add_date: bool = True) str
Generate a compact, explicit model folder name with parameters and optional date. Example output: liza_lora_a-c_b-k_o-1_lin_trick-d_2025-09-10
- tptt.configuration_tptt.render_template(template_path: str, variables: dict) str
Load and render a Jinja2 template from any file path.
- tptt.configuration_tptt.write_model_card(output_path: str, content: str)
Write the generated content into README.md.
tptt.modeling_tptt module
This module implements the TPTT model with linear attention (LiZA) and LoRA support. Author : Fabien FURFARO TPTT : Transforming Pretrained Transformers into Titans (https://arxiv.org/abs/2506.17671)
- class tptt.modeling_tptt.CausalAvgPool1d(output_size: int, offsets: tuple[int] = (0, 1, 2), mode: str = 'replicate')
Bases:
Module
Causal sliding window average (uniform, no shape loss along sequence)
- forward(x: Tensor) Tensor
x: [B, S, F] → [B, S, F → output_size]
- class tptt.modeling_tptt.LCache
Bases:
object
Cache for storing intermediate states of linear attention layers.
- reset()
Clear all cached states and reset the token counter
- update(layer_idx: int, **kwargs)
Detach all tensors to avoid retaining computation graphs
- class tptt.modeling_tptt.LiZAttention(base_attn: Module, layer_idx: int, base_config: PretrainedConfig, linear_cache: LCache | None = None, operator_mode: str = 'delta_rule', use_linear_checkpoint: bool = False, recurrent_config: Dict[str, Any] | None = None, max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', padding_side: str = 'right', disable_linear_attn: bool = False, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)
Bases:
Module
LiZA Linear Attention module, mixing linear and vanilla attention.
- forward(hidden_states: Tensor, attention_mask: Tensor | None = None, **kwargs) Tensor
Mix linear and self attention forward
- property is_sliding
Check if the base attention contain sliding window attention.
- class tptt.modeling_tptt.LinearAttention(hidden_dim: int, num_heads: int, head_dim: int | None = None, num_key_value_heads: int | None = None, num_key_value_groups: int | None = None, bias: bool = True, dropout: float | None = None, linear_precision: dtype = torch.float32, padding_side: str = 'right', shared_attn: bool = False, layer_idx: int = 0, operator_mode: str | None = 'linear', use_linear_checkpoint: bool = False, recurrent_config: Dict[str, Any] | None = None, linear_cache: LCache | None = None, max_chunk_size: int = 64, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)
Bases:
Module
Linear multi-head attention layer: [B, S, D] -> [B, S, D] Projections + gating + efficient linear attention mechanism (TPTT compatible).
- compute_extended_householder(q, k, v, alpha, beta, seq_len)
Expand HouseHolder state (n_h > 1) with correct sequence length after expansion
- compute_gate(k: Tensor, v: Tensor) Tuple[Tensor]
Compute the gating tensor according to the beta_gate.
- forward(x: List[Tensor] | Tensor, attn_mask: Tensor | None = None, out_proj: Module | None = None, **kwargs: Any) Tensor
Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].
- get_cache(use_cache: bool) Tuple[Tensor | None, Tuple[Tensor, Tensor, Tensor, Tensor] | None]
Retrieve recurrent state and qkvg buffers from the cache. (Only if causal)
- merge_head_output(out, out_proj, dtype, device)
Merge heads, RMSNorm. Input shape: [B, H, S, d], output [B, S, D].
- prepare_attention_input(q: Tensor, k: Tensor, v: Tensor) Tensor
Prepare input for linear attention. q,k,v, Input shape: [B, S, D], output [B, S, D].
- save_cache(q: Tensor, k: Tensor, v: Tensor, alpha: Tensor, beta: Tensor, state: Tensor, n_orders: int) None
Save the recurrent state and qkv buffers to the cache. (Only if causal)
- class tptt.modeling_tptt.LinearAttentionOp(use_linear_checkpoint: bool = False, max_chunk_size: int = 64, linear_precision: dtype = torch.float32)
Bases:
Module
Base class for linear attention operators.
- static chunk_delta_product_forward(query: Tensor, key: Tensor, value: Tensor, alpha: Tensor, beta: Tensor, chunk_size: int, linear_activation: bool = True, initial_state: Tensor | None = None, use_checkpoint: bool = True, linear_precision: dtype = torch.float32) Tuple[Tensor, Tensor]
Chunkwise parallel implementation https://arxiv.org/abs/2406.06484 For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.
- forward(q: Tensor, k: Tensor, v: Tensor, alpha: Tensor, beta: Tensor, linear_activation: bool, recurrent_state: Tensor | None = None, **kwargs) Tensor
gate Forward pass for the attention operator.
- class tptt.modeling_tptt.TpttModel(config: TpttConfig, **kwargs)
Bases:
PreTrainedModel
TPTT model wrapper with linear attention (LiZA) and LoRA support. Handles only architecture and weights.
- config_class
alias of
TpttConfig
- forward(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, labels: LongTensor | None = None, **kwargs)
Forward pass. All arguments are passed to the underlying base model.
- classmethod from_pretrained(pretrained_model_name_or_path=None, *model_args, **kwargs)
Custom from_pretrained that accepts the standard positional argument
- generate(*args, **kwargs)
Delegate the generate call to the backbone model, which supports generation
- retie_lm_after_load(**kwargs)
Re-link lm_head after loading external weights.
- save_pretrained(path: str, **kwargs)
Save model weights, config, and source code to the given path.
- class tptt.modeling_tptt.VirtualTokenExpander(num_heads: int, head_dim: int, n: int, mode: Literal['dt', 'rot', 'rdt'] = 'rdt', flip: bool = True)
Bases:
Module
Expands input tokens into ‘n’ virtual tokens using derivative and rotative methods.
- forward(x: Tensor, force: str | None = None) Tensor
Forward pass to expand input tokens. [B, H, S, D] -> [B, H, S, n, D]
- tptt.modeling_tptt.apply_linear_attention_mask(attention_mask: Tensor, v: Tensor, padding_side: str = 'right') Tensor
Extract if padding –> [B,S]
- tptt.modeling_tptt.construct_causal_forward_solver(tri_tensor: Tensor, dtype: dtype = torch.float32) Tensor
Forward substitution for fast inversion during chunk propagation.
- tptt.modeling_tptt.describe(x: Tensor, name='tensor') None
Prints the shape, min, max, mean, and std of a tensor.
- tptt.modeling_tptt.ensure_stability(tensor: Tensor, min_val: float = -10000.0, max_val: float = 10000.0) Tensor
stability forcing
- tptt.modeling_tptt.extract_layer_idx(module_name: str) int
Extract the layer index from a module name string.
- tptt.modeling_tptt.find_embedding_lm(module: Module) Module | None
Find the embedding weight in a model module.
- tptt.modeling_tptt.get_tptt_model(model: ~torch.nn.modules.module.Module, base_config: ~transformers.configuration_utils.PretrainedConfig, linear_cache: ~tptt.modeling_tptt.LCache | None = None, liza_attention: ~torch.nn.modules.module.Module = <class 'tptt.modeling_tptt.LiZAttention'>, target_modules_names: list[str] | None = None, operator_mode: str = 'delta_rule', use_linear_checkpoint: bool = False, recurrent_config: ~typing.Dict[str, ~typing.Any] | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: ~torch.dtype = torch.float32, max_self_attn_length: int | None = None, padding_side: str = 'right', bidirectional: bool = False, pooling_config: ~typing.Dict[str, ~typing.Any] | None = None, **kwargs) Tuple[PreTrainedModel, LCache]
Replace target modules in a model with LiZAttention.
- tptt.modeling_tptt.get_valid_chunk_size(total_l: int, chunk_size: int) int
Return the largest chunk_size <= chunk_size that divides total_l.
- tptt.modeling_tptt.load_tptt_safetensors(repo_or_path: str, model: PreTrainedModel | PeftModel, subfolder: str | None = None, token: str | None = None) PreTrainedModel | PeftModel
Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed.
- tptt.modeling_tptt.match_dim(x: Tensor, dim: int, target_size: int) Tensor
Match the size of tensor x along dimension dim to target_size by interpolation
- tptt.modeling_tptt.save_tptt_safetensors(model, path: str, name: str = 'adapter_model.safetensors')
Save trainable LoRA/Specific weights and adapting key names
- tptt.modeling_tptt.sequential_delta_product_scan(q_chunks: Tensor, w: Tensor, u: Tensor, alpha_chunks: Tensor, linear_activation: bool, initial_recurrent_state: Tensor, linear_precision: dtype, use_checkpoint: bool) Tuple[Tensor, Tensor]
DeltaProduct implementation https://arxiv.org/abs/2502.10297 Implements the per-token Householder state updates.
- tptt.modeling_tptt.set_trainable_parameters(model: PreTrainedModel, trainable_patterns: List[str] = None) PreTrainedModel
Freeze model parameters except trainable_patterns.
- tptt.modeling_tptt.soft_clamp(x: Tensor, min_val: float = 1e-06, max_val: float = 0.999999) Tensor
Differentiable clamping for stability
- tptt.modeling_tptt.split_qkv(base_attn: Module, qkv: Tensor) tuple[Tensor, Tensor, Tensor]
Split the QKV tensor into separate Q, K, and V tensors.
- tptt.modeling_tptt.truncate_attention_mask(hidden_states: Tensor, attention_mask: Tensor, max_length: int) tuple[Tensor, Tensor]
Truncate hidden_states and attention_mask to the last window of size max_length
- tptt.modeling_tptt.unlinear_activation(x: Tensor, scale: float = 2.0) Tensor
Unlinear activation between chunk
tptt.train_tptt module
Author : Fabien FURFARO
- class tptt.train_tptt.LiZACallback(model: PreTrainedModel, mode: str = 'gradual', initial_weight: float = 0.0, final_weight: float = 0.5, transition_step: int | tuple | list = 100, weight_list: list | None = None, switch_period: int = 1)
Bases:
TrainerCallback
TrainerCallback to schedule mag_weight or enable/disable linear attention during training.
- Modes:
“gradual”: linear interpolation from initial_weight to final_weight.
“cyclic”: alternate between values in weight_list at each step.
“switch”: alternately enable/disable linear attention at each step.
- on_log(args, state, control, logs=None, **kwargs)
Event called after logging the last logs.
- on_step_end(args, state, control, **kwargs)
Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs.
- class tptt.train_tptt.SaveBestModelCallback
Bases:
TrainerCallback
TrainerCallback to save the best model based on evaluation loss.
- on_evaluate(args, state, control, metrics=None, **kwargs)
Event called after an evaluation phase.
- tptt.train_tptt.ensure_int(value: int | tuple | list) int
Ensure the value is a plain integer.
Module contents
This module implements the TPTT model with linear attention (LiZA) and LoRA support.
- class tptt.CausalAvgPool1d(output_size: int, offsets: tuple[int] = (0, 1, 2), mode: str = 'replicate')
Bases:
Module
Causal sliding window average (uniform, no shape loss along sequence)
- forward(x: Tensor) Tensor
x: [B, S, F] → [B, S, F → output_size]
- class tptt.LCache
Bases:
object
Cache for storing intermediate states of linear attention layers.
- reset()
Clear all cached states and reset the token counter
- update(layer_idx: int, **kwargs)
Detach all tensors to avoid retaining computation graphs
- class tptt.LiZACallback(model: PreTrainedModel, mode: str = 'gradual', initial_weight: float = 0.0, final_weight: float = 0.5, transition_step: int | tuple | list = 100, weight_list: list | None = None, switch_period: int = 1)
Bases:
TrainerCallback
TrainerCallback to schedule mag_weight or enable/disable linear attention during training.
- Modes:
“gradual”: linear interpolation from initial_weight to final_weight.
“cyclic”: alternate between values in weight_list at each step.
“switch”: alternately enable/disable linear attention at each step.
- on_log(args, state, control, logs=None, **kwargs)
Event called after logging the last logs.
- on_step_end(args, state, control, **kwargs)
Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs.
- class tptt.LiZAttention(base_attn: Module, layer_idx: int, base_config: PretrainedConfig, linear_cache: LCache | None = None, operator_mode: str = 'delta_rule', use_linear_checkpoint: bool = False, recurrent_config: Dict[str, Any] | None = None, max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', padding_side: str = 'right', disable_linear_attn: bool = False, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)
Bases:
Module
LiZA Linear Attention module, mixing linear and vanilla attention.
- forward(hidden_states: Tensor, attention_mask: Tensor | None = None, **kwargs) Tensor
Mix linear and self attention forward
- property is_sliding
Check if the base attention contain sliding window attention.
- class tptt.LinearAttention(hidden_dim: int, num_heads: int, head_dim: int | None = None, num_key_value_heads: int | None = None, num_key_value_groups: int | None = None, bias: bool = True, dropout: float | None = None, linear_precision: dtype = torch.float32, padding_side: str = 'right', shared_attn: bool = False, layer_idx: int = 0, operator_mode: str | None = 'linear', use_linear_checkpoint: bool = False, recurrent_config: Dict[str, Any] | None = None, linear_cache: LCache | None = None, max_chunk_size: int = 64, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)
Bases:
Module
Linear multi-head attention layer: [B, S, D] -> [B, S, D] Projections + gating + efficient linear attention mechanism (TPTT compatible).
- compute_extended_householder(q, k, v, alpha, beta, seq_len)
Expand HouseHolder state (n_h > 1) with correct sequence length after expansion
- compute_gate(k: Tensor, v: Tensor) Tuple[Tensor]
Compute the gating tensor according to the beta_gate.
- forward(x: List[Tensor] | Tensor, attn_mask: Tensor | None = None, out_proj: Module | None = None, **kwargs: Any) Tensor
Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].
- get_cache(use_cache: bool) Tuple[Tensor | None, Tuple[Tensor, Tensor, Tensor, Tensor] | None]
Retrieve recurrent state and qkvg buffers from the cache. (Only if causal)
- merge_head_output(out, out_proj, dtype, device)
Merge heads, RMSNorm. Input shape: [B, H, S, d], output [B, S, D].
- prepare_attention_input(q: Tensor, k: Tensor, v: Tensor) Tensor
Prepare input for linear attention. q,k,v, Input shape: [B, S, D], output [B, S, D].
- save_cache(q: Tensor, k: Tensor, v: Tensor, alpha: Tensor, beta: Tensor, state: Tensor, n_orders: int) None
Save the recurrent state and qkv buffers to the cache. (Only if causal)
- class tptt.LinearAttentionOp(use_linear_checkpoint: bool = False, max_chunk_size: int = 64, linear_precision: dtype = torch.float32)
Bases:
Module
Base class for linear attention operators.
- static chunk_delta_product_forward(query: Tensor, key: Tensor, value: Tensor, alpha: Tensor, beta: Tensor, chunk_size: int, linear_activation: bool = True, initial_state: Tensor | None = None, use_checkpoint: bool = True, linear_precision: dtype = torch.float32) Tuple[Tensor, Tensor]
Chunkwise parallel implementation https://arxiv.org/abs/2406.06484 For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.
- forward(q: Tensor, k: Tensor, v: Tensor, alpha: Tensor, beta: Tensor, linear_activation: bool, recurrent_state: Tensor | None = None, **kwargs) Tensor
gate Forward pass for the attention operator.
- class tptt.SaveBestModelCallback
Bases:
TrainerCallback
TrainerCallback to save the best model based on evaluation loss.
- on_evaluate(args, state, control, metrics=None, **kwargs)
Event called after an evaluation phase.
- class tptt.TpttConfig(base_model_config: dict | PretrainedConfig | None = None, base_model_name: str = 'google/gemma-3-270m', base_model_subfolder: str | None = None, name_or_path: str | None = None, model_task: str = 'causal_lm', target_modules_names: List[str] | None = None, operator_mode: str | None = None, order: int = 1, alpha_gate: str = '1', beta_gate: str = 'k', linear: bool = True, trick: str = 'derivative', use_linear_checkpoint: bool | None = None, max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', lora_config: dict | None = None, padding_side: str | None = None, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None, **kwargs)
Bases:
PretrainedConfig
Configuration class for the TPTT model. This class merges the backbone config (e.g., Llama) with custom TPTT parameters,
- RECURRENT_MODES = {'delta_product': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'dt'}, 'delta_product_c': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'rdt'}, 'delta_product_r': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'rot'}, 'delta_rule': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 1, 'trick': 'dt'}, 'delta_rule_gelu': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': False, 'order': 1, 'trick': 'dt'}, 'delta_rule_kv': {'alpha_gate': 'c', 'beta_gate': 'kv', 'linear': True, 'order': 1, 'trick': 'dt'}, 'delta_rule_v': {'alpha_gate': 'c', 'beta_gate': 'v', 'linear': True, 'order': 1, 'trick': 'dt'}, 'gated_delta_product': {'alpha_gate': 'k', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'dt'}, 'gated_delta_rule': {'alpha_gate': 'k', 'beta_gate': 'k', 'linear': True, 'order': 1, 'trick': 'dt'}}
- architectures = ['TpttModel']
- auto_map = {'AutoConfig': 'configuration_tptt.TpttConfig', 'AutoModelForCausalLM': 'modeling_tptt.TpttModel'}
- model_type: str = 'tptt'
- class tptt.TpttModel(config: TpttConfig, **kwargs)
Bases:
PreTrainedModel
TPTT model wrapper with linear attention (LiZA) and LoRA support. Handles only architecture and weights.
- config_class
alias of
TpttConfig
- forward(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, labels: LongTensor | None = None, **kwargs)
Forward pass. All arguments are passed to the underlying base model.
- classmethod from_pretrained(pretrained_model_name_or_path=None, *model_args, **kwargs)
Custom from_pretrained that accepts the standard positional argument
- generate(*args, **kwargs)
Delegate the generate call to the backbone model, which supports generation
- retie_lm_after_load(**kwargs)
Re-link lm_head after loading external weights.
- save_pretrained(path: str, **kwargs)
Save model weights, config, and source code to the given path.
- class tptt.VirtualTokenExpander(num_heads: int, head_dim: int, n: int, mode: Literal['dt', 'rot', 'rdt'] = 'rdt', flip: bool = True)
Bases:
Module
Expands input tokens into ‘n’ virtual tokens using derivative and rotative methods.
- forward(x: Tensor, force: str | None = None) Tensor
Forward pass to expand input tokens. [B, H, S, D] -> [B, H, S, n, D]
- tptt.generate_model_card(output_path: str, config: dict | object, template: str | None, extra_variables: Dict | None = None)
Generate a README.md file from a Jinja2 template and a configuration.
- template can be either:
a full path to a template file
a short name (e.g., “model_card”) -> will be looked up inside default_templates_dir
- tptt.get_tptt_model(model: ~torch.nn.modules.module.Module, base_config: ~transformers.configuration_utils.PretrainedConfig, linear_cache: ~tptt.modeling_tptt.LCache | None = None, liza_attention: ~torch.nn.modules.module.Module = <class 'tptt.modeling_tptt.LiZAttention'>, target_modules_names: list[str] | None = None, operator_mode: str = 'delta_rule', use_linear_checkpoint: bool = False, recurrent_config: ~typing.Dict[str, ~typing.Any] | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: ~torch.dtype = torch.float32, max_self_attn_length: int | None = None, padding_side: str = 'right', bidirectional: bool = False, pooling_config: ~typing.Dict[str, ~typing.Any] | None = None, **kwargs) Tuple[PreTrainedModel, LCache]
Replace target modules in a model with LiZAttention.
- tptt.load_tptt_safetensors(repo_or_path: str, model: PreTrainedModel | PeftModel, subfolder: str | None = None, token: str | None = None) PreTrainedModel | PeftModel
Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed.
- tptt.save_tptt_safetensors(model, path: str, name: str = 'adapter_model.safetensors')
Save trainable LoRA/Specific weights and adapting key names