tptt package

Submodules

tptt.configuration_tptt module

Author : Fabien FURFARO

class tptt.configuration_tptt.TpttConfig(base_model_config: dict | PretrainedConfig | None = None, base_model_name: str = 'google/gemma-3-270m', base_model_subfolder: str | None = None, name_or_path: str | None = None, model_task: str = 'causal_lm', target_modules_names: List[str] | None = None, operator_mode: str | None = None, order: int = 1, alpha_gate: str = '1', beta_gate: str = 'k', linear: bool = True, trick: str = 'derivative', use_linear_checkpoint: bool | None = None, max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', lora_config: dict | None = None, padding_side: str | None = None, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None, **kwargs)

Bases: PretrainedConfig

Configuration class for the TPTT model. This class merges the backbone config (e.g., Llama) with custom TPTT parameters,

RECURRENT_MODES = {'delta_product': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'dt'}, 'delta_product_c': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'rdt'}, 'delta_product_r': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'rot'}, 'delta_rule': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 1, 'trick': 'dt'}, 'delta_rule_gelu': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': False, 'order': 1, 'trick': 'dt'}, 'delta_rule_kv': {'alpha_gate': 'c', 'beta_gate': 'kv', 'linear': True, 'order': 1, 'trick': 'dt'}, 'delta_rule_v': {'alpha_gate': 'c', 'beta_gate': 'v', 'linear': True, 'order': 1, 'trick': 'dt'}, 'gated_delta_product': {'alpha_gate': 'k', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'dt'}, 'gated_delta_rule': {'alpha_gate': 'k', 'beta_gate': 'k', 'linear': True, 'order': 1, 'trick': 'dt'}}
architectures = ['TpttModel']
auto_map = {'AutoConfig': 'configuration_tptt.TpttConfig', 'AutoModelForCausalLM': 'modeling_tptt.TpttModel'}
model_type: str = 'tptt'
tptt.configuration_tptt.convert_sets_to_lists(obj)

Convert sets to list for LoRA serialized config

tptt.configuration_tptt.generate_model_card(output_path: str, config: dict | object, template: str | None, extra_variables: Dict | None = None)

Generate a README.md file from a Jinja2 template and a configuration.

  • template can be either:
    • a full path to a template file

    • a short name (e.g., “model_card”) -> will be looked up inside default_templates_dir

tptt.configuration_tptt.get_model_name(lora: bool = True, cross_gate: bool = False, bidirectional: bool = False, order: int = 1, alpha_gate: str = 'c', beta_gate: str = 'k', linear: bool = True, trick: str = 'dt', prefix: str = 'liza', add_date: bool = True) str

Generate a compact, explicit model folder name with parameters and optional date. Example output: liza_lora_a-c_b-k_o-1_lin_trick-d_2025-09-10

tptt.configuration_tptt.render_template(template_path: str, variables: dict) str

Load and render a Jinja2 template from any file path.

tptt.configuration_tptt.write_model_card(output_path: str, content: str)

Write the generated content into README.md.

tptt.modeling_tptt module

