// Copyright 2023 The frp Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package transport import ( "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/pem" "math/big" "os" ) func newCustomTLSKeyPair(certfile, keyfile string) (*tls.Certificate, error) { tlsCert, err := tls.LoadX509KeyPair(certfile, keyfile) if err != nil { return nil, err } return &tlsCert, nil } func newRandomTLSKeyPair() *tls.Certificate { key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { panic(err) } template := x509.Certificate{SerialNumber: big.NewInt(1)} certDER, err := x509.CreateCertificate( rand.Reader, &template, &template, &key.PublicKey, key) if err != nil { panic(err) } keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { panic(err) } return &tlsCert } // Only support one ca file to add func newCertPool(caPath string) (*x509.CertPool, error) { pool := x509.NewCertPool() caCrt, err := os.ReadFile(caPath) if err != nil { return nil, err } pool.AppendCertsFromPEM(caCrt) return pool, nil } func NewServerTLSConfig(certPath, keyPath, caPath string) (*tls.Config, error) { base := &tls.Config{} if certPath == "" || keyPath == "" { // server will generate tls conf by itself cert := newRandomTLSKeyPair() base.Certificates = []tls.Certificate{*cert} } else { cert, err := newCustomTLSKeyPair(certPath, keyPath) if err != nil { return nil, err } base.Certificates = []tls.Certificate{*cert} } if caPath != "" { pool, err := newCertPool(caPath) if err != nil { return nil, err } base.ClientAuth = tls.RequireAndVerifyClientCert base.ClientCAs = pool } return base, nil } func NewClientTLSConfig(certPath, keyPath, caPath, serverName string) (*tls.Config, error) { base := &tls.Config{} if certPath != "" && keyPath != "" { cert, err := newCustomTLSKeyPair(certPath, keyPath) if err != nil { return nil, err } base.Certificates = []tls.Certificate{*cert} } base.ServerName = serverName if caPath != "" { pool, err := newCertPool(caPath) if err != nil { return nil, err } base.RootCAs = pool base.InsecureSkipVerify = false } else { base.InsecureSkipVerify = true } return base, nil } func NewRandomPrivateKey() ([]byte, error) { key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, err } keyPEM := pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key), }) return keyPEM, nil }