Source code for Tools.inclusion_updater

"""
CLI tool to place protein inclusions in membrane.
Reads input.str for protein definitions and places new inclusions in the membrane.

Usage:
    TS2CG INU -p point -t 0 -r 2 -N 10 -c 0.1 -o point_new -l both
"""

import argparse
from pathlib import Path
import numpy as np
import logging
from dataclasses import dataclass
from typing import List, Optional, Set, Dict, Sequence, Literal
from scipy.special import logsumexp
from scipy.spatial.distance import cdist

from PointClass.point import Point

logger = logging.getLogger(__name__)

[docs] def pbc_wrap(membrane): """Wrap membrane coordinates into the primary box""" box = membrane.box membrane.outer.coordinates = membrane.outer.coordinates - box * np.round(membrane.outer.coordinates / box) if not membrane.monolayer: membrane.inner.coordinates = membrane.inner.coordinates - box * np.round(membrane.inner.coordinates / box) return membrane
[docs] def get_nearby_points_both_leaflets(membrane: Point, leaflet: str, point_idx: int, radius: float) -> Dict[str, np.ndarray]: """ Find points within radius in both leaflets from a point in the specified leaflet. Args: membrane: Membrane Point object leaflet: Which leaflet the center point is in ('inner' or 'outer') point_idx: Index of the center point radius: Exclusion radius Returns: Dict with excluded points for each leaflet """ # Get coordinates of the center point source_layer = membrane.outer if leaflet == 'outer' else membrane.inner center_coords = source_layer.coordinates[point_idx] excluded = {} # Check both leaflets for target_leaflet in ['outer', 'inner']: if target_leaflet == 'inner' and membrane.monolayer: continue target_layer = membrane.outer if target_leaflet == 'outer' else membrane.inner # Calculate displacement vectors displacement = target_layer.coordinates - center_coords # Apply minimum image convention displacement = displacement - membrane.box * np.round(displacement / membrane.box) # Calculate distances distances = np.sqrt(np.sum(displacement**2, axis=1)) # Get points within radius excluded[target_leaflet] = np.where(distances <= radius)[0] return excluded
[docs] def calculate_curvature_weights(curvatures: np.ndarray, target_curvature: Optional[float], k_factor: float) -> np.ndarray: """Calculate Boltzmann weights based on curvature preference""" if target_curvature is None: # If no curvature preference, return uniform weights return np.ones_like(curvatures) / len(curvatures) # Calculate weights using the log-sum-exp trick for numerical stability deltas = curvatures - target_curvature log_weights = -k_factor * deltas**2 log_weights -= logsumexp(log_weights) # normalize return np.exp(log_weights)
[docs] def get_points_near_existing_proteins(membrane: Point, radius: float) -> Dict[str, Set[int]]: """Get points that are too close to existing proteins in any leaflet""" excluded_points = {'outer': set(), 'inner': set()} # Skip if no existing proteins if not membrane.inclusions.points: return excluded_points # Check each existing protein for inclusion in membrane.inclusions.points: point_id = inclusion['point_id'] # Determine which leaflet the protein is in (assuming outer if point exists in both) if point_id in membrane.outer.ids: source_leaflet = 'outer' else: source_leaflet = 'inner' # Get points too close in both leaflets nearby = get_nearby_points_both_leaflets( membrane, source_leaflet, point_id, radius ) # Update excluded points for each leaflet for leaflet in nearby: excluded_points[leaflet].update(nearby[leaflet]) return excluded_points
[docs] def place_proteins(membrane: Point, type_id: int, radius: float, num_proteins: Optional[int] = None, target_curvature: Optional[float] = None, k_factor: float = 1.0, leaflet: str = 'both') -> Dict[str, int]: """Place proteins in membrane with given constraints""" # Get next available type_id existing_type_ids = [inc['type_id'] for inc in membrane.inclusions.points] type_id = type_id or max(existing_type_ids, default=0) + 1 # Get points excluded by existing proteins excluded_by_existing = get_points_near_existing_proteins(membrane, radius) logger.info("Excluding points near existing proteins:") for l, points in excluded_by_existing.items(): if points: logger.info(f" {l} leaflet: {len(points)} points excluded") # Initialize available points for each leaflet available_points = { 'outer': set(membrane.outer.ids) - excluded_by_existing['outer'] if leaflet in ['both', 'outer'] else set(), 'inner': (set(membrane.inner.ids) - excluded_by_existing['inner'] if not membrane.monolayer and leaflet in ['both', 'inner'] else set()) } total_proteins = num_proteins if num_proteins else len(membrane.outer.ids) results = {'outer': 0, 'inner': 0} logger.info(f"Attempting to place {total_proteins} of protein type {type_id}") logger.info(f"Available points for placement:") for l, points in available_points.items(): if points: logger.info(f" {l} leaflet: {len(points)} points") logger.info(f"Radius: {radius:.1f} nm" + (f", Target curvature: {target_curvature:.3f}" if target_curvature is not None else "")) placed = 0 while placed < total_proteins: # Determine valid leaflets (those with available points) valid_leaflets = [] if leaflet in ['both', 'outer'] and available_points['outer']: valid_leaflets.append('outer') if leaflet in ['both', 'inner'] and available_points['inner']: valid_leaflets.append('inner') if not valid_leaflets: logger.warning("No more valid points available for placement") break # Randomly choose leaflet current_leaflet = rng.choice(valid_leaflets) membrane_layer = membrane.outer if current_leaflet == 'outer' else membrane.inner # Get valid indices for current leaflet valid_indices = np.array(list(available_points[current_leaflet])) # Calculate weights based on curvature preference curvatures = membrane_layer.mean_curvature[valid_indices] if current_leaflet == 'inner': curvatures = -curvatures # Flip curvature for inner leaflet weights = calculate_curvature_weights( curvatures, target_curvature, k_factor ) # Choose placement point chosen_idx = rng.choice(valid_indices, p=weights) # Add protein inclusion membrane.inclusions.add_protein( type_id=type_id, point_id=chosen_idx ) # Update available points accounting for exclusion radius in both leaflets nearby = get_nearby_points_both_leaflets( membrane, current_leaflet, chosen_idx, radius ) # Update available points for both leaflets for leaflet_name, excluded in nearby.items(): available_points[leaflet_name] -= set(excluded) placed += 1 results[current_leaflet] += 1 logger.debug(f"Placed protein at point {chosen_idx} in {current_leaflet} leaflet") return results
[docs] def INU(args: List[str]) -> None: """Main entry point for protein inclusion tool""" parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument('-p', '--point-dir', type=Path, default="point", help='Path to membrane point directory (default: point/)') parser.add_argument('-t', '--type-id', type=int, required=False, help='Specify which protein type to add into the membrane ') parser.add_argument('-r', '--radius', type=float, required=True, help='Exclusion radius for protein placement (point to point distance)') parser.add_argument('-c', '--curvature', type=float, help='Target curvature for placement (optional)') parser.add_argument('-n', '--num-proteins', type=int, help='Number of proteins to place (optional)') parser.add_argument('-k', '--k-factor', type=float, default=1.0, help='Scaling factor for curvature preference strength (default: 1.0)') parser.add_argument('-l', '--leaflet', choices=['both', 'inner', 'outer'], default='both', help='Which membrane leaflet(s) to modify (default: both)') parser.add_argument('-o', '--output', type=Path, help='Output directory (defaults to input directory)') parser.add_argument('--seed', type=int, help='Random seed for reproducibility') args = parser.parse_args(args) logging.basicConfig(level=logging.INFO) if args.num_proteins is None and args.curvature is not None: logger.warning( "Curvature preference specified without number of proteins to place. " "This will attempt to place proteins at all valid points, which may not be desired. " "Consider specifying -N/--num-proteins to limit the number of proteins placed." ) # setup numpy random number generator global rng rng = np.random.default_rng(args.seed) try: # Load membrane membrane = Point(args.point_dir) # wrap membrane inside box membrane = pbc_wrap(membrane) # Place proteins results = place_proteins( membrane=membrane, type_id=args.type_id, radius=args.radius, num_proteins=args.num_proteins, target_curvature=args.curvature, k_factor=args.k_factor, leaflet=args.leaflet ) # Log results total_placed = sum(results.values()) logger.info(f"Successfully placed {total_placed} of protein type {args.type_id}:") for leaflet, count in results.items(): if count > 0: logger.info(f" {leaflet} leaflet: {count} proteins") # Save membrane output_dir = args.output or args.point_dir membrane.save(output_dir) logger.info(f"Updated membrane in {output_dir}") except Exception as e: logger.error(f"Error: {e}") raise