This module implements the TPTT model with linear attention (LiZA) and LoRA support. Author : Fabien FURFARO TPTT : Transforming Pretrained Transformers into Titans (https://arxiv.org/abs/2506.17671)

class tptt.modeling_tptt.CausalAvgPool1d(output_size: int, offsets: tuple[int] = (0, 1, 2), mode: str = 'replicate')

Bases: Module

Causal sliding window average (uniform, no shape loss along sequence)

forward(x: Tensor) Tensor

x: [B, S, F] → [B, S, F → output_size]

class tptt.modeling_tptt.LCache

Bases: object

Cache for storing intermediate states of linear attention layers.

reset()

Clear all cached states and reset the token counter

update(layer_idx: int, **kwargs)

Detach all tensors to avoid retaining computation graphs

class tptt.modeling_tptt.LiZAttention(base_attn: Module, layer_idx: int, base_config: PretrainedConfig, linear_cache: LCache | None = None, operator_mode: str = 'delta_rule', use_linear_checkpoint: bool = False, recurrent_config: Dict[str, Any] | None = None, max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', padding_side: str = 'right', disable_linear_attn: bool = False, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)

Bases: Module

LiZA Linear Attention module, mixing linear and vanilla attention.

forward(hidden_states: Tensor, attention_mask: Tensor | None = None, **kwargs) Tensor

Mix linear and self attention forward

property is_sliding

Check if the base attention contain sliding window attention.

class tptt.modeling_tptt.LinearAttention(hidden_dim: int, num_heads: int, head_dim: int | None = None, num_key_value_heads: int | None = None, num_key_value_groups: int | None = None, bias: bool = True, dropout: float | None = None, linear_precision: dtype = torch.float32, padding_side: str = 'right', shared_attn: bool = False, layer_idx: int = 0, operator_mode: str | None = 'linear', use_linear_checkpoint: bool = False, recurrent_config: Dict[str, Any] | None = None, linear_cache: LCache | None = None, max_chunk_size: int = 64, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)

Bases: Module

Linear multi-head attention layer: [B, S, D] -> [B, S, D] Projections + gating + efficient linear attention mechanism (TPTT compatible).

compute_extended_householder(q, k, v, alpha, beta, seq_len)

Expand HouseHolder state (n_h > 1) with correct sequence length after expansion

compute_gate(k: Tensor, v: Tensor) Tuple[Tensor]

Compute the gating tensor according to the beta_gate.

forward(x: List[Tensor] | Tensor, attn_mask: Tensor | None = None, out_proj: Module | None = None, **kwargs: Any) Tensor

Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].

get_cache(use_cache: bool) Tuple[Tensor | None, Tuple[Tensor, Tensor, Tensor, Tensor] | None]

Retrieve recurrent state and qkvg buffers from the cache. (Only if causal)

merge_head_output(out, out_proj, dtype, device)

Merge heads, RMSNorm. Input shape: [B, H, S, d], output [B, S, D].

prepare_attention_input(q: Tensor, k: Tensor, v: Tensor) Tensor

Prepare input for linear attention. q,k,v, Input shape: [B, S, D], output [B, S, D].

save_cache(q: Tensor, k: Tensor, v: Tensor, alpha: Tensor, beta: Tensor, state: Tensor, n_orders: int) None

Save the recurrent state and qkv buffers to the cache. (Only if causal)

class tptt.modeling_tptt.LinearAttentionOp(use_linear_checkpoint: bool = False, max_chunk_size: int = 64, linear_precision: dtype = torch.float32)

Bases: Module

Base class for linear attention operators.

static chunk_delta_product_forward(query: Tensor, key: Tensor, value: Tensor, alpha: Tensor, beta: Tensor, chunk_size: int, linear_activation: bool = True, initial_state: Tensor | None = None, use_checkpoint: bool = True, linear_precision: dtype = torch.float32) Tuple[Tensor, Tensor]

Chunkwise parallel implementation https://arxiv.org/abs/2406.06484 For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.

forward(q: Tensor, k: Tensor, v: Tensor, alpha: Tensor, beta: Tensor, linear_activation: bool, recurrent_state: Tensor | None = None, **kwargs) Tensor

gate Forward pass for the attention operator.

class tptt.modeling_tptt.TpttModel(config: TpttConfig, **kwargs)

Bases: PreTrainedModel

TPTT model wrapper with linear attention (LiZA) and LoRA support. Handles only architecture and weights.

config_class

alias of TpttConfig

forward(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, labels: LongTensor | None = None, **kwargs)

Forward pass. All arguments are passed to the underlying base model.

classmethod from_pretrained(pretrained_model_name_or_path=None, *model_args, **kwargs)

Custom from_pretrained that accepts the standard positional argument

generate(*args, **kwargs)

Delegate the generate call to the backbone model, which supports generation

retie_lm_after_load(**kwargs)

Re-link lm_head after loading external weights.

save_pretrained(path: str, **kwargs)

Save model weights, config, and source code to the given path.

class tptt.modeling_tptt.VirtualTokenExpander(num_heads: int, head_dim: int, n: int, mode: Literal['dt', 'rot', 'rdt'] = 'rdt', flip: bool = True)

Bases: Module

Expands input tokens into ‘n’ virtual tokens using derivative and rotative methods.

forward(x: Tensor, force: str | None = None) Tensor

Forward pass to expand input tokens. [B, H, S, D] -> [B, H, S, n, D]

tptt.modeling_tptt.apply_linear_attention_mask(attention_mask: Tensor, v: Tensor, padding_side: str = 'right') Tensor

Extract if padding –> [B,S]

tptt.modeling_tptt.construct_causal_forward_solver(tri_tensor: Tensor, dtype: dtype = torch.float32) Tensor

Forward substitution for fast inversion during chunk propagation.

tptt.modeling_tptt.describe(x: Tensor, name='tensor') None

Prints the shape, min, max, mean, and std of a tensor.

tptt.modeling_tptt.ensure_stability(tensor: Tensor, min_val: float = -10000.0, max_val: float = 10000.0) Tensor

stability forcing

tptt.modeling_tptt.extract_layer_idx(module_name: str) int

Extract the layer index from a module name string.

tptt.modeling_tptt.find_embedding_lm(module: Module) Module | None

Find the embedding weight in a model module.

tptt.modeling_tptt.get_tptt_model(model: ~torch.nn.modules.module.Module, base_config: ~transformers.configuration_utils.PretrainedConfig, linear_cache: ~tptt.modeling_tptt.LCache | None = None, liza_attention: ~torch.nn.modules.module.Module = <class 'tptt.modeling_tptt.LiZAttention'>, target_modules_names: list[str] | None = None, operator_mode: str = 'delta_rule', use_linear_checkpoint: bool = False, recurrent_config: ~typing.Dict[str, ~typing.Any] | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: ~torch.dtype = torch.float32, max_self_attn_length: int | None = None, padding_side: str = 'right', bidirectional: bool = False, pooling_config: ~typing.Dict[str, ~typing.Any] | None = None, **kwargs) Tuple[PreTrainedModel, LCache]

Replace target modules in a model with LiZAttention.

tptt.modeling_tptt.get_valid_chunk_size(total_l: int, chunk_size: int) int

Return the largest chunk_size <= chunk_size that divides total_l.

tptt.modeling_tptt.load_tptt_safetensors(repo_or_path: str, model: PreTrainedModel | PeftModel, subfolder: str | None = None, token: str | None = None) PreTrainedModel | PeftModel

Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed.

tptt.modeling_tptt.match_dim(x: Tensor, dim: int, target_size: int) Tensor

Match the size of tensor x along dimension dim to target_size by interpolation

tptt.modeling_tptt.save_tptt_safetensors(model, path: str, name: str = 'adapter_model.safetensors')

Save trainable LoRA/Specific weights and adapting key names

tptt.modeling_tptt.sequential_delta_product_scan(q_chunks: Tensor, w: Tensor, u: Tensor, alpha_chunks: Tensor, linear_activation: bool, initial_recurrent_state: Tensor, linear_precision: dtype, use_checkpoint: bool) Tuple[Tensor, Tensor]

DeltaProduct implementation https://arxiv.org/abs/2502.10297 Implements the per-token Householder state updates.

tptt.modeling_tptt.set_trainable_parameters(model: PreTrainedModel, trainable_patterns: List[str] = None) PreTrainedModel

Freeze model parameters except trainable_patterns.

tptt.modeling_tptt.soft_clamp(x: Tensor, min_val: float = 1e-06, max_val: float = 0.999999) Tensor

Differentiable clamping for stability

tptt.modeling_tptt.split_qkv(base_attn: Module, qkv: Tensor) tuple[Tensor, Tensor, Tensor]

Split the QKV tensor into separate Q, K, and V tensors.

tptt.modeling_tptt.truncate_attention_mask(hidden_states: Tensor, attention_mask: Tensor, max_length: int) tuple[Tensor, Tensor]

Truncate hidden_states and attention_mask to the last window of size max_length

tptt.modeling_tptt.unlinear_activation(x: Tensor, scale: float = 2.0) Tensor

Unlinear activation between chunk

tptt.train_tptt module

Author : Fabien FURFARO

class tptt.train_tptt.LiZACallback(model: PreTrainedModel, mode: str = 'gradual', initial_weight: float = 0.0, final_weight: float = 0.5, transition_step: int | tuple | list = 100, weight_list: list | None = None, switch_period: int = 1)

Bases: TrainerCallback

TrainerCallback to schedule mag_weight or enable/disable linear attention during training.

Modes:
  • “gradual”: linear interpolation from initial_weight to final_weight.

  • “cyclic”: alternate between values in weight_list at each step.

  • “switch”: alternately enable/disable linear attention at each step.

on_log(args, state, control, logs=None, **kwargs)

Event called after logging the last logs.

on_step_end(args, state, control, **kwargs)

Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs.

class tptt.train_tptt.SaveBestModelCallback

Bases: TrainerCallback

TrainerCallback to save the best model based on evaluation loss.

on_evaluate(args, state, control, metrics=None, **kwargs)

Event called after an evaluation phase.

tptt.train_tptt.ensure_int(value: int | tuple | list) int

Ensure the value is a plain integer.

Module contents

This module implements the TPTT model with linear attention (LiZA) and LoRA support.

class tptt.CausalAvgPool1d(output_size: int, offsets: tuple[int] = (0, 1, 2), mode: str = 'replicate')

Bases: Module

Causal sliding window average (uniform, no shape loss along sequence)

forward(x: Tensor) Tensor

x: [B, S, F] → [B, S, F → output_size]

class tptt.LCache

Bases: object

Cache for storing intermediate states of linear attention layers.

reset()

Clear all cached states and reset the token counter

update(layer_idx: int, **kwargs)

Detach all tensors to avoid retaining computation graphs

class tptt.LiZACallback(model: PreTrainedModel, mode: str = 'gradual', initial_weight: float = 0.0, final_weight: float = 0.5, transition_step: int | tuple | list = 100, weight_list: list | None = None, switch_period: int = 1)

Bases: TrainerCallback

TrainerCallback to schedule mag_weight or enable/disable linear attention during training.

Modes:
  • “gradual”: linear interpolation from initial_weight to final_weight.

  • “cyclic”: alternate between values in weight_list at each step.

  • “switch”: alternately enable/disable linear attention at each step.

on_log(args, state, control, logs=None, **kwargs)

Event called after logging the last logs.

on_step_end(args, state, control, **kwargs)

Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs.

class tptt.LiZAttention(base_attn: Module, layer_idx: int, base_config: PretrainedConfig, linear_cache: LCache | None = None, operator_mode: str = 'delta_rule', use_linear_checkpoint: bool = False, recurrent_config: Dict[str, Any] | None = None, max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', padding_side: str = 'right', disable_linear_attn: bool = False, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)

Bases: Module

LiZA Linear Attention module, mixing linear and vanilla attention.

forward(hidden_states: Tensor, attention_mask: Tensor | None = None, **kwargs) Tensor

Mix linear and self attention forward

property is_sliding

Check if the base attention contain sliding window attention.

class tptt.LinearAttention(hidden_dim: int, num_heads: int, head_dim: int | None = None, num_key_value_heads: int | None = None, num_key_value_groups: int | None = None, bias: bool = True, dropout: float | None = None, linear_precision: dtype = torch.float32, padding_side: str = 'right', shared_attn: bool = False, layer_idx: int = 0, operator_mode: str | None = 'linear', use_linear_checkpoint: bool = False, recurrent_config: Dict[str, Any] | None = None, linear_cache: LCache | None = None, max_chunk_size: int = 64, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)

Bases: Module

Linear multi-head attention layer: [B, S, D] -> [B, S, D] Projections + gating + efficient linear attention mechanism (TPTT compatible).

compute_extended_householder(q, k, v, alpha, beta, seq_len)

Expand HouseHolder state (n_h > 1) with correct sequence length after expansion

compute_gate(k: Tensor, v: Tensor) Tuple[Tensor]

Compute the gating tensor according to the beta_gate.

forward(x: List[Tensor] | Tensor, attn_mask: Tensor | None = None, out_proj: Module | None = None, **kwargs: Any) Tensor

Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].

get_cache(use_cache: bool) Tuple[Tensor | None, Tuple[Tensor, Tensor, Tensor, Tensor] | None]

Retrieve recurrent state and qkvg buffers from the cache. (Only if causal)

merge_head_output(out, out_proj, dtype, device)

Merge heads, RMSNorm. Input shape: [B, H, S, d], output [B, S, D].

prepare_attention_input(q: Tensor, k: Tensor, v: Tensor) Tensor

Prepare input for linear attention. q,k,v, Input shape: [B, S, D], output [B, S, D].

save_cache(q: Tensor, k: Tensor, v: Tensor, alpha: Tensor, beta: Tensor, state: Tensor, n_orders: int) None

Save the recurrent state and qkv buffers to the cache. (Only if causal)

class tptt.LinearAttentionOp(use_linear_checkpoint: bool = False, max_chunk_size: int = 64, linear_precision: dtype = torch.float32)

Bases: Module

Base class for linear attention operators.

static chunk_delta_product_forward(query: Tensor, key: Tensor, value: Tensor, alpha: Tensor, beta: Tensor, chunk_size: int, linear_activation: bool = True, initial_state: Tensor | None = None, use_checkpoint: bool = True, linear_precision: dtype = torch.float32) Tuple[Tensor, Tensor]

Chunkwise parallel implementation https://arxiv.org/abs/2406.06484 For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.

forward(q: Tensor, k: Tensor, v: Tensor, alpha: Tensor, beta: Tensor, linear_activation: bool, recurrent_state: Tensor | None = None, **kwargs) Tensor

gate Forward pass for the attention operator.

class tptt.SaveBestModelCallback

Bases: TrainerCallback

TrainerCallback to save the best model based on evaluation loss.

on_evaluate(args, state, control, metrics=None, **kwargs)

Event called after an evaluation phase.

class tptt.TpttConfig(base_model_config: dict | PretrainedConfig | None = None, base_model_name: str = 'google/gemma-3-270m', base_model_subfolder: str | None = None, name_or_path: str | None = None, model_task: str = 'causal_lm', target_modules_names: List[str] | None = None, operator_mode: str | None = None, order: int = 1, alpha_gate: str = '1', beta_gate: str = 'k', linear: bool = True, trick: str = 'derivative', use_linear_checkpoint: bool | None = None, max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', lora_config: dict | None = None, padding_side: str | None = None, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None, **kwargs)

Bases: PretrainedConfig

Configuration class for the TPTT model. This class merges the backbone config (e.g., Llama) with custom TPTT parameters,

RECURRENT_MODES = {'delta_product': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'dt'}, 'delta_product_c': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'rdt'}, 'delta_product_r': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'rot'}, 'delta_rule': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': True, 'order': 1, 'trick': 'dt'}, 'delta_rule_gelu': {'alpha_gate': 'c', 'beta_gate': 'k', 'linear': False, 'order': 1, 'trick': 'dt'}, 'delta_rule_kv': {'alpha_gate': 'c', 'beta_gate': 'kv', 'linear': True, 'order': 1, 'trick': 'dt'}, 'delta_rule_v': {'alpha_gate': 'c', 'beta_gate': 'v', 'linear': True, 'order': 1, 'trick': 'dt'}, 'gated_delta_product': {'alpha_gate': 'k', 'beta_gate': 'k', 'linear': True, 'order': 2, 'trick': 'dt'}, 'gated_delta_rule': {'alpha_gate': 'k', 'beta_gate': 'k', 'linear': True, 'order': 1, 'trick': 'dt'}}
architectures = ['TpttModel']
auto_map = {'AutoConfig': 'configuration_tptt.TpttConfig', 'AutoModelForCausalLM': 'modeling_tptt.TpttModel'}
model_type: str = 'tptt'
class tptt.TpttModel(config: TpttConfig, **kwargs)

Bases: PreTrainedModel

TPTT model wrapper with linear attention (LiZA) and LoRA support. Handles only architecture and weights.

config_class

alias of TpttConfig

forward(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, labels: LongTensor | None = None, **kwargs)

Forward pass. All arguments are passed to the underlying base model.

classmethod from_pretrained(pretrained_model_name_or_path=None, *model_args, **kwargs)

Custom from_pretrained that accepts the standard positional argument

generate(*args, **kwargs)

Delegate the generate call to the backbone model, which supports generation

retie_lm_after_load(**kwargs)

Re-link lm_head after loading external weights.

save_pretrained(path: str, **kwargs)

Save model weights, config, and source code to the given path.

class tptt.VirtualTokenExpander(num_heads: int, head_dim: int, n: int, mode: Literal['dt', 'rot', 'rdt'] = 'rdt', flip: bool = True)

Bases: Module

Expands input tokens into ‘n’ virtual tokens using derivative and rotative methods.

forward(x: Tensor, force: str | None = None) Tensor

Forward pass to expand input tokens. [B, H, S, D] -> [B, H, S, n, D]

tptt.generate_model_card(output_path: str, config: dict | object, template: str | None, extra_variables: Dict | None = None)

Generate a README.md file from a Jinja2 template and a configuration.

  • template can be either:
    • a full path to a template file

    • a short name (e.g., “model_card”) -> will be looked up inside default_templates_dir

tptt.get_tptt_model(model: ~torch.nn.modules.module.Module, base_config: ~transformers.configuration_utils.PretrainedConfig, linear_cache: ~tptt.modeling_tptt.LCache | None = None, liza_attention: ~torch.nn.modules.module.Module = <class 'tptt.modeling_tptt.LiZAttention'>, target_modules_names: list[str] | None = None, operator_mode: str = 'delta_rule', use_linear_checkpoint: bool = False, recurrent_config: ~typing.Dict[str, ~typing.Any] | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: ~torch.dtype = torch.float32, max_self_attn_length: int | None = None, padding_side: str = 'right', bidirectional: bool = False, pooling_config: ~typing.Dict[str, ~typing.Any] | None = None, **kwargs) Tuple[PreTrainedModel, LCache]

Replace target modules in a model with LiZAttention.

tptt.load_tptt_safetensors(repo_or_path: str, model: PreTrainedModel | PeftModel, subfolder: str | None = None, token: str | None = None) PreTrainedModel | PeftModel

Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed.

tptt.save_tptt_safetensors(model, path: str, name: str = 'adapter_model.safetensors')

Save trainable LoRA/Specific weights and adapting key names